From cae7e001e31c25b980648a5fc5c838b93396df2c Mon Sep 17 00:00:00 2001 From: almelid Date: Fri, 7 May 2010 17:33:49 +0000 Subject: [PATCH] morro --- SQLAlchemy.egg-info/PKG-INFO | 45 + SQLAlchemy.egg-info/SOURCES.txt | 499 ++ SQLAlchemy.egg-info/dependency_links.txt | 1 + SQLAlchemy.egg-info/entry_points.txt | 3 + SQLAlchemy.egg-info/top_level.txt | 1 + sqlalchemy/__init__.py | 119 + sqlalchemy/cextension/processors.c | 393 ++ sqlalchemy/cextension/resultproxy.c | 586 +++ sqlalchemy/connectors/__init__.py | 6 + sqlalchemy/connectors/mxodbc.py | 146 + sqlalchemy/connectors/pyodbc.py | 113 + sqlalchemy/connectors/zxJDBC.py | 48 + sqlalchemy/databases/__init__.py | 31 + sqlalchemy/dialects/__init__.py | 12 + sqlalchemy/dialects/access/__init__.py | 0 sqlalchemy/dialects/access/base.py | 418 ++ sqlalchemy/dialects/firebird/__init__.py | 16 + sqlalchemy/dialects/firebird/base.py | 619 +++ sqlalchemy/dialects/firebird/kinterbasdb.py | 120 + sqlalchemy/dialects/informix/__init__.py | 3 + sqlalchemy/dialects/informix/base.py | 306 ++ sqlalchemy/dialects/informix/informixdb.py | 46 + sqlalchemy/dialects/maxdb/__init__.py | 3 + sqlalchemy/dialects/maxdb/base.py | 1058 ++++ sqlalchemy/dialects/maxdb/sapdb.py | 17 + sqlalchemy/dialects/mssql/__init__.py | 19 + sqlalchemy/dialects/mssql/adodbapi.py | 59 + sqlalchemy/dialects/mssql/base.py | 1297 +++++ .../dialects/mssql/information_schema.py | 83 + sqlalchemy/dialects/mssql/mxodbc.py | 83 + sqlalchemy/dialects/mssql/pymssql.py | 101 + sqlalchemy/dialects/mssql/pyodbc.py | 197 + sqlalchemy/dialects/mssql/zxjdbc.py | 64 + sqlalchemy/dialects/mysql/__init__.py | 17 + sqlalchemy/dialects/mysql/base.py | 2528 ++++++++++ sqlalchemy/dialects/mysql/mysqlconnector.py | 132 + sqlalchemy/dialects/mysql/mysqldb.py | 202 + sqlalchemy/dialects/mysql/oursql.py | 255 + sqlalchemy/dialects/mysql/pyodbc.py | 76 + sqlalchemy/dialects/mysql/zxjdbc.py | 111 + sqlalchemy/dialects/oracle/__init__.py | 17 + sqlalchemy/dialects/oracle/base.py | 1030 ++++ sqlalchemy/dialects/oracle/cx_oracle.py | 529 ++ sqlalchemy/dialects/oracle/zxjdbc.py | 209 + sqlalchemy/dialects/postgres.py | 10 + sqlalchemy/dialects/postgresql/__init__.py | 14 + sqlalchemy/dialects/postgresql/base.py | 1161 +++++ sqlalchemy/dialects/postgresql/pg8000.py | 105 + sqlalchemy/dialects/postgresql/psycopg2.py | 239 + .../dialects/postgresql/pypostgresql.py | 69 + sqlalchemy/dialects/postgresql/zxjdbc.py | 19 + sqlalchemy/dialects/sqlite/__init__.py | 14 + sqlalchemy/dialects/sqlite/base.py | 596 +++ sqlalchemy/dialects/sqlite/pysqlite.py | 236 + sqlalchemy/dialects/sybase/__init__.py | 20 + sqlalchemy/dialects/sybase/base.py | 420 ++ sqlalchemy/dialects/sybase/mxodbc.py | 17 + sqlalchemy/dialects/sybase/pyodbc.py | 75 + sqlalchemy/dialects/sybase/pysybase.py | 98 + .../dialects/type_migration_guidelines.txt | 145 + sqlalchemy/engine/__init__.py | 274 ++ sqlalchemy/engine/base.py | 2422 ++++++++++ sqlalchemy/engine/ddl.py | 128 + sqlalchemy/engine/default.py | 700 +++ sqlalchemy/engine/reflection.py | 370 ++ sqlalchemy/engine/strategies.py | 227 + sqlalchemy/engine/threadlocal.py | 103 + sqlalchemy/engine/url.py | 214 + sqlalchemy/exc.py | 191 + sqlalchemy/ext/__init__.py | 1 + sqlalchemy/ext/associationproxy.py | 878 ++++ sqlalchemy/ext/compiler.py | 194 + sqlalchemy/ext/declarative.py | 940 ++++ sqlalchemy/ext/horizontal_shard.py | 125 + sqlalchemy/ext/orderinglist.py | 315 ++ sqlalchemy/ext/serializer.py | 155 + sqlalchemy/ext/sqlsoup.py | 551 +++ sqlalchemy/interfaces.py | 205 + sqlalchemy/log.py | 119 + sqlalchemy/orm/__init__.py | 1176 +++++ sqlalchemy/orm/attributes.py | 1708 +++++++ sqlalchemy/orm/collections.py | 1438 ++++++ sqlalchemy/orm/dependency.py | 575 +++ sqlalchemy/orm/dynamic.py | 293 ++ sqlalchemy/orm/evaluator.py | 104 + sqlalchemy/orm/exc.py | 98 + sqlalchemy/orm/identity.py | 251 + sqlalchemy/orm/interfaces.py | 1098 +++++ sqlalchemy/orm/mapper.py | 1958 ++++++++ sqlalchemy/orm/properties.py | 1205 +++++ sqlalchemy/orm/query.py | 2469 ++++++++++ sqlalchemy/orm/scoping.py | 205 + sqlalchemy/orm/session.py | 1604 +++++++ sqlalchemy/orm/shard.py | 15 + sqlalchemy/orm/state.py | 527 ++ sqlalchemy/orm/strategies.py | 1229 +++++ sqlalchemy/orm/sync.py | 98 + sqlalchemy/orm/unitofwork.py | 781 +++ sqlalchemy/orm/uowdumper.py | 101 + sqlalchemy/orm/util.py | 668 +++ sqlalchemy/pool.py | 913 ++++ sqlalchemy/processors.py | 101 + sqlalchemy/queue.py | 183 + sqlalchemy/schema.py | 2386 +++++++++ sqlalchemy/sql/__init__.py | 58 + sqlalchemy/sql/compiler.py | 1612 +++++++ sqlalchemy/sql/expression.py | 4258 +++++++++++++++++ sqlalchemy/sql/functions.py | 104 + sqlalchemy/sql/operators.py | 135 + sqlalchemy/sql/util.py | 651 +++ sqlalchemy/sql/visitors.py | 256 + sqlalchemy/test/__init__.py | 26 + sqlalchemy/test/assertsql.py | 285 ++ sqlalchemy/test/config.py | 180 + sqlalchemy/test/engines.py | 300 ++ sqlalchemy/test/entities.py | 83 + sqlalchemy/test/noseplugin.py | 162 + sqlalchemy/test/orm.py | 111 + sqlalchemy/test/pickleable.py | 75 + sqlalchemy/test/profiling.py | 222 + sqlalchemy/test/requires.py | 259 + sqlalchemy/test/schema.py | 79 + sqlalchemy/test/testing.py | 779 +++ sqlalchemy/test/util.py | 53 + sqlalchemy/topological.py | 297 ++ sqlalchemy/types.py | 1742 +++++++ sqlalchemy/util.py | 1651 +++++++ 127 files changed, 57530 insertions(+) create mode 100644 SQLAlchemy.egg-info/PKG-INFO create mode 100644 SQLAlchemy.egg-info/SOURCES.txt create mode 100644 SQLAlchemy.egg-info/dependency_links.txt create mode 100644 SQLAlchemy.egg-info/entry_points.txt create mode 100644 SQLAlchemy.egg-info/top_level.txt create mode 100644 sqlalchemy/__init__.py create mode 100644 sqlalchemy/cextension/processors.c create mode 100644 sqlalchemy/cextension/resultproxy.c create mode 100644 sqlalchemy/connectors/__init__.py create mode 100644 sqlalchemy/connectors/mxodbc.py create mode 100644 sqlalchemy/connectors/pyodbc.py create mode 100644 sqlalchemy/connectors/zxJDBC.py create mode 100644 sqlalchemy/databases/__init__.py create mode 100644 sqlalchemy/dialects/__init__.py create mode 100644 sqlalchemy/dialects/access/__init__.py create mode 100644 sqlalchemy/dialects/access/base.py create mode 100644 sqlalchemy/dialects/firebird/__init__.py create mode 100644 sqlalchemy/dialects/firebird/base.py create mode 100644 sqlalchemy/dialects/firebird/kinterbasdb.py create mode 100644 sqlalchemy/dialects/informix/__init__.py create mode 100644 sqlalchemy/dialects/informix/base.py create mode 100644 sqlalchemy/dialects/informix/informixdb.py create mode 100644 sqlalchemy/dialects/maxdb/__init__.py create mode 100644 sqlalchemy/dialects/maxdb/base.py create mode 100644 sqlalchemy/dialects/maxdb/sapdb.py create mode 100644 sqlalchemy/dialects/mssql/__init__.py create mode 100644 sqlalchemy/dialects/mssql/adodbapi.py create mode 100644 sqlalchemy/dialects/mssql/base.py create mode 100644 sqlalchemy/dialects/mssql/information_schema.py create mode 100644 sqlalchemy/dialects/mssql/mxodbc.py create mode 100644 sqlalchemy/dialects/mssql/pymssql.py create mode 100644 sqlalchemy/dialects/mssql/pyodbc.py create mode 100644 sqlalchemy/dialects/mssql/zxjdbc.py create mode 100644 sqlalchemy/dialects/mysql/__init__.py create mode 100644 sqlalchemy/dialects/mysql/base.py create mode 100644 sqlalchemy/dialects/mysql/mysqlconnector.py create mode 100644 sqlalchemy/dialects/mysql/mysqldb.py create mode 100644 sqlalchemy/dialects/mysql/oursql.py create mode 100644 sqlalchemy/dialects/mysql/pyodbc.py create mode 100644 sqlalchemy/dialects/mysql/zxjdbc.py create mode 100644 sqlalchemy/dialects/oracle/__init__.py create mode 100644 sqlalchemy/dialects/oracle/base.py create mode 100644 sqlalchemy/dialects/oracle/cx_oracle.py create mode 100644 sqlalchemy/dialects/oracle/zxjdbc.py create mode 100644 sqlalchemy/dialects/postgres.py create mode 100644 sqlalchemy/dialects/postgresql/__init__.py create mode 100644 sqlalchemy/dialects/postgresql/base.py create mode 100644 sqlalchemy/dialects/postgresql/pg8000.py create mode 100644 sqlalchemy/dialects/postgresql/psycopg2.py create mode 100644 sqlalchemy/dialects/postgresql/pypostgresql.py create mode 100644 sqlalchemy/dialects/postgresql/zxjdbc.py create mode 100644 sqlalchemy/dialects/sqlite/__init__.py create mode 100644 sqlalchemy/dialects/sqlite/base.py create mode 100644 sqlalchemy/dialects/sqlite/pysqlite.py create mode 100644 sqlalchemy/dialects/sybase/__init__.py create mode 100644 sqlalchemy/dialects/sybase/base.py create mode 100644 sqlalchemy/dialects/sybase/mxodbc.py create mode 100644 sqlalchemy/dialects/sybase/pyodbc.py create mode 100644 sqlalchemy/dialects/sybase/pysybase.py create mode 100644 sqlalchemy/dialects/type_migration_guidelines.txt create mode 100644 sqlalchemy/engine/__init__.py create mode 100644 sqlalchemy/engine/base.py create mode 100644 sqlalchemy/engine/ddl.py create mode 100644 sqlalchemy/engine/default.py create mode 100644 sqlalchemy/engine/reflection.py create mode 100644 sqlalchemy/engine/strategies.py create mode 100644 sqlalchemy/engine/threadlocal.py create mode 100644 sqlalchemy/engine/url.py create mode 100644 sqlalchemy/exc.py create mode 100644 sqlalchemy/ext/__init__.py create mode 100644 sqlalchemy/ext/associationproxy.py create mode 100644 sqlalchemy/ext/compiler.py create mode 100644 sqlalchemy/ext/declarative.py create mode 100644 sqlalchemy/ext/horizontal_shard.py create mode 100644 sqlalchemy/ext/orderinglist.py create mode 100644 sqlalchemy/ext/serializer.py create mode 100644 sqlalchemy/ext/sqlsoup.py create mode 100644 sqlalchemy/interfaces.py create mode 100644 sqlalchemy/log.py create mode 100644 sqlalchemy/orm/__init__.py create mode 100644 sqlalchemy/orm/attributes.py create mode 100644 sqlalchemy/orm/collections.py create mode 100644 sqlalchemy/orm/dependency.py create mode 100644 sqlalchemy/orm/dynamic.py create mode 100644 sqlalchemy/orm/evaluator.py create mode 100644 sqlalchemy/orm/exc.py create mode 100644 sqlalchemy/orm/identity.py create mode 100644 sqlalchemy/orm/interfaces.py create mode 100644 sqlalchemy/orm/mapper.py create mode 100644 sqlalchemy/orm/properties.py create mode 100644 sqlalchemy/orm/query.py create mode 100644 sqlalchemy/orm/scoping.py create mode 100644 sqlalchemy/orm/session.py create mode 100644 sqlalchemy/orm/shard.py create mode 100644 sqlalchemy/orm/state.py create mode 100644 sqlalchemy/orm/strategies.py create mode 100644 sqlalchemy/orm/sync.py create mode 100644 sqlalchemy/orm/unitofwork.py create mode 100644 sqlalchemy/orm/uowdumper.py create mode 100644 sqlalchemy/orm/util.py create mode 100644 sqlalchemy/pool.py create mode 100644 sqlalchemy/processors.py create mode 100644 sqlalchemy/queue.py create mode 100644 sqlalchemy/schema.py create mode 100644 sqlalchemy/sql/__init__.py create mode 100644 sqlalchemy/sql/compiler.py create mode 100644 sqlalchemy/sql/expression.py create mode 100644 sqlalchemy/sql/functions.py create mode 100644 sqlalchemy/sql/operators.py create mode 100644 sqlalchemy/sql/util.py create mode 100644 sqlalchemy/sql/visitors.py create mode 100644 sqlalchemy/test/__init__.py create mode 100644 sqlalchemy/test/assertsql.py create mode 100644 sqlalchemy/test/config.py create mode 100644 sqlalchemy/test/engines.py create mode 100644 sqlalchemy/test/entities.py create mode 100644 sqlalchemy/test/noseplugin.py create mode 100644 sqlalchemy/test/orm.py create mode 100644 sqlalchemy/test/pickleable.py create mode 100644 sqlalchemy/test/profiling.py create mode 100644 sqlalchemy/test/requires.py create mode 100644 sqlalchemy/test/schema.py create mode 100644 sqlalchemy/test/testing.py create mode 100644 sqlalchemy/test/util.py create mode 100644 sqlalchemy/topological.py create mode 100644 sqlalchemy/types.py create mode 100644 sqlalchemy/util.py diff --git a/SQLAlchemy.egg-info/PKG-INFO b/SQLAlchemy.egg-info/PKG-INFO new file mode 100644 index 0000000..5ac791d --- /dev/null +++ b/SQLAlchemy.egg-info/PKG-INFO @@ -0,0 +1,45 @@ +Metadata-Version: 1.0 +Name: SQLAlchemy +Version: 0.6beta3 +Summary: Database Abstraction Library +Home-page: http://www.sqlalchemy.org +Author: Mike Bayer +Author-email: mike_mp@zzzcomputing.com +License: MIT License +Description: SQLAlchemy is: + + * The Python SQL toolkit and Object Relational Mapper that gives application developers the full power and flexibility of SQL. SQLAlchemy provides a full suite of well known enterprise-level persistence patterns, designed for efficient and high-performing database access, adapted into a simple and Pythonic domain language. + * extremely easy to use for all the basic tasks, such as: accessing pooled connections, constructing SQL from Python expressions, finding object instances, and commiting object modifications back to the database. + * powerful enough for complicated tasks, such as: eager load a graph of objects and their dependencies via joins; map recursive adjacency structures automatically; map objects to not just tables but to any arbitrary join or select statement; combine multiple tables together to load whole sets of otherwise unrelated objects from a single result set; commit entire graphs of object changes in one step. + * built to conform to what DBAs demand, including the ability to swap out generated SQL with hand-optimized statements, full usage of bind parameters for all literal values, fully transactionalized and consistent updates using Unit of Work. + * modular. Different parts of SQLAlchemy can be used independently of the rest, including the connection pool, SQL construction, and ORM. SQLAlchemy is constructed in an open style that allows plenty of customization, with an architecture that supports custom datatypes, custom SQL extensions, and ORM plugins which can augment or extend mapping functionality. + + SQLAlchemy's Philosophy: + + * SQL databases behave less and less like object collections the more size and performance start to matter; object collections behave less and less like tables and rows the more abstraction starts to matter. SQLAlchemy aims to accomodate both of these principles. + * Your classes aren't tables, and your objects aren't rows. Databases aren't just collections of tables; they're relational algebra engines. You don't have to select from just tables, you can select from joins, subqueries, and unions. Database and domain concepts should be visibly decoupled from the beginning, allowing both sides to develop to their full potential. + * For example, table metadata (objects that describe tables) are declared distinctly from the classes theyre designed to store. That way database relationship concepts don't interfere with your object design concepts, and vice-versa; the transition from table-mapping to selectable-mapping is seamless; a class can be mapped against the database in more than one way. SQLAlchemy provides a powerful mapping layer that can work as automatically or as manually as you choose, determining relationships based on foreign keys or letting you define the join conditions explicitly, to bridge the gap between database and domain. + + SQLAlchemy's Advantages: + + * The Unit Of Work system organizes pending CRUD operations into queues and commits them all in one batch. It then performs a topological "dependency sort" of all items to be committed and deleted and groups redundant statements together. This produces the maxiumum efficiency and transaction safety, and minimizes chances of deadlocks. Modeled after Fowler's "Unit of Work" pattern as well as Java Hibernate. + * Function-based query construction allows boolean expressions, operators, functions, table aliases, selectable subqueries, create/update/insert/delete queries, correlated updates, correlated EXISTS clauses, UNION clauses, inner and outer joins, bind parameters, free mixing of literal text within expressions, as little or as much as desired. Query-compilation is vendor-specific; the same query object can be compiled into any number of resulting SQL strings depending on its compilation algorithm. + * Database mapping and class design are totally separate. Persisted objects have no subclassing requirement (other than 'object') and are POPO's : plain old Python objects. They retain serializability (pickling) for usage in various caching systems and session objects. SQLAlchemy "decorates" classes with non-intrusive property accessors to automatically log object creates and modifications with the UnitOfWork engine, to lazyload related data, as well as to track attribute change histories. + * Custom list classes can be used with eagerly or lazily loaded child object lists, allowing rich relationships to be created on the fly as SQLAlchemy appends child objects to an object attribute. + * Composite (multiple-column) primary keys are supported, as are "association" objects that represent the middle of a "many-to-many" relationship. + * Self-referential tables and mappers are supported. Adjacency list structures can be created, saved, and deleted with proper cascading, with no extra programming. + * Data mapping can be used in a row-based manner. Any bizarre hyper-optimized query that you or your DBA can cook up, you can run in SQLAlchemy, and as long as it returns the expected columns within a rowset, you can get your objects from it. For a rowset that contains more than one kind of object per row, multiple mappers can be chained together to return multiple object instance lists from a single database round trip. + * The type system allows pre- and post- processing of data, both at the bind parameter and the result set level. User-defined types can be freely mixed with built-in types. Generic types as well as SQL-specific types are available. + + SVN version: + + + +Platform: UNKNOWN +Classifier: Development Status :: 4 - Beta +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Topic :: Database :: Front-Ends +Classifier: Operating System :: OS Independent diff --git a/SQLAlchemy.egg-info/SOURCES.txt b/SQLAlchemy.egg-info/SOURCES.txt new file mode 100644 index 0000000..8c59c00 --- /dev/null +++ b/SQLAlchemy.egg-info/SOURCES.txt @@ -0,0 +1,499 @@ +.hgignore +.hgtags +CHANGES +CHANGES_PRE_05 +LICENSE +MANIFEST.in +README +README.py3k +README.unittests +distribute_setup.py +ez_setup.py +sa2to3.py +setup.cfg +setup.py +sqla_nose.py +doc/copyright.html +doc/dbengine.html +doc/examples.html +doc/genindex.html +doc/index.html +doc/intro.html +doc/mappers.html +doc/metadata.html +doc/ormtutorial.html +doc/search.html +doc/searchindex.js +doc/session.html +doc/sqlexpression.html +doc/_images/sqla_arch_small.jpg +doc/_sources/copyright.txt +doc/_sources/dbengine.txt +doc/_sources/examples.txt +doc/_sources/index.txt +doc/_sources/intro.txt +doc/_sources/mappers.txt +doc/_sources/metadata.txt +doc/_sources/ormtutorial.txt +doc/_sources/session.txt +doc/_sources/sqlexpression.txt +doc/_sources/reference/index.txt +doc/_sources/reference/dialects/access.txt +doc/_sources/reference/dialects/firebird.txt +doc/_sources/reference/dialects/index.txt +doc/_sources/reference/dialects/informix.txt +doc/_sources/reference/dialects/maxdb.txt +doc/_sources/reference/dialects/mssql.txt +doc/_sources/reference/dialects/mysql.txt +doc/_sources/reference/dialects/oracle.txt +doc/_sources/reference/dialects/postgresql.txt +doc/_sources/reference/dialects/sqlite.txt +doc/_sources/reference/dialects/sybase.txt +doc/_sources/reference/ext/associationproxy.txt +doc/_sources/reference/ext/compiler.txt +doc/_sources/reference/ext/declarative.txt +doc/_sources/reference/ext/horizontal_shard.txt +doc/_sources/reference/ext/index.txt +doc/_sources/reference/ext/orderinglist.txt +doc/_sources/reference/ext/serializer.txt +doc/_sources/reference/ext/sqlsoup.txt +doc/_sources/reference/orm/collections.txt +doc/_sources/reference/orm/index.txt +doc/_sources/reference/orm/interfaces.txt +doc/_sources/reference/orm/mapping.txt +doc/_sources/reference/orm/query.txt +doc/_sources/reference/orm/sessions.txt +doc/_sources/reference/orm/utilities.txt +doc/_sources/reference/sqlalchemy/connections.txt +doc/_sources/reference/sqlalchemy/expressions.txt +doc/_sources/reference/sqlalchemy/index.txt +doc/_sources/reference/sqlalchemy/interfaces.txt +doc/_sources/reference/sqlalchemy/pooling.txt +doc/_sources/reference/sqlalchemy/schema.txt +doc/_sources/reference/sqlalchemy/types.txt +doc/_static/basic.css +doc/_static/default.css +doc/_static/docs.css +doc/_static/doctools.js +doc/_static/file.png +doc/_static/init.js +doc/_static/jquery.js +doc/_static/minus.png +doc/_static/plus.png +doc/_static/pygments.css +doc/_static/searchtools.js +doc/build/Makefile +doc/build/conf.py +doc/build/copyright.rst +doc/build/dbengine.rst +doc/build/examples.rst +doc/build/index.rst +doc/build/intro.rst +doc/build/mappers.rst +doc/build/metadata.rst +doc/build/ormtutorial.rst +doc/build/session.rst +doc/build/sqla_arch_small.jpg +doc/build/sqlexpression.rst +doc/build/testdocs.py +doc/build/builder/__init__.py +doc/build/builder/builders.py +doc/build/builder/util.py +doc/build/reference/index.rst +doc/build/reference/dialects/access.rst +doc/build/reference/dialects/firebird.rst +doc/build/reference/dialects/index.rst +doc/build/reference/dialects/informix.rst +doc/build/reference/dialects/maxdb.rst +doc/build/reference/dialects/mssql.rst +doc/build/reference/dialects/mysql.rst +doc/build/reference/dialects/oracle.rst +doc/build/reference/dialects/postgresql.rst +doc/build/reference/dialects/sqlite.rst +doc/build/reference/dialects/sybase.rst +doc/build/reference/ext/associationproxy.rst +doc/build/reference/ext/compiler.rst +doc/build/reference/ext/declarative.rst +doc/build/reference/ext/horizontal_shard.rst +doc/build/reference/ext/index.rst +doc/build/reference/ext/orderinglist.rst +doc/build/reference/ext/serializer.rst +doc/build/reference/ext/sqlsoup.rst +doc/build/reference/orm/collections.rst +doc/build/reference/orm/index.rst +doc/build/reference/orm/interfaces.rst +doc/build/reference/orm/mapping.rst +doc/build/reference/orm/query.rst +doc/build/reference/orm/sessions.rst +doc/build/reference/orm/utilities.rst +doc/build/reference/sqlalchemy/connections.rst +doc/build/reference/sqlalchemy/expressions.rst +doc/build/reference/sqlalchemy/index.rst +doc/build/reference/sqlalchemy/interfaces.rst +doc/build/reference/sqlalchemy/pooling.rst +doc/build/reference/sqlalchemy/schema.rst +doc/build/reference/sqlalchemy/types.rst +doc/build/static/docs.css +doc/build/static/init.js +doc/build/templates/genindex.mako +doc/build/templates/layout.mako +doc/build/templates/page.mako +doc/build/templates/search.mako +doc/build/templates/site_base.mako +doc/build/templates/static_base.mako +doc/build/texinputs/sphinx.sty +doc/reference/index.html +doc/reference/dialects/access.html +doc/reference/dialects/firebird.html +doc/reference/dialects/index.html +doc/reference/dialects/informix.html +doc/reference/dialects/maxdb.html +doc/reference/dialects/mssql.html +doc/reference/dialects/mysql.html +doc/reference/dialects/oracle.html +doc/reference/dialects/postgresql.html +doc/reference/dialects/sqlite.html +doc/reference/dialects/sybase.html +doc/reference/ext/associationproxy.html +doc/reference/ext/compiler.html +doc/reference/ext/declarative.html +doc/reference/ext/horizontal_shard.html +doc/reference/ext/index.html +doc/reference/ext/orderinglist.html +doc/reference/ext/serializer.html +doc/reference/ext/sqlsoup.html +doc/reference/orm/collections.html +doc/reference/orm/index.html +doc/reference/orm/interfaces.html +doc/reference/orm/mapping.html +doc/reference/orm/query.html +doc/reference/orm/sessions.html +doc/reference/orm/utilities.html +doc/reference/sqlalchemy/connections.html +doc/reference/sqlalchemy/expressions.html +doc/reference/sqlalchemy/index.html +doc/reference/sqlalchemy/interfaces.html +doc/reference/sqlalchemy/pooling.html +doc/reference/sqlalchemy/schema.html +doc/reference/sqlalchemy/types.html +examples/__init__.py +examples/adjacency_list/__init__.py +examples/adjacency_list/adjacency_list.py +examples/association/__init__.py +examples/association/basic_association.py +examples/association/proxied_association.py +examples/beaker_caching/__init__.py +examples/beaker_caching/advanced.py +examples/beaker_caching/environment.py +examples/beaker_caching/fixture_data.py +examples/beaker_caching/helloworld.py +examples/beaker_caching/local_session_caching.py +examples/beaker_caching/meta.py +examples/beaker_caching/model.py +examples/beaker_caching/relation_caching.py +examples/custom_attributes/__init__.py +examples/custom_attributes/custom_management.py +examples/custom_attributes/listen_for_events.py +examples/derived_attributes/__init__.py +examples/derived_attributes/attributes.py +examples/dynamic_dict/__init__.py +examples/dynamic_dict/dynamic_dict.py +examples/elementtree/__init__.py +examples/elementtree/adjacency_list.py +examples/elementtree/optimized_al.py +examples/elementtree/pickle.py +examples/elementtree/test.xml +examples/elementtree/test2.xml +examples/elementtree/test3.xml +examples/graphs/__init__.py +examples/graphs/directed_graph.py +examples/inheritance/__init__.py +examples/inheritance/concrete.py +examples/inheritance/polymorph.py +examples/inheritance/single.py +examples/large_collection/__init__.py +examples/large_collection/large_collection.py +examples/nested_sets/__init__.py +examples/nested_sets/nested_sets.py +examples/poly_assoc/__init__.py +examples/poly_assoc/poly_assoc.py +examples/poly_assoc/poly_assoc_fk.py +examples/poly_assoc/poly_assoc_generic.py +examples/postgis/__init__.py +examples/postgis/postgis.py +examples/sharding/__init__.py +examples/sharding/attribute_shard.py +examples/versioning/__init__.py +examples/versioning/history_meta.py +examples/versioning/test_versioning.py +examples/vertical/__init__.py +examples/vertical/dictlike-polymorphic.py +examples/vertical/dictlike.py +lib/SQLAlchemy.egg-info/PKG-INFO +lib/SQLAlchemy.egg-info/SOURCES.txt +lib/SQLAlchemy.egg-info/dependency_links.txt +lib/SQLAlchemy.egg-info/entry_points.txt +lib/SQLAlchemy.egg-info/top_level.txt +lib/sqlalchemy/__init__.py +lib/sqlalchemy/exc.py +lib/sqlalchemy/interfaces.py +lib/sqlalchemy/log.py +lib/sqlalchemy/pool.py +lib/sqlalchemy/processors.py +lib/sqlalchemy/queue.py +lib/sqlalchemy/schema.py +lib/sqlalchemy/topological.py +lib/sqlalchemy/types.py +lib/sqlalchemy/util.py +lib/sqlalchemy/cextension/processors.c +lib/sqlalchemy/cextension/resultproxy.c +lib/sqlalchemy/connectors/__init__.py +lib/sqlalchemy/connectors/mxodbc.py +lib/sqlalchemy/connectors/pyodbc.py +lib/sqlalchemy/connectors/zxJDBC.py +lib/sqlalchemy/databases/__init__.py +lib/sqlalchemy/dialects/__init__.py +lib/sqlalchemy/dialects/postgres.py +lib/sqlalchemy/dialects/type_migration_guidelines.txt +lib/sqlalchemy/dialects/access/__init__.py +lib/sqlalchemy/dialects/access/base.py +lib/sqlalchemy/dialects/firebird/__init__.py +lib/sqlalchemy/dialects/firebird/base.py +lib/sqlalchemy/dialects/firebird/kinterbasdb.py +lib/sqlalchemy/dialects/informix/__init__.py +lib/sqlalchemy/dialects/informix/base.py +lib/sqlalchemy/dialects/informix/informixdb.py +lib/sqlalchemy/dialects/maxdb/__init__.py +lib/sqlalchemy/dialects/maxdb/base.py +lib/sqlalchemy/dialects/maxdb/sapdb.py +lib/sqlalchemy/dialects/mssql/__init__.py +lib/sqlalchemy/dialects/mssql/adodbapi.py +lib/sqlalchemy/dialects/mssql/base.py +lib/sqlalchemy/dialects/mssql/information_schema.py +lib/sqlalchemy/dialects/mssql/mxodbc.py +lib/sqlalchemy/dialects/mssql/pymssql.py +lib/sqlalchemy/dialects/mssql/pyodbc.py +lib/sqlalchemy/dialects/mssql/zxjdbc.py +lib/sqlalchemy/dialects/mysql/__init__.py +lib/sqlalchemy/dialects/mysql/base.py +lib/sqlalchemy/dialects/mysql/mysqlconnector.py +lib/sqlalchemy/dialects/mysql/mysqldb.py +lib/sqlalchemy/dialects/mysql/oursql.py +lib/sqlalchemy/dialects/mysql/pyodbc.py +lib/sqlalchemy/dialects/mysql/zxjdbc.py +lib/sqlalchemy/dialects/oracle/__init__.py +lib/sqlalchemy/dialects/oracle/base.py +lib/sqlalchemy/dialects/oracle/cx_oracle.py +lib/sqlalchemy/dialects/oracle/zxjdbc.py +lib/sqlalchemy/dialects/postgresql/__init__.py +lib/sqlalchemy/dialects/postgresql/base.py +lib/sqlalchemy/dialects/postgresql/pg8000.py +lib/sqlalchemy/dialects/postgresql/psycopg2.py +lib/sqlalchemy/dialects/postgresql/pypostgresql.py +lib/sqlalchemy/dialects/postgresql/zxjdbc.py +lib/sqlalchemy/dialects/sqlite/__init__.py +lib/sqlalchemy/dialects/sqlite/base.py +lib/sqlalchemy/dialects/sqlite/pysqlite.py +lib/sqlalchemy/dialects/sybase/__init__.py +lib/sqlalchemy/dialects/sybase/base.py +lib/sqlalchemy/dialects/sybase/mxodbc.py +lib/sqlalchemy/dialects/sybase/pyodbc.py +lib/sqlalchemy/dialects/sybase/pysybase.py +lib/sqlalchemy/engine/__init__.py +lib/sqlalchemy/engine/base.py +lib/sqlalchemy/engine/ddl.py +lib/sqlalchemy/engine/default.py +lib/sqlalchemy/engine/reflection.py +lib/sqlalchemy/engine/strategies.py +lib/sqlalchemy/engine/threadlocal.py +lib/sqlalchemy/engine/url.py +lib/sqlalchemy/ext/__init__.py +lib/sqlalchemy/ext/associationproxy.py +lib/sqlalchemy/ext/compiler.py +lib/sqlalchemy/ext/declarative.py +lib/sqlalchemy/ext/horizontal_shard.py +lib/sqlalchemy/ext/orderinglist.py +lib/sqlalchemy/ext/serializer.py +lib/sqlalchemy/ext/sqlsoup.py +lib/sqlalchemy/orm/__init__.py +lib/sqlalchemy/orm/attributes.py +lib/sqlalchemy/orm/collections.py +lib/sqlalchemy/orm/dependency.py +lib/sqlalchemy/orm/dynamic.py +lib/sqlalchemy/orm/evaluator.py +lib/sqlalchemy/orm/exc.py +lib/sqlalchemy/orm/identity.py +lib/sqlalchemy/orm/interfaces.py +lib/sqlalchemy/orm/mapper.py +lib/sqlalchemy/orm/properties.py +lib/sqlalchemy/orm/query.py +lib/sqlalchemy/orm/scoping.py +lib/sqlalchemy/orm/session.py +lib/sqlalchemy/orm/shard.py +lib/sqlalchemy/orm/state.py +lib/sqlalchemy/orm/strategies.py +lib/sqlalchemy/orm/sync.py +lib/sqlalchemy/orm/unitofwork.py +lib/sqlalchemy/orm/uowdumper.py +lib/sqlalchemy/orm/util.py +lib/sqlalchemy/sql/__init__.py +lib/sqlalchemy/sql/compiler.py +lib/sqlalchemy/sql/expression.py +lib/sqlalchemy/sql/functions.py +lib/sqlalchemy/sql/operators.py +lib/sqlalchemy/sql/util.py +lib/sqlalchemy/sql/visitors.py +lib/sqlalchemy/test/__init__.py +lib/sqlalchemy/test/assertsql.py +lib/sqlalchemy/test/config.py +lib/sqlalchemy/test/engines.py +lib/sqlalchemy/test/entities.py +lib/sqlalchemy/test/noseplugin.py +lib/sqlalchemy/test/orm.py +lib/sqlalchemy/test/pickleable.py +lib/sqlalchemy/test/profiling.py +lib/sqlalchemy/test/requires.py +lib/sqlalchemy/test/schema.py +lib/sqlalchemy/test/testing.py +lib/sqlalchemy/test/util.py +test/__init__.py +test/binary_data_one.dat +test/binary_data_two.dat +test/aaa_profiling/__init__.py +test/aaa_profiling/test_compiler.py +test/aaa_profiling/test_memusage.py +test/aaa_profiling/test_orm.py +test/aaa_profiling/test_pool.py +test/aaa_profiling/test_resultset.py +test/aaa_profiling/test_zoomark.py +test/aaa_profiling/test_zoomark_orm.py +test/base/__init__.py +test/base/test_dependency.py +test/base/test_except.py +test/base/test_utils.py +test/dialect/__init__.py +test/dialect/test_access.py +test/dialect/test_firebird.py +test/dialect/test_informix.py +test/dialect/test_maxdb.py +test/dialect/test_mssql.py +test/dialect/test_mxodbc.py +test/dialect/test_mysql.py +test/dialect/test_oracle.py +test/dialect/test_postgresql.py +test/dialect/test_sqlite.py +test/dialect/test_sybase.py +test/engine/__init__.py +test/engine/_base.py +test/engine/test_bind.py +test/engine/test_ddlevents.py +test/engine/test_execute.py +test/engine/test_metadata.py +test/engine/test_parseconnect.py +test/engine/test_pool.py +test/engine/test_reconnect.py +test/engine/test_reflection.py +test/engine/test_transaction.py +test/ex/__init__.py +test/ex/test_examples.py +test/ext/__init__.py +test/ext/test_associationproxy.py +test/ext/test_compiler.py +test/ext/test_declarative.py +test/ext/test_horizontal_shard.py +test/ext/test_orderinglist.py +test/ext/test_serializer.py +test/ext/test_sqlsoup.py +test/orm/__init__.py +test/orm/_base.py +test/orm/_fixtures.py +test/orm/test_association.py +test/orm/test_assorted_eager.py +test/orm/test_attributes.py +test/orm/test_backref_mutations.py +test/orm/test_bind.py +test/orm/test_cascade.py +test/orm/test_collection.py +test/orm/test_compile.py +test/orm/test_cycles.py +test/orm/test_defaults.py +test/orm/test_deprecations.py +test/orm/test_dynamic.py +test/orm/test_eager_relations.py +test/orm/test_evaluator.py +test/orm/test_expire.py +test/orm/test_extendedattr.py +test/orm/test_generative.py +test/orm/test_instrumentation.py +test/orm/test_lazy_relations.py +test/orm/test_lazytest1.py +test/orm/test_manytomany.py +test/orm/test_mapper.py +test/orm/test_merge.py +test/orm/test_naturalpks.py +test/orm/test_onetoone.py +test/orm/test_pickled.py +test/orm/test_query.py +test/orm/test_relationships.py +test/orm/test_scoping.py +test/orm/test_selectable.py +test/orm/test_session.py +test/orm/test_subquery_relations.py +test/orm/test_transaction.py +test/orm/test_unitofwork.py +test/orm/test_utils.py +test/orm/test_versioning.py +test/orm/inheritance/__init__.py +test/orm/inheritance/test_abc_inheritance.py +test/orm/inheritance/test_abc_polymorphic.py +test/orm/inheritance/test_basic.py +test/orm/inheritance/test_concrete.py +test/orm/inheritance/test_magazine.py +test/orm/inheritance/test_manytomany.py +test/orm/inheritance/test_poly_linked_list.py +test/orm/inheritance/test_polymorph.py +test/orm/inheritance/test_polymorph2.py +test/orm/inheritance/test_productspec.py +test/orm/inheritance/test_query.py +test/orm/inheritance/test_selects.py +test/orm/inheritance/test_single.py +test/perf/cascade_speed.py +test/perf/insertspeed.py +test/perf/masscreate.py +test/perf/masscreate2.py +test/perf/masseagerload.py +test/perf/massload.py +test/perf/massload2.py +test/perf/masssave.py +test/perf/objselectspeed.py +test/perf/objupdatespeed.py +test/perf/ormsession.py +test/perf/poolload.py +test/perf/sessions.py +test/perf/stress_all.py +test/perf/stresstest.py +test/perf/threaded_compile.py +test/perf/wsgi.py +test/sql/__init__.py +test/sql/_base.py +test/sql/test_case_statement.py +test/sql/test_columns.py +test/sql/test_compiler.py +test/sql/test_constraints.py +test/sql/test_defaults.py +test/sql/test_functions.py +test/sql/test_generative.py +test/sql/test_labels.py +test/sql/test_query.py +test/sql/test_quote.py +test/sql/test_returning.py +test/sql/test_rowcount.py +test/sql/test_selectable.py +test/sql/test_types.py +test/sql/test_unicode.py +test/zblog/__init__.py +test/zblog/blog.py +test/zblog/mappers.py +test/zblog/tables.py +test/zblog/test_zblog.py +test/zblog/user.py \ No newline at end of file diff --git a/SQLAlchemy.egg-info/dependency_links.txt b/SQLAlchemy.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/SQLAlchemy.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/SQLAlchemy.egg-info/entry_points.txt b/SQLAlchemy.egg-info/entry_points.txt new file mode 100644 index 0000000..250bc61 --- /dev/null +++ b/SQLAlchemy.egg-info/entry_points.txt @@ -0,0 +1,3 @@ +[nose.plugins.0.10] +sqlalchemy = sqlalchemy.test.noseplugin:NoseSQLAlchemy + diff --git a/SQLAlchemy.egg-info/top_level.txt b/SQLAlchemy.egg-info/top_level.txt new file mode 100644 index 0000000..39fb2be --- /dev/null +++ b/SQLAlchemy.egg-info/top_level.txt @@ -0,0 +1 @@ +sqlalchemy diff --git a/sqlalchemy/__init__.py b/sqlalchemy/__init__.py new file mode 100644 index 0000000..376b13e --- /dev/null +++ b/sqlalchemy/__init__.py @@ -0,0 +1,119 @@ +# __init__.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import inspect +import sys + +import sqlalchemy.exc as exceptions +sys.modules['sqlalchemy.exceptions'] = exceptions + +from sqlalchemy.sql import ( + alias, + and_, + asc, + between, + bindparam, + case, + cast, + collate, + delete, + desc, + distinct, + except_, + except_all, + exists, + extract, + func, + insert, + intersect, + intersect_all, + join, + literal, + literal_column, + modifier, + not_, + null, + or_, + outerjoin, + outparam, + select, + subquery, + text, + tuple_, + union, + union_all, + update, + ) + +from sqlalchemy.types import ( + BLOB, + BOOLEAN, + BigInteger, + Binary, + Boolean, + CHAR, + CLOB, + DATE, + DATETIME, + DECIMAL, + Date, + DateTime, + Enum, + FLOAT, + Float, + INT, + INTEGER, + Integer, + Interval, + LargeBinary, + NCHAR, + NVARCHAR, + NUMERIC, + Numeric, + PickleType, + SMALLINT, + SmallInteger, + String, + TEXT, + TIME, + TIMESTAMP, + Text, + Time, + Unicode, + UnicodeText, + VARCHAR, + ) + + +from sqlalchemy.schema import ( + CheckConstraint, + Column, + ColumnDefault, + Constraint, + DDL, + DefaultClause, + FetchedValue, + ForeignKey, + ForeignKeyConstraint, + Index, + MetaData, + PassiveDefault, + PrimaryKeyConstraint, + Sequence, + Table, + ThreadLocalMetaData, + UniqueConstraint, + ) + +from sqlalchemy.engine import create_engine, engine_from_config + + +__all__ = sorted(name for name, obj in locals().items() + if not (name.startswith('_') or inspect.ismodule(obj))) + +__version__ = '0.6beta3' + +del inspect, sys diff --git a/sqlalchemy/cextension/processors.c b/sqlalchemy/cextension/processors.c new file mode 100644 index 0000000..6e33027 --- /dev/null +++ b/sqlalchemy/cextension/processors.c @@ -0,0 +1,393 @@ +/* +processors.c +Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com + +This module is part of SQLAlchemy and is released under +the MIT License: http://www.opensource.org/licenses/mit-license.php +*/ + +#include +#include + +static PyObject * +int_to_boolean(PyObject *self, PyObject *arg) +{ + long l = 0; + PyObject *res; + + if (arg == Py_None) + Py_RETURN_NONE; + + l = PyInt_AsLong(arg); + if (l == 0) { + res = Py_False; + } else if (l == 1) { + res = Py_True; + } else if ((l == -1) && PyErr_Occurred()) { + /* -1 can be either the actual value, or an error flag. */ + return NULL; + } else { + PyErr_SetString(PyExc_ValueError, + "int_to_boolean only accepts None, 0 or 1"); + return NULL; + } + + Py_INCREF(res); + return res; +} + +static PyObject * +to_str(PyObject *self, PyObject *arg) +{ + if (arg == Py_None) + Py_RETURN_NONE; + + return PyObject_Str(arg); +} + +static PyObject * +to_float(PyObject *self, PyObject *arg) +{ + if (arg == Py_None) + Py_RETURN_NONE; + + return PyNumber_Float(arg); +} + +static PyObject * +str_to_datetime(PyObject *self, PyObject *arg) +{ + const char *str; + unsigned int year, month, day, hour, minute, second, microsecond = 0; + + if (arg == Py_None) + Py_RETURN_NONE; + + str = PyString_AsString(arg); + if (str == NULL) + return NULL; + + /* microseconds are optional */ + /* + TODO: this is slightly less picky than the Python version which would + not accept "2000-01-01 00:00:00.". I don't know which is better, but they + should be coherent. + */ + if (sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day, + &hour, &minute, &second, µsecond) < 6) { + PyErr_SetString(PyExc_ValueError, "Couldn't parse datetime string."); + return NULL; + } + return PyDateTime_FromDateAndTime(year, month, day, + hour, minute, second, microsecond); +} + +static PyObject * +str_to_time(PyObject *self, PyObject *arg) +{ + const char *str; + unsigned int hour, minute, second, microsecond = 0; + + if (arg == Py_None) + Py_RETURN_NONE; + + str = PyString_AsString(arg); + if (str == NULL) + return NULL; + + /* microseconds are optional */ + /* + TODO: this is slightly less picky than the Python version which would + not accept "00:00:00.". I don't know which is better, but they should be + coherent. + */ + if (sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second, + µsecond) < 3) { + PyErr_SetString(PyExc_ValueError, "Couldn't parse time string."); + return NULL; + } + return PyTime_FromTime(hour, minute, second, microsecond); +} + +static PyObject * +str_to_date(PyObject *self, PyObject *arg) +{ + const char *str; + unsigned int year, month, day; + + if (arg == Py_None) + Py_RETURN_NONE; + + str = PyString_AsString(arg); + if (str == NULL) + return NULL; + + if (sscanf(str, "%4u-%2u-%2u", &year, &month, &day) != 3) { + PyErr_SetString(PyExc_ValueError, "Couldn't parse date string."); + return NULL; + } + return PyDate_FromDate(year, month, day); +} + + +/*********** + * Structs * + ***********/ + +typedef struct { + PyObject_HEAD + PyObject *encoding; + PyObject *errors; +} UnicodeResultProcessor; + +typedef struct { + PyObject_HEAD + PyObject *type; + PyObject *format; +} DecimalResultProcessor; + + + +/************************** + * UnicodeResultProcessor * + **************************/ + +static int +UnicodeResultProcessor_init(UnicodeResultProcessor *self, PyObject *args, + PyObject *kwds) +{ + PyObject *encoding, *errors = NULL; + static char *kwlist[] = {"encoding", "errors", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:__init__", kwlist, + &encoding, &errors)) + return -1; + + Py_INCREF(encoding); + self->encoding = encoding; + + if (errors) { + Py_INCREF(errors); + } else { + errors = PyString_FromString("strict"); + if (errors == NULL) + return -1; + } + self->errors = errors; + + return 0; +} + +static PyObject * +UnicodeResultProcessor_process(UnicodeResultProcessor *self, PyObject *value) +{ + const char *encoding, *errors; + char *str; + Py_ssize_t len; + + if (value == Py_None) + Py_RETURN_NONE; + + if (PyString_AsStringAndSize(value, &str, &len)) + return NULL; + + encoding = PyString_AS_STRING(self->encoding); + errors = PyString_AS_STRING(self->errors); + + return PyUnicode_Decode(str, len, encoding, errors); +} + +static PyMethodDef UnicodeResultProcessor_methods[] = { + {"process", (PyCFunction)UnicodeResultProcessor_process, METH_O, + "The value processor itself."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject UnicodeResultProcessorType = { + PyObject_HEAD_INIT(NULL) + 0, /* ob_size */ + "sqlalchemy.cprocessors.UnicodeResultProcessor", /* tp_name */ + sizeof(UnicodeResultProcessor), /* tp_basicsize */ + 0, /* tp_itemsize */ + 0, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "UnicodeResultProcessor objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + UnicodeResultProcessor_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)UnicodeResultProcessor_init, /* tp_init */ + 0, /* tp_alloc */ + 0, /* tp_new */ +}; + +/************************** + * DecimalResultProcessor * + **************************/ + +static int +DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args, + PyObject *kwds) +{ + PyObject *type, *format; + + if (!PyArg_ParseTuple(args, "OS", &type, &format)) + return -1; + + Py_INCREF(type); + self->type = type; + + Py_INCREF(format); + self->format = format; + + return 0; +} + +static PyObject * +DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value) +{ + PyObject *str, *result, *args; + + if (value == Py_None) + Py_RETURN_NONE; + + if (PyFloat_CheckExact(value)) { + /* Decimal does not accept float values directly */ + args = PyTuple_Pack(1, value); + if (args == NULL) + return NULL; + + str = PyString_Format(self->format, args); + if (str == NULL) + return NULL; + + result = PyObject_CallFunctionObjArgs(self->type, str, NULL); + Py_DECREF(str); + return result; + } else { + return PyObject_CallFunctionObjArgs(self->type, value, NULL); + } +} + +static PyMethodDef DecimalResultProcessor_methods[] = { + {"process", (PyCFunction)DecimalResultProcessor_process, METH_O, + "The value processor itself."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject DecimalResultProcessorType = { + PyObject_HEAD_INIT(NULL) + 0, /* ob_size */ + "sqlalchemy.DecimalResultProcessor", /* tp_name */ + sizeof(DecimalResultProcessor), /* tp_basicsize */ + 0, /* tp_itemsize */ + 0, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "DecimalResultProcessor objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + DecimalResultProcessor_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)DecimalResultProcessor_init, /* tp_init */ + 0, /* tp_alloc */ + 0, /* tp_new */ +}; + +#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ +#define PyMODINIT_FUNC void +#endif + + +static PyMethodDef module_methods[] = { + {"int_to_boolean", int_to_boolean, METH_O, + "Convert an integer to a boolean."}, + {"to_str", to_str, METH_O, + "Convert any value to its string representation."}, + {"to_float", to_float, METH_O, + "Convert any value to its floating point representation."}, + {"str_to_datetime", str_to_datetime, METH_O, + "Convert an ISO string to a datetime.datetime object."}, + {"str_to_time", str_to_time, METH_O, + "Convert an ISO string to a datetime.time object."}, + {"str_to_date", str_to_date, METH_O, + "Convert an ISO string to a datetime.date object."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +PyMODINIT_FUNC +initcprocessors(void) +{ + PyObject *m; + + UnicodeResultProcessorType.tp_new = PyType_GenericNew; + if (PyType_Ready(&UnicodeResultProcessorType) < 0) + return; + + DecimalResultProcessorType.tp_new = PyType_GenericNew; + if (PyType_Ready(&DecimalResultProcessorType) < 0) + return; + + m = Py_InitModule3("cprocessors", module_methods, + "Module containing C versions of data processing functions."); + if (m == NULL) + return; + + PyDateTime_IMPORT; + + Py_INCREF(&UnicodeResultProcessorType); + PyModule_AddObject(m, "UnicodeResultProcessor", + (PyObject *)&UnicodeResultProcessorType); + + Py_INCREF(&DecimalResultProcessorType); + PyModule_AddObject(m, "DecimalResultProcessor", + (PyObject *)&DecimalResultProcessorType); +} + diff --git a/sqlalchemy/cextension/resultproxy.c b/sqlalchemy/cextension/resultproxy.c new file mode 100644 index 0000000..b530b65 --- /dev/null +++ b/sqlalchemy/cextension/resultproxy.c @@ -0,0 +1,586 @@ +/* +resultproxy.c +Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com + +This module is part of SQLAlchemy and is released under +the MIT License: http://www.opensource.org/licenses/mit-license.php +*/ + +#include + + +/*********** + * Structs * + ***********/ + +typedef struct { + PyObject_HEAD + PyObject *parent; + PyObject *row; + PyObject *processors; + PyObject *keymap; +} BaseRowProxy; + +/**************** + * BaseRowProxy * + ****************/ + +static PyObject * +safe_rowproxy_reconstructor(PyObject *self, PyObject *args) +{ + PyObject *cls, *state, *tmp; + BaseRowProxy *obj; + + if (!PyArg_ParseTuple(args, "OO", &cls, &state)) + return NULL; + + obj = (BaseRowProxy *)PyObject_CallMethod(cls, "__new__", "O", cls); + if (obj == NULL) + return NULL; + + tmp = PyObject_CallMethod((PyObject *)obj, "__setstate__", "O", state); + if (tmp == NULL) { + Py_DECREF(obj); + return NULL; + } + Py_DECREF(tmp); + + if (obj->parent == NULL || obj->row == NULL || + obj->processors == NULL || obj->keymap == NULL) { + PyErr_SetString(PyExc_RuntimeError, + "__setstate__ for BaseRowProxy subclasses must set values " + "for parent, row, processors and keymap"); + Py_DECREF(obj); + return NULL; + } + + return (PyObject *)obj; +} + +static int +BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds) +{ + PyObject *parent, *row, *processors, *keymap; + + if (!PyArg_UnpackTuple(args, "BaseRowProxy", 4, 4, + &parent, &row, &processors, &keymap)) + return -1; + + Py_INCREF(parent); + self->parent = parent; + + if (!PyTuple_CheckExact(row)) { + PyErr_SetString(PyExc_TypeError, "row must be a tuple"); + return -1; + } + Py_INCREF(row); + self->row = row; + + if (!PyList_CheckExact(processors)) { + PyErr_SetString(PyExc_TypeError, "processors must be a list"); + return -1; + } + Py_INCREF(processors); + self->processors = processors; + + if (!PyDict_CheckExact(keymap)) { + PyErr_SetString(PyExc_TypeError, "keymap must be a dict"); + return -1; + } + Py_INCREF(keymap); + self->keymap = keymap; + + return 0; +} + +/* We need the reduce method because otherwise the default implementation + * does very weird stuff for pickle protocol 0 and 1. It calls + * BaseRowProxy.__new__(RowProxy_instance) upon *pickling*. + */ +static PyObject * +BaseRowProxy_reduce(PyObject *self) +{ + PyObject *method, *state; + PyObject *module, *reconstructor, *cls; + + method = PyObject_GetAttrString(self, "__getstate__"); + if (method == NULL) + return NULL; + + state = PyObject_CallObject(method, NULL); + Py_DECREF(method); + if (state == NULL) + return NULL; + + module = PyImport_ImportModule("sqlalchemy.engine.base"); + if (module == NULL) + return NULL; + + reconstructor = PyObject_GetAttrString(module, "rowproxy_reconstructor"); + Py_DECREF(module); + if (reconstructor == NULL) { + Py_DECREF(state); + return NULL; + } + + cls = PyObject_GetAttrString(self, "__class__"); + if (cls == NULL) { + Py_DECREF(reconstructor); + Py_DECREF(state); + return NULL; + } + + return Py_BuildValue("(N(NN))", reconstructor, cls, state); +} + +static void +BaseRowProxy_dealloc(BaseRowProxy *self) +{ + Py_XDECREF(self->parent); + Py_XDECREF(self->row); + Py_XDECREF(self->processors); + Py_XDECREF(self->keymap); + self->ob_type->tp_free((PyObject *)self); +} + +static PyObject * +BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple) +{ + Py_ssize_t num_values, num_processors; + PyObject **valueptr, **funcptr, **resultptr; + PyObject *func, *result, *processed_value; + + num_values = Py_SIZE(values); + num_processors = Py_SIZE(processors); + if (num_values != num_processors) { + PyErr_SetString(PyExc_RuntimeError, + "number of values in row differ from number of column processors"); + return NULL; + } + + if (astuple) { + result = PyTuple_New(num_values); + } else { + result = PyList_New(num_values); + } + if (result == NULL) + return NULL; + + /* we don't need to use PySequence_Fast as long as values, processors and + * result are simple tuple or lists. */ + valueptr = PySequence_Fast_ITEMS(values); + funcptr = PySequence_Fast_ITEMS(processors); + resultptr = PySequence_Fast_ITEMS(result); + while (--num_values >= 0) { + func = *funcptr; + if (func != Py_None) { + processed_value = PyObject_CallFunctionObjArgs(func, *valueptr, + NULL); + if (processed_value == NULL) { + Py_DECREF(result); + return NULL; + } + *resultptr = processed_value; + } else { + Py_INCREF(*valueptr); + *resultptr = *valueptr; + } + valueptr++; + funcptr++; + resultptr++; + } + return result; +} + +static PyListObject * +BaseRowProxy_values(BaseRowProxy *self) +{ + return (PyListObject *)BaseRowProxy_processvalues(self->row, + self->processors, 0); +} + +static PyTupleObject * +BaseRowProxy_tuplevalues(BaseRowProxy *self) +{ + return (PyTupleObject *)BaseRowProxy_processvalues(self->row, + self->processors, 1); +} + +static PyObject * +BaseRowProxy_iter(BaseRowProxy *self) +{ + PyObject *values, *result; + + values = (PyObject *)BaseRowProxy_tuplevalues(self); + if (values == NULL) + return NULL; + + result = PyObject_GetIter(values); + Py_DECREF(values); + if (result == NULL) + return NULL; + + return result; +} + +static Py_ssize_t +BaseRowProxy_length(BaseRowProxy *self) +{ + return Py_SIZE(self->row); +} + +static PyObject * +BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key) +{ + PyObject *processors, *values; + PyObject *processor, *value; + PyObject *record, *result, *indexobject; + PyObject *exc_module, *exception; + char *cstr_key; + long index; + + if (PyInt_CheckExact(key)) { + index = PyInt_AS_LONG(key); + } else if (PyLong_CheckExact(key)) { + index = PyLong_AsLong(key); + if ((index == -1) && PyErr_Occurred()) + /* -1 can be either the actual value, or an error flag. */ + return NULL; + } else if (PySlice_Check(key)) { + values = PyObject_GetItem(self->row, key); + if (values == NULL) + return NULL; + + processors = PyObject_GetItem(self->processors, key); + if (processors == NULL) { + Py_DECREF(values); + return NULL; + } + + result = BaseRowProxy_processvalues(values, processors, 1); + Py_DECREF(values); + Py_DECREF(processors); + return result; + } else { + record = PyDict_GetItem((PyObject *)self->keymap, key); + if (record == NULL) { + record = PyObject_CallMethod(self->parent, "_key_fallback", + "O", key); + if (record == NULL) + return NULL; + } + + indexobject = PyTuple_GetItem(record, 1); + if (indexobject == NULL) + return NULL; + + if (indexobject == Py_None) { + exc_module = PyImport_ImportModule("sqlalchemy.exc"); + if (exc_module == NULL) + return NULL; + + exception = PyObject_GetAttrString(exc_module, + "InvalidRequestError"); + Py_DECREF(exc_module); + if (exception == NULL) + return NULL; + + cstr_key = PyString_AsString(key); + if (cstr_key == NULL) + return NULL; + + PyErr_Format(exception, + "Ambiguous column name '%s' in result set! " + "try 'use_labels' option on select statement.", cstr_key); + return NULL; + } + + index = PyInt_AsLong(indexobject); + if ((index == -1) && PyErr_Occurred()) + /* -1 can be either the actual value, or an error flag. */ + return NULL; + } + processor = PyList_GetItem(self->processors, index); + if (processor == NULL) + return NULL; + + value = PyTuple_GetItem(self->row, index); + if (value == NULL) + return NULL; + + if (processor != Py_None) { + return PyObject_CallFunctionObjArgs(processor, value, NULL); + } else { + Py_INCREF(value); + return value; + } +} + +static PyObject * +BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name) +{ + PyObject *tmp; + + if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) + return NULL; + PyErr_Clear(); + } + else + return tmp; + + return BaseRowProxy_subscript(self, name); +} + +/*********************** + * getters and setters * + ***********************/ + +static PyObject * +BaseRowProxy_getparent(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->parent); + return self->parent; +} + +static int +BaseRowProxy_setparent(BaseRowProxy *self, PyObject *value, void *closure) +{ + PyObject *module, *cls; + + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'parent' attribute"); + return -1; + } + + module = PyImport_ImportModule("sqlalchemy.engine.base"); + if (module == NULL) + return -1; + + cls = PyObject_GetAttrString(module, "ResultMetaData"); + Py_DECREF(module); + if (cls == NULL) + return -1; + + if (PyObject_IsInstance(value, cls) != 1) { + PyErr_SetString(PyExc_TypeError, + "The 'parent' attribute value must be an instance of " + "ResultMetaData"); + return -1; + } + Py_DECREF(cls); + Py_XDECREF(self->parent); + Py_INCREF(value); + self->parent = value; + + return 0; +} + +static PyObject * +BaseRowProxy_getrow(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->row); + return self->row; +} + +static int +BaseRowProxy_setrow(BaseRowProxy *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'row' attribute"); + return -1; + } + + if (!PyTuple_CheckExact(value)) { + PyErr_SetString(PyExc_TypeError, + "The 'row' attribute value must be a tuple"); + return -1; + } + + Py_XDECREF(self->row); + Py_INCREF(value); + self->row = value; + + return 0; +} + +static PyObject * +BaseRowProxy_getprocessors(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->processors); + return self->processors; +} + +static int +BaseRowProxy_setprocessors(BaseRowProxy *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'processors' attribute"); + return -1; + } + + if (!PyList_CheckExact(value)) { + PyErr_SetString(PyExc_TypeError, + "The 'processors' attribute value must be a list"); + return -1; + } + + Py_XDECREF(self->processors); + Py_INCREF(value); + self->processors = value; + + return 0; +} + +static PyObject * +BaseRowProxy_getkeymap(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->keymap); + return self->keymap; +} + +static int +BaseRowProxy_setkeymap(BaseRowProxy *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'keymap' attribute"); + return -1; + } + + if (!PyDict_CheckExact(value)) { + PyErr_SetString(PyExc_TypeError, + "The 'keymap' attribute value must be a dict"); + return -1; + } + + Py_XDECREF(self->keymap); + Py_INCREF(value); + self->keymap = value; + + return 0; +} + +static PyGetSetDef BaseRowProxy_getseters[] = { + {"_parent", + (getter)BaseRowProxy_getparent, (setter)BaseRowProxy_setparent, + "ResultMetaData", + NULL}, + {"_row", + (getter)BaseRowProxy_getrow, (setter)BaseRowProxy_setrow, + "Original row tuple", + NULL}, + {"_processors", + (getter)BaseRowProxy_getprocessors, (setter)BaseRowProxy_setprocessors, + "list of type processors", + NULL}, + {"_keymap", + (getter)BaseRowProxy_getkeymap, (setter)BaseRowProxy_setkeymap, + "Key to (processor, index) dict", + NULL}, + {NULL} +}; + +static PyMethodDef BaseRowProxy_methods[] = { + {"values", (PyCFunction)BaseRowProxy_values, METH_NOARGS, + "Return the values represented by this BaseRowProxy as a list."}, + {"__reduce__", (PyCFunction)BaseRowProxy_reduce, METH_NOARGS, + "Pickle support method."}, + {NULL} /* Sentinel */ +}; + +static PySequenceMethods BaseRowProxy_as_sequence = { + (lenfunc)BaseRowProxy_length, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + 0, /* sq_item */ + 0, /* sq_slice */ + 0, /* sq_ass_item */ + 0, /* sq_ass_slice */ + 0, /* sq_contains */ + 0, /* sq_inplace_concat */ + 0, /* sq_inplace_repeat */ +}; + +static PyMappingMethods BaseRowProxy_as_mapping = { + (lenfunc)BaseRowProxy_length, /* mp_length */ + (binaryfunc)BaseRowProxy_subscript, /* mp_subscript */ + 0 /* mp_ass_subscript */ +}; + +static PyTypeObject BaseRowProxyType = { + PyObject_HEAD_INIT(NULL) + 0, /* ob_size */ + "sqlalchemy.cresultproxy.BaseRowProxy", /* tp_name */ + sizeof(BaseRowProxy), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)BaseRowProxy_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + &BaseRowProxy_as_sequence, /* tp_as_sequence */ + &BaseRowProxy_as_mapping, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + (getattrofunc)BaseRowProxy_getattro,/* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "BaseRowProxy is a abstract base class for RowProxy", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + (getiterfunc)BaseRowProxy_iter, /* tp_iter */ + 0, /* tp_iternext */ + BaseRowProxy_methods, /* tp_methods */ + 0, /* tp_members */ + BaseRowProxy_getseters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)BaseRowProxy_init, /* tp_init */ + 0, /* tp_alloc */ + 0 /* tp_new */ +}; + + +#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ +#define PyMODINIT_FUNC void +#endif + + +static PyMethodDef module_methods[] = { + {"safe_rowproxy_reconstructor", safe_rowproxy_reconstructor, METH_VARARGS, + "reconstruct a RowProxy instance from its pickled form."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +PyMODINIT_FUNC +initcresultproxy(void) +{ + PyObject *m; + + BaseRowProxyType.tp_new = PyType_GenericNew; + if (PyType_Ready(&BaseRowProxyType) < 0) + return; + + m = Py_InitModule3("cresultproxy", module_methods, + "Module containing C versions of core ResultProxy classes."); + if (m == NULL) + return; + + Py_INCREF(&BaseRowProxyType); + PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType); + +} + diff --git a/sqlalchemy/connectors/__init__.py b/sqlalchemy/connectors/__init__.py new file mode 100644 index 0000000..f1383ad --- /dev/null +++ b/sqlalchemy/connectors/__init__.py @@ -0,0 +1,6 @@ + + +class Connector(object): + pass + + \ No newline at end of file diff --git a/sqlalchemy/connectors/mxodbc.py b/sqlalchemy/connectors/mxodbc.py new file mode 100644 index 0000000..816474d --- /dev/null +++ b/sqlalchemy/connectors/mxodbc.py @@ -0,0 +1,146 @@ +""" +Provide an SQLALchemy connector for the eGenix mxODBC commercial +Python adapter for ODBC. This is not a free product, but eGenix +provides SQLAlchemy with a license for use in continuous integration +testing. + +This has been tested for use with mxODBC 3.1.2 on SQL Server 2005 +and 2008, using the SQL Server Native driver. However, it is +possible for this to be used on other database platforms. + +For more info on mxODBC, see http://www.egenix.com/ + +""" + +import sys +import re +import warnings +from decimal import Decimal + +from sqlalchemy.connectors import Connector +from sqlalchemy import types as sqltypes +import sqlalchemy.processors as processors + +class MxODBCConnector(Connector): + driver='mxodbc' + + supports_sane_multi_rowcount = False + supports_unicode_statements = False + supports_unicode_binds = False + + supports_native_decimal = True + + @classmethod + def dbapi(cls): + # this classmethod will normally be replaced by an instance + # attribute of the same name, so this is normally only called once. + cls._load_mx_exceptions() + platform = sys.platform + if platform == 'win32': + from mx.ODBC import Windows as module + # this can be the string "linux2", and possibly others + elif 'linux' in platform: + from mx.ODBC import unixODBC as module + elif platform == 'darwin': + from mx.ODBC import iODBC as module + else: + raise ImportError, "Unrecognized platform for mxODBC import" + return module + + @classmethod + def _load_mx_exceptions(cls): + """ Import mxODBC exception classes into the module namespace, + as if they had been imported normally. This is done here + to avoid requiring all SQLAlchemy users to install mxODBC. + """ + global InterfaceError, ProgrammingError + from mx.ODBC import InterfaceError + from mx.ODBC import ProgrammingError + + def on_connect(self): + def connect(conn): + conn.stringformat = self.dbapi.MIXED_STRINGFORMAT + conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT + conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT + conn.errorhandler = self._error_handler() + return connect + + def _error_handler(self): + """ Return a handler that adjusts mxODBC's raised Warnings to + emit Python standard warnings. + """ + from mx.ODBC.Error import Warning as MxOdbcWarning + def error_handler(connection, cursor, errorclass, errorvalue): + + if issubclass(errorclass, MxOdbcWarning): + errorclass.__bases__ = (Warning,) + warnings.warn(message=str(errorvalue), + category=errorclass, + stacklevel=2) + else: + raise errorclass, errorvalue + return error_handler + + def create_connect_args(self, url): + """ Return a tuple of *args,**kwargs for creating a connection. + + The mxODBC 3.x connection constructor looks like this: + + connect(dsn, user='', password='', + clear_auto_commit=1, errorhandler=None) + + This method translates the values in the provided uri + into args and kwargs needed to instantiate an mxODBC Connection. + + The arg 'errorhandler' is not used by SQLAlchemy and will + not be populated. + + """ + opts = url.translate_connect_args(username='user') + opts.update(url.query) + args = opts.pop('host') + opts.pop('port', None) + opts.pop('database', None) + return (args,), opts + + def is_disconnect(self, e): + # eGenix recommends checking connection.closed here, + # but how can we get a handle on the current connection? + if isinstance(e, self.dbapi.ProgrammingError): + return "connection already closed" in str(e) + elif isinstance(e, self.dbapi.Error): + return '[08S01]' in str(e) + else: + return False + + def _get_server_version_info(self, connection): + # eGenix suggests using conn.dbms_version instead of what we're doing here + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + # 18 == pyodbc.SQL_DBMS_VER + for n in r.split(dbapi_con.getinfo(18)[1]): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + + def do_execute(self, cursor, statement, parameters, context=None): + if context: + native_odbc_execute = context.execution_options.\ + get('native_odbc_execute', 'auto') + if native_odbc_execute is True: + # user specified native_odbc_execute=True + cursor.execute(statement, parameters) + elif native_odbc_execute is False: + # user specified native_odbc_execute=False + cursor.executedirect(statement, parameters) + elif context.is_crud: + # statement is UPDATE, DELETE, INSERT + cursor.execute(statement, parameters) + else: + # all other statements + cursor.executedirect(statement, parameters) + else: + cursor.executedirect(statement, parameters) diff --git a/sqlalchemy/connectors/pyodbc.py b/sqlalchemy/connectors/pyodbc.py new file mode 100644 index 0000000..b291f3e --- /dev/null +++ b/sqlalchemy/connectors/pyodbc.py @@ -0,0 +1,113 @@ +from sqlalchemy.connectors import Connector +from sqlalchemy.util import asbool + +import sys +import re +import urllib +import decimal + +class PyODBCConnector(Connector): + driver='pyodbc' + + supports_sane_multi_rowcount = False + # PyODBC unicode is broken on UCS-4 builds + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = supports_unicode + supports_native_decimal = True + default_paramstyle = 'named' + + # for non-DSN connections, this should + # hold the desired driver name + pyodbc_driver_name = None + + # will be set to True after initialize() + # if the freetds.so is detected + freetds = False + + @classmethod + def dbapi(cls): + return __import__('pyodbc') + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + + keys = opts + query = url.query + + connect_args = {} + for param in ('ansi', 'unicode_results', 'autocommit'): + if param in keys: + connect_args[param] = asbool(keys.pop(param)) + + if 'odbc_connect' in keys: + connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))] + else: + dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys) + if dsn_connection: + connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))] + else: + port = '' + if 'port' in keys and not 'port' in query: + port = ',%d' % int(keys.pop('port')) + + connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name), + 'Server=%s%s' % (keys.pop('host', ''), port), + 'Database=%s' % keys.pop('database', '') ] + + user = keys.pop("user", None) + if user: + connectors.append("UID=%s" % user) + connectors.append("PWD=%s" % keys.pop('password', '')) + else: + connectors.append("Trusted_Connection=Yes") + + # if set to 'Yes', the ODBC layer will try to automagically convert + # textual data from your database encoding to your client encoding + # This should obviously be set to 'No' if you query a cp1253 encoded + # database from a latin1 client... + if 'odbc_autotranslate' in keys: + connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate")) + + connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()]) + return [[";".join (connectors)], connect_args] + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.ProgrammingError): + return "The cursor's connection has been closed." in str(e) or \ + 'Attempt to use a closed connection.' in str(e) + elif isinstance(e, self.dbapi.Error): + return '[08S01]' in str(e) + else: + return False + + def initialize(self, connection): + # determine FreeTDS first. can't issue SQL easily + # without getting unicode_statements/binds set up. + + pyodbc = self.dbapi + + dbapi_con = connection.connection + + self.freetds = bool(re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME))) + + # the "Py2K only" part here is theoretical. + # have not tried pyodbc + python3.1 yet. + # Py2K + self.supports_unicode_statements = not self.freetds + self.supports_unicode_binds = not self.freetds + # end Py2K + + # run other initialization which asks for user name, etc. + super(PyODBCConnector, self).initialize(connection) + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) diff --git a/sqlalchemy/connectors/zxJDBC.py b/sqlalchemy/connectors/zxJDBC.py new file mode 100644 index 0000000..ae43128 --- /dev/null +++ b/sqlalchemy/connectors/zxJDBC.py @@ -0,0 +1,48 @@ +import sys +from sqlalchemy.connectors import Connector + +class ZxJDBCConnector(Connector): + driver = 'zxjdbc' + + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + supports_unicode_binds = True + supports_unicode_statements = sys.version > '2.5.0+' + description_encoding = None + default_paramstyle = 'qmark' + + jdbc_db_name = None + jdbc_driver_name = None + + @classmethod + def dbapi(cls): + from com.ziclix.python.sql import zxJDBC + return zxJDBC + + def _driver_kwargs(self): + """Return kw arg dict to be sent to connect().""" + return {} + + def _create_jdbc_url(self, url): + """Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`""" + return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host, + url.port is not None and ':%s' % url.port or '', + url.database) + + def create_connect_args(self, url): + opts = self._driver_kwargs() + opts.update(url.query) + return [[self._create_jdbc_url(url), url.username, url.password, self.jdbc_driver_name], + opts] + + def is_disconnect(self, e): + if not isinstance(e, self.dbapi.ProgrammingError): + return False + e = str(e) + return 'connection is closed' in e or 'cursor is closed' in e + + def _get_server_version_info(self, connection): + # use connection.connection.dbversion, and parse appropriately + # to get a tuple + raise NotImplementedError() diff --git a/sqlalchemy/databases/__init__.py b/sqlalchemy/databases/__init__.py new file mode 100644 index 0000000..3593f1d --- /dev/null +++ b/sqlalchemy/databases/__init__.py @@ -0,0 +1,31 @@ +# __init__.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from sqlalchemy.dialects.sqlite import base as sqlite +from sqlalchemy.dialects.postgresql import base as postgresql +postgres = postgresql +from sqlalchemy.dialects.mysql import base as mysql +from sqlalchemy.dialects.oracle import base as oracle +from sqlalchemy.dialects.firebird import base as firebird +from sqlalchemy.dialects.maxdb import base as maxdb +from sqlalchemy.dialects.informix import base as informix +from sqlalchemy.dialects.mssql import base as mssql +from sqlalchemy.dialects.access import base as access +from sqlalchemy.dialects.sybase import base as sybase + + +__all__ = ( + 'access', + 'firebird', + 'informix', + 'maxdb', + 'mssql', + 'mysql', + 'postgresql', + 'sqlite', + 'oracle', + 'sybase', + ) diff --git a/sqlalchemy/dialects/__init__.py b/sqlalchemy/dialects/__init__.py new file mode 100644 index 0000000..91ca91f --- /dev/null +++ b/sqlalchemy/dialects/__init__.py @@ -0,0 +1,12 @@ +__all__ = ( +# 'access', +# 'firebird', +# 'informix', +# 'maxdb', +# 'mssql', + 'mysql', + 'oracle', + 'postgresql', + 'sqlite', +# 'sybase', + ) diff --git a/sqlalchemy/dialects/access/__init__.py b/sqlalchemy/dialects/access/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sqlalchemy/dialects/access/base.py b/sqlalchemy/dialects/access/base.py new file mode 100644 index 0000000..2b76b93 --- /dev/null +++ b/sqlalchemy/dialects/access/base.py @@ -0,0 +1,418 @@ +# access.py +# Copyright (C) 2007 Paul Johnston, paj@pajhome.org.uk +# Portions derived from jet2sql.py by Matt Keranen, mksql@yahoo.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Support for the Microsoft Access database. + +This dialect is *not* ported to SQLAlchemy 0.6. + +This dialect is *not* tested on SQLAlchemy 0.6. + + +""" +from sqlalchemy import sql, schema, types, exc, pool +from sqlalchemy.sql import compiler, expression +from sqlalchemy.engine import default, base, reflection +from sqlalchemy import processors + +class AcNumeric(types.Numeric): + def get_col_spec(self): + return "NUMERIC" + + def bind_processor(self, dialect): + return processors.to_str + + def result_processor(self, dialect, coltype): + return None + +class AcFloat(types.Float): + def get_col_spec(self): + return "FLOAT" + + def bind_processor(self, dialect): + """By converting to string, we can use Decimal types round-trip.""" + return processors.to_str + +class AcInteger(types.Integer): + def get_col_spec(self): + return "INTEGER" + +class AcTinyInteger(types.Integer): + def get_col_spec(self): + return "TINYINT" + +class AcSmallInteger(types.SmallInteger): + def get_col_spec(self): + return "SMALLINT" + +class AcDateTime(types.DateTime): + def __init__(self, *a, **kw): + super(AcDateTime, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + +class AcDate(types.Date): + def __init__(self, *a, **kw): + super(AcDate, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + +class AcText(types.Text): + def get_col_spec(self): + return "MEMO" + +class AcString(types.String): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + +class AcUnicode(types.Unicode): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + return None + +class AcChar(types.CHAR): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + +class AcBinary(types.LargeBinary): + def get_col_spec(self): + return "BINARY" + +class AcBoolean(types.Boolean): + def get_col_spec(self): + return "YESNO" + +class AcTimeStamp(types.TIMESTAMP): + def get_col_spec(self): + return "TIMESTAMP" + +class AccessExecutionContext(default.DefaultExecutionContext): + def _has_implicit_sequence(self, column): + if column.primary_key and column.autoincrement: + if isinstance(column.type, types.Integer) and not column.foreign_keys: + if column.default is None or (isinstance(column.default, schema.Sequence) and \ + column.default.optional): + return True + return False + + def post_exec(self): + """If we inserted into a row with a COUNTER column, fetch the ID""" + + if self.compiled.isinsert: + tbl = self.compiled.statement.table + if not hasattr(tbl, 'has_sequence'): + tbl.has_sequence = None + for column in tbl.c: + if getattr(column, 'sequence', False) or self._has_implicit_sequence(column): + tbl.has_sequence = column + break + + if bool(tbl.has_sequence): + # TBD: for some reason _last_inserted_ids doesn't exist here + # (but it does at corresponding point in mssql???) + #if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:] + # print "LAST ROW ID", self._last_inserted_ids + + super(AccessExecutionContext, self).post_exec() + + +const, daoEngine = None, None +class AccessDialect(default.DefaultDialect): + colspecs = { + types.Unicode : AcUnicode, + types.Integer : AcInteger, + types.SmallInteger: AcSmallInteger, + types.Numeric : AcNumeric, + types.Float : AcFloat, + types.DateTime : AcDateTime, + types.Date : AcDate, + types.String : AcString, + types.LargeBinary : AcBinary, + types.Boolean : AcBoolean, + types.Text : AcText, + types.CHAR: AcChar, + types.TIMESTAMP: AcTimeStamp, + } + name = 'access' + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + ported_sqla_06 = False + + def type_descriptor(self, typeobj): + newobj = types.adapt_type(typeobj, self.colspecs) + return newobj + + def __init__(self, **params): + super(AccessDialect, self).__init__(**params) + self.text_as_varchar = False + self._dtbs = None + + def dbapi(cls): + import win32com.client, pythoncom + + global const, daoEngine + if const is None: + const = win32com.client.constants + for suffix in (".36", ".35", ".30"): + try: + daoEngine = win32com.client.gencache.EnsureDispatch("DAO.DBEngine" + suffix) + break + except pythoncom.com_error: + pass + else: + raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.") + + import pyodbc as module + return module + dbapi = classmethod(dbapi) + + def create_connect_args(self, url): + opts = url.translate_connect_args() + connectors = ["Driver={Microsoft Access Driver (*.mdb)}"] + connectors.append("Dbq=%s" % opts["database"]) + user = opts.get("username", None) + if user: + connectors.append("UID=%s" % user) + connectors.append("PWD=%s" % opts.get("password", "")) + return [[";".join(connectors)], {}] + + def last_inserted_ids(self): + return self.context.last_inserted_ids + + def do_execute(self, cursor, statement, params, **kwargs): + if params == {}: + params = () + super(AccessDialect, self).do_execute(cursor, statement, params, **kwargs) + + def _execute(self, c, statement, parameters): + try: + if parameters == {}: + parameters = () + c.execute(statement, parameters) + self.context.rowcount = c.rowcount + except Exception, e: + raise exc.DBAPIError.instance(statement, parameters, e) + + def has_table(self, connection, tablename, schema=None): + # This approach seems to be more reliable that using DAO + try: + connection.execute('select top 1 * from [%s]' % tablename) + return True + except Exception, e: + return False + + def reflecttable(self, connection, table, include_columns): + # This is defined in the function, as it relies on win32com constants, + # that aren't imported until dbapi method is called + if not hasattr(self, 'ischema_names'): + self.ischema_names = { + const.dbByte: AcBinary, + const.dbInteger: AcInteger, + const.dbLong: AcInteger, + const.dbSingle: AcFloat, + const.dbDouble: AcFloat, + const.dbDate: AcDateTime, + const.dbLongBinary: AcBinary, + const.dbMemo: AcText, + const.dbBoolean: AcBoolean, + const.dbText: AcUnicode, # All Access strings are unicode + const.dbCurrency: AcNumeric, + } + + # A fresh DAO connection is opened for each reflection + # This is necessary, so we get the latest updates + dtbs = daoEngine.OpenDatabase(connection.engine.url.database) + + try: + for tbl in dtbs.TableDefs: + if tbl.Name.lower() == table.name.lower(): + break + else: + raise exc.NoSuchTableError(table.name) + + for col in tbl.Fields: + coltype = self.ischema_names[col.Type] + if col.Type == const.dbText: + coltype = coltype(col.Size) + + colargs = \ + { + 'nullable': not(col.Required or col.Attributes & const.dbAutoIncrField), + } + default = col.DefaultValue + + if col.Attributes & const.dbAutoIncrField: + colargs['default'] = schema.Sequence(col.Name + '_seq') + elif default: + if col.Type == const.dbBoolean: + default = default == 'Yes' and '1' or '0' + colargs['server_default'] = schema.DefaultClause(sql.text(default)) + + table.append_column(schema.Column(col.Name, coltype, **colargs)) + + # TBD: check constraints + + # Find primary key columns first + for idx in tbl.Indexes: + if idx.Primary: + for col in idx.Fields: + thecol = table.c[col.Name] + table.primary_key.add(thecol) + if isinstance(thecol.type, AcInteger) and \ + not (thecol.default and isinstance(thecol.default.arg, schema.Sequence)): + thecol.autoincrement = False + + # Then add other indexes + for idx in tbl.Indexes: + if not idx.Primary: + if len(idx.Fields) == 1: + col = table.c[idx.Fields[0].Name] + if not col.primary_key: + col.index = True + col.unique = idx.Unique + else: + pass # TBD: multi-column indexes + + + for fk in dtbs.Relations: + if fk.ForeignTable != table.name: + continue + scols = [c.ForeignName for c in fk.Fields] + rcols = ['%s.%s' % (fk.Table, c.Name) for c in fk.Fields] + table.append_constraint(schema.ForeignKeyConstraint(scols, rcols, link_to_name=True)) + + finally: + dtbs.Close() + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + # A fresh DAO connection is opened for each reflection + # This is necessary, so we get the latest updates + dtbs = daoEngine.OpenDatabase(connection.engine.url.database) + + names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"] + dtbs.Close() + return names + + +class AccessCompiler(compiler.SQLCompiler): + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update ({ + 'month': 'm', + 'day': 'd', + 'year': 'yyyy', + 'second': 's', + 'hour': 'h', + 'doy': 'y', + 'minute': 'n', + 'quarter': 'q', + 'dow': 'w', + 'week': 'ww' + }) + + def visit_select_precolumns(self, select): + """Access puts TOP, it's version of LIMIT here """ + s = select.distinct and "DISTINCT " or "" + if select.limit: + s += "TOP %s " % (select.limit) + if select.offset: + raise exc.InvalidRequestError('Access does not support LIMIT with an offset') + return s + + def limit_clause(self, select): + """Limit in access is after the select keyword""" + return "" + + def binary_operator_string(self, binary): + """Access uses "mod" instead of "%" """ + return binary.operator == '%' and 'mod' or binary.operator + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label() + else: + return super(AccessCompiler, self).label_select_column(select, column, asfrom) + + function_rewrites = {'current_date': 'now', + 'current_timestamp': 'now', + 'length': 'len', + } + def visit_function(self, func): + """Access function names differ from the ANSI SQL names; rewrite common ones""" + func.name = self.function_rewrites.get(func.name, func.name) + return super(AccessCompiler, self).visit_function(func) + + def for_update_clause(self, select): + """FOR UPDATE is not supported by Access; silently ignore""" + return '' + + # Strip schema + def visit_table(self, table, asfrom=False, **kwargs): + if asfrom: + return self.preparer.quote(table.name, table.quote) + else: + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \ + self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) + + def visit_extract(self, extract, **kw): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) + +class AccessDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() + + # install a sequence if we have an implicit IDENTITY column + if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ + column.autoincrement and isinstance(column.type, types.Integer) and not column.foreign_keys: + if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): + column.sequence = schema.Sequence(column.name + '_seq') + + if not column.nullable: + colspec += " NOT NULL" + + if hasattr(column, 'sequence'): + column.table.has_sequence = column + colspec = self.preparer.format_column(column) + " counter" + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + index = drop.element + self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False))) + +class AccessIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = compiler.RESERVED_WORDS.copy() + reserved_words.update(['value', 'text']) + def __init__(self, dialect): + super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') + + +dialect = AccessDialect +dialect.poolclass = pool.SingletonThreadPool +dialect.statement_compiler = AccessCompiler +dialect.ddlcompiler = AccessDDLCompiler +dialect.preparer = AccessIdentifierPreparer +dialect.execution_ctx_cls = AccessExecutionContext diff --git a/sqlalchemy/dialects/firebird/__init__.py b/sqlalchemy/dialects/firebird/__init__.py new file mode 100644 index 0000000..f39e93c --- /dev/null +++ b/sqlalchemy/dialects/firebird/__init__.py @@ -0,0 +1,16 @@ +from sqlalchemy.dialects.firebird import base, kinterbasdb + +base.dialect = kinterbasdb.dialect + +from sqlalchemy.dialects.firebird.base import \ + SMALLINT, BIGINT, FLOAT, FLOAT, DATE, TIME, \ + TEXT, NUMERIC, FLOAT, TIMESTAMP, VARCHAR, CHAR, BLOB,\ + dialect + +__all__ = ( + 'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME', + 'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB', + 'dialect' +) + + diff --git a/sqlalchemy/dialects/firebird/base.py b/sqlalchemy/dialects/firebird/base.py new file mode 100644 index 0000000..7031815 --- /dev/null +++ b/sqlalchemy/dialects/firebird/base.py @@ -0,0 +1,619 @@ +# firebird.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Support for the Firebird database. + +Connectivity is usually supplied via the kinterbasdb_ DBAPI module. + +Dialects +~~~~~~~~ + +Firebird offers two distinct dialects_ (not to be confused with a +SQLAlchemy ``Dialect``): + +dialect 1 + This is the old syntax and behaviour, inherited from Interbase pre-6.0. + +dialect 3 + This is the newer and supported syntax, introduced in Interbase 6.0. + +The SQLAlchemy Firebird dialect detects these versions and +adjusts its representation of SQL accordingly. However, +support for dialect 1 is not well tested and probably has +incompatibilities. + +Locking Behavior +~~~~~~~~~~~~~~~~ + +Firebird locks tables aggressively. For this reason, a DROP TABLE may +hang until other transactions are released. SQLAlchemy does its best +to release transactions as quickly as possible. The most common cause +of hanging transactions is a non-fully consumed result set, i.e.:: + + result = engine.execute("select * from table") + row = result.fetchone() + return + +Where above, the ``ResultProxy`` has not been fully consumed. The +connection will be returned to the pool and the transactional state +rolled back once the Python garbage collector reclaims the objects +which hold onto the connection, which often occurs asynchronously. +The above use case can be alleviated by calling ``first()`` on the +``ResultProxy`` which will fetch the first row and immediately close +all remaining cursor/connection resources. + +RETURNING support +~~~~~~~~~~~~~~~~~ + +Firebird 2.0 supports returning a result set from inserts, and 2.1 +extends that to deletes and updates. This is generically exposed by +the SQLAlchemy ``returning()`` method, such as:: + + # INSERT..RETURNING + result = table.insert().returning(table.c.col1, table.c.col2).\\ + values(name='foo') + print result.fetchall() + + # UPDATE..RETURNING + raises = empl.update().returning(empl.c.id, empl.c.salary).\\ + where(empl.c.sales>100).\\ + values(dict(salary=empl.c.salary * 1.1)) + print raises.fetchall() + + +.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html + +""" + +import datetime, re + +from sqlalchemy import schema as sa_schema +from sqlalchemy import exc, types as sqltypes, sql, util +from sqlalchemy.sql import expression +from sqlalchemy.engine import base, default, reflection +from sqlalchemy.sql import compiler + + +from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE, + FLOAT, INTEGER, NUMERIC, SMALLINT, + TEXT, TIME, TIMESTAMP, VARCHAR) + + +RESERVED_WORDS = set([ + "active", "add", "admin", "after", "all", "alter", "and", "any", "as", + "asc", "ascending", "at", "auto", "avg", "before", "begin", "between", + "bigint", "bit_length", "blob", "both", "by", "case", "cast", "char", + "character", "character_length", "char_length", "check", "close", + "collate", "column", "commit", "committed", "computed", "conditional", + "connect", "constraint", "containing", "count", "create", "cross", + "cstring", "current", "current_connection", "current_date", + "current_role", "current_time", "current_timestamp", + "current_transaction", "current_user", "cursor", "database", "date", + "day", "dec", "decimal", "declare", "default", "delete", "desc", + "descending", "disconnect", "distinct", "do", "domain", "double", + "drop", "else", "end", "entry_point", "escape", "exception", + "execute", "exists", "exit", "external", "extract", "fetch", "file", + "filter", "float", "for", "foreign", "from", "full", "function", + "gdscode", "generator", "gen_id", "global", "grant", "group", + "having", "hour", "if", "in", "inactive", "index", "inner", + "input_type", "insensitive", "insert", "int", "integer", "into", "is", + "isolation", "join", "key", "leading", "left", "length", "level", + "like", "long", "lower", "manual", "max", "maximum_segment", "merge", + "min", "minute", "module_name", "month", "names", "national", + "natural", "nchar", "no", "not", "null", "numeric", "octet_length", + "of", "on", "only", "open", "option", "or", "order", "outer", + "output_type", "overflow", "page", "pages", "page_size", "parameter", + "password", "plan", "position", "post_event", "precision", "primary", + "privileges", "procedure", "protected", "rdb$db_key", "read", "real", + "record_version", "recreate", "recursive", "references", "release", + "reserv", "reserving", "retain", "returning_values", "returns", + "revoke", "right", "rollback", "rows", "row_count", "savepoint", + "schema", "second", "segment", "select", "sensitive", "set", "shadow", + "shared", "singular", "size", "smallint", "snapshot", "some", "sort", + "sqlcode", "stability", "start", "starting", "starts", "statistics", + "sub_type", "sum", "suspend", "table", "then", "time", "timestamp", + "to", "trailing", "transaction", "trigger", "trim", "uncommitted", + "union", "unique", "update", "upper", "user", "using", "value", + "values", "varchar", "variable", "varying", "view", "wait", "when", + "where", "while", "with", "work", "write", "year", + ]) + + +colspecs = { +} + +ischema_names = { + 'SHORT': SMALLINT, + 'LONG': BIGINT, + 'QUAD': FLOAT, + 'FLOAT': FLOAT, + 'DATE': DATE, + 'TIME': TIME, + 'TEXT': TEXT, + 'INT64': NUMERIC, + 'DOUBLE': FLOAT, + 'TIMESTAMP': TIMESTAMP, + 'VARYING': VARCHAR, + 'CSTRING': CHAR, + 'BLOB': BLOB, + } + + +# TODO: date conversion types (should be implemented as _FBDateTime, _FBDate, etc. +# as bind/result functionality is required) + +class FBTypeCompiler(compiler.GenericTypeCompiler): + def visit_boolean(self, type_): + return self.visit_SMALLINT(type_) + + def visit_datetime(self, type_): + return self.visit_TIMESTAMP(type_) + + def visit_TEXT(self, type_): + return "BLOB SUB_TYPE 1" + + def visit_BLOB(self, type_): + return "BLOB SUB_TYPE 0" + + +class FBCompiler(sql.compiler.SQLCompiler): + """Firebird specific idiosincrasies""" + + def visit_mod(self, binary, **kw): + # Firebird lacks a builtin modulo operator, but there is + # an equivalent function in the ib_udf library. + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_alias(self, alias, asfrom=False, **kwargs): + if self.dialect._version_two: + return super(FBCompiler, self).visit_alias(alias, asfrom=asfrom, **kwargs) + else: + # Override to not use the AS keyword which FB 1.5 does not like + if asfrom: + alias_name = isinstance(alias.name, expression._generated_label) and \ + self._truncated_identifier("alias", alias.name) or alias.name + + return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + \ + self.preparer.format_alias(alias, alias_name) + else: + return self.process(alias.original, **kwargs) + + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0]) + start = self.process(func.clauses.clauses[1]) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2]) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + else: + return "SUBSTRING(%s FROM %s)" % (s, start) + + def visit_length_func(self, function, **kw): + if self.dialect._version_two: + return "char_length" + self.function_argspec(function) + else: + return "strlen" + self.function_argspec(function) + + visit_char_length_func = visit_length_func + + def function_argspec(self, func, **kw): + if func.clauses is not None and len(func.clauses): + return self.process(func.clause_expr) + else: + return "" + + def default_from(self): + return " FROM rdb$database" + + def visit_sequence(self, seq): + return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) + + def get_select_precolumns(self, select): + """Called when building a ``SELECT`` statement, position is just + before column list Firebird puts the limit and offset right + after the ``SELECT``... + """ + + result = "" + if select._limit: + result += "FIRST %d " % select._limit + if select._offset: + result +="SKIP %d " % select._offset + if select._distinct: + result += "DISTINCT " + return result + + def limit_clause(self, select): + """Already taken care of in the `get_select_precolumns` method.""" + + return "" + + def returning_clause(self, stmt, returning_cols): + + columns = [ + self.process( + self.label_select_column(None, c, asfrom=False), + within_columns_clause=True, + result_map=self.result_map + ) + for c in expression._select_iterables(returning_cols) + ] + return 'RETURNING ' + ', '.join(columns) + + +class FBDDLCompiler(sql.compiler.DDLCompiler): + """Firebird syntactic idiosincrasies""" + + def visit_create_sequence(self, create): + """Generate a ``CREATE GENERATOR`` statement for the sequence.""" + + # no syntax for these + # http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html + if create.element.start is not None: + raise NotImplemented("Firebird SEQUENCE doesn't support START WITH") + if create.element.increment is not None: + raise NotImplemented("Firebird SEQUENCE doesn't support INCREMENT BY") + + if self.dialect._version_two: + return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + else: + return "CREATE GENERATOR %s" % self.preparer.format_sequence(create.element) + + def visit_drop_sequence(self, drop): + """Generate a ``DROP GENERATOR`` statement for the sequence.""" + + if self.dialect._version_two: + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + else: + return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element) + + +class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): + """Install Firebird specific reserved words.""" + + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) + + +class FBExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq): + """Get the next value from the sequence using ``gen_id()``.""" + + return self._execute_scalar("SELECT gen_id(%s, 1) FROM rdb$database" % \ + self.dialect.identifier_preparer.format_sequence(seq)) + + +class FBDialect(default.DefaultDialect): + """Firebird dialect""" + + name = 'firebird' + + max_identifier_length = 31 + + supports_sequences = True + sequences_optional = False + supports_default_values = True + postfetch_lastrowid = False + + supports_native_boolean = False + + requires_name_normalize = True + supports_empty_insert = False + + + statement_compiler = FBCompiler + ddl_compiler = FBDDLCompiler + preparer = FBIdentifierPreparer + type_compiler = FBTypeCompiler + execution_ctx_cls = FBExecutionContext + + colspecs = colspecs + ischema_names = ischema_names + + # defaults to dialect ver. 3, + # will be autodetected off upon + # first connect + _version_two = True + + def initialize(self, connection): + super(FBDialect, self).initialize(connection) + self._version_two = self.server_version_info > (2, ) + if not self._version_two: + # TODO: whatever other pre < 2.0 stuff goes here + self.ischema_names = ischema_names.copy() + self.ischema_names['TIMESTAMP'] = sqltypes.DATE + self.colspecs = { + sqltypes.DateTime: sqltypes.DATE + } + else: + self.implicit_returning = True + + def normalize_name(self, name): + # Remove trailing spaces: FB uses a CHAR() type, + # that is padded with spaces + name = name and name.rstrip() + if name is None: + return None + elif name.upper() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): + return name.lower() + else: + return name + + def denormalize_name(self, name): + if name is None: + return None + elif name.lower() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): + return name.upper() + else: + return name + + def has_table(self, connection, table_name, schema=None): + """Return ``True`` if the given table exists, ignoring the `schema`.""" + + tblqry = """ + SELECT 1 FROM rdb$database + WHERE EXISTS (SELECT rdb$relation_name + FROM rdb$relations + WHERE rdb$relation_name=?) + """ + c = connection.execute(tblqry, [self.denormalize_name(table_name)]) + return c.first() is not None + + def has_sequence(self, connection, sequence_name, schema=None): + """Return ``True`` if the given sequence (generator) exists.""" + + genqry = """ + SELECT 1 FROM rdb$database + WHERE EXISTS (SELECT rdb$generator_name + FROM rdb$generators + WHERE rdb$generator_name=?) + """ + c = connection.execute(genqry, [self.denormalize_name(sequence_name)]) + return c.first() is not None + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + s = """ + SELECT DISTINCT rdb$relation_name + FROM rdb$relation_fields + WHERE rdb$system_flag=0 AND rdb$view_context IS NULL + """ + return [self.normalize_name(row[0]) for row in connection.execute(s)] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + s = """ + SELECT distinct rdb$view_name + FROM rdb$view_relations + """ + return [self.normalize_name(row[0]) for row in connection.execute(s)] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + qry = """ + SELECT rdb$view_source AS view_source + FROM rdb$relations + WHERE rdb$relation_name=? + """ + rp = connection.execute(qry, [self.denormalize_name(view_name)]) + row = rp.first() + if row: + return row['view_source'] + else: + return None + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + # Query to extract the PK/FK constrained fields of the given table + keyqry = """ + SELECT se.rdb$field_name AS fname + FROM rdb$relation_constraints rc + JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + """ + tablename = self.denormalize_name(table_name) + # get primary key fields + c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) + pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()] + return pkfields + + @reflection.cache + def get_column_sequence(self, connection, table_name, column_name, schema=None, **kw): + tablename = self.denormalize_name(table_name) + colname = self.denormalize_name(column_name) + # Heuristic-query to determine the generator associated to a PK field + genqry = """ + SELECT trigdep.rdb$depended_on_name AS fgenerator + FROM rdb$dependencies tabdep + JOIN rdb$dependencies trigdep + ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name + AND trigdep.rdb$depended_on_type=14 + AND trigdep.rdb$dependent_type=2 + JOIN rdb$triggers trig ON trig.rdb$trigger_name=tabdep.rdb$dependent_name + WHERE tabdep.rdb$depended_on_name=? + AND tabdep.rdb$depended_on_type=0 + AND trig.rdb$trigger_type=1 + AND tabdep.rdb$field_name=? + AND (SELECT count(*) + FROM rdb$dependencies trigdep2 + WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2 + """ + genr = connection.execute(genqry, [tablename, colname]).first() + if genr is not None: + return dict(name=self.normalize_name(genr['fgenerator'])) + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + # Query to extract the details of all the fields of the given table + tblqry = """ + SELECT DISTINCT r.rdb$field_name AS fname, + r.rdb$null_flag AS null_flag, + t.rdb$type_name AS ftype, + f.rdb$field_sub_type AS stype, + f.rdb$field_length/COALESCE(cs.rdb$bytes_per_character,1) AS flen, + f.rdb$field_precision AS fprec, + f.rdb$field_scale AS fscale, + COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault + FROM rdb$relation_fields r + JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name + JOIN rdb$types t + ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE' + LEFT JOIN rdb$character_sets cs ON f.rdb$character_set_id=cs.rdb$character_set_id + WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=? + ORDER BY r.rdb$field_position + """ + # get the PK, used to determine the eventual associated sequence + pkey_cols = self.get_primary_keys(connection, table_name) + + tablename = self.denormalize_name(table_name) + # get all of the fields for this table + c = connection.execute(tblqry, [tablename]) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + name = self.normalize_name(row['fname']) + orig_colname = row['fname'] + + # get the data type + colspec = row['ftype'].rstrip() + coltype = self.ischema_names.get(colspec) + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % + (colspec, name)) + coltype = sqltypes.NULLTYPE + elif colspec == 'INT64': + coltype = coltype(precision=row['fprec'], scale=row['fscale'] * -1) + elif colspec in ('VARYING', 'CSTRING'): + coltype = coltype(row['flen']) + elif colspec == 'TEXT': + coltype = TEXT(row['flen']) + elif colspec == 'BLOB': + if row['stype'] == 1: + coltype = TEXT() + else: + coltype = BLOB() + else: + coltype = coltype(row) + + # does it have a default value? + defvalue = None + if row['fdefault'] is not None: + # the value comes down as "DEFAULT 'value'": there may be + # more than one whitespace around the "DEFAULT" keyword + # (see also http://tracker.firebirdsql.org/browse/CORE-356) + defexpr = row['fdefault'].lstrip() + assert defexpr[:8].rstrip()=='DEFAULT', "Unrecognized default value: %s" % defexpr + defvalue = defexpr[8:].strip() + if defvalue == 'NULL': + # Redundant + defvalue = None + col_d = { + 'name' : name, + 'type' : coltype, + 'nullable' : not bool(row['null_flag']), + 'default' : defvalue + } + + if orig_colname.lower() == orig_colname: + col_d['quote'] = True + + # if the PK is a single field, try to see if its linked to + # a sequence thru a trigger + if len(pkey_cols)==1 and name==pkey_cols[0]: + seq_d = self.get_column_sequence(connection, tablename, name) + if seq_d is not None: + col_d['sequence'] = seq_d + + cols.append(col_d) + return cols + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # Query to extract the details of each UK/FK of the given table + fkqry = """ + SELECT rc.rdb$constraint_name AS cname, + cse.rdb$field_name AS fname, + ix2.rdb$relation_name AS targetrname, + se.rdb$field_name AS targetfname + FROM rdb$relation_constraints rc + JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name + JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key + JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name + JOIN rdb$index_segments se + ON se.rdb$index_name=ix2.rdb$index_name + AND se.rdb$field_position=cse.rdb$field_position + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + ORDER BY se.rdb$index_name, se.rdb$field_position + """ + tablename = self.denormalize_name(table_name) + + c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) + fks = util.defaultdict(lambda:{ + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + }) + + for row in c: + cname = self.normalize_name(row['cname']) + fk = fks[cname] + if not fk['name']: + fk['name'] = cname + fk['referred_table'] = self.normalize_name(row['targetrname']) + fk['constrained_columns'].append(self.normalize_name(row['fname'])) + fk['referred_columns'].append( + self.normalize_name(row['targetfname'])) + return fks.values() + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + qry = """ + SELECT ix.rdb$index_name AS index_name, + ix.rdb$unique_flag AS unique_flag, + ic.rdb$field_name AS field_name + FROM rdb$indices ix + JOIN rdb$index_segments ic + ON ix.rdb$index_name=ic.rdb$index_name + LEFT OUTER JOIN rdb$relation_constraints + ON rdb$relation_constraints.rdb$index_name = ic.rdb$index_name + WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL + AND rdb$relation_constraints.rdb$constraint_type IS NULL + ORDER BY index_name, field_name + """ + c = connection.execute(qry, [self.denormalize_name(table_name)]) + + indexes = util.defaultdict(dict) + for row in c: + indexrec = indexes[row['index_name']] + if 'name' not in indexrec: + indexrec['name'] = self.normalize_name(row['index_name']) + indexrec['column_names'] = [] + indexrec['unique'] = bool(row['unique_flag']) + + indexrec['column_names'].append(self.normalize_name(row['field_name'])) + + return indexes.values() + + def do_execute(self, cursor, statement, parameters, **kwargs): + # kinterbase does not accept a None, but wants an empty list + # when there are no arguments. + cursor.execute(statement, parameters or []) + + def do_rollback(self, connection): + # Use the retaining feature, that keeps the transaction going + connection.rollback(True) + + def do_commit(self, connection): + # Use the retaining feature, that keeps the transaction going + connection.commit(True) diff --git a/sqlalchemy/dialects/firebird/kinterbasdb.py b/sqlalchemy/dialects/firebird/kinterbasdb.py new file mode 100644 index 0000000..9984d32 --- /dev/null +++ b/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -0,0 +1,120 @@ +# kinterbasdb.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +The most common way to connect to a Firebird engine is implemented by +kinterbasdb__, currently maintained__ directly by the Firebird people. + +The connection URL is of the form +``firebird[+kinterbasdb]://user:password@host:port/path/to/db[?key=value&key=value...]``. + +Kinterbasedb backend specific keyword arguments are: + +type_conv + select the kind of mapping done on the types: by default SQLAlchemy + uses 200 with Unicode, datetime and decimal support (see details__). + +concurrency_level + set the backend policy with regards to threading issues: by default + SQLAlchemy uses policy 1 (see details__). + +__ http://sourceforge.net/projects/kinterbasdb +__ http://firebirdsql.org/index.php?op=devel&sub=python +__ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation +__ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency +""" + +from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler +from sqlalchemy import util, types as sqltypes + +class _FBNumeric_kinterbasdb(sqltypes.Numeric): + def bind_processor(self, dialect): + def process(value): + if value is not None: + return str(value) + else: + return value + return process + +class FBDialect_kinterbasdb(FBDialect): + driver = 'kinterbasdb' + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + supports_native_decimal = True + + colspecs = util.update_copy( + FBDialect.colspecs, + { + sqltypes.Numeric:_FBNumeric_kinterbasdb + } + + ) + + def __init__(self, type_conv=200, concurrency_level=1, **kwargs): + super(FBDialect_kinterbasdb, self).__init__(**kwargs) + + self.type_conv = type_conv + self.concurrency_level = concurrency_level + + @classmethod + def dbapi(cls): + k = __import__('kinterbasdb') + return k + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if opts.get('port'): + opts['host'] = "%s/%s" % (opts['host'], opts['port']) + del opts['port'] + opts.update(url.query) + + type_conv = opts.pop('type_conv', self.type_conv) + concurrency_level = opts.pop('concurrency_level', self.concurrency_level) + + if self.dbapi is not None: + initialized = getattr(self.dbapi, 'initialized', None) + if initialized is None: + # CVS rev 1.96 changed the name of the attribute: + # http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96 + initialized = getattr(self.dbapi, '_initialized', False) + if not initialized: + self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level) + return ([], opts) + + def _get_server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. + + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ + + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. This is more than enough for our purposes, + # as this is mainly (only?) used by the testsuite. + + from re import match + + fbconn = connection.connection + version = fbconn.server_version + m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version) + if not m: + raise AssertionError("Could not determine version from string '%s'" % version) + return tuple([int(x) for x in m.group(5, 6, 4)]) + + def is_disconnect(self, e): + if isinstance(e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)): + msg = str(e) + return ('Unable to complete network request to host' in msg or + 'Invalid connection state' in msg or + 'Invalid cursor state' in msg) + else: + return False + +dialect = FBDialect_kinterbasdb diff --git a/sqlalchemy/dialects/informix/__init__.py b/sqlalchemy/dialects/informix/__init__.py new file mode 100644 index 0000000..f2fcc76 --- /dev/null +++ b/sqlalchemy/dialects/informix/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.informix import base, informixdb + +base.dialect = informixdb.dialect \ No newline at end of file diff --git a/sqlalchemy/dialects/informix/base.py b/sqlalchemy/dialects/informix/base.py new file mode 100644 index 0000000..266a74a --- /dev/null +++ b/sqlalchemy/dialects/informix/base.py @@ -0,0 +1,306 @@ +# informix.py +# Copyright (C) 2005,2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# coding: gbk +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Support for the Informix database. + +This dialect is *not* tested on SQLAlchemy 0.6. + + +""" + + +import datetime + +from sqlalchemy import sql, schema, exc, pool, util +from sqlalchemy.sql import compiler +from sqlalchemy.engine import default, reflection +from sqlalchemy import types as sqltypes + + +class InfoDateTime(sqltypes.DateTime): + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace(microsecond=0) + return value + return process + +class InfoTime(sqltypes.Time): + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace(microsecond=0) + return value + return process + + def result_processor(self, dialect, coltype): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + else: + return value + return process + + +colspecs = { + sqltypes.DateTime : InfoDateTime, + sqltypes.Time: InfoTime, +} + + +ischema_names = { + 0 : sqltypes.CHAR, # CHAR + 1 : sqltypes.SMALLINT, # SMALLINT + 2 : sqltypes.INTEGER, # INT + 3 : sqltypes.FLOAT, # Float + 3 : sqltypes.Float, # SmallFloat + 5 : sqltypes.DECIMAL, # DECIMAL + 6 : sqltypes.Integer, # Serial + 7 : sqltypes.DATE, # DATE + 8 : sqltypes.Numeric, # MONEY + 10 : sqltypes.DATETIME, # DATETIME + 11 : sqltypes.LargeBinary, # BYTE + 12 : sqltypes.TEXT, # TEXT + 13 : sqltypes.VARCHAR, # VARCHAR + 15 : sqltypes.NCHAR, # NCHAR + 16 : sqltypes.NVARCHAR, # NVARCHAR + 17 : sqltypes.Integer, # INT8 + 18 : sqltypes.Integer, # Serial8 + 43 : sqltypes.String, # LVARCHAR + -1 : sqltypes.BLOB, # BLOB + -1 : sqltypes.CLOB, # CLOB +} + + +class InfoTypeCompiler(compiler.GenericTypeCompiler): + def visit_DATETIME(self, type_): + return "DATETIME YEAR TO SECOND" + + def visit_TIME(self, type_): + return "DATETIME HOUR TO SECOND" + + def visit_large_binary(self, type_): + return "BYTE" + + def visit_boolean(self, type_): + return "SMALLINT" + +class InfoSQLCompiler(compiler.SQLCompiler): + + def default_from(self): + return " from systables where tabname = 'systables' " + + def get_select_precolumns(self, select): + s = select._distinct and "DISTINCT " or "" + # only has limit + if select._limit: + s += " FIRST %s " % select._limit + else: + s += "" + return s + + def visit_select(self, select): + # the column in order by clause must in select too + + def __label(c): + try: + return c._label.lower() + except: + return '' + + # TODO: dont modify the original select, generate a new one + a = [__label(c) for c in select._raw_columns] + for c in select._order_by_clause.clauses: + if __label(c) not in a: + select.append_column(c) + + return compiler.SQLCompiler.visit_select(self, select) + + def limit_clause(self, select): + if select._offset is not None and select._offset > 0: + raise NotImplementedError("Informix does not support OFFSET") + return "" + + def visit_function(self, func): + if func.name.lower() == 'current_date': + return "today" + elif func.name.lower() == 'current_time': + return "CURRENT HOUR TO SECOND" + elif func.name.lower() in ('current_timestamp', 'now'): + return "CURRENT YEAR TO SECOND" + else: + return compiler.SQLCompiler.visit_function(self, func) + + +class InfoDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, first_pk=False): + colspec = self.preparer.format_column(column) + if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \ + isinstance(column.type, sqltypes.Integer) and first_pk: + colspec += " SERIAL" + else: + colspec += " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + + return colspec + + +class InfoIdentifierPreparer(compiler.IdentifierPreparer): + def __init__(self, dialect): + super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'") + + def format_constraint(self, constraint): + # informix doesnt support names for constraints + return '' + + def _requires_quotes(self, value): + return False + +class InformixDialect(default.DefaultDialect): + name = 'informix' + + max_identifier_length = 128 # adjusts at runtime based on server version + + type_compiler = InfoTypeCompiler + statement_compiler = InfoSQLCompiler + ddl_compiler = InfoDDLCompiler + preparer = InfoIdentifierPreparer + colspecs = colspecs + ischema_names = ischema_names + + def initialize(self, connection): + super(InformixDialect, self).initialize(connection) + + # http://www.querix.com/support/knowledge-base/error_number_message/error_200 + if self.server_version_info < (9, 2): + self.max_identifier_length = 18 + else: + self.max_identifier_length = 128 + + def do_begin(self, connect): + cu = connect.cursor() + cu.execute('SET LOCK MODE TO WAIT') + #cu.execute('SET ISOLATION TO REPEATABLE READ') + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + s = "select tabname from systables" + return [row[0] for row in connection.execute(s)] + + def has_table(self, connection, table_name, schema=None): + cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower()) + return cursor.first() is not None + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from + syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3 + where t1.tabid = t2.tabid and t2.tabname=? + and t3.tabid = t2.tabid and t3.colno = t1.colno + order by t1.colno""", table.name.lower()) + columns = [] + for name, colattr, collength, default, colno in rows: + name = name.lower() + if include_columns and name not in include_columns: + continue + + # in 7.31, coltype = 0x000 + # ^^-- column type + # ^-- 1 not null, 0 null + nullable, coltype = divmod(colattr, 256) + if coltype not in (0, 13) and default: + default = default.split()[-1] + + if coltype == 0 or coltype == 13: # char, varchar + coltype = ischema_names[coltype](collength) + if default: + default = "'%s'" % default + elif coltype == 5: # decimal + precision, scale = (collength & 0xFF00) >> 8, collength & 0xFF + if scale == 255: + scale = 0 + coltype = sqltypes.Numeric(precision, scale) + else: + try: + coltype = ischema_names[coltype] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, name)) + coltype = sqltypes.NULLTYPE + + # TODO: nullability ?? + nullable = True + + column_info = dict(name=name, type=coltype, nullable=nullable, + default=default) + columns.append(column_info) + return columns + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # FK + c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type , + t4.colname as local_column , t7.tabname as remote_table , + t6.colname as remote_column + from sysconstraints as t1 , systables as t2 , + sysindexes as t3 , syscolumns as t4 , + sysreferences as t5 , syscolumns as t6 , systables as t7 , + sysconstraints as t8 , sysindexes as t9 + where t1.tabid = t2.tabid and t2.tabname=? and t1.constrtype = 'R' + and t3.tabid = t2.tabid and t3.idxname = t1.idxname + and t4.tabid = t2.tabid and t4.colno = t3.part1 + and t5.constrid = t1.constrid and t8.constrid = t5.primary + and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname + and t7.tabid = t5.ptabid""", table.name.lower()) + + + def fkey_rec(): + return { + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + } + + fkeys = util.defaultdict(fkey_rec) + + for cons_name, cons_type, local_column, remote_table, remote_column in rows: + + rec = fkeys[cons_name] + rec['name'] = cons_name + local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + + if not rec['referred_table']: + rec['referred_table'] = remote_table + + local_cols.append(local_column) + remote_cols.append(remote_column) + + return fkeys.values() + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + c = connection.execute("""select t4.colname as local_column + from sysconstraints as t1 , systables as t2 , + sysindexes as t3 , syscolumns as t4 + where t1.tabid = t2.tabid and t2.tabname=? and t1.constrtype = 'P' + and t3.tabid = t2.tabid and t3.idxname = t1.idxname + and t4.tabid = t2.tabid and t4.colno = t3.part1""", table.name.lower()) + return [r[0] for r in c.fetchall()] + + @reflection.cache + def get_indexes(self, connection, table_name, schema, **kw): + # TODO + return [] diff --git a/sqlalchemy/dialects/informix/informixdb.py b/sqlalchemy/dialects/informix/informixdb.py new file mode 100644 index 0000000..a1305c4 --- /dev/null +++ b/sqlalchemy/dialects/informix/informixdb.py @@ -0,0 +1,46 @@ +from sqlalchemy.dialects.informix.base import InformixDialect +from sqlalchemy.engine import default + +class InformixExecutionContext_informixdb(default.DefaultExecutionContext): + def post_exec(self): + if self.isinsert: + self._lastrowid = [self.cursor.sqlerrd[1]] + + +class InformixDialect_informixdb(InformixDialect): + driver = 'informixdb' + default_paramstyle = 'qmark' + execution_context_cls = InformixExecutionContext_informixdb + + @classmethod + def dbapi(cls): + return __import__('informixdb') + + def create_connect_args(self, url): + if url.host: + dsn = '%s@%s' % (url.database, url.host) + else: + dsn = url.database + + if url.username: + opt = {'user': url.username, 'password': url.password} + else: + opt = {} + + return ([dsn], opt) + + def _get_server_version_info(self, connection): + # http://informixdb.sourceforge.net/manual.html#inspecting-version-numbers + vers = connection.dbms_version + + # TODO: not tested + return tuple([int(x) for x in vers.split('.')]) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'closed the connection' in str(e) or 'connection not open' in str(e) + else: + return False + + +dialect = InformixDialect_informixdb diff --git a/sqlalchemy/dialects/maxdb/__init__.py b/sqlalchemy/dialects/maxdb/__init__.py new file mode 100644 index 0000000..3f12448 --- /dev/null +++ b/sqlalchemy/dialects/maxdb/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.maxdb import base, sapdb + +base.dialect = sapdb.dialect \ No newline at end of file diff --git a/sqlalchemy/dialects/maxdb/base.py b/sqlalchemy/dialects/maxdb/base.py new file mode 100644 index 0000000..2e1d6a5 --- /dev/null +++ b/sqlalchemy/dialects/maxdb/base.py @@ -0,0 +1,1058 @@ +# maxdb.py +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for the MaxDB database. + +This dialect is *not* ported to SQLAlchemy 0.6. + +This dialect is *not* tested on SQLAlchemy 0.6. + +Overview +-------- + +The ``maxdb`` dialect is **experimental** and has only been tested on 7.6.03.007 +and 7.6.00.037. Of these, **only 7.6.03.007 will work** with SQLAlchemy's ORM. +The earlier version has severe ``LEFT JOIN`` limitations and will return +incorrect results from even very simple ORM queries. + +Only the native Python DB-API is currently supported. ODBC driver support +is a future enhancement. + +Connecting +---------- + +The username is case-sensitive. If you usually connect to the +database with sqlcli and other tools in lower case, you likely need to +use upper case for DB-API. + +Implementation Notes +-------------------- + +Also check the DatabaseNotes page on the wiki for detailed information. + +With the 7.6.00.37 driver and Python 2.5, it seems that all DB-API +generated exceptions are broken and can cause Python to crash. + +For 'somecol.in_([])' to work, the IN operator's generation must be changed +to cast 'NULL' to a numeric, i.e. NUM(NULL). The DB-API doesn't accept a +bind parameter there, so that particular generation must inline the NULL value, +which depends on [ticket:807]. + +The DB-API is very picky about where bind params may be used in queries. + +Bind params for some functions (e.g. MOD) need type information supplied. +The dialect does not yet do this automatically. + +Max will occasionally throw up 'bad sql, compile again' exceptions for +perfectly valid SQL. The dialect does not currently handle these, more +research is needed. + +MaxDB 7.5 and Sap DB <= 7.4 reportedly do not support schemas. A very +slightly different version of this dialect would be required to support +those versions, and can easily be added if there is demand. Some other +required components such as an Max-aware 'old oracle style' join compiler +(thetas with (+) outer indicators) are already done and available for +integration- email the devel list if you're interested in working on +this. + +""" +import datetime, itertools, re + +from sqlalchemy import exc, schema, sql, util, processors +from sqlalchemy.sql import operators as sql_operators, expression as sql_expr +from sqlalchemy.sql import compiler, visitors +from sqlalchemy.engine import base as engine_base, default, reflection +from sqlalchemy import types as sqltypes + + +class _StringType(sqltypes.String): + _type = None + + def __init__(self, length=None, encoding=None, **kw): + super(_StringType, self).__init__(length=length, **kw) + self.encoding = encoding + + def bind_processor(self, dialect): + if self.encoding == 'unicode': + return None + else: + def process(value): + if isinstance(value, unicode): + return value.encode(dialect.encoding) + else: + return value + return process + + def result_processor(self, dialect, coltype): + #XXX: this code is probably very slow and one should try (if at all + # possible) to determine the correct code path on a per-connection + # basis (ie, here in result_processor, instead of inside the processor + # function itself) and probably also use a few generic + # processors, or possibly per query (though there is no mechanism + # for that yet). + def process(value): + while True: + if value is None: + return None + elif isinstance(value, unicode): + return value + elif isinstance(value, str): + if self.convert_unicode or dialect.convert_unicode: + return value.decode(dialect.encoding) + else: + return value + elif hasattr(value, 'read'): + # some sort of LONG, snarf and retry + value = value.read(value.remainingLength()) + continue + else: + # unexpected type, return as-is + return value + return process + + +class MaxString(_StringType): + _type = 'VARCHAR' + + def __init__(self, *a, **kw): + super(MaxString, self).__init__(*a, **kw) + + +class MaxUnicode(_StringType): + _type = 'VARCHAR' + + def __init__(self, length=None, **kw): + super(MaxUnicode, self).__init__(length=length, encoding='unicode') + + +class MaxChar(_StringType): + _type = 'CHAR' + + +class MaxText(_StringType): + _type = 'LONG' + + def __init__(self, *a, **kw): + super(MaxText, self).__init__(*a, **kw) + + def get_col_spec(self): + spec = 'LONG' + if self.encoding is not None: + spec = ' '.join((spec, self.encoding)) + elif self.convert_unicode: + spec = ' '.join((spec, 'UNICODE')) + + return spec + + +class MaxNumeric(sqltypes.Numeric): + """The FIXED (also NUMERIC, DECIMAL) data type.""" + + def __init__(self, precision=None, scale=None, **kw): + kw.setdefault('asdecimal', True) + super(MaxNumeric, self).__init__(scale=scale, precision=precision, + **kw) + + def bind_processor(self, dialect): + return None + + +class MaxTimestamp(sqltypes.DateTime): + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + ms = getattr(value, 'microsecond', 0) + return value.strftime("%Y%m%d%H%M%S" + ("%06u" % ms)) + elif dialect.datetimeformat == 'iso': + ms = getattr(value, 'microsecond', 0) + return value.strftime("%Y-%m-%d %H:%M:%S." + ("%06u" % ms)) + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect, coltype): + if dialect.datetimeformat == 'internal': + def process(value): + if value is None: + return None + else: + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[4:6], value[6:8], + value[8:10], value[10:12], value[12:14], + value[14:])]) + elif dialect.datetimeformat == 'iso': + def process(value): + if value is None: + return None + else: + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[5:7], value[8:10], + value[11:13], value[14:16], value[17:19], + value[20:])]) + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % + dialect.datetimeformat) + return process + + +class MaxDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + return value.strftime("%Y%m%d") + elif dialect.datetimeformat == 'iso': + return value.strftime("%Y-%m-%d") + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect, coltype): + if dialect.datetimeformat == 'internal': + def process(value): + if value is None: + return None + else: + return datetime.date(int(value[0:4]), int(value[4:6]), + int(value[6:8])) + elif dialect.datetimeformat == 'iso': + def process(value): + if value is None: + return None + else: + return datetime.date(int(value[0:4]), int(value[5:7]), + int(value[8:10])) + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % + dialect.datetimeformat) + return process + + +class MaxTime(sqltypes.Time): + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + return value.strftime("%H%M%S") + elif dialect.datetimeformat == 'iso': + return value.strftime("%H-%M-%S") + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect, coltype): + if dialect.datetimeformat == 'internal': + def process(value): + if value is None: + return None + else: + return datetime.time(int(value[0:4]), int(value[4:6]), + int(value[6:8])) + elif dialect.datetimeformat == 'iso': + def process(value): + if value is None: + return None + else: + return datetime.time(int(value[0:4]), int(value[5:7]), + int(value[8:10])) + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % + dialect.datetimeformat) + return process + + +class MaxBlob(sqltypes.LargeBinary): + def bind_processor(self, dialect): + return processors.to_str + + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return None + else: + return value.read(value.remainingLength()) + return process + +class MaxDBTypeCompiler(compiler.GenericTypeCompiler): + def _string_spec(self, string_spec, type_): + if type_.length is None: + spec = 'LONG' + else: + spec = '%s(%s)' % (string_spec, type_.length) + + if getattr(type_, 'encoding'): + spec = ' '.join([spec, getattr(type_, 'encoding').upper()]) + return spec + + def visit_text(self, type_): + spec = 'LONG' + if getattr(type_, 'encoding', None): + spec = ' '.join((spec, type_.encoding)) + elif type_.convert_unicode: + spec = ' '.join((spec, 'UNICODE')) + + return spec + + def visit_char(self, type_): + return self._string_spec("CHAR", type_) + + def visit_string(self, type_): + return self._string_spec("VARCHAR", type_) + + def visit_large_binary(self, type_): + return "LONG BYTE" + + def visit_numeric(self, type_): + if type_.scale and type_.precision: + return 'FIXED(%s, %s)' % (type_.precision, type_.scale) + elif type_.precision: + return 'FIXED(%s)' % type_.precision + else: + return 'INTEGER' + + def visit_BOOLEAN(self, type_): + return "BOOLEAN" + +colspecs = { + sqltypes.Numeric: MaxNumeric, + sqltypes.DateTime: MaxTimestamp, + sqltypes.Date: MaxDate, + sqltypes.Time: MaxTime, + sqltypes.String: MaxString, + sqltypes.Unicode:MaxUnicode, + sqltypes.LargeBinary: MaxBlob, + sqltypes.Text: MaxText, + sqltypes.CHAR: MaxChar, + sqltypes.TIMESTAMP: MaxTimestamp, + sqltypes.BLOB: MaxBlob, + sqltypes.Unicode: MaxUnicode, + } + +ischema_names = { + 'boolean': sqltypes.BOOLEAN, + 'char': sqltypes.CHAR, + 'character': sqltypes.CHAR, + 'date': sqltypes.DATE, + 'fixed': sqltypes.Numeric, + 'float': sqltypes.FLOAT, + 'int': sqltypes.INT, + 'integer': sqltypes.INT, + 'long binary': sqltypes.BLOB, + 'long unicode': sqltypes.Text, + 'long': sqltypes.Text, + 'long': sqltypes.Text, + 'smallint': sqltypes.SmallInteger, + 'time': sqltypes.Time, + 'timestamp': sqltypes.TIMESTAMP, + 'varchar': sqltypes.VARCHAR, + } + +# TODO: migrate this to sapdb.py +class MaxDBExecutionContext(default.DefaultExecutionContext): + def post_exec(self): + # DB-API bug: if there were any functions as values, + # then do another select and pull CURRVAL from the + # autoincrement column's implicit sequence... ugh + if self.compiled.isinsert and not self.executemany: + table = self.compiled.statement.table + index, serial_col = _autoserial_column(table) + + if serial_col and (not self.compiled._safeserial or + not(self._last_inserted_ids) or + self._last_inserted_ids[index] in (None, 0)): + if table.schema: + sql = "SELECT %s.CURRVAL FROM DUAL" % ( + self.compiled.preparer.format_table(table)) + else: + sql = "SELECT CURRENT_SCHEMA.%s.CURRVAL FROM DUAL" % ( + self.compiled.preparer.format_table(table)) + + rs = self.cursor.execute(sql) + id = rs.fetchone()[0] + + if not self._last_inserted_ids: + # This shouldn't ever be > 1? Right? + self._last_inserted_ids = \ + [None] * len(table.primary_key.columns) + self._last_inserted_ids[index] = id + + super(MaxDBExecutionContext, self).post_exec() + + def get_result_proxy(self): + if self.cursor.description is not None: + for column in self.cursor.description: + if column[1] in ('Long Binary', 'Long', 'Long Unicode'): + return MaxDBResultProxy(self) + return engine_base.ResultProxy(self) + + @property + def rowcount(self): + if hasattr(self, '_rowcount'): + return self._rowcount + else: + return self.cursor.rowcount + + def fire_sequence(self, seq): + if seq.optional: + return None + return self._execute_scalar("SELECT %s.NEXTVAL FROM DUAL" % ( + self.dialect.identifier_preparer.format_sequence(seq))) + +class MaxDBCachedColumnRow(engine_base.RowProxy): + """A RowProxy that only runs result_processors once per column.""" + + def __init__(self, parent, row): + super(MaxDBCachedColumnRow, self).__init__(parent, row) + self.columns = {} + self._row = row + self._parent = parent + + def _get_col(self, key): + if key not in self.columns: + self.columns[key] = self._parent._get_col(self._row, key) + return self.columns[key] + + def __iter__(self): + for i in xrange(len(self._row)): + yield self._get_col(i) + + def __repr__(self): + return repr(list(self)) + + def __eq__(self, other): + return ((other is self) or + (other == tuple([self._get_col(key) + for key in xrange(len(self._row))]))) + def __getitem__(self, key): + if isinstance(key, slice): + indices = key.indices(len(self._row)) + return tuple([self._get_col(i) for i in xrange(*indices)]) + else: + return self._get_col(key) + + def __getattr__(self, name): + try: + return self._get_col(name) + except KeyError: + raise AttributeError(name) + + +class MaxDBResultProxy(engine_base.ResultProxy): + _process_row = MaxDBCachedColumnRow + +class MaxDBCompiler(compiler.SQLCompiler): + + function_conversion = { + 'CURRENT_DATE': 'DATE', + 'CURRENT_TIME': 'TIME', + 'CURRENT_TIMESTAMP': 'TIMESTAMP', + } + + # These functions must be written without parens when called with no + # parameters. e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL' + bare_functions = set([ + 'CURRENT_SCHEMA', 'DATE', 'FALSE', 'SYSDBA', 'TIME', 'TIMESTAMP', + 'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP', + 'UTCDATE', 'UTCDIFF']) + + def visit_mod(self, binary, **kw): + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def default_from(self): + return ' FROM DUAL' + + def for_update_clause(self, select): + clause = select.for_update + if clause is True: + return " WITH LOCK EXCLUSIVE" + elif clause is None: + return "" + elif clause == "read": + return " WITH LOCK" + elif clause == "ignore": + return " WITH LOCK (IGNORE) EXCLUSIVE" + elif clause == "nowait": + return " WITH LOCK (NOWAIT) EXCLUSIVE" + elif isinstance(clause, basestring): + return " WITH LOCK %s" % clause.upper() + elif not clause: + return "" + else: + return " WITH LOCK EXCLUSIVE" + + def function_argspec(self, fn, **kw): + if fn.name.upper() in self.bare_functions: + return "" + elif len(fn.clauses) > 0: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) + else: + return "" + + def visit_function(self, fn, **kw): + transform = self.function_conversion.get(fn.name.upper(), None) + if transform: + fn = fn._clone() + fn.name = transform + return super(MaxDBCompiler, self).visit_function(fn, **kw) + + def visit_cast(self, cast, **kwargs): + # MaxDB only supports casts * to NUMERIC, * to VARCHAR or + # date/time to VARCHAR. Casts of LONGs will fail. + if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)): + return "NUM(%s)" % self.process(cast.clause) + elif isinstance(cast.type, sqltypes.String): + return "CHR(%s)" % self.process(cast.clause) + else: + return self.process(cast.clause) + + def visit_sequence(self, sequence): + if sequence.optional: + return None + else: + return (self.dialect.identifier_preparer.format_sequence(sequence) + + ".NEXTVAL") + + class ColumnSnagger(visitors.ClauseVisitor): + def __init__(self): + self.count = 0 + self.column = None + def visit_column(self, column): + self.column = column + self.count += 1 + + def _find_labeled_columns(self, columns, use_labels=False): + labels = {} + for column in columns: + if isinstance(column, basestring): + continue + snagger = self.ColumnSnagger() + snagger.traverse(column) + if snagger.count == 1: + if isinstance(column, sql_expr._Label): + labels[unicode(snagger.column)] = column.name + elif use_labels: + labels[unicode(snagger.column)] = column._label + + return labels + + def order_by_clause(self, select, **kw): + order_by = self.process(select._order_by_clause, **kw) + + # ORDER BY clauses in DISTINCT queries must reference aliased + # inner columns by alias name, not true column name. + if order_by and getattr(select, '_distinct', False): + labels = self._find_labeled_columns(select.inner_columns, + select.use_labels) + if labels: + for needs_alias in labels.keys(): + r = re.compile(r'(^| )(%s)(,| |$)' % + re.escape(needs_alias)) + order_by = r.sub((r'\1%s\3' % labels[needs_alias]), + order_by) + + # No ORDER BY in subqueries. + if order_by: + if self.is_subquery(): + # It's safe to simply drop the ORDER BY if there is no + # LIMIT. Right? Other dialects seem to get away with + # dropping order. + if select._limit: + raise exc.InvalidRequestError( + "MaxDB does not support ORDER BY in subqueries") + else: + return "" + return " ORDER BY " + order_by + else: + return "" + + def get_select_precolumns(self, select): + # Convert a subquery's LIMIT to TOP + sql = select._distinct and 'DISTINCT ' or '' + if self.is_subquery() and select._limit: + if select._offset: + raise exc.InvalidRequestError( + 'MaxDB does not support LIMIT with an offset.') + sql += 'TOP %s ' % select._limit + return sql + + def limit_clause(self, select): + # The docs say offsets are supported with LIMIT. But they're not. + # TODO: maybe emulate by adding a ROWNO/ROWNUM predicate? + if self.is_subquery(): + # sub queries need TOP + return '' + elif select._offset: + raise exc.InvalidRequestError( + 'MaxDB does not support LIMIT with an offset.') + else: + return ' \n LIMIT %s' % (select._limit,) + + def visit_insert(self, insert): + self.isinsert = True + self._safeserial = True + + colparams = self._get_colparams(insert) + for value in (insert.parameters or {}).itervalues(): + if isinstance(value, sql_expr.Function): + self._safeserial = False + break + + return ''.join(('INSERT INTO ', + self.preparer.format_table(insert.table), + ' (', + ', '.join([self.preparer.format_column(c[0]) + for c in colparams]), + ') VALUES (', + ', '.join([c[1] for c in colparams]), + ')')) + + +class MaxDBIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = set([ + 'abs', 'absolute', 'acos', 'adddate', 'addtime', 'all', 'alpha', + 'alter', 'any', 'ascii', 'asin', 'atan', 'atan2', 'avg', 'binary', + 'bit', 'boolean', 'byte', 'case', 'ceil', 'ceiling', 'char', + 'character', 'check', 'chr', 'column', 'concat', 'constraint', 'cos', + 'cosh', 'cot', 'count', 'cross', 'curdate', 'current', 'curtime', + 'database', 'date', 'datediff', 'day', 'dayname', 'dayofmonth', + 'dayofweek', 'dayofyear', 'dec', 'decimal', 'decode', 'default', + 'degrees', 'delete', 'digits', 'distinct', 'double', 'except', + 'exists', 'exp', 'expand', 'first', 'fixed', 'float', 'floor', 'for', + 'from', 'full', 'get_objectname', 'get_schema', 'graphic', 'greatest', + 'group', 'having', 'hex', 'hextoraw', 'hour', 'ifnull', 'ignore', + 'index', 'initcap', 'inner', 'insert', 'int', 'integer', 'internal', + 'intersect', 'into', 'join', 'key', 'last', 'lcase', 'least', 'left', + 'length', 'lfill', 'list', 'ln', 'locate', 'log', 'log10', 'long', + 'longfile', 'lower', 'lpad', 'ltrim', 'makedate', 'maketime', + 'mapchar', 'max', 'mbcs', 'microsecond', 'min', 'minute', 'mod', + 'month', 'monthname', 'natural', 'nchar', 'next', 'no', 'noround', + 'not', 'now', 'null', 'num', 'numeric', 'object', 'of', 'on', + 'order', 'packed', 'pi', 'power', 'prev', 'primary', 'radians', + 'real', 'reject', 'relative', 'replace', 'rfill', 'right', 'round', + 'rowid', 'rowno', 'rpad', 'rtrim', 'second', 'select', 'selupd', + 'serial', 'set', 'show', 'sign', 'sin', 'sinh', 'smallint', 'some', + 'soundex', 'space', 'sqrt', 'stamp', 'statistics', 'stddev', + 'subdate', 'substr', 'substring', 'subtime', 'sum', 'sysdba', + 'table', 'tan', 'tanh', 'time', 'timediff', 'timestamp', 'timezone', + 'to', 'toidentifier', 'transaction', 'translate', 'trim', 'trunc', + 'truncate', 'ucase', 'uid', 'unicode', 'union', 'update', 'upper', + 'user', 'usergroup', 'using', 'utcdate', 'utcdiff', 'value', 'values', + 'varchar', 'vargraphic', 'variance', 'week', 'weekofyear', 'when', + 'where', 'with', 'year', 'zoned' ]) + + def _normalize_name(self, name): + if name is None: + return None + if name.isupper(): + lc_name = name.lower() + if not self._requires_quotes(lc_name): + return lc_name + return name + + def _denormalize_name(self, name): + if name is None: + return None + elif (name.islower() and + not self._requires_quotes(name)): + return name.upper() + else: + return name + + def _maybe_quote_identifier(self, name): + if self._requires_quotes(name): + return self.quote_identifier(name) + else: + return name + + +class MaxDBDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kw): + colspec = [self.preparer.format_column(column), + self.dialect.type_compiler.process(column.type)] + + if not column.nullable: + colspec.append('NOT NULL') + + default = column.default + default_str = self.get_column_default_string(column) + + # No DDL default for columns specified with non-optional sequence- + # this defaulting behavior is entirely client-side. (And as a + # consequence, non-reflectable.) + if (default and isinstance(default, schema.Sequence) and + not default.optional): + pass + # Regular default + elif default_str is not None: + colspec.append('DEFAULT %s' % default_str) + # Assign DEFAULT SERIAL heuristically + elif column.primary_key and column.autoincrement: + # For SERIAL on a non-primary key member, use + # DefaultClause(text('SERIAL')) + try: + first = [c for c in column.table.primary_key.columns + if (c.autoincrement and + (isinstance(c.type, sqltypes.Integer) or + (isinstance(c.type, MaxNumeric) and + c.type.precision)) and + not c.foreign_keys)].pop(0) + if column is first: + colspec.append('DEFAULT SERIAL') + except IndexError: + pass + return ' '.join(colspec) + + def get_column_default_string(self, column): + if isinstance(column.server_default, schema.DefaultClause): + if isinstance(column.default.arg, basestring): + if isinstance(column.type, sqltypes.Integer): + return str(column.default.arg) + else: + return "'%s'" % column.default.arg + else: + return unicode(self._compile(column.default.arg, None)) + else: + return None + + def visit_create_sequence(self, create): + """Creates a SEQUENCE. + + TODO: move to module doc? + + start + With an integer value, set the START WITH option. + + increment + An integer value to increment by. Default is the database default. + + maxdb_minvalue + maxdb_maxvalue + With an integer value, sets the corresponding sequence option. + + maxdb_no_minvalue + maxdb_no_maxvalue + Defaults to False. If true, sets the corresponding sequence option. + + maxdb_cycle + Defaults to False. If true, sets the CYCLE option. + + maxdb_cache + With an integer value, sets the CACHE option. + + maxdb_no_cache + Defaults to False. If true, sets NOCACHE. + """ + sequence = create.element + + if (not sequence.optional and + (not self.checkfirst or + not self.dialect.has_sequence(self.connection, sequence.name))): + + ddl = ['CREATE SEQUENCE', + self.preparer.format_sequence(sequence)] + + sequence.increment = 1 + + if sequence.increment is not None: + ddl.extend(('INCREMENT BY', str(sequence.increment))) + + if sequence.start is not None: + ddl.extend(('START WITH', str(sequence.start))) + + opts = dict([(pair[0][6:].lower(), pair[1]) + for pair in sequence.kwargs.items() + if pair[0].startswith('maxdb_')]) + + if 'maxvalue' in opts: + ddl.extend(('MAXVALUE', str(opts['maxvalue']))) + elif opts.get('no_maxvalue', False): + ddl.append('NOMAXVALUE') + if 'minvalue' in opts: + ddl.extend(('MINVALUE', str(opts['minvalue']))) + elif opts.get('no_minvalue', False): + ddl.append('NOMINVALUE') + + if opts.get('cycle', False): + ddl.append('CYCLE') + + if 'cache' in opts: + ddl.extend(('CACHE', str(opts['cache']))) + elif opts.get('no_cache', False): + ddl.append('NOCACHE') + + return ' '.join(ddl) + + +class MaxDBDialect(default.DefaultDialect): + name = 'maxdb' + supports_alter = True + supports_unicode_statements = True + max_identifier_length = 32 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + preparer = MaxDBIdentifierPreparer + statement_compiler = MaxDBCompiler + ddl_compiler = MaxDBDDLCompiler + execution_ctx_cls = MaxDBExecutionContext + + ported_sqla_06 = False + + colspecs = colspecs + ischema_names = ischema_names + + # MaxDB-specific + datetimeformat = 'internal' + + def __init__(self, _raise_known_sql_errors=False, **kw): + super(MaxDBDialect, self).__init__(**kw) + self._raise_known = _raise_known_sql_errors + + if self.dbapi is None: + self.dbapi_type_map = {} + else: + self.dbapi_type_map = { + 'Long Binary': MaxBlob(), + 'Long byte_t': MaxBlob(), + 'Long Unicode': MaxText(), + 'Timestamp': MaxTimestamp(), + 'Date': MaxDate(), + 'Time': MaxTime(), + datetime.datetime: MaxTimestamp(), + datetime.date: MaxDate(), + datetime.time: MaxTime(), + } + + def do_execute(self, cursor, statement, parameters, context=None): + res = cursor.execute(statement, parameters) + if isinstance(res, int) and context is not None: + context._rowcount = res + + def do_release_savepoint(self, connection, name): + # Does MaxDB truly support RELEASE SAVEPOINT ? All my attempts + # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS + # BEGIN SQLSTATE: I7065" + # Note that ROLLBACK TO works fine. In theory, a RELEASE should + # just free up some transactional resources early, before the overall + # COMMIT/ROLLBACK so omitting it should be relatively ok. + pass + + def _get_default_schema_name(self, connection): + return self.identifier_preparer._normalize_name( + connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar()) + + def has_table(self, connection, table_name, schema=None): + denormalize = self.identifier_preparer._denormalize_name + bind = [denormalize(table_name)] + if schema is None: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME=? AND" + " TABLES.SCHEMANAME=CURRENT_SCHEMA ") + else: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME = ? AND" + " TABLES.SCHEMANAME=? ") + bind.append(denormalize(schema)) + + rp = connection.execute(sql, bind) + return bool(rp.first()) + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is None: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=CURRENT_SCHEMA ") + rs = connection.execute(sql) + else: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=? ") + matchname = self.identifier_preparer._denormalize_name(schema) + rs = connection.execute(sql, matchname) + normalize = self.identifier_preparer._normalize_name + return [normalize(row[0]) for row in rs] + + def reflecttable(self, connection, table, include_columns): + denormalize = self.identifier_preparer._denormalize_name + normalize = self.identifier_preparer._normalize_name + + st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, ' + ' NULLABLE, "DEFAULT", DEFAULTFUNCTION ' + 'FROM COLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY POS') + + fk = ('SELECT COLUMNNAME, FKEYNAME, ' + ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, ' + ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA ' + ' THEN 1 ELSE 0 END) AS in_schema ' + 'FROM FOREIGNKEYCOLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY FKEYNAME ') + + params = [denormalize(table.name)] + if not table.schema: + st = st % 'CURRENT_SCHEMA' + fk = fk % 'CURRENT_SCHEMA' + else: + st = st % '?' + fk = fk % '?' + params.append(denormalize(table.schema)) + + rows = connection.execute(st, params).fetchall() + if not rows: + raise exc.NoSuchTableError(table.fullname) + + include_columns = set(include_columns or []) + + for row in rows: + (name, mode, col_type, encoding, length, scale, + nullable, constant_def, func_def) = row + + name = normalize(name) + + if include_columns and name not in include_columns: + continue + + type_args, type_kw = [], {} + if col_type == 'FIXED': + type_args = length, scale + # Convert FIXED(10) DEFAULT SERIAL to our Integer + if (scale == 0 and + func_def is not None and func_def.startswith('SERIAL')): + col_type = 'INTEGER' + type_args = length, + elif col_type in 'FLOAT': + type_args = length, + elif col_type in ('CHAR', 'VARCHAR'): + type_args = length, + type_kw['encoding'] = encoding + elif col_type == 'LONG': + type_kw['encoding'] = encoding + + try: + type_cls = ischema_names[col_type.lower()] + type_instance = type_cls(*type_args, **type_kw) + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (col_type, name)) + type_instance = sqltypes.NullType + + col_kw = {'autoincrement': False} + col_kw['nullable'] = (nullable == 'YES') + col_kw['primary_key'] = (mode == 'KEY') + + if func_def is not None: + if func_def.startswith('SERIAL'): + if col_kw['primary_key']: + # No special default- let the standard autoincrement + # support handle SERIAL pk columns. + col_kw['autoincrement'] = True + else: + # strip current numbering + col_kw['server_default'] = schema.DefaultClause( + sql.text('SERIAL')) + col_kw['autoincrement'] = True + else: + col_kw['server_default'] = schema.DefaultClause( + sql.text(func_def)) + elif constant_def is not None: + col_kw['server_default'] = schema.DefaultClause(sql.text( + "'%s'" % constant_def.replace("'", "''"))) + + table.append_column(schema.Column(name, type_instance, **col_kw)) + + fk_sets = itertools.groupby(connection.execute(fk, params), + lambda row: row.FKEYNAME) + for fkeyname, fkey in fk_sets: + fkey = list(fkey) + if include_columns: + key_cols = set([r.COLUMNNAME for r in fkey]) + if key_cols != include_columns: + continue + + columns, referants = [], [] + quote = self.identifier_preparer._maybe_quote_identifier + + for row in fkey: + columns.append(normalize(row.COLUMNNAME)) + if table.schema or not row.in_schema: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFSCHEMANAME', 'REFTABLENAME', + 'REFCOLUMNNAME')])) + else: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFTABLENAME', 'REFCOLUMNNAME')])) + + constraint_kw = {'name': fkeyname.lower()} + if fkey[0].RULE is not None: + rule = fkey[0].RULE + if rule.startswith('DELETE '): + rule = rule[7:] + constraint_kw['ondelete'] = rule + + table_kw = {} + if table.schema or not row.in_schema: + table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME) + + ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME), + table_kw.get('schema')) + if ref_key not in table.metadata.tables: + schema.Table(normalize(fkey[0].REFTABLENAME), + table.metadata, + autoload=True, autoload_with=connection, + **table_kw) + + constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True, + **constraint_kw) + table.append_constraint(constraint) + + def has_sequence(self, connection, name): + # [ticket:726] makes this schema-aware. + denormalize = self.identifier_preparer._denormalize_name + sql = ("SELECT sequence_name FROM SEQUENCES " + "WHERE SEQUENCE_NAME=? ") + + rp = connection.execute(sql, denormalize(name)) + return bool(rp.first()) + + +def _autoserial_column(table): + """Finds the effective DEFAULT SERIAL column of a Table, if any.""" + + for index, col in enumerate(table.primary_key.columns): + if (isinstance(col.type, (sqltypes.Integer, sqltypes.Numeric)) and + col.autoincrement): + if isinstance(col.default, schema.Sequence): + if col.default.optional: + return index, col + elif (col.default is None or + (not isinstance(col.server_default, schema.DefaultClause))): + return index, col + + return None, None + diff --git a/sqlalchemy/dialects/maxdb/sapdb.py b/sqlalchemy/dialects/maxdb/sapdb.py new file mode 100644 index 0000000..f363239 --- /dev/null +++ b/sqlalchemy/dialects/maxdb/sapdb.py @@ -0,0 +1,17 @@ +from sqlalchemy.dialects.maxdb.base import MaxDBDialect + +class MaxDBDialect_sapdb(MaxDBDialect): + driver = 'sapdb' + + @classmethod + def dbapi(cls): + from sapdb import dbapi as _dbapi + return _dbapi + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + return [], opts + + +dialect = MaxDBDialect_sapdb \ No newline at end of file diff --git a/sqlalchemy/dialects/mssql/__init__.py b/sqlalchemy/dialects/mssql/__init__.py new file mode 100644 index 0000000..65ae3e3 --- /dev/null +++ b/sqlalchemy/dialects/mssql/__init__.py @@ -0,0 +1,19 @@ +from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, pymssql, zxjdbc, mxodbc + +base.dialect = pyodbc.dialect + +from sqlalchemy.dialects.mssql.base import \ + INTEGER, BIGINT, SMALLINT, TINYINT, VARCHAR, NVARCHAR, CHAR, \ + NCHAR, TEXT, NTEXT, DECIMAL, NUMERIC, FLOAT, DATETIME,\ + DATETIME2, DATETIMEOFFSET, DATE, TIME, SMALLDATETIME, \ + BINARY, VARBINARY, BIT, REAL, IMAGE, TIMESTAMP,\ + MONEY, SMALLMONEY, UNIQUEIDENTIFIER, SQL_VARIANT, dialect + + +__all__ = ( + 'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR', + 'NCHAR', 'TEXT', 'NTEXT', 'DECIMAL', 'NUMERIC', 'FLOAT', 'DATETIME', + 'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME', + 'BINARY', 'VARBINARY', 'BIT', 'REAL', 'IMAGE', 'TIMESTAMP', + 'MONEY', 'SMALLMONEY', 'UNIQUEIDENTIFIER', 'SQL_VARIANT', 'dialect' +) \ No newline at end of file diff --git a/sqlalchemy/dialects/mssql/adodbapi.py b/sqlalchemy/dialects/mssql/adodbapi.py new file mode 100644 index 0000000..502a02a --- /dev/null +++ b/sqlalchemy/dialects/mssql/adodbapi.py @@ -0,0 +1,59 @@ +""" +The adodbapi dialect is not implemented for 0.6 at this time. + +""" +from sqlalchemy import types as sqltypes, util +from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect +import sys + +class MSDateTime_adodbapi(MSDateTime): + def result_processor(self, dialect, coltype): + def process(value): + # adodbapi will return datetimes with empty time values as datetime.date() objects. + # Promote them back to full datetime.datetime() + if type(value) is datetime.date: + return datetime.datetime(value.year, value.month, value.day) + return value + return process + + +class MSDialect_adodbapi(MSDialect): + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = True + driver = 'adodbapi' + + @classmethod + def import_dbapi(cls): + import adodbapi as module + return module + + colspecs = util.update_copy( + MSDialect.colspecs, + { + sqltypes.DateTime:MSDateTime_adodbapi + } + ) + + def create_connect_args(self, url): + keys = url.query + + connectors = ["Provider=SQLOLEDB"] + if 'port' in keys: + connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port"))) + else: + connectors.append ("Data Source=%s" % keys.get("host")) + connectors.append ("Initial Catalog=%s" % keys.get("database")) + user = keys.get("user") + if user: + connectors.append("User Id=%s" % user) + connectors.append("Password=%s" % keys.get("password", "")) + else: + connectors.append("Integrated Security=SSPI") + return [[";".join (connectors)], {}] + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e) + +dialect = MSDialect_adodbapi diff --git a/sqlalchemy/dialects/mssql/base.py b/sqlalchemy/dialects/mssql/base.py new file mode 100644 index 0000000..066ab8d --- /dev/null +++ b/sqlalchemy/dialects/mssql/base.py @@ -0,0 +1,1297 @@ +# mssql.py + +"""Support for the Microsoft SQL Server database. + +Connecting +---------- + +See the individual driver sections below for details on connecting. + +Auto Increment Behavior +----------------------- + +``IDENTITY`` columns are supported by using SQLAlchemy +``schema.Sequence()`` objects. In other words:: + + Table('test', mss_engine, + Column('id', Integer, + Sequence('blah',100,10), primary_key=True), + Column('name', String(20)) + ).create() + +would yield:: + + CREATE TABLE test ( + id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, + name VARCHAR(20) NULL, + ) + +Note that the ``start`` and ``increment`` values for sequences are +optional and will default to 1,1. + +Implicit ``autoincrement`` behavior works the same in MSSQL as it +does in other dialects and results in an ``IDENTITY`` column. + +* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for + ``INSERT`` s) + +* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on + ``INSERT`` + +Collation Support +----------------- + +MSSQL specific string types support a collation parameter that +creates a column-level specific collation for the column. The +collation parameter accepts a Windows Collation Name or a SQL +Collation Name. Supported types are MSChar, MSNChar, MSString, +MSNVarchar, MSText, and MSNText. For example:: + + Column('login', String(32, collation='Latin1_General_CI_AS')) + +will yield:: + + login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL + +LIMIT/OFFSET Support +-------------------- + +MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is +supported directly through the ``TOP`` Transact SQL keyword:: + + select.limit + +will yield:: + + SELECT TOP n + +If using SQL Server 2005 or above, LIMIT with OFFSET +support is available through the ``ROW_NUMBER OVER`` construct. +For versions below 2005, LIMIT with OFFSET usage will fail. + +Nullability +----------- +MSSQL has support for three levels of column nullability. The default +nullability allows nulls and is explicit in the CREATE TABLE +construct:: + + name VARCHAR(20) NULL + +If ``nullable=None`` is specified then no specification is made. In +other words the database's configured default is used. This will +render:: + + name VARCHAR(20) + +If ``nullable`` is ``True`` or ``False`` then the column will be +``NULL` or ``NOT NULL`` respectively. + +Date / Time Handling +-------------------- +DATE and TIME are supported. Bind parameters are converted +to datetime.datetime() objects as required by most MSSQL drivers, +and results are processed from strings if needed. +The DATE and TIME types are not available for MSSQL 2005 and +previous - if a server version below 2008 is detected, DDL +for these types will be issued as DATETIME. + +Compatibility Levels +-------------------- +MSSQL supports the notion of setting compatibility levels at the +database level. This allows, for instance, to run a database that +is compatibile with SQL2000 while running on a SQL2005 database +server. ``server_version_info`` will always retrun the database +server version information (in this case SQL2005) and not the +compatibiility level information. Because of this, if running under +a backwards compatibility mode SQAlchemy may attempt to use T-SQL +statements that are unable to be parsed by the database server. + +Known Issues +------------ + +* No support for more than one ``IDENTITY`` column per table + +""" +import datetime, decimal, inspect, operator, sys, re +import itertools + +from sqlalchemy import sql, schema as sa_schema, exc, util +from sqlalchemy.sql import select, compiler, expression, \ + operators as sql_operators, \ + functions as sql_functions, util as sql_util +from sqlalchemy.engine import default, base, reflection +from sqlalchemy import types as sqltypes +from sqlalchemy import processors +from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ + FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\ + VARBINARY, BLOB + +from sqlalchemy.dialects.mssql import information_schema as ischema + +MS_2008_VERSION = (10,) +MS_2005_VERSION = (9,) +MS_2000_VERSION = (8,) + +RESERVED_WORDS = set( + ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', + 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', + 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', + 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', + 'containstable', 'continue', 'convert', 'create', 'cross', 'current', + 'current_date', 'current_time', 'current_timestamp', 'current_user', + 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', + 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', + 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', + 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', + 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', + 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', + 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', + 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', + 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', + 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', + 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', + 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', + 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', + 'reconfigure', 'references', 'replication', 'restore', 'restrict', + 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', + 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', + 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', + 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', + 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', + 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', + 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', + 'writetext', + ]) + + +class REAL(sqltypes.Float): + """A type for ``real`` numbers.""" + + __visit_name__ = 'REAL' + + def __init__(self): + super(REAL, self).__init__(precision=24) + +class TINYINT(sqltypes.Integer): + __visit_name__ = 'TINYINT' + + +# MSSQL DATE/TIME types have varied behavior, sometimes returning +# strings. MSDate/TIME check for everything, and always +# filter bind parameters into datetime objects (required by pyodbc, +# not sure about other dialects). + +class _MSDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + def result_processor(self, dialect, coltype): + def process(value): + if isinstance(value, datetime.datetime): + return value.date() + elif isinstance(value, basestring): + return datetime.date(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process + +class TIME(sqltypes.TIME): + def __init__(self, precision=None, **kwargs): + self.precision = precision + super(TIME, self).__init__() + + __zero_date = datetime.date(1900, 1, 1) + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine(self.__zero_date, value.time()) + elif isinstance(value, datetime.time): + value = datetime.datetime.combine(self.__zero_date, value) + return value + return process + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") + def result_processor(self, dialect, coltype): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, basestring): + return datetime.time(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process + + +class _DateTimeBase(object): + def bind_processor(self, dialect): + def process(value): + # TODO: why ? + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + +class _MSDateTime(_DateTimeBase, sqltypes.DateTime): + pass + +class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'SMALLDATETIME' + +class DATETIME2(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'DATETIME2' + + def __init__(self, precision=None, **kwargs): + self.precision = precision + + +# TODO: is this not an Interval ? +class DATETIMEOFFSET(sqltypes.TypeEngine): + __visit_name__ = 'DATETIMEOFFSET' + + def __init__(self, precision=None, **kwargs): + self.precision = precision + +class _StringType(object): + """Base for MSSQL string types.""" + + def __init__(self, collation=None): + self.collation = collation + +class TEXT(_StringType, sqltypes.TEXT): + """MSSQL TEXT type, for variable-length text up to 2^31 characters.""" + + def __init__(self, *args, **kw): + """Construct a TEXT. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.Text.__init__(self, *args, **kw) + +class NTEXT(_StringType, sqltypes.UnicodeText): + """MSSQL NTEXT type, for variable-length unicode text up to 2^30 + characters.""" + + __visit_name__ = 'NTEXT' + + def __init__(self, *args, **kwargs): + """Construct a NTEXT. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kwargs.pop('collation', None) + _StringType.__init__(self, collation) + length = kwargs.pop('length', None) + sqltypes.UnicodeText.__init__(self, length, **kwargs) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum + of 8,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a VARCHAR. + + :param length: Optinal, maximum data length, in characters. + + :param convert_unicode: defaults to False. If True, convert + ``unicode`` data sent to the database to a ``str`` + bytestring, and convert bytestrings coming back from the + database into ``unicode``. + + Bytestrings are encoded using the dialect's + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which + defaults to `utf-8`. + + If False, may be overridden by + :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.VARCHAR.__init__(self, *args, **kw) + +class NVARCHAR(_StringType, sqltypes.NVARCHAR): + """MSSQL NVARCHAR type. + + For variable-length unicode character data up to 4,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a NVARCHAR. + + :param length: Optional, Maximum data length, in characters. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NVARCHAR.__init__(self, *args, **kw) + +class CHAR(_StringType, sqltypes.CHAR): + """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum + of 8,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a CHAR. + + :param length: Optinal, maximum data length, in characters. + + :param convert_unicode: defaults to False. If True, convert + ``unicode`` data sent to the database to a ``str`` + bytestring, and convert bytestrings coming back from the + database into ``unicode``. + + Bytestrings are encoded using the dialect's + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which + defaults to `utf-8`. + + If False, may be overridden by + :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.CHAR.__init__(self, *args, **kw) + +class NCHAR(_StringType, sqltypes.NCHAR): + """MSSQL NCHAR type. + + For fixed-length unicode character data up to 4,000 characters.""" + + def __init__(self, *args, **kw): + """Construct an NCHAR. + + :param length: Optional, Maximum data length, in characters. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NCHAR.__init__(self, *args, **kw) + +class IMAGE(sqltypes.LargeBinary): + __visit_name__ = 'IMAGE' + +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = 'MONEY' + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = 'SMALLMONEY' + +class UNIQUEIDENTIFIER(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + +class SQL_VARIANT(sqltypes.TypeEngine): + __visit_name__ = 'SQL_VARIANT' + +# old names. +MSDateTime = _MSDateTime +MSDate = _MSDate +MSReal = REAL +MSTinyInteger = TINYINT +MSTime = TIME +MSSmallDateTime = SMALLDATETIME +MSDateTime2 = DATETIME2 +MSDateTimeOffset = DATETIMEOFFSET +MSText = TEXT +MSNText = NTEXT +MSString = VARCHAR +MSNVarchar = NVARCHAR +MSChar = CHAR +MSNChar = NCHAR +MSBinary = BINARY +MSVarBinary = VARBINARY +MSImage = IMAGE +MSBit = BIT +MSMoney = MONEY +MSSmallMoney = SMALLMONEY +MSUniqueIdentifier = UNIQUEIDENTIFIER +MSVariant = SQL_VARIANT + +ischema_names = { + 'int' : INTEGER, + 'bigint': BIGINT, + 'smallint' : SMALLINT, + 'tinyint' : TINYINT, + 'varchar' : VARCHAR, + 'nvarchar' : NVARCHAR, + 'char' : CHAR, + 'nchar' : NCHAR, + 'text' : TEXT, + 'ntext' : NTEXT, + 'decimal' : DECIMAL, + 'numeric' : NUMERIC, + 'float' : FLOAT, + 'datetime' : DATETIME, + 'datetime2' : DATETIME2, + 'datetimeoffset' : DATETIMEOFFSET, + 'date': DATE, + 'time': TIME, + 'smalldatetime' : SMALLDATETIME, + 'binary' : BINARY, + 'varbinary' : VARBINARY, + 'bit': BIT, + 'real' : REAL, + 'image' : IMAGE, + 'timestamp': TIMESTAMP, + 'money': MONEY, + 'smallmoney': SMALLMONEY, + 'uniqueidentifier': UNIQUEIDENTIFIER, + 'sql_variant': SQL_VARIANT, +} + + +class MSTypeCompiler(compiler.GenericTypeCompiler): + def _extend(self, spec, type_): + """Extend a string-type declaration with standard SQL + COLLATE annotations. + + """ + + if getattr(type_, 'collation', None): + collation = 'COLLATE %s' % type_.collation + else: + collation = None + + if type_.length: + spec = spec + "(%d)" % type_.length + + return ' '.join([c for c in (spec, collation) + if c is not None]) + + def visit_FLOAT(self, type_): + precision = getattr(type_, 'precision', None) + if precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': precision} + + def visit_REAL(self, type_): + return "REAL" + + def visit_TINYINT(self, type_): + return "TINYINT" + + def visit_DATETIMEOFFSET(self, type_): + if type_.precision: + return "DATETIMEOFFSET(%s)" % type_.precision + else: + return "DATETIMEOFFSET" + + def visit_TIME(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "TIME(%s)" % precision + else: + return "TIME" + + def visit_DATETIME2(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "DATETIME2(%s)" % precision + else: + return "DATETIME2" + + def visit_SMALLDATETIME(self, type_): + return "SMALLDATETIME" + + def visit_unicode(self, type_): + return self.visit_NVARCHAR(type_) + + def visit_unicode_text(self, type_): + return self.visit_NTEXT(type_) + + def visit_NTEXT(self, type_): + return self._extend("NTEXT", type_) + + def visit_TEXT(self, type_): + return self._extend("TEXT", type_) + + def visit_VARCHAR(self, type_): + return self._extend("VARCHAR", type_) + + def visit_CHAR(self, type_): + return self._extend("CHAR", type_) + + def visit_NCHAR(self, type_): + return self._extend("NCHAR", type_) + + def visit_NVARCHAR(self, type_): + return self._extend("NVARCHAR", type_) + + def visit_date(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_DATE(type_) + + def visit_time(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_TIME(type_) + + def visit_large_binary(self, type_): + return self.visit_IMAGE(type_) + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return 'SMALLMONEY' + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + + def visit_SQL_VARIANT(self, type_): + return 'SQL_VARIANT' + +class MSExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + _select_lastrowid = False + _result_proxy = None + _lastrowid = None + + def pre_exec(self): + """Activate IDENTITY_INSERT if needed.""" + + if self.isinsert: + tbl = self.compiled.statement.table + seq_column = tbl._autoincrement_column + insert_has_sequence = seq_column is not None + + if insert_has_sequence: + self._enable_identity_insert = seq_column.key in self.compiled_parameters[0] + else: + self._enable_identity_insert = False + + self._select_lastrowid = insert_has_sequence and \ + not self.compiled.returning and \ + not self._enable_identity_insert and \ + not self.executemany + + if self._enable_identity_insert: + self.cursor.execute("SET IDENTITY_INSERT %s ON" % + self.dialect.identifier_preparer.format_table(tbl)) + + def post_exec(self): + """Disable IDENTITY_INSERT if enabled.""" + + if self._select_lastrowid: + if self.dialect.use_scope_identity: + self.cursor.execute("SELECT scope_identity() AS lastrowid", ()) + else: + self.cursor.execute("SELECT @@identity AS lastrowid", ()) + # fetchall() ensures the cursor is consumed without closing it + row = self.cursor.fetchall()[0] + self._lastrowid = int(row[0]) + + if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning: + self._result_proxy = base.FullyBufferedResultProxy(self) + + if self._enable_identity_insert: + self.cursor.execute( + "SET IDENTITY_INSERT %s OFF" % + self.dialect.identifier_preparer. + format_table(self.compiled.statement.table) + ) + + def get_lastrowid(self): + return self._lastrowid + + def handle_dbapi_exception(self, e): + if self._enable_identity_insert: + try: + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % + self.dialect.\ + identifier_preparer.\ + format_table(self.compiled.statement.table) + ) + except: + pass + + def get_result_proxy(self): + if self._result_proxy: + return self._result_proxy + else: + return base.ResultProxy(self) + +class MSSQLCompiler(compiler.SQLCompiler): + returning_precedes_values = True + + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond', + 'microseconds': 'microsecond' + }) + + def __init__(self, *args, **kwargs): + super(MSSQLCompiler, self).__init__(*args, **kwargs) + self.tablealiases = {} + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_current_date_func(self, fn, **kw): + return "GETDATE()" + + def visit_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_char_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_concat_op(self, binary, **kw): + return "%s + %s" % (self.process(binary.left, **kw), self.process(binary.right, **kw)) + + def visit_match_op(self, binary, **kw): + return "CONTAINS (%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) + + def get_select_precolumns(self, select): + """ MS-SQL puts TOP, it's version of LIMIT here """ + if select._distinct or select._limit: + s = select._distinct and "DISTINCT " or "" + + if select._limit: + if not select._offset: + s += "TOP %s " % (select._limit,) + return s + return compiler.SQLCompiler.get_select_precolumns(self, select) + + def limit_clause(self, select): + # Limit in mssql is after the select keyword + return "" + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + """ + if not getattr(select, '_mssql_visit', None) and select._offset: + # to use ROW_NUMBER(), an ORDER BY is required. + orderby = self.process(select._order_by_clause) + if not orderby: + raise exc.InvalidRequestError('MSSQL requires an order_by when ' + 'using an offset.') + + _offset = select._offset + _limit = select._limit + select._mssql_visit = True + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" + % orderby).label("mssql_rn") + ).order_by(None).alias() + + limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) + limitselect.append_whereclause("mssql_rn>%d" % _offset) + if _limit is not None: + limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) + return self.process(limitselect, iswrapper=True, **kwargs) + else: + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def _schema_aliased_table(self, table): + if getattr(table, 'schema', None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_table(self, table, mssql_aliased=False, **kwargs): + if mssql_aliased: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + # alias schema-qualified tables + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=True, **kwargs) + else: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + def visit_alias(self, alias, **kwargs): + # translate for schema-qualified table aliases + self.tablealiases[alias.original] = alias + kwargs['mssql_aliased'] = True + return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) + + def visit_extract(self, extract, **kw): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return ("ROLLBACK TRANSACTION %s" + % self.preparer.format_savepoint(savepoint_stmt)) + + def visit_column(self, column, result_map=None, **kwargs): + if column.table is not None and \ + (not self.isupdate and not self.isdelete) or self.is_subquery(): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = expression._corresponding_column_or_error(t, column) + + if result_map is not None: + result_map[column.name.lower()] = (column.name, (column, ), + column.type) + + return super(MSSQLCompiler, self).visit_column(converted, + result_map=None, + **kwargs) + + return super(MSSQLCompiler, self).visit_column(column, + result_map=result_map, + **kwargs) + + def visit_binary(self, binary, **kwargs): + """Move bind parameters to the right-hand side of an operator, where + possible. + + """ + if ( + isinstance(binary.left, expression._BindParamClause) + and binary.operator == operator.eq + and not isinstance(binary.right, expression._BindParamClause) + ): + return self.process(expression._BinaryExpression(binary.right, + binary.left, + binary.operator), + **kwargs) + else: + if ( + + (binary.operator is operator.eq or binary.operator is operator.ne) + and ( + (isinstance(binary.left, expression._FromGrouping) + and isinstance(binary.left.element, + expression._ScalarSelect)) + or (isinstance(binary.right, expression._FromGrouping) + and isinstance(binary.right.element, + expression._ScalarSelect)) + or isinstance(binary.left, expression._ScalarSelect) + or isinstance(binary.right, expression._ScalarSelect) + ) + + ): + op = binary.operator == operator.eq and "IN" or "NOT IN" + return self.process(expression._BinaryExpression(binary.left, + binary.right, op), + **kwargs) + return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) + + def returning_clause(self, stmt, returning_cols): + + if self.isinsert or self.isupdate: + target = stmt.table.alias("inserted") + else: + target = stmt.table.alias("deleted") + + adapter = sql_util.ClauseAdapter(target) + def col_label(col): + adapted = adapter.traverse(col) + if isinstance(col, expression._Label): + return adapted.label(c.key) + else: + return self.label_select_column(None, adapted, asfrom=False) + + columns = [ + self.process( + col_label(c), + within_columns_clause=True, + result_map=self.result_map + ) + for c in expression._select_iterables(returning_cols) + ] + return 'OUTPUT ' + ', '.join(columns) + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super(MSSQLCompiler, self).label_select_column(select, column, asfrom) + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select, **kw): + order_by = self.process(select._order_by_clause, **kw) + + # MSSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + +class MSSQLStrictCompiler(MSSQLCompiler): + """A subclass of MSSQLCompiler which disables the usage of bind + parameters where not allowed natively by MS-SQL. + + A dialect may use this compiler on a platform where native + binds are used. + + """ + ansi_bind_rules = True + + def visit_in_op(self, binary, **kw): + kw['literal_binds'] = True + return "%s IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def visit_notin_op(self, binary, **kw): + kw['literal_binds'] = True + return "%s NOT IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def visit_function(self, func, **kw): + kw['literal_binds'] = True + return super(MSSQLStrictCompiler, self).visit_function(func, **kw) + + def render_literal_value(self, value, type_): + """ + For date and datetime values, convert to a string + format acceptable to MSSQL. That seems to be the + so-called ODBC canonical date format which looks + like this: + + yyyy-mm-dd hh:mi:ss.mmm(24h) + + For other data types, call the base class implementation. + """ + # datetime and date are both subclasses of datetime.date + if issubclass(type(value), datetime.date): + # SQL Server wants single quotes around the date string. + return "'" + str(value) + "'" + else: + return super(MSSQLStrictCompiler, self).render_literal_value(value, type_) + +class MSDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = (self.preparer.format_column(column) + " " + + self.dialect.type_compiler.process(column.type)) + + if column.nullable is not None: + if not column.nullable or column.primary_key: + colspec += " NOT NULL" + else: + colspec += " NULL" + + if column.table is None: + raise exc.InvalidRequestError("mssql requires Table-bound columns " + "in order to generate DDL") + + seq_col = column.table._autoincrement_column + + # install a IDENTITY Sequence if we have an implicit IDENTITY column + if seq_col is column: + sequence = isinstance(column.default, sa_schema.Sequence) and column.default + if sequence: + start, increment = sequence.start or 1, sequence.increment or 1 + else: + start, increment = 1, 1 + colspec += " IDENTITY(%s,%s)" % (start, increment) + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(drop.element.table.name), + self.preparer.quote(self._validate_identifier(drop.element.name, False), + drop.element.quote) + ) + + +class MSIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', + final_quote=']') + + def _escape_identifier(self, value): + return value + + def quote_schema(self, schema, force=True): + """Prepare a quoted table and schema name.""" + result = '.'.join([self.quote(x, force) for x in schema.split('.')]) + return result + +class MSDialect(default.DefaultDialect): + name = 'mssql' + supports_default_values = True + supports_empty_insert = False + execution_ctx_cls = MSExecutionContext + use_scope_identity = True + max_identifier_length = 128 + schema_name = "dbo" + + colspecs = { + sqltypes.DateTime : _MSDateTime, + sqltypes.Date : _MSDate, + sqltypes.Time : TIME, + } + + ischema_names = ischema_names + + supports_native_boolean = False + supports_unicode_binds = True + postfetch_lastrowid = True + + server_version_info = () + + statement_compiler = MSSQLCompiler + ddl_compiler = MSDDLCompiler + type_compiler = MSTypeCompiler + preparer = MSIdentifierPreparer + + def __init__(self, + query_timeout=None, + use_scope_identity=True, + max_identifier_length=None, + schema_name=u"dbo", **opts): + self.query_timeout = int(query_timeout or 0) + self.schema_name = schema_name + + self.use_scope_identity = use_scope_identity + self.max_identifier_length = int(max_identifier_length or 0) or \ + self.max_identifier_length + super(MSDialect, self).__init__(**opts) + + def do_savepoint(self, connection, name): + util.warn("Savepoint support in mssql is experimental and " + "may lead to data loss.") + connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") + connection.execute("SAVE TRANSACTION %s" % name) + + def do_release_savepoint(self, connection, name): + pass + + def initialize(self, connection): + super(MSDialect, self).initialize(connection) + if self.server_version_info >= MS_2005_VERSION and \ + 'implicit_returning' not in self.__dict__: + self.implicit_returning = True + + def _get_default_schema_name(self, connection): + user_name = connection.scalar("SELECT user_name() as user_name;") + if user_name is not None: + # now, get the default schema + query = """ + SELECT default_schema_name FROM + sys.database_principals + WHERE name = ? + AND type = 'S' + """ + try: + default_schema_name = connection.scalar(query, [user_name]) + if default_schema_name is not None: + return unicode(default_schema_name) + except: + pass + return self.schema_name + + + def has_table(self, connection, tablename, schema=None): + current_schema = schema or self.default_schema_name + columns = ischema.columns + if current_schema: + whereclause = sql.and_(columns.c.table_name==tablename, + columns.c.table_schema==current_schema) + else: + whereclause = columns.c.table_name==tablename + s = sql.select([columns], whereclause) + c = connection.execute(s) + return c.first() is not None + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = sql.select([ischema.schemata.c.schema_name], + order_by=[ischema.schemata.c.schema_name] + ) + schema_names = [r[0] for r in connection.execute(s)] + return schema_names + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + current_schema = schema or self.default_schema_name + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == current_schema, + tables.c.table_type == u'BASE TABLE' + ), + order_by=[tables.c.table_name] + ) + table_names = [r[0] for r in connection.execute(s)] + return table_names + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + current_schema = schema or self.default_schema_name + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == current_schema, + tables.c.table_type == u'VIEW' + ), + order_by=[tables.c.table_name] + ) + view_names = [r[0] for r in connection.execute(s)] + return view_names + + # The cursor reports it is closed after executing the sp. + @reflection.cache + def get_indexes(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + col_finder = re.compile("(\w+)") + full_tname = "%s.%s" % (current_schema, tablename) + indexes = [] + s = sql.text("exec sp_helpindex '%s'" % full_tname) + rp = connection.execute(s) + if rp.closed: + # did not work for this setup. + return [] + for row in rp: + if 'primary key' not in row['index_description']: + indexes.append({ + 'name' : row['index_name'], + 'column_names' : col_finder.findall(row['index_keys']), + 'unique': 'unique' in row['index_description'] + }) + return indexes + + @reflection.cache + def get_view_definition(self, connection, viewname, schema=None, **kw): + current_schema = schema or self.default_schema_name + views = ischema.views + s = sql.select([views.c.view_definition], + sql.and_( + views.c.table_schema == current_schema, + views.c.table_name == viewname + ), + ) + rp = connection.execute(s) + if rp: + view_def = rp.scalar() + return view_def + + @reflection.cache + def get_columns(self, connection, tablename, schema=None, **kw): + # Get base columns + current_schema = schema or self.default_schema_name + columns = ischema.columns + if current_schema: + whereclause = sql.and_(columns.c.table_name==tablename, + columns.c.table_schema==current_schema) + else: + whereclause = columns.c.table_name==tablename + s = sql.select([columns], whereclause, order_by=[columns.c.ordinal_position]) + c = connection.execute(s) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + (name, type, nullable, charlen, numericprec, numericscale, default, collation) = ( + row[columns.c.column_name], + row[columns.c.data_type], + row[columns.c.is_nullable] == 'YES', + row[columns.c.character_maximum_length], + row[columns.c.numeric_precision], + row[columns.c.numeric_scale], + row[columns.c.column_default], + row[columns.c.collation_name] + ) + coltype = self.ischema_names.get(type, None) + + kwargs = {} + if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, + MSNText, MSBinary, MSVarBinary, sqltypes.LargeBinary): + kwargs['length'] = charlen + if collation: + kwargs['collation'] = collation + if coltype == MSText or (coltype in (MSString, MSNVarchar) and charlen == -1): + kwargs.pop('length') + + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % (type, name)) + coltype = sqltypes.NULLTYPE + + if issubclass(coltype, sqltypes.Numeric) and coltype is not MSReal: + kwargs['scale'] = numericscale + kwargs['precision'] = numericprec + + coltype = coltype(**kwargs) + cdict = { + 'name' : name, + 'type' : coltype, + 'nullable' : nullable, + 'default' : default, + 'autoincrement':False, + } + cols.append(cdict) + # autoincrement and identity + colmap = {} + for col in cols: + colmap[col['name']] = col + # We also run an sp_columns to check for identity columns: + cursor = connection.execute("sp_columns @table_name = '%s', " + "@table_owner = '%s'" + % (tablename, current_schema)) + ic = None + while True: + row = cursor.fetchone() + if row is None: + break + (col_name, type_name) = row[3], row[5] + if type_name.endswith("identity") and col_name in colmap: + ic = col_name + colmap[col_name]['autoincrement'] = True + colmap[col_name]['sequence'] = dict( + name='%s_identity' % col_name) + break + cursor.close() + + if ic is not None and self.server_version_info >= MS_2005_VERSION: + table_fullname = "%s.%s" % (current_schema, tablename) + cursor = connection.execute( + "select ident_seed('%s'), ident_incr('%s')" + % (table_fullname, table_fullname) + ) + + row = cursor.first() + if row is not None and row[0] is not None: + colmap[ic]['sequence'].update({ + 'start' : int(row[0]), + 'increment' : int(row[1]) + }) + return cols + + @reflection.cache + def get_primary_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + pkeys = [] + RR = ischema.ref_constraints # information_schema.referential_constraints + TC = ischema.constraints # information_schema.table_constraints + C = ischema.key_constraints.alias('C') # information_schema.constraint_column_usage: + # the constrained column + R = ischema.key_constraints.alias('R') # information_schema.constraint_column_usage: + # the referenced column + + # Primary key constraints + s = sql.select([C.c.column_name, TC.c.constraint_type], + sql.and_(TC.c.constraint_name == C.c.constraint_name, + C.c.table_name == tablename, + C.c.table_schema == current_schema) + ) + c = connection.execute(s) + for row in c: + if 'PRIMARY' in row[TC.c.constraint_type.name]: + pkeys.append(row[0]) + return pkeys + + @reflection.cache + def get_foreign_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + # Add constraints + RR = ischema.ref_constraints #information_schema.referential_constraints + TC = ischema.constraints #information_schema.table_constraints + C = ischema.key_constraints.alias('C') # information_schema.constraint_column_usage: + # the constrained column + R = ischema.key_constraints.alias('R') # information_schema.constraint_column_usage: + # the referenced column + + # Foreign key constraints + s = sql.select([C.c.column_name, + R.c.table_schema, R.c.table_name, R.c.column_name, + RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, + RR.c.delete_rule], + sql.and_(C.c.table_name == tablename, + C.c.table_schema == current_schema, + C.c.constraint_name == RR.c.constraint_name, + R.c.constraint_name == RR.c.unique_constraint_name, + C.c.ordinal_position == R.c.ordinal_position + ), + order_by = [RR.c.constraint_name, R.c.ordinal_position]) + + + # group rows by constraint ID, to handle multi-column FKs + fkeys = [] + fknm, scols, rcols = (None, [], []) + + def fkey_rec(): + return { + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + } + + fkeys = util.defaultdict(fkey_rec) + + for r in connection.execute(s).fetchall(): + scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r + + rec = fkeys[rfknm] + rec['name'] = rfknm + if not rec['referred_table']: + rec['referred_table'] = rtbl + + if schema is not None or current_schema != rschema: + rec['referred_schema'] = rschema + + local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + + local_cols.append(scol) + remote_cols.append(rcol) + + return fkeys.values() + diff --git a/sqlalchemy/dialects/mssql/information_schema.py b/sqlalchemy/dialects/mssql/information_schema.py new file mode 100644 index 0000000..312e83c --- /dev/null +++ b/sqlalchemy/dialects/mssql/information_schema.py @@ -0,0 +1,83 @@ +from sqlalchemy import Table, MetaData, Column, ForeignKey +from sqlalchemy.types import String, Unicode, Integer, TypeDecorator + +ischema = MetaData() + +class CoerceUnicode(TypeDecorator): + impl = Unicode + + def process_bind_param(self, value, dialect): + if isinstance(value, str): + value = value.decode(dialect.encoding) + return value + +schemata = Table("SCHEMATA", ischema, + Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), + Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), + Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), + schema="INFORMATION_SCHEMA") + +tables = Table("TABLES", ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"), + schema="INFORMATION_SCHEMA") + +columns = Table("COLUMNS", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="INFORMATION_SCHEMA") + +constraints = Table("TABLE_CONSTRAINTS", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"), + schema="INFORMATION_SCHEMA") + +column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + schema="INFORMATION_SCHEMA") + +key_constraints = Table("KEY_COLUMN_USAGE", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + schema="INFORMATION_SCHEMA") + +ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, + Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, key="unique_constraint_catalog"), # TODO: is CATLOG misspelled ? + Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, key="unique_constraint_schema"), + Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"), + Column("MATCH_OPTION", String, key="match_option"), + Column("UPDATE_RULE", String, key="update_rule"), + Column("DELETE_RULE", String, key="delete_rule"), + schema="INFORMATION_SCHEMA") + +views = Table("VIEWS", ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), + Column("CHECK_OPTION", String, key="check_option"), + Column("IS_UPDATABLE", String, key="is_updatable"), + schema="INFORMATION_SCHEMA") + diff --git a/sqlalchemy/dialects/mssql/mxodbc.py b/sqlalchemy/dialects/mssql/mxodbc.py new file mode 100644 index 0000000..efe7636 --- /dev/null +++ b/sqlalchemy/dialects/mssql/mxodbc.py @@ -0,0 +1,83 @@ +""" +Support for MS-SQL via mxODBC. + +mxODBC is available at: + + http://www.egenix.com/ + +This was tested with mxODBC 3.1.2 and the SQL Server Native +Client connected to MSSQL 2005 and 2008 Express Editions. + +Connecting +~~~~~~~~~~ + +Connection is via DSN:: + + mssql+mxodbc://:@ + +Execution Modes +~~~~~~~~~~~~~~~ + +mxODBC features two styles of statement execution, using the ``cursor.execute()`` +and ``cursor.executedirect()`` methods (the second being an extension to the +DBAPI specification). The former makes use of the native +parameter binding services of the ODBC driver, while the latter uses string escaping. +The primary advantage to native parameter binding is that the same statement, when +executed many times, is only prepared once. Whereas the primary advantage to the +latter is that the rules for bind parameter placement are relaxed. MS-SQL has very +strict rules for native binds, including that they cannot be placed within the argument +lists of function calls, anywhere outside the FROM, or even within subqueries within the +FROM clause - making the usage of bind parameters within SELECT statements impossible for +all but the most simplistic statements. For this reason, the mxODBC dialect uses the +"native" mode by default only for INSERT, UPDATE, and DELETE statements, and uses the +escaped string mode for all other statements. This behavior can be controlled completely +via :meth:`~sqlalchemy.sql.expression.Executable.execution_options` +using the ``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a value of +``True`` will unconditionally use native bind parameters and a value of ``False`` will +uncondtionally use string-escaped parameters. + +""" + +import re +import sys + +from sqlalchemy import types as sqltypes +from sqlalchemy import util +from sqlalchemy.connectors.mxodbc import MxODBCConnector +from sqlalchemy.dialects.mssql.pyodbc import MSExecutionContext_pyodbc +from sqlalchemy.dialects.mssql.base import (MSExecutionContext, MSDialect, + MSSQLCompiler, MSSQLStrictCompiler, + _MSDateTime, _MSDate, TIME) + + + +class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc): + """ + The pyodbc execution context is useful for enabling + SELECT SCOPE_IDENTITY in cases where OUTPUT clause + does not work (tables with insert triggers). + """ + #todo - investigate whether the pyodbc execution context + # is really only being used in cases where OUTPUT + # won't work. + +class MSDialect_mxodbc(MxODBCConnector, MSDialect): + + # TODO: may want to use this only if FreeTDS is not in use, + # since FreeTDS doesn't seem to use native binds. + statement_compiler = MSSQLStrictCompiler + execution_ctx_cls = MSExecutionContext_mxodbc + colspecs = { + #sqltypes.Numeric : _MSNumeric, + sqltypes.DateTime : _MSDateTime, + sqltypes.Date : _MSDate, + sqltypes.Time : TIME, + } + + + def __init__(self, description_encoding='latin-1', **params): + super(MSDialect_mxodbc, self).__init__(**params) + self.description_encoding = description_encoding + +dialect = MSDialect_mxodbc + diff --git a/sqlalchemy/dialects/mssql/pymssql.py b/sqlalchemy/dialects/mssql/pymssql.py new file mode 100644 index 0000000..ca1c4a1 --- /dev/null +++ b/sqlalchemy/dialects/mssql/pymssql.py @@ -0,0 +1,101 @@ +""" +Support for the pymssql dialect. + +This dialect supports pymssql 1.0 and greater. + +pymssql is available at: + + http://pymssql.sourceforge.net/ + +Connecting +^^^^^^^^^^ + +Sample connect string:: + + mssql+pymssql://:@ + +Adding "?charset=utf8" or similar will cause pymssql to return +strings as Python unicode objects. This can potentially improve +performance in some scenarios as decoding of strings is +handled natively. + +Limitations +^^^^^^^^^^^ + +pymssql inherits a lot of limitations from FreeTDS, including: + +* no support for multibyte schema identifiers +* poor support for large decimals +* poor support for binary fields +* poor support for VARCHAR/CHAR fields over 255 characters + +Please consult the pymssql documentation for further information. + +""" +from sqlalchemy.dialects.mssql.base import MSDialect +from sqlalchemy import types as sqltypes, util, processors +import re +import decimal + +class _MSNumeric_pymssql(sqltypes.Numeric): + def result_processor(self, dialect, type_): + if not self.asdecimal: + return processors.to_float + else: + return sqltypes.Numeric.result_processor(self, dialect, type_) + +class MSDialect_pymssql(MSDialect): + supports_sane_rowcount = False + max_identifier_length = 30 + driver = 'pymssql' + + colspecs = util.update_copy( + MSDialect.colspecs, + { + sqltypes.Numeric:_MSNumeric_pymssql, + sqltypes.Float:sqltypes.Float, + } + ) + @classmethod + def dbapi(cls): + module = __import__('pymssql') + # pymmsql doesn't have a Binary method. we use string + # TODO: monkeypatching here is less than ideal + module.Binary = str + + client_ver = tuple(int(x) for x in module.__version__.split(".")) + if client_ver < (1, ): + util.warn("The pymssql dialect expects at least " + "the 1.0 series of the pymssql DBAPI.") + return module + + def __init__(self, **params): + super(MSDialect_pymssql, self).__init__(**params) + self.use_scope_identity = True + + def _get_server_version_info(self, connection): + vers = connection.scalar("select @@version") + m = re.match(r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers) + if m: + return tuple(int(x) for x in m.group(1, 2, 3, 4)) + else: + return None + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + opts.pop('port', None) + return [[], opts] + + def is_disconnect(self, e): + for msg in ( + "Error 10054", + "Not connected to any MS SQL server", + "Connection is closed" + ): + if msg in str(e): + return True + else: + return False + +dialect = MSDialect_pymssql \ No newline at end of file diff --git a/sqlalchemy/dialects/mssql/pyodbc.py b/sqlalchemy/dialects/mssql/pyodbc.py new file mode 100644 index 0000000..c74be0e --- /dev/null +++ b/sqlalchemy/dialects/mssql/pyodbc.py @@ -0,0 +1,197 @@ +""" +Support for MS-SQL via pyodbc. + +pyodbc is available at: + + http://pypi.python.org/pypi/pyodbc/ + +Connecting +^^^^^^^^^^ + +Examples of pyodbc connection string URLs: + +* ``mssql+pyodbc://mydsn`` - connects using the specified DSN named ``mydsn``. + The connection string that is created will appear like:: + + dsn=mydsn;Trusted_Connection=Yes + +* ``mssql+pyodbc://user:pass@mydsn`` - connects using the DSN named + ``mydsn`` passing in the ``UID`` and ``PWD`` information. The + connection string that is created will appear like:: + + dsn=mydsn;UID=user;PWD=pass + +* ``mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english`` - connects + using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD`` + information, plus the additional connection configuration option + ``LANGUAGE``. The connection string that is created will appear + like:: + + dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english + +* ``mssql+pyodbc://user:pass@host/db`` - connects using a connection string + dynamically created that would appear like:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass + +* ``mssql+pyodbc://user:pass@host:123/db`` - connects using a connection + string that is dynamically created, which also includes the port + information using the comma syntax. If your connection string + requires the port information to be passed as a ``port`` keyword + see the next example. This will create the following connection + string:: + + DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass + +* ``mssql+pyodbc://user:pass@host/db?port=123`` - connects using a connection + string that is dynamically created that includes the port + information as a separate ``port`` keyword. This will create the + following connection string:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123 + +If you require a connection string that is outside the options +presented above, use the ``odbc_connect`` keyword to pass in a +urlencoded connection string. What gets passed in will be urldecoded +and passed directly. + +For example:: + + mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb + +would create the following connection string:: + + dsn=mydsn;Database=db + +Encoding your connection string can be easily accomplished through +the python shell. For example:: + + >>> import urllib + >>> urllib.quote_plus('dsn=mydsn;Database=db') + 'dsn%3Dmydsn%3BDatabase%3Ddb' + + +""" + +from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect +from sqlalchemy.connectors.pyodbc import PyODBCConnector +from sqlalchemy import types as sqltypes, util +import decimal + +class _MSNumeric_pyodbc(sqltypes.Numeric): + """Turns Decimals with adjusted() < 0 or > 7 into strings. + + This is the only method that is proven to work with Pyodbc+MSSQL + without crashing (floats can be used but seem to cause sporadic + crashes). + + """ + + def bind_processor(self, dialect): + super_process = super(_MSNumeric_pyodbc, self).bind_processor(dialect) + + def process(value): + if self.asdecimal and \ + isinstance(value, decimal.Decimal): + + adjusted = value.adjusted() + if adjusted < 0: + return self._small_dec_to_string(value) + elif adjusted > 7: + return self._large_dec_to_string(value) + + if super_process: + return super_process(value) + else: + return value + return process + + def _small_dec_to_string(self, value): + return "%s0.%s%s" % ( + (value < 0 and '-' or ''), + '0' * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value._int])) + + def _large_dec_to_string(self, value): + if 'E' in str(value): + result = "%s%s%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int]), + "0" * (value.adjusted() - (len(value._int)-1))) + else: + if (len(value._int) - 1) > value.adjusted(): + result = "%s%s.%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int][0:value.adjusted() + 1]), + "".join([str(s) for s in value._int][value.adjusted() + 1:])) + else: + result = "%s%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int][0:value.adjusted() + 1])) + return result + + +class MSExecutionContext_pyodbc(MSExecutionContext): + _embedded_scope_identity = False + + def pre_exec(self): + """where appropriate, issue "select scope_identity()" in the same statement. + + Background on why "scope_identity()" is preferable to "@@identity": + http://msdn.microsoft.com/en-us/library/ms190315.aspx + + Background on why we attempt to embed "scope_identity()" into the same + statement as the INSERT: + http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values? + + """ + + super(MSExecutionContext_pyodbc, self).pre_exec() + + # don't embed the scope_identity select into an "INSERT .. DEFAULT VALUES" + if self._select_lastrowid and \ + self.dialect.use_scope_identity and \ + len(self.parameters[0]): + self._embedded_scope_identity = True + + self.statement += "; select scope_identity()" + + def post_exec(self): + if self._embedded_scope_identity: + # Fetch the last inserted id from the manipulated statement + # We may have to skip over a number of result sets with no data (due to triggers, etc.) + while True: + try: + # fetchall() ensures the cursor is consumed + # without closing it (FreeTDS particularly) + row = self.cursor.fetchall()[0] + break + except self.dialect.dbapi.Error, e: + # no way around this - nextset() consumes the previous set + # so we need to just keep flipping + self.cursor.nextset() + + self._lastrowid = int(row[0]) + else: + super(MSExecutionContext_pyodbc, self).post_exec() + + +class MSDialect_pyodbc(PyODBCConnector, MSDialect): + + execution_ctx_cls = MSExecutionContext_pyodbc + + pyodbc_driver_name = 'SQL Server' + + colspecs = util.update_copy( + MSDialect.colspecs, + { + sqltypes.Numeric:_MSNumeric_pyodbc + } + ) + + def __init__(self, description_encoding='latin-1', **params): + super(MSDialect_pyodbc, self).__init__(**params) + self.description_encoding = description_encoding + self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset') + +dialect = MSDialect_pyodbc diff --git a/sqlalchemy/dialects/mssql/zxjdbc.py b/sqlalchemy/dialects/mssql/zxjdbc.py new file mode 100644 index 0000000..b11eb17 --- /dev/null +++ b/sqlalchemy/dialects/mssql/zxjdbc.py @@ -0,0 +1,64 @@ +"""Support for the Microsoft SQL Server database via the zxjdbc JDBC +connector. + +JDBC Driver +----------- + +Requires the jTDS driver, available from: http://jtds.sourceforge.net/ + +Connecting +---------- + +URLs are of the standard form of +``mssql+zxjdbc://user:pass@host:port/dbname[?key=value&key=value...]``. + +Additional arguments which may be specified either as query string +arguments on the URL, or as keyword arguments to +:func:`~sqlalchemy.create_engine()` will be passed as Connection +properties to the underlying JDBC driver. + +""" +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.dialects.mssql.base import MSDialect, MSExecutionContext +from sqlalchemy.engine import base + +class MSExecutionContext_zxjdbc(MSExecutionContext): + + _embedded_scope_identity = False + + def pre_exec(self): + super(MSExecutionContext_zxjdbc, self).pre_exec() + # scope_identity after the fact returns null in jTDS so we must + # embed it + if self._select_lastrowid and self.dialect.use_scope_identity: + self._embedded_scope_identity = True + self.statement += "; SELECT scope_identity()" + + def post_exec(self): + if self._embedded_scope_identity: + while True: + try: + row = self.cursor.fetchall()[0] + break + except self.dialect.dbapi.Error, e: + self.cursor.nextset() + self._lastrowid = int(row[0]) + + if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning: + self._result_proxy = base.FullyBufferedResultProxy(self) + + if self._enable_identity_insert: + table = self.dialect.identifier_preparer.format_table(self.compiled.statement.table) + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table) + + +class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect): + jdbc_db_name = 'jtds:sqlserver' + jdbc_driver_name = 'net.sourceforge.jtds.jdbc.Driver' + + execution_ctx_cls = MSExecutionContext_zxjdbc + + def _get_server_version_info(self, connection): + return tuple(int(x) for x in connection.connection.dbversion.split('.')) + +dialect = MSDialect_zxjdbc diff --git a/sqlalchemy/dialects/mysql/__init__.py b/sqlalchemy/dialects/mysql/__init__.py new file mode 100644 index 0000000..f37a0c7 --- /dev/null +++ b/sqlalchemy/dialects/mysql/__init__.py @@ -0,0 +1,17 @@ +from sqlalchemy.dialects.mysql import base, mysqldb, oursql, pyodbc, zxjdbc, mysqlconnector + +# default dialect +base.dialect = mysqldb.dialect + +from sqlalchemy.dialects.mysql.base import \ + BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, DOUBLE, ENUM, DECIMAL,\ + FLOAT, INTEGER, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, MEDIUMTEXT, NCHAR, \ + NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT,\ + VARBINARY, VARCHAR, YEAR, dialect + +__all__ = ( +'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE', +'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT', +'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP', +'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect' +) diff --git a/sqlalchemy/dialects/mysql/base.py b/sqlalchemy/dialects/mysql/base.py new file mode 100644 index 0000000..6a07614 --- /dev/null +++ b/sqlalchemy/dialects/mysql/base.py @@ -0,0 +1,2528 @@ +# -*- fill-column: 78 -*- +# mysql/base.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# and Jason Kirtland. +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for the MySQL database. + +Supported Versions and Features +------------------------------- + +SQLAlchemy supports 6 major MySQL versions: 3.23, 4.0, 4.1, 5.0, 5.1 and 6.0, +with capabilities increasing with more modern servers. + +Versions 4.1 and higher support the basic SQL functionality that SQLAlchemy +uses in the ORM and SQL expressions. These versions pass the applicable tests +in the suite 100%. No heroic measures are taken to work around major missing +SQL features- if your server version does not support sub-selects, for +example, they won't work in SQLAlchemy either. + +Most available DBAPI drivers are supported; see below. + +===================================== =============== +Feature Minimum Version +===================================== =============== +sqlalchemy.orm 4.1.1 +Table Reflection 3.23.x +DDL Generation 4.1.1 +utf8/Full Unicode Connections 4.1.1 +Transactions 3.23.15 +Two-Phase Transactions 5.0.3 +Nested Transactions 5.0.3 +===================================== =============== + +See the official MySQL documentation for detailed information about features +supported in any given server release. + +Connecting +---------- + +See the API documentation on individual drivers for details on connecting. + +Data Types +---------- + +All of MySQL's standard types are supported. These can also be specified within +table metadata, for the purpose of issuing CREATE TABLE statements +which include MySQL-specific extensions. The types are available +from the module, as in:: + + from sqlalchemy.dialects import mysql + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('ittybittyblob', mysql.TINYBLOB), + Column('biggy', mysql.BIGINT(unsigned=True))) + +See the API documentation on specific column types for further details. + +Connection Timeouts +------------------- + +MySQL features an automatic connection close behavior, for connections that have +been idle for eight hours or more. To circumvent having this issue, use the +``pool_recycle`` option which controls the maximum age of any connection:: + + engine = create_engine('mysql+mysqldb://...', pool_recycle=3600) + +Storage Engines +--------------- + +Most MySQL server installations have a default table type of ``MyISAM``, a +non-transactional table type. During a transaction, non-transactional storage +engines do not participate and continue to store table changes in autocommit +mode. For fully atomic transactions, all participating tables must use a +transactional engine such as ``InnoDB``, ``Falcon``, ``SolidDB``, `PBXT`, etc. + +Storage engines can be elected when creating tables in SQLAlchemy by supplying +a ``mysql_engine='whatever'`` to the ``Table`` constructor. Any MySQL table +creation option can be specified in this syntax:: + + Table('mytable', metadata, + Column('data', String(32)), + mysql_engine='InnoDB', + mysql_charset='utf8' + ) + +Keys +---- + +Not all MySQL storage engines support foreign keys. For ``MyISAM`` and +similar engines, the information loaded by table reflection will not include +foreign keys. For these tables, you may supply a +:class:`~sqlalchemy.ForeignKeyConstraint` at reflection time:: + + Table('mytable', metadata, + ForeignKeyConstraint(['other_id'], ['othertable.other_id']), + autoload=True + ) + +When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT``` on +an integer primary key column:: + + >>> t = Table('mytable', metadata, + ... Column('mytable_id', Integer, primary_key=True) + ... ) + >>> t.create() + CREATE TABLE mytable ( + id INTEGER NOT NULL AUTO_INCREMENT, + PRIMARY KEY (id) + ) + +You can disable this behavior by supplying ``autoincrement=False`` to the +:class:`~sqlalchemy.Column`. This flag can also be used to enable +auto-increment on a secondary column in a multi-column key for some storage +engines:: + + Table('mytable', metadata, + Column('gid', Integer, primary_key=True, autoincrement=False), + Column('id', Integer, primary_key=True) + ) + +SQL Mode +-------- + +MySQL SQL modes are supported. Modes that enable ``ANSI_QUOTES`` (such as +``ANSI``) require an engine option to modify SQLAlchemy's quoting style. +When using an ANSI-quoting mode, supply ``use_ansiquotes=True`` when +creating your ``Engine``:: + + create_engine('mysql://localhost/test', use_ansiquotes=True) + +This is an engine-wide option and is not toggleable on a per-connection basis. +SQLAlchemy does not presume to ``SET sql_mode`` for you with this option. For +the best performance, set the quoting style server-wide in ``my.cnf`` or by +supplying ``--sql-mode`` to ``mysqld``. You can also use a +:class:`sqlalchemy.pool.Pool` listener hook to issue a ``SET SESSION +sql_mode='...'`` on connect to configure each connection. + +If you do not specify ``use_ansiquotes``, the regular MySQL quoting style is +used by default. + +If you do issue a ``SET sql_mode`` through SQLAlchemy, the dialect must be +updated if the quoting style is changed. Again, this change will affect all +connections:: + + connection.execute('SET sql_mode="ansi"') + connection.dialect.use_ansiquotes = True + +MySQL SQL Extensions +-------------------- + +Many of the MySQL SQL extensions are handled through SQLAlchemy's generic +function and operator support:: + + table.select(table.c.password==func.md5('plaintext')) + table.select(table.c.username.op('regexp')('^[a-d]')) + +And of course any valid MySQL statement can be executed as a string as well. + +Some limited direct support for MySQL extensions to SQL is currently +available. + +* SELECT pragma:: + + select(..., prefixes=['HIGH_PRIORITY', 'SQL_SMALL_RESULT']) + +* UPDATE with LIMIT:: + + update(..., mysql_limit=10) + +Troubleshooting +--------------- + +If you have problems that seem server related, first check that you are +using the most recent stable MySQL-Python package available. The Database +Notes page on the wiki at http://www.sqlalchemy.org is a good resource for +timely information affecting MySQL in SQLAlchemy. + +""" + +import datetime, inspect, re, sys + +from sqlalchemy import schema as sa_schema +from sqlalchemy import exc, log, sql, util +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy.sql import functions as sql_functions +from sqlalchemy.sql import compiler +from array import array as _array + +from sqlalchemy.engine import reflection +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy import types as sqltypes + +from sqlalchemy.types import DATE, DATETIME, BOOLEAN, TIME, \ + BLOB, BINARY, VARBINARY + +RESERVED_WORDS = set( + ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc', + 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', + 'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check', + 'collate', 'column', 'condition', 'constraint', 'continue', 'convert', + 'create', 'cross', 'current_date', 'current_time', 'current_timestamp', + 'current_user', 'cursor', 'database', 'databases', 'day_hour', + 'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal', + 'declare', 'default', 'delayed', 'delete', 'desc', 'describe', + 'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop', + 'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists', + 'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8', + 'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group', 'having', + 'high_priority', 'hour_microsecond', 'hour_minute', 'hour_second', 'if', + 'ignore', 'in', 'index', 'infile', 'inner', 'inout', 'insensitive', + 'insert', 'int', 'int1', 'int2', 'int3', 'int4', 'int8', 'integer', + 'interval', 'into', 'is', 'iterate', 'join', 'key', 'keys', 'kill', + 'leading', 'leave', 'left', 'like', 'limit', 'linear', 'lines', 'load', + 'localtime', 'localtimestamp', 'lock', 'long', 'longblob', 'longtext', + 'loop', 'low_priority', 'master_ssl_verify_server_cert', 'match', + 'mediumblob', 'mediumint', 'mediumtext', 'middleint', + 'minute_microsecond', 'minute_second', 'mod', 'modifies', 'natural', + 'not', 'no_write_to_binlog', 'null', 'numeric', 'on', 'optimize', + 'option', 'optionally', 'or', 'order', 'out', 'outer', 'outfile', + 'precision', 'primary', 'procedure', 'purge', 'range', 'read', 'reads', + 'read_only', 'read_write', 'real', 'references', 'regexp', 'release', + 'rename', 'repeat', 'replace', 'require', 'restrict', 'return', + 'revoke', 'right', 'rlike', 'schema', 'schemas', 'second_microsecond', + 'select', 'sensitive', 'separator', 'set', 'show', 'smallint', 'spatial', + 'specific', 'sql', 'sqlexception', 'sqlstate', 'sqlwarning', + 'sql_big_result', 'sql_calc_found_rows', 'sql_small_result', 'ssl', + 'starting', 'straight_join', 'table', 'terminated', 'then', 'tinyblob', + 'tinyint', 'tinytext', 'to', 'trailing', 'trigger', 'true', 'undo', + 'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use', + 'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary', + 'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with', + 'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0 + 'columns', 'fields', 'privileges', 'soname', 'tables', # 4.1 + 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', + 'read_only', 'read_write', # 5.1 + ]) + +AUTOCOMMIT_RE = re.compile( + r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)', + re.I | re.UNICODE) +SET_RE = re.compile( + r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w', + re.I | re.UNICODE) + + +class _NumericType(object): + """Base for MySQL numeric types.""" + + def __init__(self, **kw): + self.unsigned = kw.pop('unsigned', False) + self.zerofill = kw.pop('zerofill', False) + super(_NumericType, self).__init__(**kw) + +class _FloatType(_NumericType, sqltypes.Float): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + if isinstance(self, (REAL, DOUBLE)) and \ + ( + (precision is None and scale is not None) or + (precision is not None and scale is None) + ): + raise exc.ArgumentError( + "You must specify both precision and scale or omit " + "both altogether.") + + super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw) + self.scale = scale + +class _IntegerType(_NumericType, sqltypes.Integer): + def __init__(self, display_width=None, **kw): + self.display_width = display_width + super(_IntegerType, self).__init__(**kw) + +class _StringType(sqltypes.String): + """Base for MySQL string types.""" + + def __init__(self, charset=None, collation=None, + ascii=False, binary=False, + national=False, **kw): + self.charset = charset + # allow collate= or collation= + self.collation = kw.pop('collate', collation) + self.ascii = ascii + # We have to munge the 'unicode' param strictly as a dict + # otherwise 2to3 will turn it into str. + self.__dict__['unicode'] = kw.get('unicode', False) + # sqltypes.String does not accept the 'unicode' arg at all. + if 'unicode' in kw: + del kw['unicode'] + self.binary = binary + self.national = national + super(_StringType, self).__init__(**kw) + + def __repr__(self): + attributes = inspect.getargspec(self.__init__)[0][1:] + attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) + + params = {} + for attr in attributes: + val = getattr(self, attr) + if val is not None and val is not False: + params[attr] = val + + return "%s(%s)" % (self.__class__.__name__, + ', '.join(['%s=%r' % (k, params[k]) for k in params])) + + +class NUMERIC(_NumericType, sqltypes.NUMERIC): + """MySQL NUMERIC type.""" + + __visit_name__ = 'NUMERIC' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a NUMERIC. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(NUMERIC, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + + +class DECIMAL(_NumericType, sqltypes.DECIMAL): + """MySQL DECIMAL type.""" + + __visit_name__ = 'DECIMAL' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DECIMAL. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(DECIMAL, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + + +class DOUBLE(_FloatType): + """MySQL DOUBLE type.""" + + __visit_name__ = 'DOUBLE' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DOUBLE. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(DOUBLE, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + +class REAL(_FloatType): + """MySQL REAL type.""" + + __visit_name__ = 'REAL' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a REAL. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(REAL, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + +class FLOAT(_FloatType, sqltypes.FLOAT): + """MySQL FLOAT type.""" + + __visit_name__ = 'FLOAT' + + def __init__(self, precision=None, scale=None, asdecimal=False, **kw): + """Construct a FLOAT. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(FLOAT, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + + def bind_processor(self, dialect): + return None + +class INTEGER(_IntegerType, sqltypes.INTEGER): + """MySQL INTEGER type.""" + + __visit_name__ = 'INTEGER' + + def __init__(self, display_width=None, **kw): + """Construct an INTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(INTEGER, self).__init__(display_width=display_width, **kw) + +class BIGINT(_IntegerType, sqltypes.BIGINT): + """MySQL BIGINTEGER type.""" + + __visit_name__ = 'BIGINT' + + def __init__(self, display_width=None, **kw): + """Construct a BIGINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(BIGINT, self).__init__(display_width=display_width, **kw) + +class MEDIUMINT(_IntegerType): + """MySQL MEDIUMINTEGER type.""" + + __visit_name__ = 'MEDIUMINT' + + def __init__(self, display_width=None, **kw): + """Construct a MEDIUMINTEGER + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(MEDIUMINT, self).__init__(display_width=display_width, **kw) + +class TINYINT(_IntegerType): + """MySQL TINYINT type.""" + + __visit_name__ = 'TINYINT' + + def __init__(self, display_width=None, **kw): + """Construct a TINYINT. + + Note: following the usual MySQL conventions, TINYINT(1) columns + reflected during Table(..., autoload=True) are treated as + Boolean columns. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(TINYINT, self).__init__(display_width=display_width, **kw) + +class SMALLINT(_IntegerType, sqltypes.SMALLINT): + """MySQL SMALLINTEGER type.""" + + __visit_name__ = 'SMALLINT' + + def __init__(self, display_width=None, **kw): + """Construct a SMALLINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(SMALLINT, self).__init__(display_width=display_width, **kw) + +class BIT(sqltypes.TypeEngine): + """MySQL BIT type. + + This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater for + MyISAM, MEMORY, InnoDB and BDB. For older versions, use a MSTinyInteger() + type. + + """ + + __visit_name__ = 'BIT' + + def __init__(self, length=None): + """Construct a BIT. + + :param length: Optional, number of bits. + + """ + self.length = length + + def result_processor(self, dialect, coltype): + """Convert a MySQL's 64 bit, variable length binary string to a long. + + TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector + already do this, so this logic should be moved to those dialects. + + """ + + def process(value): + if value is not None: + v = 0L + for i in map(ord, value): + v = v << 8 | i + return v + return value + return process + +class _MSTime(sqltypes.Time): + """MySQL TIME type.""" + + __visit_name__ = 'TIME' + + def result_processor(self, dialect, coltype): + time = datetime.time + def process(value): + # convert from a timedelta value + if value is not None: + seconds = value.seconds + minutes = seconds / 60 + return time(minutes / 60, minutes % 60, seconds - minutes * 60) + else: + return None + return process + +class TIMESTAMP(sqltypes.TIMESTAMP): + """MySQL TIMESTAMP type.""" + __visit_name__ = 'TIMESTAMP' + +class YEAR(sqltypes.TypeEngine): + """MySQL YEAR type, for single byte storage of years 1901-2155.""" + + __visit_name__ = 'YEAR' + + def __init__(self, display_width=None): + self.display_width = display_width + +class TEXT(_StringType, sqltypes.TEXT): + """MySQL TEXT type, for text up to 2^16 characters.""" + + __visit_name__ = 'TEXT' + + def __init__(self, length=None, **kw): + """Construct a TEXT. + + :param length: Optional, if provided the server may optimize storage + by substituting the smallest TEXT type sufficient to store + ``length`` characters. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(TEXT, self).__init__(length=length, **kw) + +class TINYTEXT(_StringType): + """MySQL TINYTEXT type, for text up to 2^8 characters.""" + + __visit_name__ = 'TINYTEXT' + + def __init__(self, **kwargs): + """Construct a TINYTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(TINYTEXT, self).__init__(**kwargs) + +class MEDIUMTEXT(_StringType): + """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" + + __visit_name__ = 'MEDIUMTEXT' + + def __init__(self, **kwargs): + """Construct a MEDIUMTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(MEDIUMTEXT, self).__init__(**kwargs) + +class LONGTEXT(_StringType): + """MySQL LONGTEXT type, for text up to 2^32 characters.""" + + __visit_name__ = 'LONGTEXT' + + def __init__(self, **kwargs): + """Construct a LONGTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(LONGTEXT, self).__init__(**kwargs) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """MySQL VARCHAR type, for variable-length character data.""" + + __visit_name__ = 'VARCHAR' + + def __init__(self, length=None, **kwargs): + """Construct a VARCHAR. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(VARCHAR, self).__init__(length=length, **kwargs) + +class CHAR(_StringType, sqltypes.CHAR): + """MySQL CHAR type, for fixed-length character data.""" + + __visit_name__ = 'CHAR' + + def __init__(self, length, **kwargs): + """Construct a CHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + super(CHAR, self).__init__(length=length, **kwargs) + +class NVARCHAR(_StringType, sqltypes.NVARCHAR): + """MySQL NVARCHAR type. + + For variable-length character data in the server's configured national + character set. + """ + + __visit_name__ = 'NVARCHAR' + + def __init__(self, length=None, **kwargs): + """Construct an NVARCHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs['national'] = True + super(NVARCHAR, self).__init__(length=length, **kwargs) + + +class NCHAR(_StringType, sqltypes.NCHAR): + """MySQL NCHAR type. + + For fixed-length character data in the server's configured national + character set. + """ + + __visit_name__ = 'NCHAR' + + def __init__(self, length=None, **kwargs): + """Construct an NCHAR. Arguments are: + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs['national'] = True + super(NCHAR, self).__init__(length=length, **kwargs) + + + + +class TINYBLOB(sqltypes._Binary): + """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" + + __visit_name__ = 'TINYBLOB' + +class MEDIUMBLOB(sqltypes._Binary): + """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" + + __visit_name__ = 'MEDIUMBLOB' + +class LONGBLOB(sqltypes._Binary): + """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" + + __visit_name__ = 'LONGBLOB' + +class ENUM(sqltypes.Enum, _StringType): + """MySQL ENUM type.""" + + __visit_name__ = 'ENUM' + + def __init__(self, *enums, **kw): + """Construct an ENUM. + + Example: + + Column('myenum', MSEnum("foo", "bar", "baz")) + + Arguments are: + + :param enums: The range of valid values for this ENUM. Values will be + quoted when generating the schema according to the quoting flag (see + below). + + :param strict: Defaults to False: ensure that a given value is in this + ENUM's range of permissible values when inserting or updating rows. + Note that MySQL will not raise a fatal error if you attempt to store + an out of range value- an alternate value will be stored instead. + (See MySQL ENUM documentation.) + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + :param quoting: Defaults to 'auto': automatically determine enum value + quoting. If all enum values are surrounded by the same quoting + character, then use 'quoted' mode. Otherwise, use 'unquoted' mode. + + 'quoted': values in enums are already quoted, they will be used + directly when generating the schema - this usage is deprecated. + + 'unquoted': values in enums are not quoted, they will be escaped and + surrounded by single quotes when generating the schema. + + Previous versions of this type always required manually quoted + values to be supplied; future versions will always quote the string + literals for you. This is a transitional option. + + """ + self.quoting = kw.pop('quoting', 'auto') + + if self.quoting == 'auto' and len(enums): + # What quoting character are we using? + q = None + for e in enums: + if len(e) == 0: + self.quoting = 'unquoted' + break + elif q is None: + q = e[0] + + if e[0] != q or e[-1] != q: + self.quoting = 'unquoted' + break + else: + self.quoting = 'quoted' + + if self.quoting == 'quoted': + util.warn_deprecated( + 'Manually quoting ENUM value literals is deprecated. Supply ' + 'unquoted values and use the quoting= option in cases of ' + 'ambiguity.') + enums = self._strip_enums(enums) + + self.strict = kw.pop('strict', False) + length = max([len(v) for v in enums] + [0]) + kw.pop('metadata', None) + kw.pop('schema', None) + kw.pop('name', None) + kw.pop('quote', None) + _StringType.__init__(self, length=length, **kw) + sqltypes.Enum.__init__(self, *enums) + + @classmethod + def _strip_enums(cls, enums): + strip_enums = [] + for a in enums: + if a[0:1] == '"' or a[0:1] == "'": + # strip enclosing quotes and unquote interior + a = a[1:-1].replace(a[0] * 2, a[0]) + strip_enums.append(a) + return strip_enums + + def bind_processor(self, dialect): + super_convert = super(ENUM, self).bind_processor(dialect) + def process(value): + if self.strict and value is not None and value not in self.enums: + raise exc.InvalidRequestError('"%s" not a valid value for ' + 'this enum' % value) + if super_convert: + return super_convert(value) + else: + return value + return process + +class SET(_StringType): + """MySQL SET type.""" + + __visit_name__ = 'SET' + + def __init__(self, *values, **kw): + """Construct a SET. + + Example:: + + Column('myset', MSSet("'foo'", "'bar'", "'baz'")) + + Arguments are: + + :param values: The range of valid values for this SET. Values will be + used exactly as they appear when generating schemas. Strings must + be quoted, as in the example above. Single-quotes are suggested for + ANSI compatibility and are required for portability to servers with + ANSI_QUOTES enabled. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + self._ddl_values = values + + strip_values = [] + for a in values: + if a[0:1] == '"' or a[0:1] == "'": + # strip enclosing quotes and unquote interior + a = a[1:-1].replace(a[0] * 2, a[0]) + strip_values.append(a) + + self.values = strip_values + length = max([len(v) for v in strip_values] + [0]) + super(SET, self).__init__(length=length, **kw) + + def result_processor(self, dialect, coltype): + def process(value): + # The good news: + # No ',' quoting issues- commas aren't allowed in SET values + # The bad news: + # Plenty of driver inconsistencies here. + if isinstance(value, util.set_types): + # ..some versions convert '' to an empty set + if not value: + value.add('') + # ..some return sets.Set, even for pythons that have __builtin__.set + if not isinstance(value, set): + value = set(value) + return value + # ...and some versions return strings + if value is not None: + return set(value.split(',')) + else: + return value + return process + + def bind_processor(self, dialect): + super_convert = super(SET, self).bind_processor(dialect) + def process(value): + if value is None or isinstance(value, (int, long, basestring)): + pass + else: + if None in value: + value = set(value) + value.remove(None) + value.add('') + value = ','.join(value) + if super_convert: + return super_convert(value) + else: + return value + return process + +# old names +MSTime = _MSTime +MSSet = SET +MSEnum = ENUM +MSLongBlob = LONGBLOB +MSMediumBlob = MEDIUMBLOB +MSTinyBlob = TINYBLOB +MSBlob = BLOB +MSBinary = BINARY +MSVarBinary = VARBINARY +MSNChar = NCHAR +MSNVarChar = NVARCHAR +MSChar = CHAR +MSString = VARCHAR +MSLongText = LONGTEXT +MSMediumText = MEDIUMTEXT +MSTinyText = TINYTEXT +MSText = TEXT +MSYear = YEAR +MSTimeStamp = TIMESTAMP +MSBit = BIT +MSSmallInteger = SMALLINT +MSTinyInteger = TINYINT +MSMediumInteger = MEDIUMINT +MSBigInteger = BIGINT +MSNumeric = NUMERIC +MSDecimal = DECIMAL +MSDouble = DOUBLE +MSReal = REAL +MSFloat = FLOAT +MSInteger = INTEGER + +colspecs = { + sqltypes.Numeric: NUMERIC, + sqltypes.Float: FLOAT, + sqltypes.Time: _MSTime, + sqltypes.Enum: ENUM, +} + +# Everything 3.23 through 5.1 excepting OpenGIS types. +ischema_names = { + 'bigint': BIGINT, + 'binary': BINARY, + 'bit': BIT, + 'blob': BLOB, + 'boolean': BOOLEAN, + 'char': CHAR, + 'date': DATE, + 'datetime': DATETIME, + 'decimal': DECIMAL, + 'double': DOUBLE, + 'enum': ENUM, + 'fixed': DECIMAL, + 'float': FLOAT, + 'int': INTEGER, + 'integer': INTEGER, + 'longblob': LONGBLOB, + 'longtext': LONGTEXT, + 'mediumblob': MEDIUMBLOB, + 'mediumint': MEDIUMINT, + 'mediumtext': MEDIUMTEXT, + 'nchar': NCHAR, + 'nvarchar': NVARCHAR, + 'numeric': NUMERIC, + 'set': SET, + 'smallint': SMALLINT, + 'text': TEXT, + 'time': TIME, + 'timestamp': TIMESTAMP, + 'tinyblob': TINYBLOB, + 'tinyint': TINYINT, + 'tinytext': TINYTEXT, + 'varbinary': VARBINARY, + 'varchar': VARCHAR, + 'year': YEAR, +} + +class MySQLExecutionContext(default.DefaultExecutionContext): + + def should_autocommit_text(self, statement): + return AUTOCOMMIT_RE.match(statement) + +class MySQLCompiler(compiler.SQLCompiler): + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update ({ + 'milliseconds': 'millisecond', + }) + + def visit_random_func(self, fn, **kw): + return "rand%s" % self.function_argspec(fn) + + def visit_utc_timestamp_func(self, fn, **kw): + return "UTC_TIMESTAMP" + + def visit_concat_op(self, binary, **kw): + return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_match_op(self, binary, **kw): + return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (self.process(binary.left), self.process(binary.right)) + + def get_from_hint_text(self, table, text): + return text + + def visit_typeclause(self, typeclause): + type_ = typeclause.type.dialect_impl(self.dialect) + if isinstance(type_, sqltypes.Integer): + if getattr(type_, 'unsigned', False): + return 'UNSIGNED INTEGER' + else: + return 'SIGNED INTEGER' + elif isinstance(type_, sqltypes.TIMESTAMP): + return 'DATETIME' + elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime, sqltypes.Date, sqltypes.Time)): + return self.dialect.type_compiler.process(type_) + elif isinstance(type_, sqltypes.Text): + return 'CHAR' + elif (isinstance(type_, sqltypes.String) and not + isinstance(type_, (ENUM, SET))): + if getattr(type_, 'length'): + return 'CHAR(%s)' % type_.length + else: + return 'CHAR' + elif isinstance(type_, sqltypes._Binary): + return 'BINARY' + elif isinstance(type_, NUMERIC): + return self.dialect.type_compiler.process(type_).replace('NUMERIC', 'DECIMAL') + else: + return None + + def visit_cast(self, cast, **kwargs): + # No cast until 4, no decimals until 5. + type_ = self.process(cast.typeclause) + if type_ is None: + return self.process(cast.clause) + + return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) + + def get_select_precolumns(self, select): + if isinstance(select._distinct, basestring): + return select._distinct.upper() + " " + elif select._distinct: + return "DISTINCT " + else: + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + # 'JOIN ... ON ...' for inner joins isn't available until 4.0. + # Apparently < 3.23.17 requires theta joins for inner joins + # (but not outer). Not generating these currently, but + # support can be added, preferably after dialects are + # refactored to be version-sensitive. + return ''.join( + (self.process(join.left, asfrom=True, **kwargs), + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "), + self.process(join.right, asfrom=True, **kwargs), + " ON ", + self.process(join.onclause, **kwargs))) + + def for_update_clause(self, select): + if select.for_update == 'read': + return ' LOCK IN SHARE MODE' + else: + return super(MySQLCompiler, self).for_update_clause(select) + + def limit_clause(self, select): + # MySQL supports: + # LIMIT + # LIMIT , + # and in server versions > 3.3: + # LIMIT OFFSET + # The latter is more readable for offsets but we're stuck with the + # former until we can refine dialects by server revision. + + limit, offset = select._limit, select._offset + + if (limit, offset) == (None, None): + return '' + elif offset is not None: + # As suggested by the MySQL docs, need to apply an + # artificial limit if one wasn't provided + if limit is None: + limit = 18446744073709551615 + return ' \n LIMIT %s, %s' % (offset, limit) + else: + # No offset provided, so just use the limit + return ' \n LIMIT %s' % (limit,) + + def visit_update(self, update_stmt): + self.stack.append({'from': set([update_stmt.table])}) + + self.isupdate = True + colparams = self._get_colparams(update_stmt) + + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \ + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams]) + + if update_stmt._whereclause is not None: + text += " WHERE " + self.process(update_stmt._whereclause) + + limit = update_stmt.kwargs.get('mysql_limit', None) + if limit: + text += " LIMIT %s" % limit + + self.stack.pop(-1) + + return text + +# ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. +# Starting with MySQL 4.1.2, these indexes are created automatically. +# In older versions, the indexes must be created explicitly or the +# creation of foreign key constraints fails." + +class MySQLDDLCompiler(compiler.DDLCompiler): + def create_table_constraints(self, table): + """Get table constraints.""" + constraint_string = super(MySQLDDLCompiler, self).create_table_constraints(table) + + is_innodb = table.kwargs.has_key('mysql_engine') and \ + table.kwargs['mysql_engine'].lower() == 'innodb' + + auto_inc_column = table._autoincrement_column + + if is_innodb and \ + auto_inc_column is not None and \ + auto_inc_column is not list(table.primary_key)[0]: + if constraint_string: + constraint_string += ", \n\t" + constraint_string += "KEY `idx_autoinc_%s`(`%s`)" % (auto_inc_column.name, \ + self.preparer.format_column(auto_inc_column)) + + return constraint_string + + + def get_column_specification(self, column, **kw): + """Builds column DDL.""" + + colspec = [self.preparer.format_column(column), + self.dialect.type_compiler.process(column.type) + ] + + default = self.get_column_default_string(column) + if default is not None: + colspec.append('DEFAULT ' + default) + + is_timestamp = isinstance(column.type, sqltypes.TIMESTAMP) + if not column.nullable and not is_timestamp: + colspec.append('NOT NULL') + + elif column.nullable and is_timestamp and default is None: + colspec.append('NULL') + + if column.primary_key and column.autoincrement: + try: + first = [c for c in column.table.primary_key.columns + if (c.autoincrement and + isinstance(c.type, sqltypes.Integer) and + not c.foreign_keys)].pop(0) + if column is first: + colspec.append('AUTO_INCREMENT') + except IndexError: + pass + + return ' '.join(colspec) + + def post_create_table(self, table): + """Build table-level CREATE options like ENGINE and COLLATE.""" + + table_opts = [] + for k in table.kwargs: + if k.startswith('mysql_'): + opt = k[6:].upper() + + arg = table.kwargs[k] + if opt in _options_of_type_string: + arg = "'%s'" % arg.replace("\\", "\\\\").replace("'", "''") + + if opt in ('DATA_DIRECTORY', 'INDEX_DIRECTORY', + 'DEFAULT_CHARACTER_SET', 'CHARACTER_SET', 'DEFAULT_CHARSET', + 'DEFAULT_COLLATE'): + opt = opt.replace('_', ' ') + + joiner = '=' + if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', + 'CHARACTER SET', 'COLLATE'): + joiner = ' ' + + table_opts.append(joiner.join((opt, arg))) + return ' '.join(table_opts) + + def visit_drop_index(self, drop): + index = drop.element + + return "\nDROP INDEX %s ON %s" % \ + (self.preparer.quote(self._validate_identifier(index.name, False), index.quote), + self.preparer.format_table(index.table)) + + def visit_drop_constraint(self, drop): + constraint = drop.element + if isinstance(constraint, sa_schema.ForeignKeyConstraint): + qual = "FOREIGN KEY " + const = self.preparer.format_constraint(constraint) + elif isinstance(constraint, sa_schema.PrimaryKeyConstraint): + qual = "PRIMARY KEY " + const = "" + elif isinstance(constraint, sa_schema.UniqueConstraint): + qual = "INDEX " + const = self.preparer.format_constraint(constraint) + else: + qual = "" + const = self.preparer.format_constraint(constraint) + return "ALTER TABLE %s DROP %s%s" % \ + (self.preparer.format_table(constraint.table), + qual, const) + +class MySQLTypeCompiler(compiler.GenericTypeCompiler): + def _extend_numeric(self, type_, spec): + "Extend a numeric-type declaration with MySQL specific extensions." + + if not self._mysql_type(type_): + return spec + + if type_.unsigned: + spec += ' UNSIGNED' + if type_.zerofill: + spec += ' ZEROFILL' + return spec + + def _extend_string(self, type_, defaults, spec): + """Extend a string-type declaration with standard SQL CHARACTER SET / + COLLATE annotations and MySQL specific extensions. + + """ + + def attr(name): + return getattr(type_, name, defaults.get(name)) + + if attr('charset'): + charset = 'CHARACTER SET %s' % attr('charset') + elif attr('ascii'): + charset = 'ASCII' + elif attr('unicode'): + charset = 'UNICODE' + else: + charset = None + + if attr('collation'): + collation = 'COLLATE %s' % type_.collation + elif attr('binary'): + collation = 'BINARY' + else: + collation = None + + if attr('national'): + # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. + return ' '.join([c for c in ('NATIONAL', spec, collation) + if c is not None]) + return ' '.join([c for c in (spec, charset, collation) + if c is not None]) + + def _mysql_type(self, type_): + return isinstance(type_, (_StringType, _NumericType)) + + def visit_NUMERIC(self, type_): + if type_.precision is None: + return self._extend_numeric(type_, "NUMERIC") + elif type_.scale is None: + return self._extend_numeric(type_, "NUMERIC(%(precision)s)" % {'precision': type_.precision}) + else: + return self._extend_numeric(type_, "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}) + + def visit_DECIMAL(self, type_): + if type_.precision is None: + return self._extend_numeric(type_, "DECIMAL") + elif type_.scale is None: + return self._extend_numeric(type_, "DECIMAL(%(precision)s)" % {'precision': type_.precision}) + else: + return self._extend_numeric(type_, "DECIMAL(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}) + + def visit_DOUBLE(self, type_): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric(type_, "DOUBLE(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale' : type_.scale}) + else: + return self._extend_numeric(type_, 'DOUBLE') + + def visit_REAL(self, type_): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric(type_, "REAL(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale' : type_.scale}) + else: + return self._extend_numeric(type_, 'REAL') + + def visit_FLOAT(self, type_): + if self._mysql_type(type_) and type_.scale is not None and type_.precision is not None: + return self._extend_numeric(type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)) + elif type_.precision is not None: + return self._extend_numeric(type_, "FLOAT(%s)" % (type_.precision,)) + else: + return self._extend_numeric(type_, "FLOAT") + + def visit_INTEGER(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "INTEGER(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "INTEGER") + + def visit_BIGINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "BIGINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "BIGINT") + + def visit_MEDIUMINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "MEDIUMINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "MEDIUMINT") + + def visit_TINYINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "TINYINT(%s)" % type_.display_width) + else: + return self._extend_numeric(type_, "TINYINT") + + def visit_SMALLINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "SMALLINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "SMALLINT") + + def visit_BIT(self, type_): + if type_.length is not None: + return "BIT(%s)" % type_.length + else: + return "BIT" + + def visit_DATETIME(self, type_): + return "DATETIME" + + def visit_DATE(self, type_): + return "DATE" + + def visit_TIME(self, type_): + return "TIME" + + def visit_TIMESTAMP(self, type_): + return 'TIMESTAMP' + + def visit_YEAR(self, type_): + if type_.display_width is None: + return "YEAR" + else: + return "YEAR(%s)" % type_.display_width + + def visit_TEXT(self, type_): + if type_.length: + return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) + else: + return self._extend_string(type_, {}, "TEXT") + + def visit_TINYTEXT(self, type_): + return self._extend_string(type_, {}, "TINYTEXT") + + def visit_MEDIUMTEXT(self, type_): + return self._extend_string(type_, {}, "MEDIUMTEXT") + + def visit_LONGTEXT(self, type_): + return self._extend_string(type_, {}, "LONGTEXT") + + def visit_VARCHAR(self, type_): + if type_.length: + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) + else: + raise exc.InvalidRequestError("VARCHAR requires a length when rendered on MySQL") + + def visit_CHAR(self, type_): + if type_.length: + return self._extend_string(type_, {}, "CHAR(%(length)s)" % {'length' : type_.length}) + else: + return self._extend_string(type_, {}, "CHAR") + + def visit_NVARCHAR(self, type_): + # We'll actually generate the equiv. "NATIONAL VARCHAR" instead + # of "NVARCHAR". + if type_.length: + return self._extend_string(type_, {'national':True}, "VARCHAR(%(length)s)" % {'length': type_.length}) + else: + raise exc.InvalidRequestError("NVARCHAR requires a length when rendered on MySQL") + + def visit_NCHAR(self, type_): + # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR". + if type_.length: + return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length': type_.length}) + else: + return self._extend_string(type_, {'national':True}, "CHAR") + + def visit_VARBINARY(self, type_): + return "VARBINARY(%d)" % type_.length + + def visit_large_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_enum(self, type_): + if not type_.native_enum: + return super(MySQLTypeCompiler, self).visit_enum(type_) + else: + return self.visit_ENUM(type_) + + def visit_BLOB(self, type_): + if type_.length: + return "BLOB(%d)" % type_.length + else: + return "BLOB" + + def visit_TINYBLOB(self, type_): + return "TINYBLOB" + + def visit_MEDIUMBLOB(self, type_): + return "MEDIUMBLOB" + + def visit_LONGBLOB(self, type_): + return "LONGBLOB" + + def visit_ENUM(self, type_): + quoted_enums = [] + for e in type_.enums: + quoted_enums.append("'%s'" % e.replace("'", "''")) + return self._extend_string(type_, {}, "ENUM(%s)" % ",".join(quoted_enums)) + + def visit_SET(self, type_): + return self._extend_string(type_, {}, "SET(%s)" % ",".join(type_._ddl_values)) + + def visit_BOOLEAN(self, type): + return "BOOL" + + +class MySQLIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = RESERVED_WORDS + + def __init__(self, dialect, server_ansiquotes=False, **kw): + if not server_ansiquotes: + quote = "`" + else: + quote = '"' + + super(MySQLIdentifierPreparer, self).__init__( + dialect, + initial_quote=quote, + escape_quote=quote) + + def _quote_free_identifiers(self, *ids): + """Unilaterally identifier-quote any number of strings.""" + + return tuple([self.quote_identifier(i) for i in ids if i is not None]) + +class MySQLDialect(default.DefaultDialect): + """Details of the MySQL dialect. Not used directly in application code.""" + + name = 'mysql' + supports_alter = True + # identifiers are 64, however aliases can be 255... + max_identifier_length = 255 + + supports_native_enum = True + + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + default_paramstyle = 'format' + colspecs = colspecs + + statement_compiler = MySQLCompiler + ddl_compiler = MySQLDDLCompiler + type_compiler = MySQLTypeCompiler + ischema_names = ischema_names + preparer = MySQLIdentifierPreparer + + def __init__(self, use_ansiquotes=None, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + + def do_commit(self, connection): + """Execute a COMMIT.""" + + # COMMIT/ROLLBACK were introduced in 3.23.15. + # Yes, we have at least one user who has to talk to these old versions! + # + # Ignore commit/rollback if support isn't present, otherwise even basic + # operations via autocommit fail. + try: + connection.commit() + except: + if self.server_version_info < (3, 23, 15): + args = sys.exc_info()[1].args + if args and args[0] == 1064: + return + raise + + def do_rollback(self, connection): + """Execute a ROLLBACK.""" + + try: + connection.rollback() + except: + if self.server_version_info < (3, 23, 15): + args = sys.exc_info()[1].args + if args and args[0] == 1064: + return + raise + + def do_begin_twophase(self, connection, xid): + connection.execute(sql.text("XA BEGIN :xid"), xid=xid) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("XA END :xid"), xid=xid) + connection.execute(sql.text("XA PREPARE :xid"), xid=xid) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + connection.execute(sql.text("XA END :xid"), xid=xid) + connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid) + + def do_commit_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute(sql.text("XA COMMIT :xid"), xid=xid) + + def do_recover_twophase(self, connection): + resultset = connection.execute("XA RECOVER") + return [row['data'][0:row['gtrid_length']] for row in resultset] + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return self._extract_error_code(e) in (2006, 2013, 2014, 2045, 2055) + elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, this is the error you get + return "(0, '')" in str(e) + else: + return False + + def _compat_fetchall(self, rp, charset=None): + """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" + + return [_DecodingRowProxy(row, charset) for row in rp.fetchall()] + + def _compat_fetchone(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" + + return _DecodingRowProxy(rp.fetchone(), charset) + + def _compat_first(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" + + return _DecodingRowProxy(rp.first(), charset) + + def _extract_error_code(self, exception): + raise NotImplementedError() + + def _get_default_schema_name(self, connection): + return connection.execute('SELECT DATABASE()').scalar() + + + def has_table(self, connection, table_name, schema=None): + # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly + # on macosx (and maybe win?) with multibyte table names. + # + # TODO: if this is not a problem on win, make the strategy swappable + # based on platform. DESCRIBE is slower. + + # [ticket:726] + # full_name = self.identifier_preparer.format_table(table, + # use_schema=True) + + + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, table_name)) + + st = "DESCRIBE %s" % full_name + rs = None + try: + try: + rs = connection.execute(st) + have = rs.rowcount > 0 + rs.close() + return have + except exc.SQLError, e: + if self._extract_error_code(e) == 1146: + return False + raise + finally: + if rs: + rs.close() + + def initialize(self, connection): + default.DefaultDialect.initialize(self, connection) + self._connection_charset = self._detect_charset(connection) + self._server_casing = self._detect_casing(connection) + self._server_collations = self._detect_collations(connection) + self._server_ansiquotes = self._detect_ansiquotes(connection) + if self._server_ansiquotes: + # if ansiquotes == True, build a new IdentifierPreparer + # with the new setting + self.identifier_preparer = self.preparer(self, server_ansiquotes=self._server_ansiquotes) + + @reflection.cache + def get_schema_names(self, connection, **kw): + rp = connection.execute("SHOW schemas") + return [r[0] for r in rp] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + """Return a Unicode SHOW TABLES from a given schema.""" + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + + charset = self._connection_charset + if self.server_version_info < (5, 0, 2): + rp = connection.execute("SHOW TABLES FROM %s" % + self.identifier_preparer.quote_identifier(current_schema)) + return [row[0] for row in self._compat_fetchall(rp, charset=charset)] + else: + rp = connection.execute("SHOW FULL TABLES FROM %s" % + self.identifier_preparer.quote_identifier(current_schema)) + + return [row[0] for row in self._compat_fetchall(rp, charset=charset)\ + if row[1] == 'BASE TABLE'] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + charset = self._connection_charset + if self.server_version_info < (5, 0, 2): + raise NotImplementedError + if schema is None: + schema = self.default_schema_name + if self.server_version_info < (5, 0, 2): + return self.get_table_names(connection, schema) + charset = self._connection_charset + rp = connection.execute("SHOW FULL TABLES FROM %s" % + self.identifier_preparer.quote_identifier(schema)) + return [row[0] for row in self._compat_fetchall(rp, charset=charset)\ + if row[1] == 'VIEW'] + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + return parsed_state.table_options + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + return parsed_state.columns + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + for key in parsed_state.keys: + if key['type'] == 'PRIMARY': + # There can be only one. + ##raise Exception, str(key) + return [s[0] for s in key['columns']] + return [] + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + default_schema = None + + fkeys = [] + + for spec in parsed_state.constraints: + # only FOREIGN KEYs + ref_name = spec['table'][-1] + ref_schema = len(spec['table']) > 1 and spec['table'][-2] or schema + + if not ref_schema: + if default_schema is None: + default_schema = \ + connection.dialect.default_schema_name + if schema == default_schema: + ref_schema = schema + + loc_names = spec['local'] + ref_names = spec['foreign'] + + con_kw = {} + for opt in ('name', 'onupdate', 'ondelete'): + if spec.get(opt, False): + con_kw[opt] = spec[opt] + + fkey_d = { + 'name' : spec['name'], + 'constrained_columns' : loc_names, + 'referred_schema' : ref_schema, + 'referred_table' : ref_name, + 'referred_columns' : ref_names, + 'options' : con_kw + } + fkeys.append(fkey_d) + return fkeys + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + + indexes = [] + for spec in parsed_state.keys: + unique = False + flavor = spec['type'] + if flavor == 'PRIMARY': + continue + if flavor == 'UNIQUE': + unique = True + elif flavor in (None, 'FULLTEXT', 'SPATIAL'): + pass + else: + self.logger.info( + "Converting unknown KEY type %s to a plain KEY" % flavor) + pass + index_d = {} + index_d['name'] = spec['name'] + index_d['column_names'] = [s[0] for s in spec['columns']] + index_d['unique'] = unique + index_d['type'] = flavor + indexes.append(index_d) + return indexes + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + + charset = self._connection_charset + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, view_name)) + sql = self._show_create_table(connection, None, charset, + full_name=full_name) + return sql + + def _parsed_state_or_create(self, connection, table_name, schema=None, **kw): + return self._setup_parser( + connection, + table_name, + schema, + info_cache=kw.get('info_cache', None) + ) + + @util.memoized_property + def _tabledef_parser(self): + """return the MySQLTableDefinitionParser, generate if needed. + + The deferred creation ensures that the dialect has + retrieved server version information first. + + """ + if (self.server_version_info < (4, 1) and self._server_ansiquotes): + # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 + preparer = self.preparer(self, server_ansiquotes=False) + else: + preparer = self.identifier_preparer + return MySQLTableDefinitionParser(self, preparer) + + @reflection.cache + def _setup_parser(self, connection, table_name, schema=None, **kw): + charset = self._connection_charset + parser = self._tabledef_parser + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, table_name)) + sql = self._show_create_table(connection, None, charset, + full_name=full_name) + if sql.startswith('CREATE ALGORITHM'): + # Adapt views to something table-like. + columns = self._describe_table(connection, None, charset, + full_name=full_name) + sql = parser._describe_to_create(table_name, columns) + return parser.parse(sql, charset) + + def _adjust_casing(self, table, charset=None): + """Adjust Table name to the server case sensitivity, if needed.""" + + casing = self._server_casing + + # For winxx database hosts. TODO: is this really needed? + if casing == 1 and table.name != table.name.lower(): + table.name = table.name.lower() + lc_alias = sa_schema._get_table_key(table.name, table.schema) + table.metadata.tables[lc_alias] = table + + def _detect_charset(self, connection): + raise NotImplementedError() + + def _detect_casing(self, connection): + """Sniff out identifier case sensitivity. + + Cached per-connection. This value can not change without a server + restart. + + """ + # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html + + charset = self._connection_charset + row = self._compat_first(connection.execute( + "SHOW VARIABLES LIKE 'lower_case_table_names'"), + charset=charset) + if not row: + cs = 0 + else: + # 4.0.15 returns OFF or ON according to [ticket:489] + # 3.23 doesn't, 4.0.27 doesn't.. + if row[1] == 'OFF': + cs = 0 + elif row[1] == 'ON': + cs = 1 + else: + cs = int(row[1]) + return cs + + def _detect_collations(self, connection): + """Pull the active COLLATIONS list from the server. + + Cached per-connection. + """ + + collations = {} + if self.server_version_info < (4, 1, 0): + pass + else: + charset = self._connection_charset + rs = connection.execute('SHOW COLLATION') + for row in self._compat_fetchall(rs, charset): + collations[row[0]] = row[1] + return collations + + def _detect_ansiquotes(self, connection): + """Detect and adjust for the ANSI_QUOTES sql mode.""" + + row = self._compat_first( + connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), + charset=self._connection_charset) + + if not row: + mode = '' + else: + mode = row[1] or '' + # 4.0 + if mode.isdigit(): + mode_no = int(mode) + mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or '' + + return 'ANSI_QUOTES' in mode + + def _show_create_table(self, connection, table, charset=None, + full_name=None): + """Run SHOW CREATE TABLE for a ``Table``.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "SHOW CREATE TABLE %s" % full_name + + rp = None + try: + rp = connection.execute(st) + except exc.SQLError, e: + if self._extract_error_code(e) == 1146: + raise exc.NoSuchTableError(full_name) + else: + raise + row = self._compat_first(rp, charset=charset) + if not row: + raise exc.NoSuchTableError(full_name) + return row[1].strip() + + return sql + + def _describe_table(self, connection, table, charset=None, + full_name=None): + """Run DESCRIBE for a ``Table`` and return processed rows.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "DESCRIBE %s" % full_name + + rp, rows = None, None + try: + try: + rp = connection.execute(st) + except exc.SQLError, e: + if self._extract_error_code(e) == 1146: + raise exc.NoSuchTableError(full_name) + else: + raise + rows = self._compat_fetchall(rp, charset=charset) + finally: + if rp: + rp.close() + return rows + +class ReflectedState(object): + """Stores raw information about a SHOW CREATE TABLE statement.""" + + def __init__(self): + self.columns = [] + self.table_options = {} + self.table_name = None + self.keys = [] + self.constraints = [] + +class MySQLTableDefinitionParser(object): + """Parses the results of a SHOW CREATE TABLE statement.""" + + def __init__(self, dialect, preparer): + self.dialect = dialect + self.preparer = preparer + self._prep_regexes() + + def parse(self, show_create, charset): + state = ReflectedState() + state.charset = charset + for line in re.split(r'\r?\n', show_create): + if line.startswith(' ' + self.preparer.initial_quote): + self._parse_column(line, state) + # a regular table options line + elif line.startswith(') '): + self._parse_table_options(line, state) + # an ANSI-mode table options line + elif line == ')': + pass + elif line.startswith('CREATE '): + self._parse_table_name(line, state) + # Not present in real reflection, but may be if loading from a file. + elif not line: + pass + else: + type_, spec = self._parse_constraints(line) + if type_ is None: + util.warn("Unknown schema content: %r" % line) + elif type_ == 'key': + state.keys.append(spec) + elif type_ == 'constraint': + state.constraints.append(spec) + else: + pass + + return state + + def _parse_constraints(self, line): + """Parse a KEY or CONSTRAINT line. + + :param line: A line of SHOW CREATE TABLE output + """ + + # KEY + m = self._re_key.match(line) + if m: + spec = m.groupdict() + # convert columns into name, length pairs + spec['columns'] = self._parse_keyexprs(spec['columns']) + return 'key', spec + + # CONSTRAINT + m = self._re_constraint.match(line) + if m: + spec = m.groupdict() + spec['table'] = \ + self.preparer.unformat_identifiers(spec['table']) + spec['local'] = [c[0] + for c in self._parse_keyexprs(spec['local'])] + spec['foreign'] = [c[0] + for c in self._parse_keyexprs(spec['foreign'])] + return 'constraint', spec + + # PARTITION and SUBPARTITION + m = self._re_partition.match(line) + if m: + # Punt! + return 'partition', line + + # No match. + return (None, line) + + def _parse_table_name(self, line, state): + """Extract the table name. + + :param line: The first line of SHOW CREATE TABLE + """ + + regex, cleanup = self._pr_name + m = regex.match(line) + if m: + state.table_name = cleanup(m.group('name')) + + def _parse_table_options(self, line, state): + """Build a dictionary of all reflected table-level options. + + :param line: The final line of SHOW CREATE TABLE output. + """ + + options = {} + + if not line or line == ')': + pass + + else: + rest_of_line = line[:] + for regex, cleanup in self._pr_options: + m = regex.search(rest_of_line) + if not m: + continue + directive, value = m.group('directive'), m.group('val') + if cleanup: + value = cleanup(value) + options[directive.lower()] = value + rest_of_line = regex.sub('', rest_of_line) + + for nope in ('auto_increment', 'data directory', 'index directory'): + options.pop(nope, None) + + for opt, val in options.items(): + state.table_options['mysql_%s' % opt] = val + + def _parse_column(self, line, state): + """Extract column details. + + Falls back to a 'minimal support' variant if full parse fails. + + :param line: Any column-bearing line from SHOW CREATE TABLE + """ + + spec = None + m = self._re_column.match(line) + if m: + spec = m.groupdict() + spec['full'] = True + else: + m = self._re_column_loose.match(line) + if m: + spec = m.groupdict() + spec['full'] = False + if not spec: + util.warn("Unknown column definition %r" % line) + return + if not spec['full']: + util.warn("Incomplete reflection of column definition %r" % line) + + name, type_, args, notnull = \ + spec['name'], spec['coltype'], spec['arg'], spec['notnull'] + + try: + col_type = self.dialect.ischema_names[type_] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (type_, name)) + col_type = sqltypes.NullType + + # Column type positional arguments eg. varchar(32) + if args is None or args == '': + type_args = [] + elif args[0] == "'" and args[-1] == "'": + type_args = self._re_csv_str.findall(args) + else: + type_args = [int(v) for v in self._re_csv_int.findall(args)] + + # Column type keyword options + type_kw = {} + for kw in ('unsigned', 'zerofill'): + if spec.get(kw, False): + type_kw[kw] = True + for kw in ('charset', 'collate'): + if spec.get(kw, False): + type_kw[kw] = spec[kw] + + if type_ == 'enum': + type_args = ENUM._strip_enums(type_args) + + type_instance = col_type(*type_args, **type_kw) + + col_args, col_kw = [], {} + + # NOT NULL + col_kw['nullable'] = True + if spec.get('notnull', False): + col_kw['nullable'] = False + + # AUTO_INCREMENT + if spec.get('autoincr', False): + col_kw['autoincrement'] = True + elif issubclass(col_type, sqltypes.Integer): + col_kw['autoincrement'] = False + + # DEFAULT + default = spec.get('default', None) + + if default == 'NULL': + # eliminates the need to deal with this later. + default = None + + col_d = dict(name=name, type=type_instance, default=default) + col_d.update(col_kw) + state.columns.append(col_d) + + def _describe_to_create(self, table_name, columns): + """Re-format DESCRIBE output as a SHOW CREATE TABLE string. + + DESCRIBE is a much simpler reflection and is sufficient for + reflecting views for runtime use. This method formats DDL + for columns only- keys are omitted. + + :param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples. + SHOW FULL COLUMNS FROM rows must be rearranged for use with + this function. + """ + + buffer = [] + for row in columns: + (name, col_type, nullable, default, extra) = \ + [row[i] for i in (0, 1, 2, 4, 5)] + + line = [' '] + line.append(self.preparer.quote_identifier(name)) + line.append(col_type) + if not nullable: + line.append('NOT NULL') + if default: + if 'auto_increment' in default: + pass + elif (col_type.startswith('timestamp') and + default.startswith('C')): + line.append('DEFAULT') + line.append(default) + elif default == 'NULL': + line.append('DEFAULT') + line.append(default) + else: + line.append('DEFAULT') + line.append("'%s'" % default.replace("'", "''")) + if extra: + line.append(extra) + + buffer.append(' '.join(line)) + + return ''.join([('CREATE TABLE %s (\n' % + self.preparer.quote_identifier(table_name)), + ',\n'.join(buffer), + '\n) ']) + + def _parse_keyexprs(self, identifiers): + """Unpack '"col"(2),"col" ASC'-ish strings into components.""" + + return self._re_keyexprs.findall(identifiers) + + def _prep_regexes(self): + """Pre-compile regular expressions.""" + + self._re_columns = [] + self._pr_options = [] + + _final = self.preparer.final_quote + + quotes = dict(zip(('iq', 'fq', 'esc_fq'), + [re.escape(s) for s in + (self.preparer.initial_quote, + _final, + self.preparer._escape_identifier(_final))])) + + self._pr_name = _pr_compile( + r'^CREATE (?:\w+ +)?TABLE +' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($' % quotes, + self.preparer._unescape_identifier) + + # `col`,`col2`(32),`col3`(15) DESC + # + # Note: ASC and DESC aren't reflected, so we'll punt... + self._re_keyexprs = _re_compile( + r'(?:' + r'(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)' + r'(?:\((\d+)\))?(?=\,|$))+' % quotes) + + # 'foo' or 'foo','bar' or 'fo,o','ba''a''r' + self._re_csv_str = _re_compile(r'\x27(?:\x27\x27|[^\x27])*\x27') + + # 123 or 123,456 + self._re_csv_int = _re_compile(r'\d+') + + + # `colname` [type opts] + # (NOT NULL | NULL) + # DEFAULT ('value' | CURRENT_TIMESTAMP...) + # COMMENT 'comment' + # COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT) + # STORAGE (DISK|MEMORY) + self._re_column = _re_compile( + r' ' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'(?P\w+)' + r'(?:\((?P(?:\d+|\d+,\d+|' + r'(?:\x27(?:\x27\x27|[^\x27])*\x27,?)+))\))?' + r'(?: +(?PUNSIGNED))?' + r'(?: +(?PZEROFILL))?' + r'(?: +CHARACTER SET +(?P[\w_]+))?' + r'(?: +COLLATE +(?P[\w_]+))?' + r'(?: +(?PNOT NULL))?' + r'(?: +DEFAULT +(?P' + r'(?:NULL|\x27(?:\x27\x27|[^\x27])*\x27|\w+)' + r'(?:ON UPDATE \w+)?' + r'))?' + r'(?: +(?PAUTO_INCREMENT))?' + r'(?: +COMMENT +(P(?:\x27\x27|[^\x27])+))?' + r'(?: +COLUMN_FORMAT +(?P\w+))?' + r'(?: +STORAGE +(?P\w+))?' + r'(?: +(?P.*))?' + r',?$' + % quotes + ) + + # Fallback, try to parse as little as possible + self._re_column_loose = _re_compile( + r' ' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'(?P\w+)' + r'(?:\((?P(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?' + r'.*?(?PNOT NULL)?' + % quotes + ) + + # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))? + # (`col` (ASC|DESC)?, `col` (ASC|DESC)?) + # KEY_BLOCK_SIZE size | WITH PARSER name + self._re_key = _re_compile( + r' ' + r'(?:(?P\S+) )?KEY' + r'(?: +%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?' + r'(?: +USING +(?P\S+))?' + r' +\((?P.+?)\)' + r'(?: +USING +(?P\S+))?' + r'(?: +KEY_BLOCK_SIZE +(?P\S+))?' + r'(?: +WITH PARSER +(?P\S+))?' + r',?$' + % quotes + ) + + # CONSTRAINT `name` FOREIGN KEY (`local_col`) + # REFERENCES `remote` (`remote_col`) + # MATCH FULL | MATCH PARTIAL | MATCH SIMPLE + # ON DELETE CASCADE ON UPDATE RESTRICT + # + # unique constraints come back as KEYs + kw = quotes.copy() + kw['on'] = 'RESTRICT|CASCASDE|SET NULL|NOACTION' + self._re_constraint = _re_compile( + r' ' + r'CONSTRAINT +' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'FOREIGN KEY +' + r'\((?P[^\)]+?)\) REFERENCES +' + r'(?P%(iq)s[^%(fq)s]+%(fq)s(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +' + r'\((?P[^\)]+?)\)' + r'(?: +(?PMATCH \w+))?' + r'(?: +ON DELETE (?P%(on)s))?' + r'(?: +ON UPDATE (?P%(on)s))?' + % kw + ) + + # PARTITION + # + # punt! + self._re_partition = _re_compile( + r' ' + r'(?:SUB)?PARTITION') + + # Table-level options (COLLATE, ENGINE, etc.) + # Do the string options first, since they have quoted strings we need to get rid of. + for option in _options_of_type_string: + self._add_option_string(option) + + for option in ('ENGINE', 'TYPE', 'AUTO_INCREMENT', + 'AVG_ROW_LENGTH', 'CHARACTER SET', + 'DEFAULT CHARSET', 'CHECKSUM', + 'COLLATE', 'DELAY_KEY_WRITE', 'INSERT_METHOD', + 'MAX_ROWS', 'MIN_ROWS', 'PACK_KEYS', 'ROW_FORMAT', + 'KEY_BLOCK_SIZE'): + self._add_option_word(option) + + self._add_option_regex('UNION', r'\([^\)]+\)') + self._add_option_regex('TABLESPACE', r'.*? STORAGE DISK') + self._add_option_regex('RAID_TYPE', + r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+') + + _optional_equals = r'(?:\s*(?:=\s*)|\s+)' + + def _add_option_string(self, directive): + regex = (r'(?P%s)%s' + r"'(?P(?:[^']|'')*?)'(?!')" % + (re.escape(directive), self._optional_equals)) + self._pr_options.append( + _pr_compile(regex, lambda v: v.replace("\\\\","\\").replace("''", "'"))) + + def _add_option_word(self, directive): + regex = (r'(?P%s)%s' + r'(?P\w+)' % + (re.escape(directive), self._optional_equals)) + self._pr_options.append(_pr_compile(regex)) + + def _add_option_regex(self, directive, regex): + regex = (r'(?P%s)%s' + r'(?P%s)' % + (re.escape(directive), self._optional_equals, regex)) + self._pr_options.append(_pr_compile(regex)) + +_options_of_type_string = ('COMMENT', 'DATA DIRECTORY', 'INDEX DIRECTORY', + 'PASSWORD', 'CONNECTION') + +log.class_logger(MySQLTableDefinitionParser) +log.class_logger(MySQLDialect) + + +class _DecodingRowProxy(object): + """Return unicode-decoded values based on type inspection. + + Smooth over data type issues (esp. with alpha driver versions) and + normalize strings as Unicode regardless of user-configured driver + encoding settings. + + """ + + # Some MySQL-python versions can return some columns as + # sets.Set(['value']) (seriously) but thankfully that doesn't + # seem to come up in DDL queries. + + def __init__(self, rowproxy, charset): + self.rowproxy = rowproxy + self.charset = charset + + def __getitem__(self, index): + item = self.rowproxy[index] + if isinstance(item, _array): + item = item.tostring() + # Py2K + if self.charset and isinstance(item, str): + # end Py2K + # Py3K + #if self.charset and isinstance(item, bytes): + return item.decode(self.charset) + else: + return item + + def __getattr__(self, attr): + item = getattr(self.rowproxy, attr) + if isinstance(item, _array): + item = item.tostring() + # Py2K + if self.charset and isinstance(item, str): + # end Py2K + # Py3K + #if self.charset and isinstance(item, bytes): + return item.decode(self.charset) + else: + return item + + +def _pr_compile(regex, cleanup=None): + """Prepare a 2-tuple of compiled regex and callable.""" + + return (_re_compile(regex), cleanup) + +def _re_compile(regex): + """Compile a string to regex, I and UNICODE.""" + + return re.compile(regex, re.I | re.UNICODE) + diff --git a/sqlalchemy/dialects/mysql/mysqlconnector.py b/sqlalchemy/dialects/mysql/mysqlconnector.py new file mode 100644 index 0000000..2da18e5 --- /dev/null +++ b/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -0,0 +1,132 @@ +"""Support for the MySQL database via the MySQL Connector/Python adapter. + +MySQL Connector/Python is available at: + + https://launchpad.net/myconnpy + +Connecting +----------- + +Connect string format:: + + mysql+mysqlconnector://:@[:]/ + +""" + +import re + +from sqlalchemy.dialects.mysql.base import (MySQLDialect, + MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer, + BIT) + +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import exc, log, schema, sql, types as sqltypes, util +from sqlalchemy import processors + +class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): + + def get_lastrowid(self): + return self.cursor.lastrowid + + +class MySQLCompiler_mysqlconnector(MySQLCompiler): + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) + + def post_process_text(self, text): + return text.replace('%', '%%') + +class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer): + + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value.replace("%", "%%") + +class _myconnpyBIT(BIT): + def result_processor(self, dialect, coltype): + """MySQL-connector already converts mysql bits, so.""" + + return None + +class MySQLDialect_mysqlconnector(MySQLDialect): + driver = 'mysqlconnector' + supports_unicode_statements = True + supports_unicode_binds = True + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + supports_native_decimal = True + + default_paramstyle = 'format' + execution_ctx_cls = MySQLExecutionContext_mysqlconnector + statement_compiler = MySQLCompiler_mysqlconnector + + preparer = MySQLIdentifierPreparer_mysqlconnector + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + BIT: _myconnpyBIT, + } + ) + + @classmethod + def dbapi(cls): + from mysql import connector + return connector + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + + util.coerce_kw_type(opts, 'buffered', bool) + util.coerce_kw_type(opts, 'raise_on_warnings', bool) + opts['buffered'] = True + opts['raise_on_warnings'] = True + + # FOUND_ROWS must be set in ClientFlag to enable + # supports_sane_rowcount. + if self.dbapi is not None: + try: + from mysql.connector.constants import ClientFlag + client_flags = opts.get('client_flags', ClientFlag.get_default()) + client_flags |= ClientFlag.FOUND_ROWS + opts['client_flags'] = client_flags + except: + pass + return [[], opts] + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + + from mysql.connector.constants import ClientFlag + dbapi_con.set_client_flag(ClientFlag.FOUND_ROWS) + + version = dbapi_con.get_server_version() + return tuple(version) + + def _detect_charset(self, connection): + return connection.connection.get_characterset_info() + + def _extract_error_code(self, exception): + try: + return exception.orig.errno + except AttributeError: + return None + + def is_disconnect(self, e): + errnos = (2006, 2013, 2014, 2045, 2055, 2048) + exceptions = (self.dbapi.OperationalError,self.dbapi.InterfaceError) + if isinstance(e, exceptions): + return e.errno in errnos + else: + return False + + def _compat_fetchall(self, rp, charset=None): + return rp.fetchall() + + def _compat_fetchone(self, rp, charset=None): + return rp.fetchone() + +dialect = MySQLDialect_mysqlconnector diff --git a/sqlalchemy/dialects/mysql/mysqldb.py b/sqlalchemy/dialects/mysql/mysqldb.py new file mode 100644 index 0000000..6e6bb0e --- /dev/null +++ b/sqlalchemy/dialects/mysql/mysqldb.py @@ -0,0 +1,202 @@ +"""Support for the MySQL database via the MySQL-python adapter. + +MySQL-Python is available at: + + http://sourceforge.net/projects/mysql-python + +At least version 1.2.1 or 1.2.2 should be used. + +Connecting +----------- + +Connect string format:: + + mysql+mysqldb://:@[:]/ + +Character Sets +-------------- + +Many MySQL server installations default to a ``latin1`` encoding for client +connections. All data sent through the connection will be converted into +``latin1``, even if you have ``utf8`` or another character set on your tables +and columns. With versions 4.1 and higher, you can change the connection +character set either through server configuration or by including the +``charset`` parameter in the URL used for ``create_engine``. The ``charset`` +option is passed through to MySQL-Python and has the side-effect of also +enabling ``use_unicode`` in the driver by default. For regular encoded +strings, also pass ``use_unicode=0`` in the connection arguments:: + + # set client encoding to utf8; all strings come back as unicode + create_engine('mysql+mysqldb:///mydb?charset=utf8') + + # set client encoding to utf8; all strings come back as utf8 str + create_engine('mysql+mysqldb:///mydb?charset=utf8&use_unicode=0') + +Known Issues +------------- + +MySQL-python at least as of version 1.2.2 has a serious memory leak related +to unicode conversion, a feature which is disabled via ``use_unicode=0``. +The recommended connection form with SQLAlchemy is:: + + engine = create_engine('mysql://scott:tiger@localhost/test?charset=utf8&use_unicode=0', pool_recycle=3600) + + +""" + +import re + +from sqlalchemy.dialects.mysql.base import (MySQLDialect, MySQLExecutionContext, + MySQLCompiler, MySQLIdentifierPreparer) +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import exc, log, schema, sql, types as sqltypes, util +from sqlalchemy import processors + +class MySQLExecutionContext_mysqldb(MySQLExecutionContext): + + @property + def rowcount(self): + if hasattr(self, '_rowcount'): + return self._rowcount + else: + return self.cursor.rowcount + + +class MySQLCompiler_mysqldb(MySQLCompiler): + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) + + def post_process_text(self, text): + return text.replace('%', '%%') + + +class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer): + + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value.replace("%", "%%") + +class MySQLDialect_mysqldb(MySQLDialect): + driver = 'mysqldb' + supports_unicode_statements = False + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + supports_native_decimal = True + + default_paramstyle = 'format' + execution_ctx_cls = MySQLExecutionContext_mysqldb + statement_compiler = MySQLCompiler_mysqldb + preparer = MySQLIdentifierPreparer_mysqldb + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + } + ) + + @classmethod + def dbapi(cls): + return __import__('MySQLdb') + + def do_executemany(self, cursor, statement, parameters, context=None): + rowcount = cursor.executemany(statement, parameters) + if context is not None: + context._rowcount = rowcount + + def create_connect_args(self, url): + opts = url.translate_connect_args(database='db', username='user', + password='passwd') + opts.update(url.query) + + util.coerce_kw_type(opts, 'compress', bool) + util.coerce_kw_type(opts, 'connect_timeout', int) + util.coerce_kw_type(opts, 'client_flag', int) + util.coerce_kw_type(opts, 'local_infile', int) + # Note: using either of the below will cause all strings to be returned + # as Unicode, both in raw SQL operations and with column types like + # String and MSString. + util.coerce_kw_type(opts, 'use_unicode', bool) + util.coerce_kw_type(opts, 'charset', str) + + # Rich values 'cursorclass' and 'conv' are not supported via + # query string. + + ssl = {} + for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']: + if key in opts: + ssl[key[4:]] = opts[key] + util.coerce_kw_type(ssl, key[4:], str) + del opts[key] + if ssl: + opts['ssl'] = ssl + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + client_flag = opts.get('client_flag', 0) + if self.dbapi is not None: + try: + from MySQLdb.constants import CLIENT as CLIENT_FLAGS + client_flag |= CLIENT_FLAGS.FOUND_ROWS + except: + pass + opts['client_flag'] = client_flag + return [[], opts] + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.get_server_info()): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + + def _extract_error_code(self, exception): + try: + return exception.orig.args[0] + except AttributeError: + return None + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Note: MySQL-python 1.2.1c7 seems to ignore changes made + # on a connection via set_character_set() + if self.server_version_info < (4, 1, 0): + try: + return connection.connection.character_set_name() + except AttributeError: + # < 1.2.1 final MySQL-python drivers have no charset support. + # a query is needed. + pass + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)]) + + if 'character_set_results' in opts: + return opts['character_set_results'] + try: + return connection.connection.character_set_name() + except AttributeError: + # Still no charset on < 1.2.1 final... + if 'character_set' in opts: + return opts['character_set'] + else: + util.warn( + "Could not detect the connection character set with this " + "combination of MySQL server and MySQL-python. " + "MySQL-python >= 1.2.2 is recommended. Assuming latin1.") + return 'latin1' + + +dialect = MySQLDialect_mysqldb diff --git a/sqlalchemy/dialects/mysql/oursql.py b/sqlalchemy/dialects/mysql/oursql.py new file mode 100644 index 0000000..ebc7264 --- /dev/null +++ b/sqlalchemy/dialects/mysql/oursql.py @@ -0,0 +1,255 @@ +"""Support for the MySQL database via the oursql adapter. + +OurSQL is available at: + + http://packages.python.org/oursql/ + +Connecting +----------- + +Connect string format:: + + mysql+oursql://:@[:]/ + +Character Sets +-------------- + +oursql defaults to using ``utf8`` as the connection charset, but other +encodings may be used instead. Like the MySQL-Python driver, unicode support +can be completely disabled:: + + # oursql sets the connection charset to utf8 automatically; all strings come + # back as utf8 str + create_engine('mysql+oursql:///mydb?use_unicode=0') + +To not automatically use ``utf8`` and instead use whatever the connection +defaults to, there is a separate parameter:: + + # use the default connection charset; all strings come back as unicode + create_engine('mysql+oursql:///mydb?default_charset=1') + + # use latin1 as the connection charset; all strings come back as unicode + create_engine('mysql+oursql:///mydb?charset=latin1') +""" + +import re + +from sqlalchemy.dialects.mysql.base import (BIT, MySQLDialect, MySQLExecutionContext, + MySQLCompiler, MySQLIdentifierPreparer) +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import exc, log, schema, sql, types as sqltypes, util +from sqlalchemy import processors + + + +class _oursqlBIT(BIT): + def result_processor(self, dialect, coltype): + """oursql already converts mysql bits, so.""" + + return None + + +class MySQLExecutionContext_oursql(MySQLExecutionContext): + + @property + def plain_query(self): + return self.execution_options.get('_oursql_plain_query', False) + +class MySQLDialect_oursql(MySQLDialect): + driver = 'oursql' +# Py3K +# description_encoding = None +# Py2K + supports_unicode_binds = True + supports_unicode_statements = True +# end Py2K + + supports_native_decimal = True + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + execution_ctx_cls = MySQLExecutionContext_oursql + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + sqltypes.Time: sqltypes.Time, + BIT: _oursqlBIT, + } + ) + + @classmethod + def dbapi(cls): + return __import__('oursql') + + def do_execute(self, cursor, statement, parameters, context=None): + """Provide an implementation of *cursor.execute(statement, parameters)*.""" + + if context and context.plain_query: + cursor.execute(statement, plain_query=True) + else: + cursor.execute(statement, parameters) + + def do_begin(self, connection): + connection.cursor().execute('BEGIN', plain_query=True) + + def _xa_query(self, connection, query, xid): +# Py2K + arg = connection.connection._escape_string(xid) +# end Py2K +# Py3K +# charset = self._connection_charset +# arg = connection.connection._escape_string(xid.encode(charset)).decode(charset) + connection.execution_options(_oursql_plain_query=True).execute(query % arg) + + # Because mysql is bad, these methods have to be + # reimplemented to use _PlainQuery. Basically, some queries + # refuse to return any data if they're run through + # the parameterized query API, or refuse to be parameterized + # in the first place. + def do_begin_twophase(self, connection, xid): + self._xa_query(connection, 'XA BEGIN "%s"', xid) + + def do_prepare_twophase(self, connection, xid): + self._xa_query(connection, 'XA END "%s"', xid) + self._xa_query(connection, 'XA PREPARE "%s"', xid) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + self._xa_query(connection, 'XA END "%s"', xid) + self._xa_query(connection, 'XA ROLLBACK "%s"', xid) + + def do_commit_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + self._xa_query(connection, 'XA COMMIT "%s"', xid) + + # Q: why didn't we need all these "plain_query" overrides earlier ? + # am i on a newer/older version of OurSQL ? + def has_table(self, connection, table_name, schema=None): + return MySQLDialect.has_table(self, + connection.connect().\ + execution_options(_oursql_plain_query=True), + table_name, schema) + + def get_table_options(self, connection, table_name, schema=None, **kw): + return MySQLDialect.get_table_options(self, + connection.connect().\ + execution_options(_oursql_plain_query=True), + table_name, + schema = schema, + **kw + ) + + + def get_columns(self, connection, table_name, schema=None, **kw): + return MySQLDialect.get_columns(self, + connection.connect().\ + execution_options(_oursql_plain_query=True), + table_name, + schema=schema, + **kw + ) + + def get_view_names(self, connection, schema=None, **kw): + return MySQLDialect.get_view_names(self, + connection.connect().\ + execution_options(_oursql_plain_query=True), + schema=schema, + **kw + ) + + def get_table_names(self, connection, schema=None, **kw): + return MySQLDialect.get_table_names(self, + connection.connect().\ + execution_options(_oursql_plain_query=True), + schema + ) + + def get_schema_names(self, connection, **kw): + return MySQLDialect.get_schema_names(self, + connection.connect().\ + execution_options(_oursql_plain_query=True), + **kw + ) + + def initialize(self, connection): + return MySQLDialect.initialize( + self, + connection.execution_options(_oursql_plain_query=True) + ) + + def _show_create_table(self, connection, table, charset=None, + full_name=None): + return MySQLDialect._show_create_table(self, + connection.contextual_connect(close_with_result=True). + execution_options(_oursql_plain_query=True), + table, charset, full_name) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.ProgrammingError): + return e.errno is None and 'cursor' not in e.args[1] and e.args[1].endswith('closed') + else: + return e.errno in (2006, 2013, 2014, 2045, 2055) + + def create_connect_args(self, url): + opts = url.translate_connect_args(database='db', username='user', + password='passwd') + opts.update(url.query) + + util.coerce_kw_type(opts, 'port', int) + util.coerce_kw_type(opts, 'compress', bool) + util.coerce_kw_type(opts, 'autoping', bool) + + util.coerce_kw_type(opts, 'default_charset', bool) + if opts.pop('default_charset', False): + opts['charset'] = None + else: + util.coerce_kw_type(opts, 'charset', str) + opts['use_unicode'] = opts.get('use_unicode', True) + util.coerce_kw_type(opts, 'use_unicode', bool) + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + opts.setdefault('found_rows', True) + + return [[], opts] + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.server_info): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + + def _extract_error_code(self, exception): + try: + return exception.orig.errno + except AttributeError: + return None + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + return connection.connection.charset + + def _compat_fetchall(self, rp, charset=None): + """oursql isn't super-broken like MySQLdb, yaaay.""" + return rp.fetchall() + + def _compat_fetchone(self, rp, charset=None): + """oursql isn't super-broken like MySQLdb, yaaay.""" + return rp.fetchone() + + def _compat_first(self, rp, charset=None): + return rp.first() + + +dialect = MySQLDialect_oursql diff --git a/sqlalchemy/dialects/mysql/pyodbc.py b/sqlalchemy/dialects/mysql/pyodbc.py new file mode 100644 index 0000000..1f73c6e --- /dev/null +++ b/sqlalchemy/dialects/mysql/pyodbc.py @@ -0,0 +1,76 @@ +"""Support for the MySQL database via the pyodbc adapter. + +pyodbc is available at: + + http://pypi.python.org/pypi/pyodbc/ + +Connecting +---------- + +Connect string:: + + mysql+pyodbc://:@ + +Limitations +----------- + +The mysql-pyodbc dialect is subject to unresolved character encoding issues +which exist within the current ODBC drivers available. +(see http://code.google.com/p/pyodbc/issues/detail?id=25). Consider usage +of OurSQL, MySQLdb, or MySQL-connector/Python. + +""" + +from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext +from sqlalchemy.connectors.pyodbc import PyODBCConnector +from sqlalchemy.engine import base as engine_base +from sqlalchemy import util +import re + +class MySQLExecutionContext_pyodbc(MySQLExecutionContext): + + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT LAST_INSERT_ID()") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + +class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): + supports_unicode_statements = False + execution_ctx_cls = MySQLExecutionContext_pyodbc + + pyodbc_driver_name = "MySQL" + + def __init__(self, **kw): + # deal with http://code.google.com/p/pyodbc/issues/detail?id=25 + kw.setdefault('convert_unicode', True) + super(MySQLDialect_pyodbc, self).__init__(**kw) + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)]) + for key in ('character_set_connection', 'character_set'): + if opts.get(key, None): + return opts[key] + + util.warn("Could not detect the connection character set. Assuming latin1.") + return 'latin1' + + def _extract_error_code(self, exception): + m = re.compile(r"\((\d+)\)").search(str(exception.orig.args)) + c = m.group(1) + if c: + return int(c) + else: + return None + +dialect = MySQLDialect_pyodbc diff --git a/sqlalchemy/dialects/mysql/zxjdbc.py b/sqlalchemy/dialects/mysql/zxjdbc.py new file mode 100644 index 0000000..06d3e66 --- /dev/null +++ b/sqlalchemy/dialects/mysql/zxjdbc.py @@ -0,0 +1,111 @@ +"""Support for the MySQL database via Jython's zxjdbc JDBC connector. + +JDBC Driver +----------- + +The official MySQL JDBC driver is at +http://dev.mysql.com/downloads/connector/j/. + +Connecting +---------- + +Connect string format: + + mysql+zxjdbc://:@[:]/ + +Character Sets +-------------- + +SQLAlchemy zxjdbc dialects pass unicode straight through to the +zxjdbc/JDBC layer. To allow multiple character sets to be sent from the +MySQL Connector/J JDBC driver, by default SQLAlchemy sets its +``characterEncoding`` connection property to ``UTF-8``. It may be +overriden via a ``create_engine`` URL parameter. + +""" +import re + +from sqlalchemy import types as sqltypes, util +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext + +class _ZxJDBCBit(BIT): + def result_processor(self, dialect, coltype): + """Converts boolean or byte arrays from MySQL Connector/J to longs.""" + def process(value): + if value is None: + return value + if isinstance(value, bool): + return int(value) + v = 0L + for i in value: + v = v << 8 | (i & 0xff) + value = v + return value + return process + + +class MySQLExecutionContext_zxjdbc(MySQLExecutionContext): + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT LAST_INSERT_ID()") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + + +class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect): + jdbc_db_name = 'mysql' + jdbc_driver_name = 'com.mysql.jdbc.Driver' + + execution_ctx_cls = MySQLExecutionContext_zxjdbc + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + sqltypes.Time: sqltypes.Time, + BIT: _ZxJDBCBit + } + ) + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs)) + for key in ('character_set_connection', 'character_set'): + if opts.get(key, None): + return opts[key] + + util.warn("Could not detect the connection character set. Assuming latin1.") + return 'latin1' + + def _driver_kwargs(self): + """return kw arg dict to be sent to connect().""" + return dict(characterEncoding='UTF-8', yearIsDateType='false') + + def _extract_error_code(self, exception): + # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist + # [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' () + m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.orig.args)) + c = m.group(1) + if c: + return int(c) + + def _get_server_version_info(self,connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.dbversion): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + +dialect = MySQLDialect_zxjdbc diff --git a/sqlalchemy/dialects/oracle/__init__.py b/sqlalchemy/dialects/oracle/__init__.py new file mode 100644 index 0000000..78d3c8f --- /dev/null +++ b/sqlalchemy/dialects/oracle/__init__.py @@ -0,0 +1,17 @@ +from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc + +base.dialect = cx_oracle.dialect + +from sqlalchemy.dialects.oracle.base import \ + VARCHAR, NVARCHAR, CHAR, DATE, DATETIME, NUMBER,\ + BLOB, BFILE, CLOB, NCLOB, TIMESTAMP, RAW,\ + FLOAT, DOUBLE_PRECISION, LONG, dialect, INTERVAL,\ + VARCHAR2, NVARCHAR2 + + +__all__ = ( +'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'DATETIME', 'NUMBER', +'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW', +'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL', +'VARCHAR2', 'NVARCHAR2' +) diff --git a/sqlalchemy/dialects/oracle/base.py b/sqlalchemy/dialects/oracle/base.py new file mode 100644 index 0000000..4757309 --- /dev/null +++ b/sqlalchemy/dialects/oracle/base.py @@ -0,0 +1,1030 @@ +# oracle/base.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Support for the Oracle database. + +Oracle version 8 through current (11g at the time of this writing) are supported. + +For information on connecting via specific drivers, see the documentation +for that driver. + +Connect Arguments +----------------- + +The dialect supports several :func:`~sqlalchemy.create_engine()` arguments which +affect the behavior of the dialect regardless of driver in use. + +* *use_ansi* - Use ANSI JOIN constructs (see the section on Oracle 8). Defaults + to ``True``. If ``False``, Oracle-8 compatible constructs are used for joins. + +* *optimize_limits* - defaults to ``False``. see the section on LIMIT/OFFSET. + +Auto Increment Behavior +----------------------- + +SQLAlchemy Table objects which include integer primary keys are usually assumed to have +"autoincrementing" behavior, meaning they can generate their own primary key values upon +INSERT. Since Oracle has no "autoincrement" feature, SQLAlchemy relies upon sequences +to produce these values. With the Oracle dialect, *a sequence must always be explicitly +specified to enable autoincrement*. This is divergent with the majority of documentation +examples which assume the usage of an autoincrement-capable database. To specify sequences, +use the sqlalchemy.schema.Sequence object which is passed to a Column construct:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq'), primary_key=True), + Column(...), ... + ) + +This step is also required when using table reflection, i.e. autoload=True:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq'), primary_key=True), + autoload=True + ) + +Identifier Casing +----------------- + +In Oracle, the data dictionary represents all case insensitive identifier names +using UPPERCASE text. SQLAlchemy on the other hand considers an all-lower case identifier +name to be case insensitive. The Oracle dialect converts all case insensitive identifiers +to and from those two formats during schema level communication, such as reflection of +tables and indexes. Using an UPPERCASE name on the SQLAlchemy side indicates a +case sensitive identifier, and SQLAlchemy will quote the name - this will cause mismatches +against data dictionary data received from Oracle, so unless identifier names have been +truly created as case sensitive (i.e. using quoted names), all lowercase names should be +used on the SQLAlchemy side. + +Unicode +------- + +SQLAlchemy 0.6 uses the "native unicode" mode provided as of cx_oracle 5. cx_oracle 5.0.2 +or greater is recommended for support of NCLOB. If not using cx_oracle 5, the NLS_LANG +environment variable needs to be set in order for the oracle client library to use +proper encoding, such as "AMERICAN_AMERICA.UTF8". + +Also note that Oracle supports unicode data through the NVARCHAR and NCLOB data types. +When using the SQLAlchemy Unicode and UnicodeText types, these DDL types will be used +within CREATE TABLE statements. Usage of VARCHAR2 and CLOB with unicode text still +requires NLS_LANG to be set. + +LIMIT/OFFSET Support +-------------------- + +Oracle has no support for the LIMIT or OFFSET keywords. Whereas previous versions of SQLAlchemy +used the "ROW NUMBER OVER..." construct to simulate LIMIT/OFFSET, SQLAlchemy 0.5 now uses +a wrapped subquery approach in conjunction with ROWNUM. The exact methodology is taken from +http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html . Note that the +"FIRST ROWS()" optimization keyword mentioned is not used by default, as the user community felt +this was stepping into the bounds of optimization that is better left on the DBA side, but this +prefix can be added by enabling the optimize_limits=True flag on create_engine(). + +ON UPDATE CASCADE +----------------- + +Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based solution +is available at http://asktom.oracle.com/tkyte/update_cascade/index.html . + +When using the SQLAlchemy ORM, the ORM has limited ability to manually issue +cascading updates - specify ForeignKey objects using the +"deferrable=True, initially='deferred'" keyword arguments, +and specify "passive_updates=False" on each relationship(). + +Oracle 8 Compatibility +---------------------- + +When using Oracle 8, a "use_ansi=False" flag is available which converts all +JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN +makes use of Oracle's (+) operator. + +Synonym/DBLINK Reflection +------------------------- + +When using reflection with Table objects, the dialect can optionally search for tables +indicated by synonyms that reference DBLINK-ed tables by passing the flag +oracle_resolve_synonyms=True as a keyword argument to the Table construct. If DBLINK +is not in use this flag should be left off. + +""" + +import random, re + +from sqlalchemy import schema as sa_schema +from sqlalchemy import util, sql, log +from sqlalchemy.engine import default, base, reflection +from sqlalchemy.sql import compiler, visitors, expression +from sqlalchemy.sql import operators as sql_operators, functions as sql_functions +from sqlalchemy import types as sqltypes +from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, DATE, DATETIME, \ + BLOB, CLOB, TIMESTAMP, FLOAT + +RESERVED_WORDS = set('SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN ' + 'DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED ' + 'ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE ANY ' + 'TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE BY ASC ' + 'REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES AS IN VIEW ' + 'EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS NOT TRIGGER ' + 'ELSE CREATE INTERSECT PCTFREE DISTINCT USER CONNECT SET MODE ' + 'OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR DECIMAL UNION PUBLIC ' + 'AND START UID COMMENT'.split()) + +class RAW(sqltypes.LargeBinary): + pass +OracleRaw = RAW + +class NCLOB(sqltypes.Text): + __visit_name__ = 'NCLOB' + +VARCHAR2 = VARCHAR +NVARCHAR2 = NVARCHAR + +class NUMBER(sqltypes.Numeric, sqltypes.Integer): + __visit_name__ = 'NUMBER' + + def __init__(self, precision=None, scale=None, asdecimal=None): + if asdecimal is None: + asdecimal = bool(scale and scale > 0) + + super(NUMBER, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal) + + @property + def _type_affinity(self): + if bool(self.scale and self.scale > 0): + return sqltypes.Numeric + else: + return sqltypes.Integer + + +class DOUBLE_PRECISION(sqltypes.Numeric): + __visit_name__ = 'DOUBLE_PRECISION' + def __init__(self, precision=None, scale=None, asdecimal=None): + if asdecimal is None: + asdecimal = False + + super(DOUBLE_PRECISION, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal) + +class BFILE(sqltypes.LargeBinary): + __visit_name__ = 'BFILE' + +class LONG(sqltypes.Text): + __visit_name__ = 'LONG' + +class INTERVAL(sqltypes.TypeEngine): + __visit_name__ = 'INTERVAL' + + def __init__(self, + day_precision=None, + second_precision=None): + """Construct an INTERVAL. + + Note that only DAY TO SECOND intervals are currently supported. + This is due to a lack of support for YEAR TO MONTH intervals + within available DBAPIs (cx_oracle and zxjdbc). + + :param day_precision: the day precision value. this is the number of digits + to store for the day field. Defaults to "2" + :param second_precision: the second precision value. this is the number of digits + to store for the fractional seconds field. Defaults to "6". + + """ + self.day_precision = day_precision + self.second_precision = second_precision + + @classmethod + def _adapt_from_generic_interval(cls, interval): + return INTERVAL(day_precision=interval.day_precision, + second_precision=interval.second_precision) + + def adapt(self, impltype): + return impltype(day_precision=self.day_precision, + second_precision=self.second_precision) + + @property + def _type_affinity(self): + return sqltypes.Interval + +class _OracleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + +colspecs = { + sqltypes.Boolean : _OracleBoolean, + sqltypes.Interval : INTERVAL, +} + +ischema_names = { + 'VARCHAR2' : VARCHAR, + 'NVARCHAR2' : NVARCHAR, + 'CHAR' : CHAR, + 'DATE' : DATE, + 'NUMBER' : NUMBER, + 'BLOB' : BLOB, + 'BFILE' : BFILE, + 'CLOB' : CLOB, + 'NCLOB' : NCLOB, + 'TIMESTAMP' : TIMESTAMP, + 'TIMESTAMP WITH TIME ZONE' : TIMESTAMP, + 'INTERVAL DAY TO SECOND' : INTERVAL, + 'RAW' : RAW, + 'FLOAT' : FLOAT, + 'DOUBLE PRECISION' : DOUBLE_PRECISION, + 'LONG' : LONG, +} + + +class OracleTypeCompiler(compiler.GenericTypeCompiler): + # Note: + # Oracle DATE == DATETIME + # Oracle does not allow milliseconds in DATE + # Oracle does not support TIME columns + + def visit_datetime(self, type_): + return self.visit_DATE(type_) + + def visit_float(self, type_): + return self.visit_FLOAT(type_) + + def visit_unicode(self, type_): + return self.visit_NVARCHAR(type_) + + def visit_INTERVAL(self, type_): + return "INTERVAL DAY%s TO SECOND%s" % ( + type_.day_precision is not None and + "(%d)" % type_.day_precision or + "", + type_.second_precision is not None and + "(%d)" % type_.second_precision or + "", + ) + + def visit_TIMESTAMP(self, type_): + if type_.timezone: + return "TIMESTAMP WITH TIME ZONE" + else: + return "TIMESTAMP" + + def visit_DOUBLE_PRECISION(self, type_): + return self._generate_numeric(type_, "DOUBLE PRECISION") + + def visit_NUMBER(self, type_, **kw): + return self._generate_numeric(type_, "NUMBER", **kw) + + def _generate_numeric(self, type_, name, precision=None, scale=None): + if precision is None: + precision = type_.precision + + if scale is None: + scale = getattr(type_, 'scale', None) + + if precision is None: + return name + elif scale is None: + return "%(name)s(%(precision)s)" % {'name':name,'precision': precision} + else: + return "%(name)s(%(precision)s, %(scale)s)" % {'name':name,'precision': precision, 'scale' : scale} + + def visit_VARCHAR(self, type_): + if self.dialect.supports_char_length: + return "VARCHAR(%(length)s CHAR)" % {'length' : type_.length} + else: + return "VARCHAR(%(length)s)" % {'length' : type_.length} + + def visit_NVARCHAR(self, type_): + return "NVARCHAR2(%(length)s)" % {'length' : type_.length} + + def visit_text(self, type_): + return self.visit_CLOB(type_) + + def visit_unicode_text(self, type_): + return self.visit_NCLOB(type_) + + def visit_large_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_big_integer(self, type_): + return self.visit_NUMBER(type_, precision=19) + + def visit_boolean(self, type_): + return self.visit_SMALLINT(type_) + + def visit_RAW(self, type_): + return "RAW(%(length)s)" % {'length' : type_.length} + +class OracleCompiler(compiler.SQLCompiler): + """Oracle compiler modifies the lexical structure of Select + statements to work under non-ANSI configured Oracle databases, if + the use_ansi flag is False. + """ + + compound_keywords = util.update_copy( + compiler.SQLCompiler.compound_keywords, + { + expression.CompoundSelect.EXCEPT : 'MINUS' + } + ) + + def __init__(self, *args, **kwargs): + super(OracleCompiler, self).__init__(*args, **kwargs) + self.__wheres = {} + self._quoted_bind_names = {} + + def visit_mod(self, binary, **kw): + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_char_length_func(self, fn, **kw): + return "LENGTH" + self.function_argspec(fn, **kw) + + def visit_match_op(self, binary, **kw): + return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def get_select_hint_text(self, byfroms): + return " ".join( + "/*+ %s */" % text for table, text in byfroms.items() + ) + + def function_argspec(self, fn, **kw): + if len(fn.clauses) > 0: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) + else: + return "" + + def default_from(self): + """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. + + The Oracle compiler tacks a "FROM DUAL" to the statement. + """ + + return " FROM DUAL" + + def visit_join(self, join, **kwargs): + if self.dialect.use_ansi: + return compiler.SQLCompiler.visit_join(self, join, **kwargs) + else: + kwargs['asfrom'] = True + return self.process(join.left, **kwargs) + \ + ", " + self.process(join.right, **kwargs) + + def _get_nonansi_join_whereclause(self, froms): + clauses = [] + + def visit_join(join): + if join.isouter: + def visit_binary(binary): + if binary.operator == sql_operators.eq: + if binary.left.table is join.right: + binary.left = _OuterJoinColumn(binary.left) + elif binary.right.table is join.right: + binary.right = _OuterJoinColumn(binary.right) + clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary})) + else: + clauses.append(join.onclause) + + for j in join.left, join.right: + if isinstance(j, expression.Join): + visit_join(j) + + for f in froms: + if isinstance(f, expression.Join): + visit_join(f) + return sql.and_(*clauses) + + def visit_outer_join_column(self, vc): + return self.process(vc.column) + "(+)" + + def visit_sequence(self, seq): + return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" + + def visit_alias(self, alias, asfrom=False, ashint=False, **kwargs): + """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" + + if asfrom or ashint: + alias_name = isinstance(alias.name, expression._generated_label) and \ + self._truncated_identifier("alias", alias.name) or alias.name + + if ashint: + return alias_name + elif asfrom: + return self.process(alias.original, asfrom=asfrom, **kwargs) + \ + " " + self.preparer.format_alias(alias, alias_name) + else: + return self.process(alias.original, **kwargs) + + def returning_clause(self, stmt, returning_cols): + + def create_out_param(col, i): + bindparam = sql.outparam("ret_%d" % i, type_=col.type) + self.binds[bindparam.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + + columnlist = list(expression._select_iterables(returning_cols)) + + # within_columns_clause =False so that labels (foo AS bar) don't render + columns = [self.process(c, within_columns_clause=False, result_map=self.result_map) for c in columnlist] + + binds = [create_out_param(c, i) for i, c in enumerate(columnlist)] + + return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + + def _TODO_visit_compound_select(self, select): + """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" + pass + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``rownum`` criterion. + """ + + if not getattr(select, '_oracle_visit', None): + if not self.dialect.use_ansi: + if self.stack and 'from' in self.stack[-1]: + existingfroms = self.stack[-1]['from'] + else: + existingfroms = None + + froms = select._get_display_froms(existingfroms) + whereclause = self._get_nonansi_join_whereclause(froms) + if whereclause is not None: + select = select.where(whereclause) + select._oracle_visit = True + + if select._limit is not None or select._offset is not None: + # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html + # + # Generalized form of an Oracle pagination query: + # select ... from ( + # select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from ( + # select distinct ... where ... order by ... + # ) where ROWNUM <= :limit+:offset + # ) where ora_rn > :offset + # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0 + + # TODO: use annotations instead of clone + attr set ? + select = select._generate() + select._oracle_visit = True + + # Wrap the middle select and add the hint + limitselect = sql.select([c for c in select.c]) + if select._limit and self.dialect.optimize_limits: + limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit) + + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + # If needed, add the limiting clause + if select._limit is not None: + max_row = select._limit + if select._offset is not None: + max_row += select._offset + limitselect.append_whereclause( + sql.literal_column("ROWNUM")<=max_row) + + # If needed, add the ora_rn, and wrap again with offset. + if select._offset is None: + select = limitselect + else: + limitselect = limitselect.column( + sql.literal_column("ROWNUM").label("ora_rn")) + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + offsetselect = sql.select( + [c for c in limitselect.c if c.key!='ora_rn']) + offsetselect._oracle_visit = True + offsetselect._is_wrapper = True + + offsetselect.append_whereclause( + sql.literal_column("ora_rn")>select._offset) + + select = offsetselect + + kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def limit_clause(self, select): + return "" + + def for_update_clause(self, select): + if select.for_update == "nowait": + return " FOR UPDATE NOWAIT" + else: + return super(OracleCompiler, self).for_update_clause(select) + +class OracleDDLCompiler(compiler.DDLCompiler): + + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + + # oracle has no ON UPDATE CASCADE - + # its only available via triggers http://asktom.oracle.com/tkyte/update_cascade/index.html + if constraint.onupdate is not None: + util.warn( + "Oracle does not contain native UPDATE CASCADE " + "functionality - onupdates will not be rendered for foreign keys. " + "Consider using deferrable=True, initially='deferred' or triggers.") + + return text + +class OracleIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = set([x.lower() for x in RESERVED_WORDS]) + illegal_initial_characters = set(xrange(0, 10)).union(["_", "$"]) + + def _bindparam_requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return (lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(unicode(value)) + ) + + def format_savepoint(self, savepoint): + name = re.sub(r'^_+', '', savepoint.ident) + return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) + + +class OracleExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq): + return self._execute_scalar("SELECT " + + self.dialect.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL") + +class OracleDialect(default.DefaultDialect): + name = 'oracle' + supports_alter = True + supports_unicode_statements = False + supports_unicode_binds = False + max_identifier_length = 30 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + supports_sequences = True + sequences_optional = False + postfetch_lastrowid = False + + default_paramstyle = 'named' + colspecs = colspecs + ischema_names = ischema_names + requires_name_normalize = True + + supports_default_values = False + supports_empty_insert = False + + statement_compiler = OracleCompiler + ddl_compiler = OracleDDLCompiler + type_compiler = OracleTypeCompiler + preparer = OracleIdentifierPreparer + execution_ctx_cls = OracleExecutionContext + + reflection_options = ('oracle_resolve_synonyms', ) + + supports_char_length = True + + def __init__(self, + use_ansi=True, + optimize_limits=False, + **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + self.use_ansi = use_ansi + self.optimize_limits = optimize_limits + + def initialize(self, connection): + super(OracleDialect, self).initialize(connection) + self.implicit_returning = self.server_version_info > (10, ) and \ + self.__dict__.get('implicit_returning', True) + + self.supports_char_length = self.server_version_info >= (9, ) + + if self.server_version_info < (9,): + self.colspecs = self.colspecs.copy() + self.colspecs.pop(sqltypes.Interval) + + def do_release_savepoint(self, connection, name): + # Oracle does not support RELEASE SAVEPOINT + pass + + def has_table(self, connection, table_name, schema=None): + if not schema: + schema = self.default_schema_name + cursor = connection.execute( + sql.text("SELECT table_name FROM all_tables " + "WHERE table_name = :name AND owner = :schema_name"), + name=self.denormalize_name(table_name), schema_name=self.denormalize_name(schema)) + return cursor.first() is not None + + def has_sequence(self, connection, sequence_name, schema=None): + if not schema: + schema = self.default_schema_name + cursor = connection.execute( + sql.text("SELECT sequence_name FROM all_sequences " + "WHERE sequence_name = :name AND sequence_owner = :schema_name"), + name=self.denormalize_name(sequence_name), schema_name=self.denormalize_name(schema)) + return cursor.first() is not None + + def normalize_name(self, name): + if name is None: + return None + # Py2K + if isinstance(name, str): + name = name.decode(self.encoding) + # end Py2K + if name.upper() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): + return name.lower() + else: + return name + + def denormalize_name(self, name): + if name is None: + return None + elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): + name = name.upper() + # Py2K + if not self.supports_unicode_binds: + name = name.encode(self.encoding) + else: + name = unicode(name) + # end Py2K + return name + + def _get_default_schema_name(self, connection): + return self.normalize_name(connection.execute(u'SELECT USER FROM DUAL').scalar()) + + def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None): + """search for a local synonym matching the given desired owner/name. + + if desired_owner is None, attempts to locate a distinct owner. + + returns the actual name, owner, dblink name, and synonym name if found. + """ + + q = "SELECT owner, table_owner, table_name, db_link, synonym_name FROM all_synonyms WHERE " + clauses = [] + params = {} + if desired_synonym: + clauses.append("synonym_name = :synonym_name") + params['synonym_name'] = desired_synonym + if desired_owner: + clauses.append("table_owner = :desired_owner") + params['desired_owner'] = desired_owner + if desired_table: + clauses.append("table_name = :tname") + params['tname'] = desired_table + + q += " AND ".join(clauses) + + result = connection.execute(sql.text(q), **params) + if desired_owner: + row = result.first() + if row: + return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name'] + else: + return None, None, None, None + else: + rows = result.fetchall() + if len(rows) > 1: + raise AssertionError("There are multiple tables visible to the schema, you must specify owner") + elif len(rows) == 1: + row = rows[0] + return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name'] + else: + return None, None, None, None + + @reflection.cache + def _prepare_reflection_args(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + + if resolve_synonyms: + actual_name, owner, dblink, synonym = self._resolve_synonym( + connection, + desired_owner=self.denormalize_name(schema), + desired_synonym=self.denormalize_name(table_name) + ) + else: + actual_name, owner, dblink, synonym = None, None, None, None + if not actual_name: + actual_name = self.denormalize_name(table_name) + if not dblink: + dblink = '' + if not owner: + owner = self.denormalize_name(schema or self.default_schema_name) + return (actual_name, owner, dblink, synonym) + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = "SELECT username FROM all_users ORDER BY username" + cursor = connection.execute(s,) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + schema = self.denormalize_name(schema or self.default_schema_name) + + # note that table_names() isnt loading DBLINKed or synonym'ed tables + if schema is None: + schema = self.default_schema_name + s = sql.text( + "SELECT table_name FROM all_tables " + "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') " + "AND OWNER = :owner " + "AND IOT_NAME IS NULL") + cursor = connection.execute(s, owner=schema) + return [self.normalize_name(row[0]) for row in cursor] + + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + schema = self.denormalize_name(schema or self.default_schema_name) + s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner") + cursor = connection.execute(s, owner=self.denormalize_name(schema)) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + columns = [] + if self.supports_char_length: + char_length_col = 'char_length' + else: + char_length_col = 'data_length' + + c = connection.execute(sql.text( + "SELECT column_name, data_type, %(char_length_col)s, data_precision, data_scale, " + "nullable, data_default FROM ALL_TAB_COLUMNS%(dblink)s " + "WHERE table_name = :table_name AND owner = :owner " + "ORDER BY column_id" % {'dblink': dblink, 'char_length_col':char_length_col}), + table_name=table_name, owner=schema) + + for row in c: + (colname, orig_colname, coltype, length, precision, scale, nullable, default) = \ + (self.normalize_name(row[0]), row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) + + if coltype == 'NUMBER' : + coltype = NUMBER(precision, scale) + elif coltype in ('VARCHAR2', 'NVARCHAR2', 'CHAR'): + coltype = self.ischema_names.get(coltype)(length) + elif 'WITH TIME ZONE' in coltype: + coltype = TIMESTAMP(timezone=True) + else: + coltype = re.sub(r'\(\d+\)', '', coltype) + try: + coltype = self.ischema_names[coltype] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, colname)) + coltype = sqltypes.NULLTYPE + + cdict = { + 'name': colname, + 'type': coltype, + 'nullable': nullable, + 'default': default, + } + if orig_colname.lower() == orig_colname: + cdict['quote'] = True + + columns.append(cdict) + return columns + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + + + info_cache = kw.get('info_cache') + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + indexes = [] + q = sql.text(""" + SELECT a.index_name, a.column_name, b.uniqueness + FROM ALL_IND_COLUMNS%(dblink)s a, + ALL_INDEXES%(dblink)s b + WHERE + a.index_name = b.index_name + AND a.table_owner = b.table_owner + AND a.table_name = b.table_name + + AND a.table_name = :table_name + AND a.table_owner = :schema + ORDER BY a.index_name, a.column_position""" % {'dblink': dblink}) + rp = connection.execute(q, table_name=self.denormalize_name(table_name), + schema=self.denormalize_name(schema)) + indexes = [] + last_index_name = None + pkeys = self.get_primary_keys(connection, table_name, schema, + resolve_synonyms=resolve_synonyms, + dblink=dblink, + info_cache=kw.get('info_cache')) + uniqueness = dict(NONUNIQUE=False, UNIQUE=True) + + oracle_sys_col = re.compile(r'SYS_NC\d+\$', re.IGNORECASE) + for rset in rp: + # don't include the primary key columns + if rset.column_name in [s.upper() for s in pkeys]: + continue + if rset.index_name != last_index_name: + index = dict(name=self.normalize_name(rset.index_name), column_names=[]) + indexes.append(index) + index['unique'] = uniqueness.get(rset.uniqueness, False) + + # filter out Oracle SYS_NC names. could also do an outer join + # to the all_tab_columns table and check for real col names there. + if not oracle_sys_col.match(rset.column_name): + index['column_names'].append(self.normalize_name(rset.column_name)) + last_index_name = rset.index_name + return indexes + + @reflection.cache + def _get_constraint_data(self, connection, table_name, schema=None, + dblink='', **kw): + + rp = connection.execute( + sql.text("""SELECT + ac.constraint_name, + ac.constraint_type, + loc.column_name AS local_column, + rem.table_name AS remote_table, + rem.column_name AS remote_column, + rem.owner AS remote_owner, + loc.position as loc_pos, + rem.position as rem_pos + FROM all_constraints%(dblink)s ac, + all_cons_columns%(dblink)s loc, + all_cons_columns%(dblink)s rem + WHERE ac.table_name = :table_name + AND ac.constraint_type IN ('R','P') + AND ac.owner = :owner + AND ac.owner = loc.owner + AND ac.constraint_name = loc.constraint_name + AND ac.r_owner = rem.owner(+) + AND ac.r_constraint_name = rem.constraint_name(+) + AND (rem.position IS NULL or loc.position=rem.position) + ORDER BY ac.constraint_name, loc.position""" % {'dblink': dblink}), + table_name=table_name, owner=schema) + constraint_data = rp.fetchall() + return constraint_data + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + pkeys = [] + constraint_data = self._get_constraint_data(connection, table_name, + schema, dblink, + info_cache=kw.get('info_cache')) + + for row in constraint_data: + #print "ROW:" , row + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ + row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + if cons_type == 'P': + pkeys.append(local_column) + return pkeys + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + requested_schema = schema # to check later on + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + + constraint_data = self._get_constraint_data(connection, table_name, + schema, dblink, + info_cache=kw.get('info_cache')) + + def fkey_rec(): + return { + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + } + + fkeys = util.defaultdict(fkey_rec) + + for row in constraint_data: + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ + row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + + if cons_type == 'R': + if remote_table is None: + # ticket 363 + util.warn( + ("Got 'None' querying 'table_name' from " + "all_cons_columns%(dblink)s - does the user have " + "proper rights to the table?") % {'dblink':dblink}) + continue + + rec = fkeys[cons_name] + rec['name'] = cons_name + local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + + if not rec['referred_table']: + if resolve_synonyms: + ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \ + self._resolve_synonym( + connection, + desired_owner=self.denormalize_name(remote_owner), + desired_table=self.denormalize_name(remote_table) + ) + if ref_synonym: + remote_table = self.normalize_name(ref_synonym) + remote_owner = self.normalize_name(ref_remote_owner) + + rec['referred_table'] = remote_table + + if requested_schema is not None or self.denormalize_name(remote_owner) != schema: + rec['referred_schema'] = remote_owner + + local_cols.append(local_column) + remote_cols.append(remote_column) + + return fkeys.values() + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + info_cache = kw.get('info_cache') + (view_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, view_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + s = sql.text(""" + SELECT text FROM all_views + WHERE owner = :schema + AND view_name = :view_name + """) + rp = connection.execute(s, + view_name=view_name, schema=schema).scalar() + if rp: + return rp.decode(self.encoding) + else: + return None + + + +class _OuterJoinColumn(sql.ClauseElement): + __visit_name__ = 'outer_join_column' + + def __init__(self, column): + self.column = column + + + diff --git a/sqlalchemy/dialects/oracle/cx_oracle.py b/sqlalchemy/dialects/oracle/cx_oracle.py new file mode 100644 index 0000000..91af662 --- /dev/null +++ b/sqlalchemy/dialects/oracle/cx_oracle.py @@ -0,0 +1,529 @@ +"""Support for the Oracle database via the cx_oracle driver. + +Driver +------ + +The Oracle dialect uses the cx_oracle driver, available at +http://cx-oracle.sourceforge.net/ . The dialect has several behaviors +which are specifically tailored towards compatibility with this module. + +Connecting +---------- + +Connecting with create_engine() uses the standard URL approach of +``oracle://user:pass@host:port/dbname[?key=value&key=value...]``. If dbname is present, the +host, port, and dbname tokens are converted to a TNS name using the cx_oracle +:func:`makedsn()` function. Otherwise, the host token is taken directly as a TNS name. + +Additional arguments which may be specified either as query string arguments on the +URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are: + +* *allow_twophase* - enable two-phase transactions. Defaults to ``True``. + +* *arraysize* - set the cx_oracle.arraysize value on cursors, in SQLAlchemy + it defaults to 50. See the section on "LOB Objects" below. + +* *auto_convert_lobs* - defaults to True, see the section on LOB objects. + +* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters. + This is required for LOB datatypes but can be disabled to reduce overhead. Defaults + to ``True``. + +* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an + integer value. This value is only available as a URL query string argument. + +* *threaded* - enable multithreaded access to cx_oracle connections. Defaults + to ``True``. Note that this is the opposite default of cx_oracle itself. + +Unicode +------- + +As of cx_oracle 5, Python unicode objects can be bound directly to statements, +and it appears that cx_oracle can handle these even without NLS_LANG being set. +SQLAlchemy tests for version 5 and will pass unicode objects straight to cx_oracle +if this is the case. For older versions of cx_oracle, SQLAlchemy will encode bind +parameters normally using dialect.encoding as the encoding. + +LOB Objects +----------- + +cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set +is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default, +SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First, +the LOB object requires an active cursor association, meaning if you were to fetch many rows +at once such that cx_oracle had to go back to the database and fetch a new batch of rows, +the LOB objects in the already-fetched rows are now unreadable and will raise an error. +SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read. +The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy +defaults to 50 (cx_oracle normally defaults this to one). + +Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to +"normalize" the results to look more like that of other DBAPIs. + +The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place +for all statement executions, even plain string-based statements for which SQLA has no awareness +of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases +without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch" +of LOB objects, can be disabled using auto_convert_lobs=False. + +Two Phase Transaction Support +----------------------------- + +Two Phase transactions are implemented using XA transactions. Success has been reported +with this feature but it should be regarded as experimental. + +""" + +from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, \ + RESERVED_WORDS, OracleExecutionContext +from sqlalchemy.dialects.oracle import base as oracle +from sqlalchemy.engine import base +from sqlalchemy import types as sqltypes, util, exc +from datetime import datetime +import random + +class _OracleNumeric(sqltypes.Numeric): + # cx_oracle accepts Decimal objects, but returns + # floats + def bind_processor(self, dialect): + return None + +class _OracleDate(sqltypes.Date): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + return value.date() + else: + return value + return process + +class _LOBMixin(object): + def result_processor(self, dialect, coltype): + if not dialect.auto_convert_lobs: + # return the cx_oracle.LOB directly. + return None + + def process(value): + if value is not None: + return value.read() + else: + return value + return process + +class _NativeUnicodeMixin(object): + # Py2K + def bind_processor(self, dialect): + if dialect._cx_oracle_with_unicode: + def process(value): + if value is None: + return value + else: + return unicode(value) + return process + else: + return super(_NativeUnicodeMixin, self).bind_processor(dialect) + # end Py2K + + def result_processor(self, dialect, coltype): + # if we know cx_Oracle will return unicode, + # don't process results + if dialect._cx_oracle_with_unicode: + return None + elif self.convert_unicode != 'force' and \ + dialect._cx_oracle_native_nvarchar and \ + coltype in dialect._cx_oracle_unicode_types: + return None + else: + return super(_NativeUnicodeMixin, self).result_processor(dialect, coltype) + +class _OracleChar(_NativeUnicodeMixin, sqltypes.CHAR): + def get_dbapi_type(self, dbapi): + return dbapi.FIXED_CHAR + +class _OracleNVarChar(_NativeUnicodeMixin, sqltypes.NVARCHAR): + def get_dbapi_type(self, dbapi): + return getattr(dbapi, 'UNICODE', dbapi.STRING) + +class _OracleText(_LOBMixin, sqltypes.Text): + def get_dbapi_type(self, dbapi): + return dbapi.CLOB + +class _OracleString(_NativeUnicodeMixin, sqltypes.String): + pass + +class _OracleUnicodeText(_LOBMixin, _NativeUnicodeMixin, sqltypes.UnicodeText): + def get_dbapi_type(self, dbapi): + return dbapi.NCLOB + + def result_processor(self, dialect, coltype): + lob_processor = _LOBMixin.result_processor(self, dialect, coltype) + if lob_processor is None: + return None + + string_processor = _NativeUnicodeMixin.result_processor(self, dialect, coltype) + + if string_processor is None: + return lob_processor + else: + def process(value): + return string_processor(lob_processor(value)) + return process + +class _OracleInteger(sqltypes.Integer): + def result_processor(self, dialect, coltype): + def to_int(val): + if val is not None: + val = int(val) + return val + return to_int + +class _OracleBinary(_LOBMixin, sqltypes.LargeBinary): + def get_dbapi_type(self, dbapi): + return dbapi.BLOB + + def bind_processor(self, dialect): + return None + +class _OracleInterval(oracle.INTERVAL): + def get_dbapi_type(self, dbapi): + return dbapi.INTERVAL + +class _OracleRaw(oracle.RAW): + pass + +class OracleCompiler_cx_oracle(OracleCompiler): + def bindparam_string(self, name): + if self.preparer._bindparam_requires_quotes(name): + quoted_name = '"%s"' % name + self._quoted_bind_names[name] = quoted_name + return OracleCompiler.bindparam_string(self, quoted_name) + else: + return OracleCompiler.bindparam_string(self, name) + + +class OracleExecutionContext_cx_oracle(OracleExecutionContext): + + def pre_exec(self): + quoted_bind_names = \ + getattr(self.compiled, '_quoted_bind_names', None) + if quoted_bind_names: + if not self.dialect.supports_unicode_binds: + quoted_bind_names = \ + dict( + (fromname, toname.encode(self.dialect.encoding)) + for fromname, toname in + quoted_bind_names.items() + ) + for param in self.parameters: + for fromname, toname in quoted_bind_names.items(): + param[toname] = param[fromname] + del param[fromname] + + if self.dialect.auto_setinputsizes: + # cx_oracle really has issues when you setinputsizes + # on String, including that outparams/RETURNING + # breaks for varchars + self.set_input_sizes(quoted_bind_names, + exclude_types=self.dialect._cx_oracle_string_types + ) + + # if a single execute, check for outparams + if len(self.compiled_parameters) == 1: + for bindparam in self.compiled.binds.values(): + if bindparam.isoutparam: + dbtype = bindparam.type.dialect_impl(self.dialect).\ + get_dbapi_type(self.dialect.dbapi) + if not hasattr(self, 'out_parameters'): + self.out_parameters = {} + if dbtype is None: + raise exc.InvalidRequestError("Cannot create out parameter for parameter " + "%r - it's type %r is not supported by" + " cx_oracle" % + (name, bindparam.type) + ) + name = self.compiled.bind_names[bindparam] + self.out_parameters[name] = self.cursor.var(dbtype) + self.parameters[0][quoted_bind_names.get(name, name)] = \ + self.out_parameters[name] + + def create_cursor(self): + c = self._connection.connection.cursor() + if self.dialect.arraysize: + c.arraysize = self.dialect.arraysize + return c + + def get_result_proxy(self): + if hasattr(self, 'out_parameters') and self.compiled.returning: + returning_params = dict( + (k, v.getvalue()) + for k, v in self.out_parameters.items() + ) + return ReturningResultProxy(self, returning_params) + + result = None + if self.cursor.description is not None: + for column in self.cursor.description: + type_code = column[1] + if type_code in self.dialect._cx_oracle_binary_types: + result = base.BufferedColumnResultProxy(self) + + if result is None: + result = base.ResultProxy(self) + + if hasattr(self, 'out_parameters'): + if self.compiled_parameters is not None and \ + len(self.compiled_parameters) == 1: + result.out_parameters = out_parameters = {} + + for bind, name in self.compiled.bind_names.items(): + if name in self.out_parameters: + type = bind.type + impl_type = type.dialect_impl(self.dialect) + dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi) + result_processor = impl_type.\ + result_processor(self.dialect, + dbapi_type) + if result_processor is not None: + out_parameters[name] = \ + result_processor(self.out_parameters[name].getvalue()) + else: + out_parameters[name] = self.out_parameters[name].getvalue() + else: + result.out_parameters = dict( + (k, v.getvalue()) + for k, v in self.out_parameters.items() + ) + + return result + +class OracleExecutionContext_cx_oracle_with_unicode(OracleExecutionContext_cx_oracle): + """Support WITH_UNICODE in Python 2.xx. + + WITH_UNICODE allows cx_Oracle's Python 3 unicode handling + behavior under Python 2.x. This mode in some cases disallows + and in other cases silently passes corrupted data when + non-Python-unicode strings (a.k.a. plain old Python strings) + are passed as arguments to connect(), the statement sent to execute(), + or any of the bind parameter keys or values sent to execute(). + This optional context therefore ensures that all statements are + passed as Python unicode objects. + + """ + def __init__(self, *arg, **kw): + OracleExecutionContext_cx_oracle.__init__(self, *arg, **kw) + self.statement = unicode(self.statement) + + def _execute_scalar(self, stmt): + return super(OracleExecutionContext_cx_oracle_with_unicode, self).\ + _execute_scalar(unicode(stmt)) + +class ReturningResultProxy(base.FullyBufferedResultProxy): + """Result proxy which stuffs the _returning clause + outparams into the fetch.""" + + def __init__(self, context, returning_params): + self._returning_params = returning_params + super(ReturningResultProxy, self).__init__(context) + + def _cursor_description(self): + returning = self.context.compiled.returning + + ret = [] + for c in returning: + if hasattr(c, 'name'): + ret.append((c.name, c.type)) + else: + ret.append((c.anon_label, c.type)) + return ret + + def _buffer_rows(self): + return [tuple(self._returning_params["ret_%d" % i] + for i, c in enumerate(self._returning_params))] + +class OracleDialect_cx_oracle(OracleDialect): + execution_ctx_cls = OracleExecutionContext_cx_oracle + statement_compiler = OracleCompiler_cx_oracle + driver = "cx_oracle" + + colspecs = colspecs = { + sqltypes.Numeric: _OracleNumeric, + sqltypes.Date : _OracleDate, # generic type, assume datetime.date is desired + oracle.DATE: oracle.DATE, # non generic type - passthru + sqltypes.LargeBinary : _OracleBinary, + sqltypes.Boolean : oracle._OracleBoolean, + sqltypes.Interval : _OracleInterval, + oracle.INTERVAL : _OracleInterval, + sqltypes.Text : _OracleText, + sqltypes.String : _OracleString, + sqltypes.UnicodeText : _OracleUnicodeText, + sqltypes.CHAR : _OracleChar, + sqltypes.Integer : _OracleInteger, # this is only needed for OUT parameters. + # it would be nice if we could not use it otherwise. + oracle.NUMBER : oracle.NUMBER, # don't let this get converted + oracle.RAW: _OracleRaw, + sqltypes.Unicode: _OracleNVarChar, + sqltypes.NVARCHAR : _OracleNVarChar, + } + + + execute_sequence_format = list + + def __init__(self, + auto_setinputsizes=True, + auto_convert_lobs=True, + threaded=True, + allow_twophase=True, + arraysize=50, **kwargs): + OracleDialect.__init__(self, **kwargs) + self.threaded = threaded + self.arraysize = arraysize + self.allow_twophase = allow_twophase + self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) + self.auto_setinputsizes = auto_setinputsizes + self.auto_convert_lobs = auto_convert_lobs + + if hasattr(self.dbapi, 'version'): + cx_oracle_ver = tuple([int(x) for x in self.dbapi.version.split('.')]) + else: + cx_oracle_ver = (0, 0, 0) + + def types(*names): + return set([ + getattr(self.dbapi, name, None) for name in names + ]).difference([None]) + + self._cx_oracle_string_types = types("STRING", "UNICODE", "NCLOB", "CLOB") + self._cx_oracle_unicode_types = types("UNICODE", "NCLOB") + self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB") + self.supports_unicode_binds = cx_oracle_ver >= (5, 0) + self._cx_oracle_native_nvarchar = cx_oracle_ver >= (5, 0) + + if cx_oracle_ver is None: + # this occurs in tests with mock DBAPIs + self._cx_oracle_string_types = set() + self._cx_oracle_with_unicode = False + elif cx_oracle_ver >= (5,) and not hasattr(self.dbapi, 'UNICODE'): + # cx_Oracle WITH_UNICODE mode. *only* python + # unicode objects accepted for anything + self.supports_unicode_statements = True + self.supports_unicode_binds = True + self._cx_oracle_with_unicode = True + # Py2K + # There's really no reason to run with WITH_UNICODE under Python 2.x. + # Give the user a hint. + util.warn("cx_Oracle is compiled under Python 2.xx using the " + "WITH_UNICODE flag. Consider recompiling cx_Oracle without " + "this flag, which is in no way necessary for full support of Unicode. " + "Otherwise, all string-holding bind parameters must " + "be explicitly typed using SQLAlchemy's String type or one of its subtypes," + "or otherwise be passed as Python unicode. Plain Python strings " + "passed as bind parameters will be silently corrupted by cx_Oracle." + ) + self.execution_ctx_cls = OracleExecutionContext_cx_oracle_with_unicode + # end Py2K + else: + self._cx_oracle_with_unicode = False + + if cx_oracle_ver is None or \ + not self.auto_convert_lobs or \ + not hasattr(self.dbapi, 'CLOB'): + self.dbapi_type_map = {} + else: + # only use this for LOB objects. using it for strings, dates + # etc. leads to a little too much magic, reflection doesn't know if it should + # expect encoded strings or unicodes, etc. + self.dbapi_type_map = { + self.dbapi.CLOB: oracle.CLOB(), + self.dbapi.NCLOB:oracle.NCLOB(), + self.dbapi.BLOB: oracle.BLOB(), + self.dbapi.BINARY: oracle.RAW(), + } + + @classmethod + def dbapi(cls): + import cx_Oracle + return cx_Oracle + + def create_connect_args(self, url): + dialect_opts = dict(url.query) + for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs', + 'threaded', 'allow_twophase'): + if opt in dialect_opts: + util.coerce_kw_type(dialect_opts, opt, bool) + setattr(self, opt, dialect_opts[opt]) + + if url.database: + # if we have a database, then we have a remote host + port = url.port + if port: + port = int(port) + else: + port = 1521 + dsn = self.dbapi.makedsn(url.host, port, url.database) + else: + # we have a local tnsname + dsn = url.host + + opts = dict( + user=url.username, + password=url.password, + dsn=dsn, + threaded=self.threaded, + twophase=self.allow_twophase, + ) + + # Py2K + if self._cx_oracle_with_unicode: + for k, v in opts.items(): + if isinstance(v, str): + opts[k] = unicode(v) + # end Py2K + + if 'mode' in url.query: + opts['mode'] = url.query['mode'] + if isinstance(opts['mode'], basestring): + mode = opts['mode'].upper() + if mode == 'SYSDBA': + opts['mode'] = self.dbapi.SYSDBA + elif mode == 'SYSOPER': + opts['mode'] = self.dbapi.SYSOPER + else: + util.coerce_kw_type(opts, 'mode', int) + return ([], opts) + + def _get_server_version_info(self, connection): + return tuple(int(x) for x in connection.connection.version.split('.')) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.InterfaceError): + return "not connected" in str(e) + else: + return "ORA-03114" in str(e) or "ORA-03113" in str(e) + + def create_xid(self): + """create a two-phase transaction ID. + + this id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). its format is unspecified.""" + + id = random.randint(0, 2 ** 128) + return (0x1234, "%032x" % id, "%032x" % 9) + + def do_begin_twophase(self, connection, xid): + connection.connection.begin(*xid) + + def do_prepare_twophase(self, connection, xid): + connection.connection.prepare() + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + pass + +dialect = OracleDialect_cx_oracle diff --git a/sqlalchemy/dialects/oracle/zxjdbc.py b/sqlalchemy/dialects/oracle/zxjdbc.py new file mode 100644 index 0000000..d742654 --- /dev/null +++ b/sqlalchemy/dialects/oracle/zxjdbc.py @@ -0,0 +1,209 @@ +"""Support for the Oracle database via the zxjdbc JDBC connector. + +JDBC Driver +----------- + +The official Oracle JDBC driver is at +http://www.oracle.com/technology/software/tech/java/sqlj_jdbc/index.html. + +""" +import decimal +import re + +from sqlalchemy import sql, types as sqltypes, util +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, OracleExecutionContext +from sqlalchemy.engine import base, default +from sqlalchemy.sql import expression + +SQLException = zxJDBC = None + +class _ZxJDBCDate(sqltypes.Date): + + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return None + else: + return value.date() + return process + + +class _ZxJDBCNumeric(sqltypes.Numeric): + + def result_processor(self, dialect, coltype): + #XXX: does the dialect return Decimal or not??? + # if it does (in all cases), we could use a None processor as well as + # the to_float generic processor + if self.asdecimal: + def process(value): + if isinstance(value, decimal.Decimal): + return value + else: + return decimal.Decimal(str(value)) + else: + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + + +class OracleCompiler_zxjdbc(OracleCompiler): + + def returning_clause(self, stmt, returning_cols): + self.returning_cols = list(expression._select_iterables(returning_cols)) + + # within_columns_clause=False so that labels (foo AS bar) don't render + columns = [self.process(c, within_columns_clause=False, result_map=self.result_map) + for c in self.returning_cols] + + if not hasattr(self, 'returning_parameters'): + self.returning_parameters = [] + + binds = [] + for i, col in enumerate(self.returning_cols): + dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + self.returning_parameters.append((i + 1, dbtype)) + + bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype)) + self.binds[bindparam.key] = bindparam + binds.append(self.bindparam_string(self._truncate_bindparam(bindparam))) + + return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + + +class OracleExecutionContext_zxjdbc(OracleExecutionContext): + + def pre_exec(self): + if hasattr(self.compiled, 'returning_parameters'): + # prepare a zxJDBC statement so we can grab its underlying + # OraclePreparedStatement's getReturnResultSet later + self.statement = self.cursor.prepare(self.statement) + + def get_result_proxy(self): + if hasattr(self.compiled, 'returning_parameters'): + rrs = None + try: + try: + rrs = self.statement.__statement__.getReturnResultSet() + rrs.next() + except SQLException, sqle: + msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode()) + if sqle.getSQLState() is not None: + msg += ' [SQLState: %s]' % sqle.getSQLState() + raise zxJDBC.Error(msg) + else: + row = tuple(self.cursor.datahandler.getPyObject(rrs, index, dbtype) + for index, dbtype in self.compiled.returning_parameters) + return ReturningResultProxy(self, row) + finally: + if rrs is not None: + try: + rrs.close() + except SQLException: + pass + self.statement.close() + + return base.ResultProxy(self) + + def create_cursor(self): + cursor = self._connection.connection.cursor() + cursor.datahandler = self.dialect.DataHandler(cursor.datahandler) + return cursor + + +class ReturningResultProxy(base.FullyBufferedResultProxy): + + """ResultProxy backed by the RETURNING ResultSet results.""" + + def __init__(self, context, returning_row): + self._returning_row = returning_row + super(ReturningResultProxy, self).__init__(context) + + def _cursor_description(self): + ret = [] + for c in self.context.compiled.returning_cols: + if hasattr(c, 'name'): + ret.append((c.name, c.type)) + else: + ret.append((c.anon_label, c.type)) + return ret + + def _buffer_rows(self): + return [self._returning_row] + + +class ReturningParam(object): + + """A bindparam value representing a RETURNING parameter. + + Specially handled by OracleReturningDataHandler. + """ + + def __init__(self, type): + self.type = type + + def __eq__(self, other): + if isinstance(other, ReturningParam): + return self.type == other.type + return NotImplemented + + def __ne__(self, other): + if isinstance(other, ReturningParam): + return self.type != other.type + return NotImplemented + + def __repr__(self): + kls = self.__class__ + return '<%s.%s object at 0x%x type=%s>' % (kls.__module__, kls.__name__, id(self), + self.type) + + +class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect): + jdbc_db_name = 'oracle' + jdbc_driver_name = 'oracle.jdbc.OracleDriver' + + statement_compiler = OracleCompiler_zxjdbc + execution_ctx_cls = OracleExecutionContext_zxjdbc + + colspecs = util.update_copy( + OracleDialect.colspecs, + { + sqltypes.Date : _ZxJDBCDate, + sqltypes.Numeric: _ZxJDBCNumeric + } + ) + + def __init__(self, *args, **kwargs): + super(OracleDialect_zxjdbc, self).__init__(*args, **kwargs) + global SQLException, zxJDBC + from java.sql import SQLException + from com.ziclix.python.sql import zxJDBC + from com.ziclix.python.sql.handler import OracleDataHandler + class OracleReturningDataHandler(OracleDataHandler): + + """zxJDBC DataHandler that specially handles ReturningParam.""" + + def setJDBCObject(self, statement, index, object, dbtype=None): + if type(object) is ReturningParam: + statement.registerReturnParameter(index, object.type) + elif dbtype is None: + OracleDataHandler.setJDBCObject(self, statement, index, object) + else: + OracleDataHandler.setJDBCObject(self, statement, index, object, dbtype) + self.DataHandler = OracleReturningDataHandler + + def initialize(self, connection): + super(OracleDialect_zxjdbc, self).initialize(connection) + self.implicit_returning = connection.connection.driverversion >= '10.2' + + def _create_jdbc_url(self, url): + return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database) + + def _get_server_version_info(self, connection): + version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1) + return tuple(int(x) for x in version.split('.')) + +dialect = OracleDialect_zxjdbc diff --git a/sqlalchemy/dialects/postgres.py b/sqlalchemy/dialects/postgres.py new file mode 100644 index 0000000..0c1d3fd --- /dev/null +++ b/sqlalchemy/dialects/postgres.py @@ -0,0 +1,10 @@ +# backwards compat with the old name +from sqlalchemy.util import warn_deprecated + +warn_deprecated( + "The SQLAlchemy PostgreSQL dialect has been renamed from 'postgres' to 'postgresql'. " + "The new URL format is postgresql[+driver]://:@/" + ) + +from sqlalchemy.dialects.postgresql import * +from sqlalchemy.dialects.postgresql import base diff --git a/sqlalchemy/dialects/postgresql/__init__.py b/sqlalchemy/dialects/postgresql/__init__.py new file mode 100644 index 0000000..6aca1e1 --- /dev/null +++ b/sqlalchemy/dialects/postgresql/__init__.py @@ -0,0 +1,14 @@ +from sqlalchemy.dialects.postgresql import base, psycopg2, pg8000, pypostgresql, zxjdbc + +base.dialect = psycopg2.dialect + +from sqlalchemy.dialects.postgresql.base import \ + INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, INET, \ + CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME,\ + DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect + +__all__ = ( +'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', 'FLOAT', 'REAL', 'INET', +'CIDR', 'UUID', 'BIT', 'MACADDR', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', +'DATE', 'BYTEA', 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect' +) diff --git a/sqlalchemy/dialects/postgresql/base.py b/sqlalchemy/dialects/postgresql/base.py new file mode 100644 index 0000000..bef2f1c --- /dev/null +++ b/sqlalchemy/dialects/postgresql/base.py @@ -0,0 +1,1161 @@ +# postgresql.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for the PostgreSQL database. + +For information on connecting using specific drivers, see the documentation section +regarding that driver. + +Sequences/SERIAL +---------------- + +PostgreSQL supports sequences, and SQLAlchemy uses these as the default means of creating +new primary key values for integer-based primary key columns. When creating tables, +SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns, +which generates a sequence corresponding to the column and associated with it based on +a naming convention. + +To specify a specific named sequence to be used for primary key generation, use the +:func:`~sqlalchemy.schema.Sequence` construct:: + + Table('sometable', metadata, + Column('id', Integer, Sequence('some_id_seq'), primary_key=True) + ) + +Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of +having the "last insert identifier" available, the sequence is executed independently +beforehand and the new value is retrieved, to be used in the subsequent insert. Note +that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using +"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior +is used. + +PostgreSQL 8.2 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports +as well. A future release of SQLA will use this feature by default in lieu of +sequence pre-execution in order to retrieve new primary key values, when available. + +INSERT/UPDATE...RETURNING +------------------------- + +The dialect supports PG 8.2's ``INSERT..RETURNING``, ``UPDATE..RETURNING`` and ``DELETE..RETURNING`` syntaxes, +but must be explicitly enabled on a per-statement basis:: + + # INSERT..RETURNING + result = table.insert().returning(table.c.col1, table.c.col2).\\ + values(name='foo') + print result.fetchall() + + # UPDATE..RETURNING + result = table.update().returning(table.c.col1, table.c.col2).\\ + where(table.c.name=='foo').values(name='bar') + print result.fetchall() + + # DELETE..RETURNING + result = table.delete().returning(table.c.col1, table.c.col2).\\ + where(table.c.name=='foo') + print result.fetchall() + +Indexes +------- + +PostgreSQL supports partial indexes. To create them pass a postgresql_where +option to the Index constructor:: + + Index('my_index', my_table.c.id, postgresql_where=tbl.c.value > 10) + +""" + +import re + +from sqlalchemy import schema as sa_schema +from sqlalchemy import sql, schema, exc, util +from sqlalchemy.engine import base, default, reflection +from sqlalchemy.sql import compiler, expression, util as sql_util +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes + +from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \ + CHAR, TEXT, FLOAT, NUMERIC, \ + DATE, BOOLEAN + +class REAL(sqltypes.Float): + __visit_name__ = "REAL" + +class BYTEA(sqltypes.LargeBinary): + __visit_name__ = 'BYTEA' + +class DOUBLE_PRECISION(sqltypes.Float): + __visit_name__ = 'DOUBLE_PRECISION' + +class INET(sqltypes.TypeEngine): + __visit_name__ = "INET" +PGInet = INET + +class CIDR(sqltypes.TypeEngine): + __visit_name__ = "CIDR" +PGCidr = CIDR + +class MACADDR(sqltypes.TypeEngine): + __visit_name__ = "MACADDR" +PGMacAddr = MACADDR + +class TIMESTAMP(sqltypes.TIMESTAMP): + def __init__(self, timezone=False, precision=None): + super(TIMESTAMP, self).__init__(timezone=timezone) + self.precision = precision + +class TIME(sqltypes.TIME): + def __init__(self, timezone=False, precision=None): + super(TIME, self).__init__(timezone=timezone) + self.precision = precision + +class INTERVAL(sqltypes.TypeEngine): + __visit_name__ = 'INTERVAL' + def __init__(self, precision=None): + self.precision = precision + + def adapt(self, impltype): + return impltype(self.precision) + + @classmethod + def _adapt_from_generic_interval(cls, interval): + return INTERVAL(precision=interval.second_precision) + + @property + def _type_affinity(self): + return sqltypes.Interval + +PGInterval = INTERVAL + +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' +PGBit = BIT + +class UUID(sqltypes.TypeEngine): + __visit_name__ = 'UUID' +PGUuid = UUID + +class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): + __visit_name__ = 'ARRAY' + + def __init__(self, item_type, mutable=True): + """Construct an ARRAY. + + E.g.:: + + Column('myarray', ARRAY(Integer)) + + Arguments are: + + :param item_type: The data type of items of this array. Note that dimensionality is + irrelevant here, so multi-dimensional arrays like ``INTEGER[][]``, are constructed as + ``ARRAY(Integer)``, not as ``ARRAY(ARRAY(Integer))`` or such. The type mapping figures + out on the fly + + :param mutable: Defaults to True: specify whether lists passed to this class should be + considered mutable. If so, generic copy operations (typically used by the ORM) will + shallow-copy values. + + """ + if isinstance(item_type, ARRAY): + raise ValueError("Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype") + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + self.mutable = mutable + + def copy_value(self, value): + if value is None: + return None + elif self.mutable: + return list(value) + else: + return value + + def compare_values(self, x, y): + return x == y + + def is_mutable(self): + return self.mutable + + def dialect_impl(self, dialect, **kwargs): + impl = super(ARRAY, self).dialect_impl(dialect, **kwargs) + if impl is self: + impl = self.__class__.__new__(self.__class__) + impl.__dict__.update(self.__dict__) + impl.item_type = self.item_type.dialect_impl(dialect) + return impl + + def adapt(self, impltype): + return impltype( + self.item_type, + mutable=self.mutable + ) + + def bind_processor(self, dialect): + item_proc = self.item_type.bind_processor(dialect) + if item_proc: + def convert_item(item): + if isinstance(item, (list, tuple)): + return [convert_item(child) for child in item] + else: + return item_proc(item) + else: + def convert_item(item): + if isinstance(item, (list, tuple)): + return [convert_item(child) for child in item] + else: + return item + def process(value): + if value is None: + return value + return [convert_item(item) for item in value] + return process + + def result_processor(self, dialect, coltype): + item_proc = self.item_type.result_processor(dialect, coltype) + if item_proc: + def convert_item(item): + if isinstance(item, list): + return [convert_item(child) for child in item] + else: + return item_proc(item) + else: + def convert_item(item): + if isinstance(item, list): + return [convert_item(child) for child in item] + else: + return item + def process(value): + if value is None: + return value + return [convert_item(item) for item in value] + return process +PGArray = ARRAY + +class ENUM(sqltypes.Enum): + + def create(self, bind=None, checkfirst=True): + if not checkfirst or not bind.dialect.has_type(bind, self.name, schema=self.schema): + bind.execute(CreateEnumType(self)) + + def drop(self, bind=None, checkfirst=True): + if not checkfirst or bind.dialect.has_type(bind, self.name, schema=self.schema): + bind.execute(DropEnumType(self)) + + def _on_table_create(self, event, target, bind, **kw): + self.create(bind=bind, checkfirst=True) + + def _on_metadata_create(self, event, target, bind, **kw): + if self.metadata is not None: + self.create(bind=bind, checkfirst=True) + + def _on_metadata_drop(self, event, target, bind, **kw): + self.drop(bind=bind, checkfirst=True) + +colspecs = { + sqltypes.Interval:INTERVAL, + sqltypes.Enum:ENUM, +} + +ischema_names = { + 'integer' : INTEGER, + 'bigint' : BIGINT, + 'smallint' : SMALLINT, + 'character varying' : VARCHAR, + 'character' : CHAR, + '"char"' : sqltypes.String, + 'name' : sqltypes.String, + 'text' : TEXT, + 'numeric' : NUMERIC, + 'float' : FLOAT, + 'real' : REAL, + 'inet': INET, + 'cidr': CIDR, + 'uuid': UUID, + 'bit':BIT, + 'macaddr': MACADDR, + 'double precision' : DOUBLE_PRECISION, + 'timestamp' : TIMESTAMP, + 'timestamp with time zone' : TIMESTAMP, + 'timestamp without time zone' : TIMESTAMP, + 'time with time zone' : TIME, + 'time without time zone' : TIME, + 'date' : DATE, + 'time': TIME, + 'bytea' : BYTEA, + 'boolean' : BOOLEAN, + 'interval':INTERVAL, + 'interval year to month':INTERVAL, + 'interval day to second':INTERVAL, +} + + + +class PGCompiler(compiler.SQLCompiler): + + def visit_match_op(self, binary, **kw): + return "%s @@ to_tsquery(%s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_ilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_notilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_sequence(self, seq): + if seq.optional: + return None + else: + return "nextval('%s')" % self.preparer.format_sequence(seq) + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT ALL" + text += " OFFSET " + str(select._offset) + return text + + def get_select_precolumns(self, select): + if select._distinct is not False: + if select._distinct is True: + return "DISTINCT " + elif isinstance(select._distinct, (list, tuple)): + return "DISTINCT ON (" + ', '.join( + [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct] + )+ ") " + else: + return "DISTINCT ON (" + unicode(select._distinct) + ") " + else: + return "" + + def for_update_clause(self, select): + if select.for_update == 'nowait': + return " FOR UPDATE NOWAIT" + else: + return super(PGCompiler, self).for_update_clause(select) + + def returning_clause(self, stmt, returning_cols): + + columns = [ + self.process( + self.label_select_column(None, c, asfrom=False), + within_columns_clause=True, + result_map=self.result_map) + for c in expression._select_iterables(returning_cols) + ] + + return 'RETURNING ' + ', '.join(columns) + + def visit_extract(self, extract, **kwargs): + field = self.extract_map.get(extract.field, extract.field) + if extract.expr.type: + affinity = extract.expr.type._type_affinity + else: + affinity = None + + casts = { + sqltypes.Date:'date', + sqltypes.DateTime:'timestamp', + sqltypes.Interval:'interval', sqltypes.Time:'time' + } + cast = casts.get(affinity, None) + if isinstance(extract.expr, sql.ColumnElement) and cast is not None: + expr = extract.expr.op('::')(sql.literal_column(cast)) + else: + expr = extract.expr + return "EXTRACT(%s FROM %s)" % ( + field, self.process(expr)) + +class PGDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + if column.primary_key and \ + len(column.foreign_keys)==0 and \ + column.autoincrement and \ + isinstance(column.type, sqltypes.Integer) and \ + not isinstance(column.type, sqltypes.SmallInteger) and \ + (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if isinstance(column.type, sqltypes.BigInteger): + colspec += " BIGSERIAL" + else: + colspec += " SERIAL" + else: + colspec += " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + return colspec + + def visit_create_enum_type(self, create): + type_ = create.element + + return "CREATE TYPE %s AS ENUM (%s)" % ( + self.preparer.format_type(type_), + ",".join("'%s'" % e for e in type_.enums) + ) + + def visit_drop_enum_type(self, drop): + type_ = drop.element + + return "DROP TYPE %s" % ( + self.preparer.format_type(type_) + ) + + def visit_create_index(self, create): + preparer = self.preparer + index = create.element + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % (preparer.quote(self._validate_identifier(index.name, True), index.quote), + preparer.format_table(index.table), + ', '.join([preparer.format_column(c) for c in index.columns])) + + if "postgres_where" in index.kwargs: + whereclause = index.kwargs['postgres_where'] + util.warn_deprecated("The 'postgres_where' argument has been renamed to 'postgresql_where'.") + elif 'postgresql_where' in index.kwargs: + whereclause = index.kwargs['postgresql_where'] + else: + whereclause = None + + if whereclause is not None: + whereclause = sql_util.expression_as_ddl(whereclause) + where_compiled = self.sql_compiler.process(whereclause) + text += " WHERE " + where_compiled + return text + + +class PGTypeCompiler(compiler.GenericTypeCompiler): + def visit_INET(self, type_): + return "INET" + + def visit_CIDR(self, type_): + return "CIDR" + + def visit_MACADDR(self, type_): + return "MACADDR" + + def visit_FLOAT(self, type_): + if not type_.precision: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': type_.precision} + + def visit_DOUBLE_PRECISION(self, type_): + return "DOUBLE PRECISION" + + def visit_BIGINT(self, type_): + return "BIGINT" + + def visit_datetime(self, type_): + return self.visit_TIMESTAMP(type_) + + def visit_enum(self, type_): + if not type_.native_enum or not self.dialect.supports_native_enum: + return super(PGTypeCompiler, self).visit_enum(type_) + else: + return self.visit_ENUM(type_) + + def visit_ENUM(self, type_): + return self.dialect.identifier_preparer.format_type(type_) + + def visit_TIMESTAMP(self, type_): + return "TIMESTAMP%s %s" % ( + getattr(type_, 'precision', None) and "(%d)" % type_.precision or "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + ) + + def visit_TIME(self, type_): + return "TIME%s %s" % ( + getattr(type_, 'precision', None) and "(%d)" % type_.precision or "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + ) + + def visit_INTERVAL(self, type_): + if type_.precision is not None: + return "INTERVAL(%d)" % type_.precision + else: + return "INTERVAL" + + def visit_BIT(self, type_): + return "BIT" + + def visit_UUID(self, type_): + return "UUID" + + def visit_large_binary(self, type_): + return self.visit_BYTEA(type_) + + def visit_BYTEA(self, type_): + return "BYTEA" + + def visit_REAL(self, type_): + return "REAL" + + def visit_ARRAY(self, type_): + return self.process(type_.item_type) + '[]' + + +class PGIdentifierPreparer(compiler.IdentifierPreparer): + def _unquote_identifier(self, value): + if value[0] == self.initial_quote: + value = value[1:-1].replace(self.escape_to_quote, self.escape_quote) + return value + + def format_type(self, type_, use_schema=True): + if not type_.name: + raise exc.ArgumentError("Postgresql ENUM type requires a name.") + + name = self.quote(type_.name, type_.quote) + if not self.omit_schema and use_schema and type_.schema is not None: + name = self.quote_schema(type_.schema, type_.quote) + "." + name + return name + +class PGInspector(reflection.Inspector): + + def __init__(self, conn): + reflection.Inspector.__init__(self, conn) + + def get_table_oid(self, table_name, schema=None): + """Return the oid from `table_name` and `schema`.""" + + return self.dialect.get_table_oid(self.conn, table_name, schema, + info_cache=self.info_cache) + +class CreateEnumType(schema._CreateDropBase): + __visit_name__ = "create_enum_type" + +class DropEnumType(schema._CreateDropBase): + __visit_name__ = "drop_enum_type" + +class PGExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq): + if not seq.optional: + return self._execute_scalar(("select nextval('%s')" % \ + self.dialect.identifier_preparer.format_sequence(seq))) + else: + return None + + def get_insert_default(self, column): + if column.primary_key: + if (isinstance(column.server_default, schema.DefaultClause) and + column.server_default.arg is not None): + + # pre-execute passive defaults on primary key columns + return self._execute_scalar("select %s" % column.server_default.arg) + + elif column is column.table._autoincrement_column \ + and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + + # execute the sequence associated with a SERIAL primary key column. + # for non-primary-key SERIAL, the ID just generates server side. + sch = column.table.schema + + if sch is not None: + exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) + else: + exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) + + return self._execute_scalar(exc) + + return super(PGExecutionContext, self).get_insert_default(column) + +class PGDialect(default.DefaultDialect): + name = 'postgresql' + supports_alter = True + max_identifier_length = 63 + supports_sane_rowcount = True + + supports_native_enum = True + supports_native_boolean = True + + supports_sequences = True + sequences_optional = True + preexecute_autoincrement_sequences = True + postfetch_lastrowid = False + + supports_default_values = True + supports_empty_insert = False + default_paramstyle = 'pyformat' + ischema_names = ischema_names + colspecs = colspecs + + statement_compiler = PGCompiler + ddl_compiler = PGDDLCompiler + type_compiler = PGTypeCompiler + preparer = PGIdentifierPreparer + execution_ctx_cls = PGExecutionContext + inspector = PGInspector + isolation_level = None + + def __init__(self, isolation_level=None, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + self.isolation_level = isolation_level + + def initialize(self, connection): + super(PGDialect, self).initialize(connection) + self.implicit_returning = self.server_version_info > (8, 2) and \ + self.__dict__.get('implicit_returning', True) + self.supports_native_enum = self.server_version_info >= (8, 3) + if not self.supports_native_enum: + self.colspecs = self.colspecs.copy() + del self.colspecs[ENUM] + + def on_connect(self): + if self.isolation_level is not None: + def connect(conn): + cursor = conn.cursor() + cursor.execute("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL %s" + % self.isolation_level) + cursor.execute("COMMIT") + cursor.close() + return connect + else: + return None + + def do_begin_twophase(self, connection, xid): + self.do_begin(connection.connection) + + def do_prepare_twophase(self, connection, xid): + connection.execute("PREPARE TRANSACTION '%s'" % xid) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions + # Must find out a way how to make the dbapi not open a transaction. + connection.execute("ROLLBACK") + connection.execute("ROLLBACK PREPARED '%s'" % xid) + connection.execute("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + connection.execute("ROLLBACK") + connection.execute("COMMIT PREPARED '%s'" % xid) + connection.execute("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) + return [row[0] for row in resultset] + + def _get_default_schema_name(self, connection): + return connection.scalar("select current_schema()") + + def has_table(self, connection, table_name, schema=None): + # seems like case gets folded in pg_class... + if schema is None: + cursor = connection.execute( + sql.text("select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=current_schema() and " + "lower(relname)=:name", + bindparams=[ + sql.bindparam('name', unicode(table_name.lower()), + type_=sqltypes.Unicode)] + ) + ) + else: + cursor = connection.execute( + sql.text("select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=:schema and lower(relname)=:name", + bindparams=[ + sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode), + sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)] + ) + ) + return bool(cursor.first()) + + def has_sequence(self, connection, sequence_name, schema=None): + if schema is None: + cursor = connection.execute( + sql.text("SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and n.nspname=current_schema()" + " and lower(relname)=:name", + bindparams=[ + sql.bindparam('name', unicode(sequence_name.lower()), + type_=sqltypes.Unicode) + ] + ) + ) + else: + cursor = connection.execute( + sql.text("SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and n.nspname=:schema and " + "lower(relname)=:name", + bindparams=[ + sql.bindparam('name', unicode(sequence_name.lower()), + type_=sqltypes.Unicode), + sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode) + ] + ) + ) + + return bool(cursor.first()) + + def has_type(self, connection, type_name, schema=None): + bindparams = [ + sql.bindparam('typname', + unicode(type_name), type_=sqltypes.Unicode), + sql.bindparam('nspname', + unicode(schema), type_=sqltypes.Unicode), + ] + if schema is not None: + query = """ + SELECT EXISTS ( + SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n + WHERE t.typnamespace = n.oid + AND t.typname = :typname + AND n.nspname = :nspname + ) + """ + else: + query = """ + SELECT EXISTS ( + SELECT * FROM pg_catalog.pg_type t + WHERE t.typname = :typname + AND pg_type_is_visible(t.oid) + ) + """ + cursor = connection.execute(sql.text(query, bindparams=bindparams)) + return bool(cursor.scalar()) + + def _get_server_version_info(self, connection): + v = connection.execute("select version()").scalar() + m = re.match('PostgreSQL (\d+)\.(\d+)(?:\.(\d+))?(?:devel)?', v) + if not m: + raise AssertionError("Could not determine version from string '%s'" % v) + return tuple([int(x) for x in m.group(1, 2, 3) if x is not None]) + + @reflection.cache + def get_table_oid(self, connection, table_name, schema=None, **kw): + """Fetch the oid for schema.table_name. + + Several reflection methods require the table oid. The idea for using + this method is that it can be fetched one time and cached for + subsequent calls. + + """ + table_oid = None + if schema is not None: + schema_where_clause = "n.nspname = :schema" + else: + schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" + query = """ + SELECT c.oid + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE (%s) + AND c.relname = :table_name AND c.relkind in ('r','v') + """ % schema_where_clause + # Since we're binding to unicode, table_name and schema_name must be + # unicode. + table_name = unicode(table_name) + if schema is not None: + schema = unicode(schema) + s = sql.text(query, bindparams=[ + sql.bindparam('table_name', type_=sqltypes.Unicode), + sql.bindparam('schema', type_=sqltypes.Unicode) + ], + typemap={'oid':sqltypes.Integer} + ) + c = connection.execute(s, table_name=table_name, schema=schema) + table_oid = c.scalar() + if table_oid is None: + raise exc.NoSuchTableError(table_name) + return table_oid + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = """ + SELECT nspname + FROM pg_namespace + ORDER BY nspname + """ + rp = connection.execute(s) + # what about system tables? + # Py3K + #schema_names = [row[0] for row in rp \ + # if not row[0].startswith('pg_')] + # Py2K + schema_names = [row[0].decode(self.encoding) for row in rp \ + if not row[0].startswith('pg_')] + # end Py2K + return schema_names + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + + result = connection.execute( + sql.text(u"SELECT relname FROM pg_class c " + "WHERE relkind = 'r' " + "AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) " % + current_schema, + typemap = {'relname':sqltypes.Unicode} + ) + ) + return [row[0] for row in result] + + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + s = """ + SELECT relname + FROM pg_class c + WHERE relkind = 'v' + AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) + """ % dict(schema=current_schema) + # Py3K + #view_names = [row[0] for row in connection.execute(s)] + # Py2K + view_names = [row[0].decode(self.encoding) for row in connection.execute(s)] + # end Py2K + return view_names + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + s = """ + SELECT definition FROM pg_views + WHERE schemaname = :schema + AND viewname = :view_name + """ + rp = connection.execute(sql.text(s), + view_name=view_name, schema=current_schema) + if rp: + # Py3K + #view_def = rp.scalar() + # Py2K + view_def = rp.scalar().decode(self.encoding) + # end Py2K + return view_def + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + SQL_COLS = """ + SELECT a.attname, + pg_catalog.format_type(a.atttypid, a.atttypmod), + (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d + WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef) + AS DEFAULT, + a.attnotnull, a.attnum, a.attrelid as table_oid + FROM pg_catalog.pg_attribute a + WHERE a.attrelid = :table_oid + AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum + """ + s = sql.text(SQL_COLS, + bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)], + typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode} + ) + c = connection.execute(s, table_oid=table_oid) + rows = c.fetchall() + domains = self._load_domains(connection) + enums = self._load_enums(connection) + + # format columns + columns = [] + for name, format_type, default, notnull, attnum, table_oid in rows: + ## strip (5) from character varying(5), timestamp(5) with time zone, etc + attype = re.sub(r'\([\d,]+\)', '', format_type) + + # strip '[]' from integer[], etc. + attype = re.sub(r'\[\]', '', attype) + + nullable = not notnull + is_array = format_type.endswith('[]') + charlen = re.search('\(([\d,]+)\)', format_type) + if charlen: + charlen = charlen.group(1) + kwargs = {} + + if attype == 'numeric': + if charlen: + prec, scale = charlen.split(',') + args = (int(prec), int(scale)) + else: + args = () + elif attype == 'double precision': + args = (53, ) + elif attype == 'integer': + args = (32, 0) + elif attype in ('timestamp with time zone', 'time with time zone'): + kwargs['timezone'] = True + if charlen: + kwargs['precision'] = int(charlen) + args = () + elif attype in ('timestamp without time zone', 'time without time zone', 'time'): + kwargs['timezone'] = False + if charlen: + kwargs['precision'] = int(charlen) + args = () + elif attype in ('interval','interval year to month','interval day to second'): + if charlen: + kwargs['precision'] = int(charlen) + args = () + elif charlen: + args = (int(charlen),) + else: + args = () + + if attype in self.ischema_names: + coltype = self.ischema_names[attype] + elif attype in enums: + enum = enums[attype] + coltype = ENUM + if "." in attype: + kwargs['schema'], kwargs['name'] = attype.split('.') + else: + kwargs['name'] = attype + args = tuple(enum['labels']) + elif attype in domains: + domain = domains[attype] + if domain['attype'] in self.ischema_names: + # A table can't override whether the domain is nullable. + nullable = domain['nullable'] + if domain['default'] and not default: + # It can, however, override the default value, but can't set it to null. + default = domain['default'] + coltype = self.ischema_names[domain['attype']] + else: + coltype = None + + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = ARRAY(coltype) + else: + util.warn("Did not recognize type '%s' of column '%s'" % + (attype, name)) + coltype = sqltypes.NULLTYPE + # adjust the default value + autoincrement = False + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + autoincrement = True + # the default is related to a Sequence + sch = schema + if '.' not in match.group(2) and sch is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / "quote schema" + default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) + + column_info = dict(name=name, type=coltype, nullable=nullable, + default=default, autoincrement=autoincrement) + columns.append(column_info) + return columns + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + PK_SQL = """ + SELECT attname FROM pg_attribute + WHERE attrelid = ( + SELECT indexrelid FROM pg_index i + WHERE i.indrelid = :table_oid + AND i.indisprimary = 't') + ORDER BY attnum + """ + t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + primary_keys = [r[0] for r in c.fetchall()] + return primary_keys + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + preparer = self.identifier_preparer + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + FK_SQL = """ + SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef + FROM pg_catalog.pg_constraint r + WHERE r.conrelid = :table AND r.contype = 'f' + ORDER BY 1 + """ + + t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode}) + c = connection.execute(t, table=table_oid) + fkeys = [] + for conname, condef in c.fetchall(): + m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups() + (constrained_columns, referred_schema, referred_table, referred_columns) = m + constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] + if referred_schema: + referred_schema = preparer._unquote_identifier(referred_schema) + elif schema is not None and schema == self.default_schema_name: + # no schema (i.e. its the default schema), and the table we're + # reflecting has the default schema explicit, then use that. + # i.e. try to use the user's conventions + referred_schema = schema + referred_table = preparer._unquote_identifier(referred_table) + referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] + fkey_d = { + 'name' : conname, + 'constrained_columns' : constrained_columns, + 'referred_schema' : referred_schema, + 'referred_table' : referred_table, + 'referred_columns' : referred_columns + } + fkeys.append(fkey_d) + return fkeys + + @reflection.cache + def get_indexes(self, connection, table_name, schema, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + IDX_SQL = """ + SELECT c.relname, i.indisunique, i.indexprs, i.indpred, + a.attname + FROM pg_index i, pg_class c, pg_attribute a + WHERE i.indrelid = :table_oid AND i.indexrelid = c.oid + AND a.attrelid = i.indexrelid AND i.indisprimary = 'f' + ORDER BY c.relname, a.attnum + """ + t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + index_names = {} + indexes = [] + sv_idx_name = None + for row in c.fetchall(): + idx_name, unique, expr, prd, col = row + if expr: + if idx_name != sv_idx_name: + util.warn( + "Skipped unsupported reflection of expression-based index %s" + % idx_name) + sv_idx_name = idx_name + continue + if prd and not idx_name == sv_idx_name: + util.warn( + "Predicate of partial index %s ignored during reflection" + % idx_name) + sv_idx_name = idx_name + if idx_name in index_names: + index_d = index_names[idx_name] + else: + index_d = {'column_names':[]} + indexes.append(index_d) + index_names[idx_name] = index_d + index_d['name'] = idx_name + index_d['column_names'].append(col) + index_d['unique'] = unique + return indexes + + def _load_enums(self, connection): + if not self.supports_native_enum: + return {} + + ## Load data types for enums: + SQL_ENUMS = """ + SELECT t.typname as "name", + -- t.typdefault as "default", -- no enum defaults in 8.4 at least + pg_catalog.pg_type_is_visible(t.oid) as "visible", + n.nspname as "schema", + e.enumlabel as "label" + FROM pg_catalog.pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + LEFT JOIN pg_catalog.pg_constraint r ON t.oid = r.contypid + LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid + WHERE t.typtype = 'e' + ORDER BY "name", e.oid -- e.oid gives us label order + """ + + s = sql.text(SQL_ENUMS, typemap={'attname':sqltypes.Unicode, 'label':sqltypes.Unicode}) + c = connection.execute(s) + + enums = {} + for enum in c.fetchall(): + if enum['visible']: + # 'visible' just means whether or not the enum is in a + # schema that's on the search path -- or not overriden by + # a schema with higher presedence. If it's not visible, + # it will be prefixed with the schema-name when it's used. + name = enum['name'] + else: + name = "%s.%s" % (enum['schema'], enum['name']) + + if name in enums: + enums[name]['labels'].append(enum['label']) + else: + enums[name] = { + 'labels': [enum['label']], + } + + return enums + + def _load_domains(self, connection): + ## Load data types for domains: + SQL_DOMAINS = """ + SELECT t.typname as "name", + pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", + not t.typnotnull as "nullable", + t.typdefault as "default", + pg_catalog.pg_type_is_visible(t.oid) as "visible", + n.nspname as "schema" + FROM pg_catalog.pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + LEFT JOIN pg_catalog.pg_constraint r ON t.oid = r.contypid + WHERE t.typtype = 'd' + """ + + s = sql.text(SQL_DOMAINS, typemap={'attname':sqltypes.Unicode}) + c = connection.execute(s) + + domains = {} + for domain in c.fetchall(): + ## strip (30) from character varying(30) + attype = re.search('([^\(]+)', domain['attype']).group(1) + if domain['visible']: + # 'visible' just means whether or not the domain is in a + # schema that's on the search path -- or not overriden by + # a schema with higher presedence. If it's not visible, + # it will be prefixed with the schema-name when it's used. + name = domain['name'] + else: + name = "%s.%s" % (domain['schema'], domain['name']) + + domains[name] = { + 'attype':attype, + 'nullable': domain['nullable'], + 'default': domain['default'] + } + + return domains + diff --git a/sqlalchemy/dialects/postgresql/pg8000.py b/sqlalchemy/dialects/postgresql/pg8000.py new file mode 100644 index 0000000..a620daa --- /dev/null +++ b/sqlalchemy/dialects/postgresql/pg8000.py @@ -0,0 +1,105 @@ +"""Support for the PostgreSQL database via the pg8000 driver. + +Connecting +---------- + +URLs are of the form +`postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]`. + +Unicode +------- + +pg8000 requires that the postgresql client encoding be configured in the postgresql.conf file +in order to use encodings other than ascii. Set this value to the same value as +the "encoding" parameter on create_engine(), usually "utf-8". + +Interval +-------- + +Passing data from/to the Interval type is not supported as of yet. + +""" +import decimal + +from sqlalchemy.engine import default +from sqlalchemy import util, exc +from sqlalchemy import processors +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgresql.base import PGDialect, \ + PGCompiler, PGIdentifierPreparer, PGExecutionContext + +class _PGNumeric(sqltypes.Numeric): + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in (700, 701): + return processors.to_decimal_processor_factory(decimal.Decimal) + elif coltype == 1700: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) + else: + if coltype in (700, 701): + # pg8000 returns float natively for 701 + return None + elif coltype == 1700: + return processors.to_float + else: + raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) + +class PGExecutionContext_pg8000(PGExecutionContext): + pass + + +class PGCompiler_pg8000(PGCompiler): + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) + + def post_process_text(self, text): + if '%%' in text: + util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() " + "expressions to '%%'.") + return text.replace('%', '%%') + + +class PGIdentifierPreparer_pg8000(PGIdentifierPreparer): + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value.replace('%', '%%') + + +class PGDialect_pg8000(PGDialect): + driver = 'pg8000' + + supports_unicode_statements = True + + supports_unicode_binds = True + + default_paramstyle = 'format' + supports_sane_multi_rowcount = False + execution_ctx_cls = PGExecutionContext_pg8000 + statement_compiler = PGCompiler_pg8000 + preparer = PGIdentifierPreparer_pg8000 + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric : _PGNumeric, + } + ) + + @classmethod + def dbapi(cls): + return __import__('pg8000').dbapi + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e): + return "connection is closed" in str(e) + +dialect = PGDialect_pg8000 diff --git a/sqlalchemy/dialects/postgresql/psycopg2.py b/sqlalchemy/dialects/postgresql/psycopg2.py new file mode 100644 index 0000000..f21c9a5 --- /dev/null +++ b/sqlalchemy/dialects/postgresql/psycopg2.py @@ -0,0 +1,239 @@ +"""Support for the PostgreSQL database via the psycopg2 driver. + +Driver +------ + +The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ . +The dialect has several behaviors which are specifically tailored towards compatibility +with this module. + +Note that psycopg1 is **not** supported. + +Connecting +---------- + +URLs are of the form `postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]`. + +psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are: + +* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support + this feature. What this essentially means from a psycopg2 point of view is that the cursor is + created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows + are not immediately pre-fetched and buffered after statement execution, but are instead left + on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy` + uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows + at a time are fetched over the wire to reduce conversational overhead. +* *use_native_unicode* - Enable the usage of Psycopg2 "native unicode" mode per connection. True + by default. +* *isolation_level* - Sets the transaction isolation level for each transaction + within the engine. Valid isolation levels are `READ_COMMITTED`, + `READ_UNCOMMITTED`, `REPEATABLE_READ`, and `SERIALIZABLE`. + +Transactions +------------ + +The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. + +NOTICE logging +--------------- + +The psycopg2 dialect will log Postgresql NOTICE messages via the +``sqlalchemy.dialects.postgresql`` logger:: + + import logging + logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) + + +Per-Statement Execution Options +------------------------------- + +The following per-statement execution options are respected: + +* *stream_results* - Enable or disable usage of server side cursors for the SELECT-statement. + If *None* or not set, the *server_side_cursors* option of the connection is used. If + auto-commit is enabled, the option is ignored. + +""" + +import random +import re +import decimal +import logging + +from sqlalchemy import util +from sqlalchemy import processors +from sqlalchemy.engine import base, default +from sqlalchemy.sql import expression +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, \ + PGIdentifierPreparer, PGExecutionContext, \ + ENUM, ARRAY + + +logger = logging.getLogger('sqlalchemy.dialects.postgresql') + + +class _PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in (700, 701): + return processors.to_decimal_processor_factory(decimal.Decimal) + elif coltype == 1700: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) + else: + if coltype in (700, 701): + # pg8000 returns float natively for 701 + return None + elif coltype == 1700: + return processors.to_float + else: + raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) + +class _PGEnum(ENUM): + def __init__(self, *arg, **kw): + super(_PGEnum, self).__init__(*arg, **kw) + if self.convert_unicode: + self.convert_unicode = "force" + +class _PGArray(ARRAY): + def __init__(self, *arg, **kw): + super(_PGArray, self).__init__(*arg, **kw) + # FIXME: this check won't work for setups that + # have convert_unicode only on their create_engine(). + if isinstance(self.item_type, sqltypes.String) and \ + self.item_type.convert_unicode: + self.item_type.convert_unicode = "force" + +# When we're handed literal SQL, ensure it's a SELECT-query. Since +# 8.3, combining cursors and "FOR UPDATE" has been fine. +SERVER_SIDE_CURSOR_RE = re.compile( + r'\s*SELECT', + re.I | re.UNICODE) + +class PGExecutionContext_psycopg2(PGExecutionContext): + def create_cursor(self): + # TODO: coverage for server side cursors + select.for_update() + + if self.dialect.server_side_cursors: + is_server_side = \ + self.execution_options.get('stream_results', True) and ( + (self.compiled and isinstance(self.compiled.statement, expression.Selectable) \ + or \ + ( + (not self.compiled or + isinstance(self.compiled.statement, expression._TextClause)) + and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)) + ) + ) + else: + is_server_side = self.execution_options.get('stream_results', False) + + self.__is_server_side = is_server_side + if is_server_side: + # use server-side cursors: + # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html + ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:]) + return self._connection.connection.cursor(ident) + else: + return self._connection.connection.cursor() + + def get_result_proxy(self): + if logger.isEnabledFor(logging.INFO): + self._log_notices(self.cursor) + + if self.__is_server_side: + return base.BufferedRowResultProxy(self) + else: + return base.ResultProxy(self) + + def _log_notices(self, cursor): + for notice in cursor.connection.notices: + # NOTICE messages have a + # newline character at the end + logger.info(notice.rstrip()) + + cursor.connection.notices[:] = [] + + +class PGCompiler_psycopg2(PGCompiler): + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) + + def post_process_text(self, text): + return text.replace('%', '%%') + + +class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value.replace('%', '%%') + +class PGDialect_psycopg2(PGDialect): + driver = 'psycopg2' + supports_unicode_statements = False + default_paramstyle = 'pyformat' + supports_sane_multi_rowcount = False + execution_ctx_cls = PGExecutionContext_psycopg2 + statement_compiler = PGCompiler_psycopg2 + preparer = PGIdentifierPreparer_psycopg2 + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric : _PGNumeric, + ENUM : _PGEnum, # needs force_unicode + sqltypes.Enum : _PGEnum, # needs force_unicode + ARRAY : _PGArray, # needs force_unicode + } + ) + + def __init__(self, server_side_cursors=False, use_native_unicode=True, **kwargs): + PGDialect.__init__(self, **kwargs) + self.server_side_cursors = server_side_cursors + self.use_native_unicode = use_native_unicode + self.supports_unicode_binds = use_native_unicode + + @classmethod + def dbapi(cls): + psycopg = __import__('psycopg2') + return psycopg + + def on_connect(self): + base_on_connect = super(PGDialect_psycopg2, self).on_connect() + if self.dbapi and self.use_native_unicode: + extensions = __import__('psycopg2.extensions').extensions + def connect(conn): + extensions.register_type(extensions.UNICODE, conn) + if base_on_connect: + base_on_connect(conn) + return connect + else: + return base_on_connect + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'closed the connection' in str(e) or 'connection not open' in str(e) + elif isinstance(e, self.dbapi.InterfaceError): + return 'connection already closed' in str(e) or 'cursor already closed' in str(e) + elif isinstance(e, self.dbapi.ProgrammingError): + # yes, it really says "losed", not "closed" + return "losed the connection unexpectedly" in str(e) + else: + return False + +dialect = PGDialect_psycopg2 + diff --git a/sqlalchemy/dialects/postgresql/pypostgresql.py b/sqlalchemy/dialects/postgresql/pypostgresql.py new file mode 100644 index 0000000..2e7ea20 --- /dev/null +++ b/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -0,0 +1,69 @@ +"""Support for the PostgreSQL database via py-postgresql. + +Connecting +---------- + +URLs are of the form `postgresql+pypostgresql://user@password@host:port/dbname[?key=value&key=value...]`. + + +""" +from sqlalchemy.engine import default +import decimal +from sqlalchemy import util +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext +from sqlalchemy import processors + +class PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return processors.to_str + + def result_processor(self, dialect, coltype): + if self.asdecimal: + return None + else: + return processors.to_float + +class PGExecutionContext_pypostgresql(PGExecutionContext): + pass + +class PGDialect_pypostgresql(PGDialect): + driver = 'pypostgresql' + + supports_unicode_statements = True + supports_unicode_binds = True + description_encoding = None + default_paramstyle = 'pyformat' + + # requires trunk version to support sane rowcounts + # TODO: use dbapi version information to set this flag appropariately + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + execution_ctx_cls = PGExecutionContext_pypostgresql + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric : PGNumeric, + sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used + } + ) + + @classmethod + def dbapi(cls): + from postgresql.driver import dbapi20 + return dbapi20 + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + else: + opts['port'] = 5432 + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e): + return "connection is closed" in str(e) + +dialect = PGDialect_pypostgresql diff --git a/sqlalchemy/dialects/postgresql/zxjdbc.py b/sqlalchemy/dialects/postgresql/zxjdbc.py new file mode 100644 index 0000000..a886901 --- /dev/null +++ b/sqlalchemy/dialects/postgresql/zxjdbc.py @@ -0,0 +1,19 @@ +"""Support for the PostgreSQL database via the zxjdbc JDBC connector. + +JDBC Driver +----------- + +The official Postgresql JDBC driver is at http://jdbc.postgresql.org/. + +""" +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.dialects.postgresql.base import PGDialect + +class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect): + jdbc_db_name = 'postgresql' + jdbc_driver_name = 'org.postgresql.Driver' + + def _get_server_version_info(self, connection): + return tuple(int(x) for x in connection.connection.dbversion.split('.')) + +dialect = PGDialect_zxjdbc diff --git a/sqlalchemy/dialects/sqlite/__init__.py b/sqlalchemy/dialects/sqlite/__init__.py new file mode 100644 index 0000000..fbbde17 --- /dev/null +++ b/sqlalchemy/dialects/sqlite/__init__.py @@ -0,0 +1,14 @@ +from sqlalchemy.dialects.sqlite import base, pysqlite + +# default dialect +base.dialect = pysqlite.dialect + + +from sqlalchemy.dialects.sqlite.base import \ + BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, FLOAT, INTEGER,\ + NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, dialect + +__all__ = ( + 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'FLOAT', 'INTEGER', + 'NUMERIC', 'SMALLINT', 'TEXT', 'TIME', 'TIMESTAMP', 'VARCHAR', 'dialect' +) \ No newline at end of file diff --git a/sqlalchemy/dialects/sqlite/base.py b/sqlalchemy/dialects/sqlite/base.py new file mode 100644 index 0000000..ca0a391 --- /dev/null +++ b/sqlalchemy/dialects/sqlite/base.py @@ -0,0 +1,596 @@ +# sqlite.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Support for the SQLite database. + +For information on connecting using a specific driver, see the documentation +section regarding that driver. + +Date and Time Types +------------------- + +SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide +out of the box functionality for translating values between Python `datetime` objects +and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime` +and related types provide date formatting and parsing functionality when SQlite is used. +The implementation classes are :class:`DATETIME`, :class:`DATE` and :class:`TIME`. +These types represent dates and times as ISO formatted strings, which also nicely +support ordering. There's no reliance on typical "libc" internals for these functions +so historical dates are fully supported. + +Auto Incrementing Behavior +-------------------------- + +Background on SQLite's autoincrement is at: http://sqlite.org/autoinc.html + +Two things to note: + +* The AUTOINCREMENT keyword is **not** required for SQLite tables to + generate primary key values automatically. AUTOINCREMENT only means that + the algorithm used to generate ROWID values should be slightly different. +* SQLite does **not** generate primary key (i.e. ROWID) values, even for + one column, if the table has a composite (i.e. multi-column) primary key. + This is regardless of the AUTOINCREMENT keyword being present or not. + +To specifically render the AUTOINCREMENT keyword on the primary key +column when rendering DDL, add the flag ``sqlite_autoincrement=True`` +to the Table construct:: + + Table('sometable', metadata, + Column('id', Integer, primary_key=True), + sqlite_autoincrement=True) + +""" + +import datetime, re, time + +from sqlalchemy import schema as sa_schema +from sqlalchemy import sql, exc, pool, DefaultClause +from sqlalchemy.engine import default +from sqlalchemy.engine import reflection +from sqlalchemy import types as sqltypes +from sqlalchemy import util +from sqlalchemy.sql import compiler, functions as sql_functions +from sqlalchemy.util import NoneType +from sqlalchemy import processors + +from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\ + FLOAT, INTEGER, NUMERIC, SMALLINT, TEXT, TIME,\ + TIMESTAMP, VARCHAR + + +class _DateTimeMixin(object): + _reg = None + _storage_format = None + + def __init__(self, storage_format=None, regexp=None, **kwargs): + if regexp is not None: + self._reg = re.compile(regexp) + if storage_format is not None: + self._storage_format = storage_format + +class DATETIME(_DateTimeMixin, sqltypes.DateTime): + _storage_format = "%04d-%02d-%02d %02d:%02d:%02d.%06d" + + def bind_processor(self, dialect): + datetime_datetime = datetime.datetime + datetime_date = datetime.date + format = self._storage_format + def process(value): + if value is None: + return None + elif isinstance(value, datetime_datetime): + return format % (value.year, value.month, value.day, + value.hour, value.minute, value.second, + value.microsecond) + elif isinstance(value, datetime_date): + return format % (value.year, value.month, value.day, + 0, 0, 0, 0) + else: + raise TypeError("SQLite DateTime type only accepts Python " + "datetime and date objects as input.") + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.datetime) + else: + return processors.str_to_datetime + +class DATE(_DateTimeMixin, sqltypes.Date): + _storage_format = "%04d-%02d-%02d" + + def bind_processor(self, dialect): + datetime_date = datetime.date + format = self._storage_format + def process(value): + if value is None: + return None + elif isinstance(value, datetime_date): + return format % (value.year, value.month, value.day) + else: + raise TypeError("SQLite Date type only accepts Python " + "date objects as input.") + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.date) + else: + return processors.str_to_date + +class TIME(_DateTimeMixin, sqltypes.Time): + _storage_format = "%02d:%02d:%02d.%06d" + + def bind_processor(self, dialect): + datetime_time = datetime.time + format = self._storage_format + def process(value): + if value is None: + return None + elif isinstance(value, datetime_time): + return format % (value.hour, value.minute, value.second, + value.microsecond) + else: + raise TypeError("SQLite Time type only accepts Python " + "time objects as input.") + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.time) + else: + return processors.str_to_time + +colspecs = { + sqltypes.Date: DATE, + sqltypes.DateTime: DATETIME, + sqltypes.Time: TIME, +} + +ischema_names = { + 'BLOB': sqltypes.BLOB, + 'BOOL': sqltypes.BOOLEAN, + 'BOOLEAN': sqltypes.BOOLEAN, + 'CHAR': sqltypes.CHAR, + 'DATE': sqltypes.DATE, + 'DATETIME': sqltypes.DATETIME, + 'DECIMAL': sqltypes.DECIMAL, + 'FLOAT': sqltypes.FLOAT, + 'INT': sqltypes.INTEGER, + 'INTEGER': sqltypes.INTEGER, + 'NUMERIC': sqltypes.NUMERIC, + 'REAL': sqltypes.Numeric, + 'SMALLINT': sqltypes.SMALLINT, + 'TEXT': sqltypes.TEXT, + 'TIME': sqltypes.TIME, + 'TIMESTAMP': sqltypes.TIMESTAMP, + 'VARCHAR': sqltypes.VARCHAR, +} + + + +class SQLiteCompiler(compiler.SQLCompiler): + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + 'month': '%m', + 'day': '%d', + 'year': '%Y', + 'second': '%S', + 'hour': '%H', + 'doy': '%j', + 'minute': '%M', + 'epoch': '%s', + 'dow': '%w', + 'week': '%W' + }) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_char_length_func(self, fn, **kw): + return "length%s" % self.function_argspec(fn) + + def visit_cast(self, cast, **kwargs): + if self.dialect.supports_cast: + return super(SQLiteCompiler, self).visit_cast(cast) + else: + return self.process(cast.clause) + + def visit_extract(self, extract, **kw): + try: + return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( + self.extract_map[extract.field], self.process(extract.expr, **kw)) + except KeyError: + raise exc.ArgumentError( + "%s is not a valid extract argument." % extract.field) + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT -1" + text += " OFFSET " + str(select._offset) + else: + text += " OFFSET 0" + return text + + def for_update_clause(self, select): + # sqlite has no "FOR UPDATE" AFAICT + return '' + + +class SQLiteDDLCompiler(compiler.DDLCompiler): + + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + + if column.primary_key and \ + column.table.kwargs.get('sqlite_autoincrement', False) and \ + len(column.table.primary_key.columns) == 1 and \ + isinstance(column.type, sqltypes.Integer) and \ + not column.foreign_keys: + colspec += " PRIMARY KEY AUTOINCREMENT" + + return colspec + + def visit_primary_key_constraint(self, constraint): + # for columns with sqlite_autoincrement=True, + # the PRIMARY KEY constraint can only be inline + # with the column itself. + if len(constraint.columns) == 1: + c = list(constraint)[0] + if c.primary_key and \ + c.table.kwargs.get('sqlite_autoincrement', False) and \ + isinstance(c.type, sqltypes.Integer) and \ + not c.foreign_keys: + return '' + + return super(SQLiteDDLCompiler, self).\ + visit_primary_key_constraint(constraint) + + + def visit_create_index(self, create): + index = create.element + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % (preparer.format_index(index, + name=self._validate_identifier(index.name, True)), + preparer.format_table(index.table, use_schema=False), + ', '.join(preparer.quote(c.name, c.quote) + for c in index.columns)) + return text + +class SQLiteTypeCompiler(compiler.GenericTypeCompiler): + def visit_large_binary(self, type_): + return self.visit_BLOB(type_) + +class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = set([ + 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', + 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', + 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', + 'conflict', 'constraint', 'create', 'cross', 'current_date', + 'current_time', 'current_timestamp', 'database', 'default', + 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', + 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', + 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', + 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', + 'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is', + 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural', + 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer', + 'plan', 'pragma', 'primary', 'query', 'raise', 'references', + 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback', + 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', + 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using', + 'vacuum', 'values', 'view', 'virtual', 'when', 'where', + ]) + + def format_index(self, index, use_schema=True, name=None): + """Prepare a quoted index and schema name.""" + + if name is None: + name = index.name + result = self.quote(name, index.quote) + if not self.omit_schema and use_schema and getattr(index.table, "schema", None): + result = self.quote_schema(index.table.schema, index.table.quote_schema) + "." + result + return result + +class SQLiteDialect(default.DefaultDialect): + name = 'sqlite' + supports_alter = False + supports_unicode_statements = True + supports_unicode_binds = True + supports_default_values = True + supports_empty_insert = False + supports_cast = True + + default_paramstyle = 'qmark' + statement_compiler = SQLiteCompiler + ddl_compiler = SQLiteDDLCompiler + type_compiler = SQLiteTypeCompiler + preparer = SQLiteIdentifierPreparer + ischema_names = ischema_names + colspecs = colspecs + isolation_level = None + + supports_cast = True + supports_default_values = True + + def __init__(self, isolation_level=None, native_datetime=False, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + if isolation_level and isolation_level not in ('SERIALIZABLE', + 'READ UNCOMMITTED'): + raise exc.ArgumentError("Invalid value for isolation_level. " + "Valid isolation levels for sqlite are 'SERIALIZABLE' and " + "'READ UNCOMMITTED'.") + self.isolation_level = isolation_level + + # this flag used by pysqlite dialect, and perhaps others in the + # future, to indicate the driver is handling date/timestamp + # conversions (and perhaps datetime/time as well on some + # hypothetical driver ?) + self.native_datetime = native_datetime + + if self.dbapi is not None: + self.supports_default_values = \ + self.dbapi.sqlite_version_info >= (3, 3, 8) + self.supports_cast = \ + self.dbapi.sqlite_version_info >= (3, 2, 3) + + + def on_connect(self): + if self.isolation_level is not None: + if self.isolation_level == 'READ UNCOMMITTED': + isolation_level = 1 + else: + isolation_level = 0 + + def connect(conn): + cursor = conn.cursor() + cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level) + cursor.close() + return connect + else: + return None + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT name FROM %s " + "WHERE type='table' ORDER BY name") % (master,) + rs = connection.execute(s) + else: + try: + s = ("SELECT name FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + except exc.DBAPIError: + raise + s = ("SELECT name FROM sqlite_master " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + + return [row[0] for row in rs] + + def has_table(self, connection, table_name, schema=None): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + cursor = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable))) + row = cursor.fetchone() + + # consume remaining rows, to work around + # http://www.sqlite.org/cvstrac/tktview?tn=1884 + while cursor.fetchone() is not None: + pass + + return (row is not None) + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT name FROM %s " + "WHERE type='view' ORDER BY name") % (master,) + rs = connection.execute(s) + else: + try: + s = ("SELECT name FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE type='view' ORDER BY name") + rs = connection.execute(s) + except exc.DBAPIError: + raise + s = ("SELECT name FROM sqlite_master " + "WHERE type='view' ORDER BY name") + rs = connection.execute(s) + + return [row[0] for row in rs] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT sql FROM %s WHERE name = '%s'" + "AND type='view'") % (master, view_name) + rs = connection.execute(s) + else: + try: + s = ("SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = '%s' " + "AND type='view'") % view_name + rs = connection.execute(s) + except exc.DBAPIError: + raise + s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " + "AND type='view'") % view_name + rs = connection.execute(s) + + result = rs.fetchall() + if result: + return result[0].sql + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + c = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable))) + found_table = False + columns = [] + while True: + row = c.fetchone() + if row is None: + break + (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5]) + name = re.sub(r'^\"|\"$', '', name) + if default: + default = re.sub(r"^\'|\'$", '', default) + match = re.match(r'(\w+)(\(.*?\))?', type_) + if match: + coltype = match.group(1) + args = match.group(2) + else: + coltype = "VARCHAR" + args = '' + try: + coltype = self.ischema_names[coltype] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, name)) + coltype = sqltypes.NullType + if args is not None: + args = re.findall(r'(\d+)', args) + coltype = coltype(*[int(a) for a in args]) + + columns.append({ + 'name' : name, + 'type' : coltype, + 'nullable' : nullable, + 'default' : default, + 'primary_key': primary_key + }) + return columns + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + cols = self.get_columns(connection, table_name, schema, **kw) + pkeys = [] + for col in cols: + if col['primary_key']: + pkeys.append(col['name']) + return pkeys + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + c = _pragma_cursor(connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))) + fkeys = [] + fks = {} + while True: + row = c.fetchone() + if row is None: + break + (constraint_name, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4]) + rtbl = re.sub(r'^\"|\"$', '', rtbl) + lcol = re.sub(r'^\"|\"$', '', lcol) + rcol = re.sub(r'^\"|\"$', '', rcol) + try: + fk = fks[constraint_name] + except KeyError: + fk = { + 'name' : constraint_name, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : rtbl, + 'referred_columns' : [] + } + fkeys.append(fk) + fks[constraint_name] = fk + + # look up the table based on the given table's engine, not 'self', + # since it could be a ProxyEngine + if lcol not in fk['constrained_columns']: + fk['constrained_columns'].append(lcol) + if rcol not in fk['referred_columns']: + fk['referred_columns'].append(rcol) + return fkeys + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + include_auto_indexes = kw.pop('include_auto_indexes', False) + qtable = quote(table_name) + c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable))) + indexes = [] + while True: + row = c.fetchone() + if row is None: + break + # ignore implicit primary key index. + # http://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html + elif not include_auto_indexes and row[1].startswith('sqlite_autoindex'): + continue + + indexes.append(dict(name=row[1], column_names=[], unique=row[2])) + # loop thru unique indexes to get the column names. + for idx in indexes: + c = connection.execute("%sindex_info(%s)" % (pragma, quote(idx['name']))) + cols = idx['column_names'] + while True: + row = c.fetchone() + if row is None: + break + cols.append(row[2]) + return indexes + + +def _pragma_cursor(cursor): + """work around SQLite issue whereby cursor.description is blank when PRAGMA returns no rows.""" + + if cursor.closed: + cursor._fetchone_impl = lambda: None + return cursor diff --git a/sqlalchemy/dialects/sqlite/pysqlite.py b/sqlalchemy/dialects/sqlite/pysqlite.py new file mode 100644 index 0000000..575cb37 --- /dev/null +++ b/sqlalchemy/dialects/sqlite/pysqlite.py @@ -0,0 +1,236 @@ +"""Support for the SQLite database via pysqlite. + +Note that pysqlite is the same driver as the ``sqlite3`` +module included with the Python distribution. + +Driver +------ + +When using Python 2.5 and above, the built in ``sqlite3`` driver is +already installed and no additional installation is needed. Otherwise, +the ``pysqlite2`` driver needs to be present. This is the same driver as +``sqlite3``, just with a different name. + +The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3`` +is loaded. This allows an explicitly installed pysqlite driver to take +precedence over the built in one. As with all dialects, a specific +DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control +this explicitly:: + + from sqlite3 import dbapi2 as sqlite + e = create_engine('sqlite+pysqlite:///file.db', module=sqlite) + +Full documentation on pysqlite is available at: +``_ + +Connect Strings +--------------- + +The file specification for the SQLite database is taken as the "database" portion of +the URL. Note that the format of a url is:: + + driver://user:pass@host/database + +This means that the actual filename to be used starts with the characters to the +**right** of the third slash. So connecting to a relative filepath looks like:: + + # relative path + e = create_engine('sqlite:///path/to/database.db') + +An absolute path, which is denoted by starting with a slash, means you need **four** +slashes:: + + # absolute path + e = create_engine('sqlite:////path/to/database.db') + +To use a Windows path, regular drive specifications and backslashes can be used. +Double backslashes are probably needed:: + + # absolute path on Windows + e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db') + +The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify +``sqlite://`` and nothing else:: + + # in-memory database + e = create_engine('sqlite://') + +Compatibility with sqlite3 "native" date and datetime types +----------------------------------------------------------- + +The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and +sqlite3.PARSE_COLNAMES options, which have the effect of any column +or expression explicitly cast as "date" or "timestamp" will be converted +to a Python date or datetime object. The date and datetime types provided +with the pysqlite dialect are not currently compatible with these options, +since they render the ISO date/datetime including microseconds, which +pysqlite's driver does not. Additionally, SQLAlchemy does not at +this time automatically render the "cast" syntax required for the +freestanding functions "current_timestamp" and "current_date" to return +datetime/date types natively. Unfortunately, pysqlite +does not provide the standard DBAPI types in `cursor.description`, +leaving SQLAlchemy with no way to detect these types on the fly +without expensive per-row type checks. + +Usage of PARSE_DECLTYPES can be forced if one configures +"native_datetime=True" on create_engine():: + + engine = create_engine('sqlite://', + connect_args={'detect_types': sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES}, + native_datetime=True + ) + +With this flag enabled, the DATE and TIMESTAMP types (but note - not the DATETIME +or TIME types...confused yet ?) will not perform any bind parameter or result +processing. Execution of "func.current_date()" will return a string. +"func.current_timestamp()" is registered as returning a DATETIME type in +SQLAlchemy, so this function still receives SQLAlchemy-level result processing. + +Threading Behavior +------------------ + +Pysqlite connections do not support being moved between threads, unless +the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition, +when using an in-memory SQLite database, the full database exists only within +the scope of a single connection. It is reported that an in-memory +database does not support being shared between threads regardless of the +``check_same_thread`` flag - which means that a multithreaded +application **cannot** share data from a ``:memory:`` database across threads +unless access to the connection is limited to a single worker thread which communicates +through a queueing mechanism to concurrent threads. + +To provide a default which accomodates SQLite's default threading capabilities +somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool` +be used by default. This pool maintains a single SQLite connection per thread +that is held open up to a count of five concurrent threads. When more than five threads +are used, a cleanup mechanism will dispose of excess unused connections. + +Two optional pool implementations that may be appropriate for particular SQLite usage scenarios: + + * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded + application using an in-memory database, assuming the threading issues inherent in + pysqlite are somehow accomodated for. This pool holds persistently onto a single connection + which is never closed, and is returned for all requests. + + * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that + makes use of a file-based sqlite database. This pool disables any actual "pooling" + behavior, and simply opens and closes real connections corresonding to the :func:`connect()` + and :func:`close()` methods. SQLite can "connect" to a particular file with very high + efficiency, so this option may actually perform better without the extra overhead + of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection + useless since the database would be lost as soon as the connection is "returned" to the pool. + +Unicode +------- + +In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's +default behavior regarding Unicode is that all strings are returned as Python unicode objects +in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is +*not* used, you will still always receive unicode data back from a result set. It is +**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type +to represent strings, since it will raise a warning if a non-unicode Python string is +passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can +quickly create confusion, particularly when using the ORM as internal data is not +always represented by an actual database result string. + +""" + +from sqlalchemy.dialects.sqlite.base import SQLiteDialect, DATETIME, DATE +from sqlalchemy import schema, exc, pool +from sqlalchemy.engine import default +from sqlalchemy import types as sqltypes +from sqlalchemy import util + + +class _SQLite_pysqliteTimeStamp(DATETIME): + def bind_processor(self, dialect): + if dialect.native_datetime: + return None + else: + return DATETIME.bind_processor(self, dialect) + + def result_processor(self, dialect, coltype): + if dialect.native_datetime: + return None + else: + return DATETIME.result_processor(self, dialect, coltype) + +class _SQLite_pysqliteDate(DATE): + def bind_processor(self, dialect): + if dialect.native_datetime: + return None + else: + return DATE.bind_processor(self, dialect) + + def result_processor(self, dialect, coltype): + if dialect.native_datetime: + return None + else: + return DATE.result_processor(self, dialect, coltype) + +class SQLiteDialect_pysqlite(SQLiteDialect): + default_paramstyle = 'qmark' + poolclass = pool.SingletonThreadPool + + colspecs = util.update_copy( + SQLiteDialect.colspecs, + { + sqltypes.Date:_SQLite_pysqliteDate, + sqltypes.TIMESTAMP:_SQLite_pysqliteTimeStamp, + } + ) + + # Py3K + #description_encoding = None + + driver = 'pysqlite' + + def __init__(self, **kwargs): + SQLiteDialect.__init__(self, **kwargs) + + if self.dbapi is not None: + sqlite_ver = self.dbapi.version_info + if sqlite_ver < (2, 1, 3): + util.warn( + ("The installed version of pysqlite2 (%s) is out-dated " + "and will cause errors in some cases. Version 2.1.3 " + "or greater is recommended.") % + '.'.join([str(subver) for subver in sqlite_ver])) + + @classmethod + def dbapi(cls): + try: + from pysqlite2 import dbapi2 as sqlite + except ImportError, e: + try: + from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. + except ImportError: + raise e + return sqlite + + def _get_server_version_info(self, connection): + return self.dbapi.sqlite_version_info + + def create_connect_args(self, url): + if url.username or url.password or url.host or url.port: + raise exc.ArgumentError( + "Invalid SQLite URL: %s\n" + "Valid SQLite URL forms are:\n" + " sqlite:///:memory: (or, sqlite://)\n" + " sqlite:///relative/path/to/file.db\n" + " sqlite:////absolute/path/to/file.db" % (url,)) + filename = url.database or ':memory:' + + opts = url.query.copy() + util.coerce_kw_type(opts, 'timeout', float) + util.coerce_kw_type(opts, 'isolation_level', str) + util.coerce_kw_type(opts, 'detect_types', int) + util.coerce_kw_type(opts, 'check_same_thread', bool) + util.coerce_kw_type(opts, 'cached_statements', int) + + return ([filename], opts) + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e) + +dialect = SQLiteDialect_pysqlite diff --git a/sqlalchemy/dialects/sybase/__init__.py b/sqlalchemy/dialects/sybase/__init__.py new file mode 100644 index 0000000..400bb29 --- /dev/null +++ b/sqlalchemy/dialects/sybase/__init__.py @@ -0,0 +1,20 @@ +from sqlalchemy.dialects.sybase import base, pysybase, pyodbc + + +from base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ + TEXT,DATE,DATETIME, FLOAT, NUMERIC,\ + BIGINT,INT, INTEGER, SMALLINT, BINARY,\ + VARBINARY,UNITEXT,UNICHAR,UNIVARCHAR,\ + IMAGE,BIT,MONEY,SMALLMONEY,TINYINT + +# default dialect +base.dialect = pyodbc.dialect + +__all__ = ( + 'CHAR', 'VARCHAR', 'TIME', 'NCHAR', 'NVARCHAR', + 'TEXT','DATE','DATETIME', 'FLOAT', 'NUMERIC', + 'BIGINT','INT', 'INTEGER', 'SMALLINT', 'BINARY', + 'VARBINARY','UNITEXT','UNICHAR','UNIVARCHAR', + 'IMAGE','BIT','MONEY','SMALLMONEY','TINYINT', + 'dialect' +) diff --git a/sqlalchemy/dialects/sybase/base.py b/sqlalchemy/dialects/sybase/base.py new file mode 100644 index 0000000..6719b42 --- /dev/null +++ b/sqlalchemy/dialects/sybase/base.py @@ -0,0 +1,420 @@ +# sybase/base.py +# Copyright (C) 2010 Michael Bayer mike_mp@zzzcomputing.com +# get_select_precolumns(), limit_clause() implementation +# copyright (C) 2007 Fisch Asset Management +# AG http://www.fam.ch, with coding by Alexander Houben +# alexander.houben@thor-solutions.ch +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for Sybase Adaptive Server Enterprise (ASE). + +Note that this dialect is no longer specific to Sybase iAnywhere. +ASE is the primary support platform. + +""" + +import operator +from sqlalchemy.sql import compiler, expression, text, bindparam +from sqlalchemy.engine import default, base, reflection +from sqlalchemy import types as sqltypes +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import schema as sa_schema +from sqlalchemy import util, sql, exc + +from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ + TEXT,DATE,DATETIME, FLOAT, NUMERIC,\ + BIGINT,INT, INTEGER, SMALLINT, BINARY,\ + VARBINARY, DECIMAL, TIMESTAMP, Unicode,\ + UnicodeText + +RESERVED_WORDS = set([ + "add", "all", "alter", "and", + "any", "as", "asc", "backup", + "begin", "between", "bigint", "binary", + "bit", "bottom", "break", "by", + "call", "capability", "cascade", "case", + "cast", "char", "char_convert", "character", + "check", "checkpoint", "close", "comment", + "commit", "connect", "constraint", "contains", + "continue", "convert", "create", "cross", + "cube", "current", "current_timestamp", "current_user", + "cursor", "date", "dbspace", "deallocate", + "dec", "decimal", "declare", "default", + "delete", "deleting", "desc", "distinct", + "do", "double", "drop", "dynamic", + "else", "elseif", "encrypted", "end", + "endif", "escape", "except", "exception", + "exec", "execute", "existing", "exists", + "externlogin", "fetch", "first", "float", + "for", "force", "foreign", "forward", + "from", "full", "goto", "grant", + "group", "having", "holdlock", "identified", + "if", "in", "index", "index_lparen", + "inner", "inout", "insensitive", "insert", + "inserting", "install", "instead", "int", + "integer", "integrated", "intersect", "into", + "iq", "is", "isolation", "join", + "key", "lateral", "left", "like", + "lock", "login", "long", "match", + "membership", "message", "mode", "modify", + "natural", "new", "no", "noholdlock", + "not", "notify", "null", "numeric", + "of", "off", "on", "open", + "option", "options", "or", "order", + "others", "out", "outer", "over", + "passthrough", "precision", "prepare", "primary", + "print", "privileges", "proc", "procedure", + "publication", "raiserror", "readtext", "real", + "reference", "references", "release", "remote", + "remove", "rename", "reorganize", "resource", + "restore", "restrict", "return", "revoke", + "right", "rollback", "rollup", "save", + "savepoint", "scroll", "select", "sensitive", + "session", "set", "setuser", "share", + "smallint", "some", "sqlcode", "sqlstate", + "start", "stop", "subtrans", "subtransaction", + "synchronize", "syntax_error", "table", "temporary", + "then", "time", "timestamp", "tinyint", + "to", "top", "tran", "trigger", + "truncate", "tsequal", "unbounded", "union", + "unique", "unknown", "unsigned", "update", + "updating", "user", "using", "validate", + "values", "varbinary", "varchar", "variable", + "varying", "view", "wait", "waitfor", + "when", "where", "while", "window", + "with", "with_cube", "with_lparen", "with_rollup", + "within", "work", "writetext", + ]) + + +class _SybaseUnitypeMixin(object): + """these types appear to return a buffer object.""" + + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + return str(value) #.decode("ucs-2") + else: + return None + return process + +class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode): + __visit_name__ = 'UNICHAR' + +class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode): + __visit_name__ = 'UNIVARCHAR' + +class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText): + __visit_name__ = 'UNITEXT' + +class TINYINT(sqltypes.Integer): + __visit_name__ = 'TINYINT' + +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = "MONEY" + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = "SMALLMONEY" + +class UNIQUEIDENTIFIER(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + +class IMAGE(sqltypes.LargeBinary): + __visit_name__ = 'IMAGE' + + +class SybaseTypeCompiler(compiler.GenericTypeCompiler): + def visit_large_binary(self, type_): + return self.visit_IMAGE(type_) + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_unicode(self, type_): + return self.visit_NVARCHAR(type_) + + def visit_UNICHAR(self, type_): + return "UNICHAR(%d)" % type_.length + + def visit_UNIVARCHAR(self, type_): + return "UNIVARCHAR(%d)" % type_.length + + def visit_UNITEXT(self, type_): + return "UNITEXT" + + def visit_TINYINT(self, type_): + return "TINYINT" + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return "SMALLMONEY" + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + +ischema_names = { + 'integer' : INTEGER, + 'unsigned int' : INTEGER, # TODO: unsigned flags + 'unsigned smallint' : SMALLINT, # TODO: unsigned flags + 'unsigned bigint' : BIGINT, # TODO: unsigned flags + 'bigint': BIGINT, + 'smallint' : SMALLINT, + 'tinyint' : TINYINT, + 'varchar' : VARCHAR, + 'long varchar' : TEXT, # TODO + 'char' : CHAR, + 'decimal' : DECIMAL, + 'numeric' : NUMERIC, + 'float' : FLOAT, + 'double' : NUMERIC, # TODO + 'binary' : BINARY, + 'varbinary' : VARBINARY, + 'bit': BIT, + 'image' : IMAGE, + 'timestamp': TIMESTAMP, + 'money': MONEY, + 'smallmoney': MONEY, + 'uniqueidentifier': UNIQUEIDENTIFIER, + +} + + +class SybaseExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + + def set_ddl_autocommit(self, connection, value): + """Must be implemented by subclasses to accommodate DDL executions. + + "connection" is the raw unwrapped DBAPI connection. "value" + is True or False. when True, the connection should be configured + such that a DDL can take place subsequently. when False, + a DDL has taken place and the connection should be resumed + into non-autocommit mode. + + """ + raise NotImplementedError() + + def pre_exec(self): + if self.isinsert: + tbl = self.compiled.statement.table + seq_column = tbl._autoincrement_column + insert_has_sequence = seq_column is not None + + if insert_has_sequence: + self._enable_identity_insert = seq_column.key in self.compiled_parameters[0] + else: + self._enable_identity_insert = False + + if self._enable_identity_insert: + self.cursor.execute("SET IDENTITY_INSERT %s ON" % + self.dialect.identifier_preparer.format_table(tbl)) + + if self.isddl: + # TODO: to enhance this, we can detect "ddl in tran" on the + # database settings. this error message should be improved to + # include a note about that. + if not self.should_autocommit: + raise exc.InvalidRequestError("The Sybase dialect only supports " + "DDL in 'autocommit' mode at this time.") + + self.root_connection.engine.logger.info("AUTOCOMMIT (Assuming no Sybase 'ddl in tran')") + + self.set_ddl_autocommit(self.root_connection.connection.connection, True) + + + def post_exec(self): + if self.isddl: + self.set_ddl_autocommit(self.root_connection, False) + + if self._enable_identity_insert: + self.cursor.execute( + "SET IDENTITY_INSERT %s OFF" % + self.dialect.identifier_preparer. + format_table(self.compiled.statement.table) + ) + + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT @@identity AS lastrowid") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + +class SybaseSQLCompiler(compiler.SQLCompiler): + ansi_bind_rules = True + + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond' + }) + + def get_select_precolumns(self, select): + s = select._distinct and "DISTINCT " or "" + if select._limit: + #if select._limit == 1: + #s += "FIRST " + #else: + #s += "TOP %s " % (select._limit,) + s += "TOP %s " % (select._limit,) + if select._offset: + if not select._limit: + # FIXME: sybase doesn't allow an offset without a limit + # so use a huge value for TOP here + s += "TOP 1000000 " + s += "START AT %s " % (select._offset+1,) + return s + + def get_from_hint_text(self, table, text): + return text + + def limit_clause(self, select): + # Limit in sybase is after the select keyword + return "" + + def visit_extract(self, extract, **kw): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select, **kw): + kw['literal_binds'] = True + order_by = self.process(select._order_by_clause, **kw) + + # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + +class SybaseDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + \ + self.dialect.type_compiler.process(column.type) + + if column.table is None: + raise exc.InvalidRequestError("The Sybase dialect requires Table-bound "\ + "columns in order to generate DDL") + seq_col = column.table._autoincrement_column + + # install a IDENTITY Sequence if we have an implicit IDENTITY column + if seq_col is column: + sequence = isinstance(column.default, sa_schema.Sequence) and column.default + if sequence: + start, increment = sequence.start or 1, sequence.increment or 1 + else: + start, increment = 1, 1 + if (start, increment) == (1, 1): + colspec += " IDENTITY" + else: + # TODO: need correct syntax for this + colspec += " IDENTITY(%s,%s)" % (start, increment) + else: + if column.nullable is not None: + if not column.nullable or column.primary_key: + colspec += " NOT NULL" + else: + colspec += " NULL" + + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(index.table.name), + self.preparer.quote(self._validate_identifier(index.name, False), index.quote) + ) + +class SybaseIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + +class SybaseDialect(default.DefaultDialect): + name = 'sybase' + supports_unicode_statements = False + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + supports_native_boolean = False + supports_unicode_binds = False + postfetch_lastrowid = True + + colspecs = {} + ischema_names = ischema_names + + type_compiler = SybaseTypeCompiler + statement_compiler = SybaseSQLCompiler + ddl_compiler = SybaseDDLCompiler + preparer = SybaseIdentifierPreparer + + def _get_default_schema_name(self, connection): + return connection.scalar( + text("SELECT user_name() as user_name", typemap={'user_name':Unicode}) + ) + + def initialize(self, connection): + super(SybaseDialect, self).initialize(connection) + if self.server_version_info is not None and\ + self.server_version_info < (15, ): + self.max_identifier_length = 30 + else: + self.max_identifier_length = 255 + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + + result = connection.execute( + text("select sysobjects.name from sysobjects, sysusers " + "where sysobjects.uid=sysusers.uid and " + "sysusers.name=:schemaname and " + "sysobjects.type='U'", + bindparams=[ + bindparam('schemaname', schema) + ]) + ) + return [r[0] for r in result] + + def has_table(self, connection, tablename, schema=None): + if schema is None: + schema = self.default_schema_name + + result = connection.execute( + text("select sysobjects.name from sysobjects, sysusers " + "where sysobjects.uid=sysusers.uid and " + "sysobjects.name=:tablename and " + "sysusers.name=:schemaname and " + "sysobjects.type='U'", + bindparams=[ + bindparam('tablename', tablename), + bindparam('schemaname', schema) + ]) + ) + return result.scalar() is not None + + def reflecttable(self, connection, table, include_columns): + raise NotImplementedError() + diff --git a/sqlalchemy/dialects/sybase/mxodbc.py b/sqlalchemy/dialects/sybase/mxodbc.py new file mode 100644 index 0000000..1481799 --- /dev/null +++ b/sqlalchemy/dialects/sybase/mxodbc.py @@ -0,0 +1,17 @@ +""" +Support for Sybase via mxodbc. + +This dialect is a stub only and is likely non functional at this time. + + +""" +from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext +from sqlalchemy.connectors.mxodbc import MxODBCConnector + +class SybaseExecutionContext_mxodbc(SybaseExecutionContext): + pass + +class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_mxodbc + +dialect = SybaseDialect_mxodbc diff --git a/sqlalchemy/dialects/sybase/pyodbc.py b/sqlalchemy/dialects/sybase/pyodbc.py new file mode 100644 index 0000000..e34f260 --- /dev/null +++ b/sqlalchemy/dialects/sybase/pyodbc.py @@ -0,0 +1,75 @@ +""" +Support for Sybase via pyodbc. + +http://pypi.python.org/pypi/pyodbc/ + +Connect strings are of the form:: + + sybase+pyodbc://:@/ + sybase+pyodbc://:@/ + +Unicode Support +--------------- + +The pyodbc driver currently supports usage of these Sybase types with +Unicode or multibyte strings:: + + CHAR + NCHAR + NVARCHAR + TEXT + VARCHAR + +Currently *not* supported are:: + + UNICHAR + UNITEXT + UNIVARCHAR + +""" + +from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext +from sqlalchemy.connectors.pyodbc import PyODBCConnector +import decimal +from sqlalchemy import types as sqltypes, util, processors + +class _SybNumeric_pyodbc(sqltypes.Numeric): + """Turns Decimals with adjusted() < -6 into floats. + + It's not yet known how to get decimals with many + significant digits or very large adjusted() into Sybase + via pyodbc. + + """ + + def bind_processor(self, dialect): + super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect) + + def process(value): + if self.asdecimal and \ + isinstance(value, decimal.Decimal): + + if value.adjusted() < -6: + return processors.to_float(value) + + if super_process: + return super_process(value) + else: + return value + return process + +class SybaseExecutionContext_pyodbc(SybaseExecutionContext): + def set_ddl_autocommit(self, connection, value): + if value: + connection.autocommit = True + else: + connection.autocommit = False + +class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_pyodbc + + colspecs = { + sqltypes.Numeric:_SybNumeric_pyodbc, + } + +dialect = SybaseDialect_pyodbc diff --git a/sqlalchemy/dialects/sybase/pysybase.py b/sqlalchemy/dialects/sybase/pysybase.py new file mode 100644 index 0000000..ee19382 --- /dev/null +++ b/sqlalchemy/dialects/sybase/pysybase.py @@ -0,0 +1,98 @@ +# pysybase.py +# Copyright (C) 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Support for Sybase via the python-sybase driver. + +http://python-sybase.sourceforge.net/ + +Connect strings are of the form:: + + sybase+pysybase://:@/[database name] + +Unicode Support +--------------- + +The python-sybase driver does not appear to support non-ASCII strings of any +kind at this time. + +""" + +from sqlalchemy import types as sqltypes, processors +from sqlalchemy.dialects.sybase.base import SybaseDialect, \ + SybaseExecutionContext, SybaseSQLCompiler + + +class _SybNumeric(sqltypes.Numeric): + def result_processor(self, dialect, type_): + if not self.asdecimal: + return processors.to_float + else: + return sqltypes.Numeric.result_processor(self, dialect, type_) + +class SybaseExecutionContext_pysybase(SybaseExecutionContext): + + def set_ddl_autocommit(self, dbapi_connection, value): + if value: + # call commit() on the Sybase connection directly, + # to avoid any side effects of calling a Connection + # transactional method inside of pre_exec() + dbapi_connection.commit() + + def pre_exec(self): + SybaseExecutionContext.pre_exec(self) + + for param in self.parameters: + for key in list(param): + param["@" + key] = param[key] + del param[key] + + +class SybaseSQLCompiler_pysybase(SybaseSQLCompiler): + def bindparam_string(self, name): + return "@" + name + +class SybaseDialect_pysybase(SybaseDialect): + driver = 'pysybase' + execution_ctx_cls = SybaseExecutionContext_pysybase + statement_compiler = SybaseSQLCompiler_pysybase + + colspecs={ + sqltypes.Numeric:_SybNumeric, + sqltypes.Float:sqltypes.Float + } + + @classmethod + def dbapi(cls): + import Sybase + return Sybase + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user', password='passwd') + + return ([opts.pop('host')], opts) + + def do_executemany(self, cursor, statement, parameters, context=None): + # calling python-sybase executemany yields: + # TypeError: string too long for buffer + for param in parameters: + cursor.execute(statement, param) + + def _get_server_version_info(self, connection): + vers = connection.scalar("select @@version_number") + # i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0), (12, 5, 0, 0) + return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10) + + def is_disconnect(self, e): + if isinstance(e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)): + msg = str(e) + return ('Unable to complete network request to host' in msg or + 'Invalid connection state' in msg or + 'Invalid cursor state' in msg) + else: + return False + +dialect = SybaseDialect_pysybase diff --git a/sqlalchemy/dialects/type_migration_guidelines.txt b/sqlalchemy/dialects/type_migration_guidelines.txt new file mode 100644 index 0000000..c26b65e --- /dev/null +++ b/sqlalchemy/dialects/type_migration_guidelines.txt @@ -0,0 +1,145 @@ +Rules for Migrating TypeEngine classes to 0.6 +--------------------------------------------- + +1. the TypeEngine classes are used for: + + a. Specifying behavior which needs to occur for bind parameters + or result row columns. + + b. Specifying types that are entirely specific to the database + in use and have no analogue in the sqlalchemy.types package. + + c. Specifying types where there is an analogue in sqlalchemy.types, + but the database in use takes vendor-specific flags for those + types. + + d. If a TypeEngine class doesn't provide any of this, it should be + *removed* from the dialect. + +2. the TypeEngine classes are *no longer* used for generating DDL. Dialects +now have a TypeCompiler subclass which uses the same visit_XXX model as +other compilers. + +3. the "ischema_names" and "colspecs" dictionaries are now required members on +the Dialect class. + +4. The names of types within dialects are now important. If a dialect-specific type +is a subclass of an existing generic type and is only provided for bind/result behavior, +the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case, +end users would never need to use _PGNumeric directly. However, if a dialect-specific +type is specifying a type *or* arguments that are not present generically, it should +match the real name of the type on that backend, in uppercase. E.g. postgresql.INET, +mysql.ENUM, postgresql.ARRAY. + +Or follow this handy flowchart: + + is the type meant to provide bind/result is the type the same name as an + behavior to a generic type (i.e. MixedCase) ---- no ---> UPPERCASE type in types.py ? + type in types.py ? | | + | no yes + yes | | + | | does your type need special + | +<--- yes --- behavior or arguments ? + | | | + | | no + name the type using | | + _MixedCase, i.e. v V + _OracleBoolean. it name the type don't make a + stays private to the dialect identically as that type, make sure the dialect's + and is invoked *only* via within the DB, base.py imports the types.py + the colspecs dict. using UPPERCASE UPPERCASE name into its namespace + | (i.e. BIT, NCHAR, INTERVAL). + | Users can import it. + | | + v v + subclass the closest is the name of this type + MixedCase type types.py, identical to an UPPERCASE + i.e. <--- no ------- name in types.py ? + class _DateTime(types.DateTime), + class DATETIME2(types.DateTime), | + class BIT(types.TypeEngine). yes + | + v + the type should + subclass the + UPPERCASE + type in types.py + (i.e. class BLOB(types.BLOB)) + + +Example 1. pysqlite needs bind/result processing for the DateTime type in types.py, +which applies to all DateTimes and subclasses. It's named _SLDateTime and +subclasses types.DateTime. + +Example 2. MS-SQL has a TIME type which takes a non-standard "precision" argument +that is rendered within DDL. So it's named TIME in the MS-SQL dialect's base.py, +and subclasses types.TIME. Users can then say mssql.TIME(precision=10). + +Example 3. MS-SQL dialects also need special bind/result processing for date +But its DATE type doesn't render DDL differently than that of a plain +DATE, i.e. it takes no special arguments. Therefore we are just adding behavior +to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses +types.Date. + +Example 4. MySQL has a SET type, there's no analogue for this in types.py. So +MySQL names it SET in the dialect's base.py, and it subclasses types.String, since +it ultimately deals with strings. + +Example 5. Postgresql has a DATETIME type. The DBAPIs handle dates correctly, +and no special arguments are used in PG's DDL beyond what types.py provides. +Postgresql dialect therefore imports types.DATETIME into its base.py. + +Ideally one should be able to specify a schema using names imported completely from a +dialect, all matching the real name on that backend: + + from sqlalchemy.dialects.postgresql import base as pg + + t = Table('mytable', metadata, + Column('id', pg.INTEGER, primary_key=True), + Column('name', pg.VARCHAR(300)), + Column('inetaddr', pg.INET) + ) + +where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types, +but the PG dialect makes them available in its own namespace. + +5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types +linked to types specified in the dialect. Again, if a type in the dialect does not +specify any special behavior for bind_processor() or result_processor() and does not +indicate a special type only available in this database, it must be *removed* from the +module and from this dictionary. + +6. "ischema_names" indicates string descriptions of types as returned from the database +linked to TypeEngine classes. + + a. The string name should be matched to the most specific type possible within + sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which + case it points to a dialect type. *It doesn't matter* if the dialect has it's + own subclass of that type with special bind/result behavior - reflect to the types.py + UPPERCASE type as much as possible. With very few exceptions, all types + should reflect to an UPPERCASE type. + + b. If the dialect contains a matching dialect-specific type that takes extra arguments + which the generic one does not, then point to the dialect-specific type. E.g. + mssql.VARCHAR takes a "collation" parameter which should be preserved. + +5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by +a subclass of compiler.GenericTypeCompiler. + + a. your TypeCompiler class will receive generic and uppercase types from + sqlalchemy.types. Do not assume the presence of dialect-specific attributes on + these types. + + b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with + methods that produce a different DDL name. Uppercase types don't do any kind of + "guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in + all cases, regardless of whether or not that type is legal on the backend database. + + c. the visit_UPPERCASE methods *should* be overridden with methods that add additional + arguments and flags to those types. + + d. the visit_lowercase methods are overridden to provide an interpretation of a generic + type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)". + + e. visit_lowercase methods should *never* render strings directly - it should always + be via calling a visit_UPPERCASE() method. diff --git a/sqlalchemy/engine/__init__.py b/sqlalchemy/engine/__init__.py new file mode 100644 index 0000000..9b3dbed --- /dev/null +++ b/sqlalchemy/engine/__init__.py @@ -0,0 +1,274 @@ +# engine/__init__.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""SQL connections, SQL execution and high-level DB-API interface. + +The engine package defines the basic components used to interface +DB-API modules with higher-level statement construction, +connection-management, execution and result contexts. The primary +"entry point" class into this package is the Engine and it's public +constructor ``create_engine()``. + +This package includes: + +base.py + Defines interface classes and some implementation classes which + comprise the basic components used to interface between a DB-API, + constructed and plain-text statements, connections, transactions, + and results. + +default.py + Contains default implementations of some of the components defined + in base.py. All current database dialects use the classes in + default.py as base classes for their own database-specific + implementations. + +strategies.py + The mechanics of constructing ``Engine`` objects are represented + here. Defines the ``EngineStrategy`` class which represents how + to go from arguments specified to the ``create_engine()`` + function, to a fully constructed ``Engine``, including + initialization of connection pooling, dialects, and specific + subclasses of ``Engine``. + +threadlocal.py + The ``TLEngine`` class is defined here, which is a subclass of + the generic ``Engine`` and tracks ``Connection`` and + ``Transaction`` objects against the identity of the current + thread. This allows certain programming patterns based around + the concept of a "thread-local connection" to be possible. + The ``TLEngine`` is created by using the "threadlocal" engine + strategy in conjunction with the ``create_engine()`` function. + +url.py + Defines the ``URL`` class which represents the individual + components of a string URL passed to ``create_engine()``. Also + defines a basic module-loading strategy for the dialect specifier + within a URL. +""" + +# not sure what this was used for +#import sqlalchemy.databases + +from sqlalchemy.engine.base import ( + BufferedColumnResultProxy, + BufferedColumnRow, + BufferedRowResultProxy, + Compiled, + Connectable, + Connection, + Dialect, + Engine, + ExecutionContext, + NestedTransaction, + ResultProxy, + RootTransaction, + RowProxy, + Transaction, + TwoPhaseTransaction, + TypeCompiler + ) +from sqlalchemy.engine import strategies +from sqlalchemy import util + + +__all__ = ( + 'BufferedColumnResultProxy', + 'BufferedColumnRow', + 'BufferedRowResultProxy', + 'Compiled', + 'Connectable', + 'Connection', + 'Dialect', + 'Engine', + 'ExecutionContext', + 'NestedTransaction', + 'ResultProxy', + 'RootTransaction', + 'RowProxy', + 'Transaction', + 'TwoPhaseTransaction', + 'TypeCompiler', + 'create_engine', + 'engine_from_config', + ) + + +default_strategy = 'plain' +def create_engine(*args, **kwargs): + """Create a new Engine instance. + + The standard method of specifying the engine is via URL as the + first positional argument, to indicate the appropriate database + dialect and connection arguments, with additional keyword + arguments sent as options to the dialect and resulting Engine. + + The URL is a string in the form + ``dialect+driver://user:password@host/dbname[?key=value..]``, where + ``dialect`` is a database name such as ``mysql``, ``oracle``, + ``postgresql``, etc., and ``driver`` the name of a DBAPI, such as + ``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively, + the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`. + + `**kwargs` takes a wide variety of options which are routed + towards their appropriate components. Arguments may be + specific to the Engine, the underlying Dialect, as well as the + Pool. Specific dialects also accept keyword arguments that + are unique to that dialect. Here, we describe the parameters + that are common to most ``create_engine()`` usage. + + :param assert_unicode: Deprecated. A warning is raised in all cases when a non-Unicode + object is passed when SQLAlchemy would coerce into an encoding + (note: but **not** when the DBAPI handles unicode objects natively). + To suppress or raise this warning to an + error, use the Python warnings filter documented at: + http://docs.python.org/library/warnings.html + + :param connect_args: a dictionary of options which will be + passed directly to the DBAPI's ``connect()`` method as + additional keyword arguments. + + :param convert_unicode=False: if set to True, all + String/character based types will convert Unicode values to raw + byte values going into the database, and all raw byte values to + Python Unicode coming out in result sets. This is an + engine-wide method to provide unicode conversion across the + board. For unicode conversion on a column-by-column level, use + the ``Unicode`` column type instead, described in `types`. + + :param creator: a callable which returns a DBAPI connection. + This creation function will be passed to the underlying + connection pool and will be used to create all new database + connections. Usage of this function causes connection + parameters specified in the URL argument to be bypassed. + + :param echo=False: if True, the Engine will log all statements + as well as a repr() of their parameter lists to the engines + logger, which defaults to sys.stdout. The ``echo`` attribute of + ``Engine`` can be modified at any time to turn logging on and + off. If set to the string ``"debug"``, result rows will be + printed to the standard output as well. This flag ultimately + controls a Python logger; see :ref:`dbengine_logging` for + information on how to configure logging directly. + + :param echo_pool=False: if True, the connection pool will log + all checkouts/checkins to the logging stream, which defaults to + sys.stdout. This flag ultimately controls a Python logger; see + :ref:`dbengine_logging` for information on how to configure logging + directly. + + :param encoding='utf-8': the encoding to use for all Unicode + translations, both by engine-wide unicode conversion as well as + the ``Unicode`` type object. + + :param label_length=None: optional integer value which limits + the size of dynamically generated column labels to that many + characters. If less than 6, labels are generated as + "_(counter)". If ``None``, the value of + ``dialect.max_identifier_length`` is used instead. + + :param listeners: A list of one or more + :class:`~sqlalchemy.interfaces.PoolListener` objects which will + receive connection pool events. + + :param logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.engine" logger. Defaults to a hexstring of the + object's id. + + :param max_overflow=10: the number of connections to allow in + connection pool "overflow", that is connections that can be + opened above and beyond the pool_size setting, which defaults + to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`. + + :param module=None: used by database implementations which + support multiple DBAPI modules, this is a reference to a DBAPI2 + module to be used instead of the engine's default module. For + PostgreSQL, the default is psycopg2. For Oracle, it's cx_Oracle. + + :param pool=None: an already-constructed instance of + :class:`~sqlalchemy.pool.Pool`, such as a + :class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this + pool will be used directly as the underlying connection pool + for the engine, bypassing whatever connection parameters are + present in the URL argument. For information on constructing + connection pools manually, see `pooling`. + + :param poolclass=None: a :class:`~sqlalchemy.pool.Pool` + subclass, which will be used to create a connection pool + instance using the connection parameters given in the URL. Note + this differs from ``pool`` in that you don't actually + instantiate the pool in this case, you just indicate what type + of pool to be used. + + :param pool_logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.pool" logger. Defaults to a hexstring of the object's + id. + + :param pool_size=5: the number of connections to keep open + inside the connection pool. This used with :class:`~sqlalchemy.pool.QueuePool` as + well as :class:`~sqlalchemy.pool.SingletonThreadPool`. + + :param pool_recycle=-1: this setting causes the pool to recycle + connections after the given number of seconds has passed. It + defaults to -1, or no timeout. For example, setting to 3600 + means connections will be recycled after one hour. Note that + MySQL in particular will ``disconnect automatically`` if no + activity is detected on a connection for eight hours (although + this is configurable with the MySQLDB connection itself and the + server configuration as well). + + :param pool_timeout=30: number of seconds to wait before giving + up on getting a connection from the pool. This is only used + with :class:`~sqlalchemy.pool.QueuePool`. + + :param strategy='plain': used to invoke alternate :class:`~sqlalchemy.engine.base.Engine.` + implementations. Currently available is the ``threadlocal`` + strategy, which is described in :ref:`threadlocal_strategy`. + + """ + + strategy = kwargs.pop('strategy', default_strategy) + strategy = strategies.strategies[strategy] + return strategy.create(*args, **kwargs) + +def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): + """Create a new Engine instance using a configuration dictionary. + + The dictionary is typically produced from a config file where keys + are prefixed, such as sqlalchemy.url, sqlalchemy.echo, etc. The + 'prefix' argument indicates the prefix to be searched for. + + A select set of keyword arguments will be "coerced" to their + expected type based on string values. In a future release, this + functionality will be expanded and include dialect-specific + arguments. + """ + + opts = _coerce_config(configuration, prefix) + opts.update(kwargs) + url = opts.pop('url') + return create_engine(url, **opts) + +def _coerce_config(configuration, prefix): + """Convert configuration values to expected types.""" + + options = dict((key[len(prefix):], configuration[key]) + for key in configuration + if key.startswith(prefix)) + for option, type_ in ( + ('convert_unicode', bool), + ('pool_timeout', int), + ('echo', bool), + ('echo_pool', bool), + ('pool_recycle', int), + ('pool_size', int), + ('max_overflow', int), + ('pool_threadlocal', bool), + ): + util.coerce_kw_type(options, option, type_) + return options diff --git a/sqlalchemy/engine/base.py b/sqlalchemy/engine/base.py new file mode 100644 index 0000000..dc42ed9 --- /dev/null +++ b/sqlalchemy/engine/base.py @@ -0,0 +1,2422 @@ +# engine/base.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +"""Basic components for SQL execution and interfacing with DB-API. + +Defines the basic components used to interface DB-API modules with +higher-level statement-construction, connection-management, execution +and result contexts. +""" + +__all__ = [ + 'BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultProxy', + 'Compiled', 'Connectable', 'Connection', 'Dialect', 'Engine', + 'ExecutionContext', 'NestedTransaction', 'ResultProxy', 'RootTransaction', + 'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction', + 'connection_memoize'] + +import inspect, StringIO, sys, operator +from itertools import izip +from sqlalchemy import exc, schema, util, types, log +from sqlalchemy.sql import expression + +class Dialect(object): + """Define the behavior of a specific database and DB-API combination. + + Any aspect of metadata definition, SQL query generation, + execution, result-set handling, or anything else which varies + between databases is defined under the general category of the + Dialect. The Dialect acts as a factory for other + database-specific object implementations including + ExecutionContext, Compiled, DefaultGenerator, and TypeEngine. + + All Dialects implement the following attributes: + + name + identifying name for the dialect from a DBAPI-neutral point of view + (i.e. 'sqlite') + + driver + identifying name for the dialect's DBAPI + + positional + True if the paramstyle for this Dialect is positional. + + paramstyle + the paramstyle to be used (some DB-APIs support multiple + paramstyles). + + convert_unicode + True if Unicode conversion should be applied to all ``str`` + types. + + encoding + type of encoding to use for unicode, usually defaults to + 'utf-8'. + + statement_compiler + a :class:`~Compiled` class used to compile SQL statements + + ddl_compiler + a :class:`~Compiled` class used to compile DDL statements + + server_version_info + a tuple containing a version number for the DB backend in use. + This value is only available for supporting dialects, and is + typically populated during the initial connection to the database. + + default_schema_name + the name of the default schema. This value is only available for + supporting dialects, and is typically populated during the + initial connection to the database. + + execution_ctx_cls + a :class:`ExecutionContext` class used to handle statement execution + + execute_sequence_format + either the 'tuple' or 'list' type, depending on what cursor.execute() + accepts for the second argument (they vary). + + preparer + a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to + quote identifiers. + + supports_alter + ``True`` if the database supports ``ALTER TABLE``. + + max_identifier_length + The maximum length of identifier names. + + supports_unicode_statements + Indicate whether the DB-API can receive SQL statements as Python + unicode strings + + supports_unicode_binds + Indicate whether the DB-API can receive string bind parameters + as Python unicode strings + + supports_sane_rowcount + Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements. + + supports_sane_multi_rowcount + Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements when executed via + executemany. + + preexecute_autoincrement_sequences + True if 'implicit' primary key functions must be executed separately + in order to get their value. This is currently oriented towards + Postgresql. + + implicit_returning + use RETURNING or equivalent during INSERT execution in order to load + newly generated primary keys and other column defaults in one execution, + which are then available via inserted_primary_key. + If an insert statement has returning() specified explicitly, + the "implicit" functionality is not used and inserted_primary_key + will not be available. + + dbapi_type_map + A mapping of DB-API type objects present in this Dialect's + DB-API implementation mapped to TypeEngine implementations used + by the dialect. + + This is used to apply types to result sets based on the DB-API + types present in cursor.description; it only takes effect for + result sets against textual statements where no explicit + typemap was present. + + colspecs + A dictionary of TypeEngine classes from sqlalchemy.types mapped + to subclasses that are specific to the dialect class. This + dictionary is class-level only and is not accessed from the + dialect instance itself. + + supports_default_values + Indicates if the construct ``INSERT INTO tablename DEFAULT + VALUES`` is supported + + supports_sequences + Indicates if the dialect supports CREATE SEQUENCE or similar. + + sequences_optional + If True, indicates if the "optional" flag on the Sequence() construct + should signal to not generate a CREATE SEQUENCE. Applies only to + dialects that support sequences. Currently used only to allow Postgresql + SERIAL to be used on a column that specifies Sequence() for usage on + other backends. + + supports_native_enum + Indicates if the dialect supports a native ENUM construct. + This will prevent types.Enum from generating a CHECK + constraint when that type is used. + + supports_native_boolean + Indicates if the dialect supports a native boolean construct. + This will prevent types.Boolean from generating a CHECK + constraint when that type is used. + + """ + + def create_connect_args(self, url): + """Build DB-API compatible connection arguments. + + Given a :class:`~sqlalchemy.engine.url.URL` object, returns a tuple + consisting of a `*args`/`**kwargs` suitable to send directly + to the dbapi's connect function. + + """ + + raise NotImplementedError() + + @classmethod + def type_descriptor(cls, typeobj): + """Transform a generic type to a dialect-specific type. + + Dialect classes will usually use the + :func:`~sqlalchemy.types.adapt_type` function in the types module to + make this job easy. + + The returned result is cached *per dialect class* so can + contain no dialect-instance state. + + """ + + raise NotImplementedError() + + def initialize(self, connection): + """Called during strategized creation of the dialect with a connection. + + Allows dialects to configure options based on server version info or + other properties. + + The connection passed here is a SQLAlchemy Connection object, + with full capabilities. + + The initalize() method of the base dialect should be called via + super(). + + """ + + pass + + def reflecttable(self, connection, table, include_columns=None): + """Load table description from the database. + + Given a :class:`~sqlalchemy.engine.Connection` and a + :class:`~sqlalchemy.schema.Table` object, reflect its columns and + properties from the database. If include_columns (a list or + set) is specified, limit the autoload to the given column + names. + + The default implementation uses the + :class:`~sqlalchemy.engine.reflection.Inspector` interface to + provide the output, building upon the granular table/column/ + constraint etc. methods of :class:`Dialect`. + + """ + + raise NotImplementedError() + + def get_columns(self, connection, table_name, schema=None, **kw): + """Return information about columns in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name`, and an optional string `schema`, return column + information as a list of dictionaries with these keys: + + name + the column's name + + type + [sqlalchemy.types#TypeEngine] + + nullable + boolean + + default + the column's default value + + autoincrement + boolean + + sequence + a dictionary of the form + {'name' : str, 'start' :int, 'increment': int} + + Additional column attributes may be present. + """ + + raise NotImplementedError() + + def get_primary_keys(self, connection, table_name, schema=None, **kw): + """Return information about primary keys in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name`, and an optional string `schema`, return primary + key information as a list of column names. + """ + + raise NotImplementedError() + + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + """Return information about foreign_keys in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name`, and an optional string `schema`, return foreign + key information as a list of dicts with these keys: + + name + the constraint's name + + constrained_columns + a list of column names that make up the foreign key + + referred_schema + the name of the referred schema + + referred_table + the name of the referred table + + referred_columns + a list of column names in the referred table that correspond to + constrained_columns + """ + + raise NotImplementedError() + + def get_table_names(self, connection, schema=None, **kw): + """Return a list of table names for `schema`.""" + + raise NotImplementedError + + def get_view_names(self, connection, schema=None, **kw): + """Return a list of all view names available in the database. + + schema: + Optional, retrieve names from a non-default schema. + """ + + raise NotImplementedError() + + def get_view_definition(self, connection, view_name, schema=None, **kw): + """Return view definition. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `view_name`, and an optional string `schema`, return the view + definition. + """ + + raise NotImplementedError() + + def get_indexes(self, connection, table_name, schema=None, **kw): + """Return information about indexes in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name` and an optional string `schema`, return index + information as a list of dictionaries with these keys: + + name + the index's name + + column_names + list of column names in order + + unique + boolean + """ + + raise NotImplementedError() + + def normalize_name(self, name): + """convert the given name to lowercase if it is detected as case insensitive. + + this method is only used if the dialect defines requires_name_normalize=True. + + """ + raise NotImplementedError() + + def denormalize_name(self, name): + """convert the given name to a case insensitive identifier for the backend + if it is an all-lowercase name. + + this method is only used if the dialect defines requires_name_normalize=True. + + """ + raise NotImplementedError() + + def has_table(self, connection, table_name, schema=None): + """Check the existence of a particular table in the database. + + Given a :class:`~sqlalchemy.engine.Connection` object and a string + `table_name`, return True if the given table (possibly within + the specified `schema`) exists in the database, False + otherwise. + """ + + raise NotImplementedError() + + def has_sequence(self, connection, sequence_name, schema=None): + """Check the existence of a particular sequence in the database. + + Given a :class:`~sqlalchemy.engine.Connection` object and a string + `sequence_name`, return True if the given sequence exists in + the database, False otherwise. + """ + + raise NotImplementedError() + + def _get_server_version_info(self, connection): + """Retrieve the server version info from the given connection. + + This is used by the default implementation to populate the + "server_version_info" attribute and is called exactly + once upon first connect. + + """ + + raise NotImplementedError() + + def _get_default_schema_name(self, connection): + """Return the string name of the currently selected schema from the given connection. + + This is used by the default implementation to populate the + "default_schema_name" attribute and is called exactly + once upon first connect. + + """ + + raise NotImplementedError() + + def do_begin(self, connection): + """Provide an implementation of *connection.begin()*, given a DB-API connection.""" + + raise NotImplementedError() + + def do_rollback(self, connection): + """Provide an implementation of *connection.rollback()*, given a DB-API connection.""" + + raise NotImplementedError() + + def create_xid(self): + """Create a two-phase transaction ID. + + This id will be passed to do_begin_twophase(), + do_rollback_twophase(), do_commit_twophase(). Its format is + unspecified. + """ + + raise NotImplementedError() + + def do_commit(self, connection): + """Provide an implementation of *connection.commit()*, given a DB-API connection.""" + + raise NotImplementedError() + + def do_savepoint(self, connection, name): + """Create a savepoint with the given name on a SQLAlchemy connection.""" + + raise NotImplementedError() + + def do_rollback_to_savepoint(self, connection, name): + """Rollback a SQL Alchemy connection to the named savepoint.""" + + raise NotImplementedError() + + def do_release_savepoint(self, connection, name): + """Release the named savepoint on a SQL Alchemy connection.""" + + raise NotImplementedError() + + def do_begin_twophase(self, connection, xid): + """Begin a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_prepare_twophase(self, connection, xid): + """Prepare a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + """Rollback a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + """Commit a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_recover_twophase(self, connection): + """Recover list of uncommited prepared two phase transaction identifiers on the given connection.""" + + raise NotImplementedError() + + def do_executemany(self, cursor, statement, parameters, context=None): + """Provide an implementation of *cursor.executemany(statement, parameters)*.""" + + raise NotImplementedError() + + def do_execute(self, cursor, statement, parameters, context=None): + """Provide an implementation of *cursor.execute(statement, parameters)*.""" + + raise NotImplementedError() + + def is_disconnect(self, e): + """Return True if the given DB-API error indicates an invalid connection""" + + raise NotImplementedError() + + def on_connect(self): + """return a callable which sets up a newly created DBAPI connection. + + The callable accepts a single argument "conn" which is the + DBAPI connection itself. It has no return value. + + This is used to set dialect-wide per-connection options such as isolation + modes, unicode modes, etc. + + If a callable is returned, it will be assembled into a pool listener + that receives the direct DBAPI connection, with all wrappers removed. + + If None is returned, no listener will be generated. + + """ + return None + + +class ExecutionContext(object): + """A messenger object for a Dialect that corresponds to a single execution. + + ExecutionContext should have these data members: + + connection + Connection object which can be freely used by default value + generators to execute SQL. This Connection should reference the + same underlying connection/transactional resources of + root_connection. + + root_connection + Connection object which is the source of this ExecutionContext. This + Connection may have close_with_result=True set, in which case it can + only be used once. + + dialect + dialect which created this ExecutionContext. + + cursor + DB-API cursor procured from the connection, + + compiled + if passed to constructor, sqlalchemy.engine.base.Compiled object + being executed, + + statement + string version of the statement to be executed. Is either + passed to the constructor, or must be created from the + sql.Compiled object by the time pre_exec() has completed. + + parameters + bind parameters passed to the execute() method. For compiled + statements, this is a dictionary or list of dictionaries. For + textual statements, it should be in a format suitable for the + dialect's paramstyle (i.e. dict or list of dicts for non + positional, list or list of lists/tuples for positional). + + isinsert + True if the statement is an INSERT. + + isupdate + True if the statement is an UPDATE. + + should_autocommit + True if the statement is a "committable" statement. + + postfetch_cols + a list of Column objects for which a server-side default or + inline SQL expression value was fired off. Applies to inserts + and updates. + """ + + def create_cursor(self): + """Return a new cursor generated from this ExecutionContext's connection. + + Some dialects may wish to change the behavior of + connection.cursor(), such as postgresql which may return a PG + "server side" cursor. + """ + + raise NotImplementedError() + + def pre_exec(self): + """Called before an execution of a compiled statement. + + If a compiled statement was passed to this ExecutionContext, + the `statement` and `parameters` datamembers must be + initialized after this statement is complete. + """ + + raise NotImplementedError() + + def post_exec(self): + """Called after the execution of a compiled statement. + + If a compiled statement was passed to this ExecutionContext, + the `last_insert_ids`, `last_inserted_params`, etc. + datamembers should be available after this method completes. + """ + + raise NotImplementedError() + + def result(self): + """Return a result object corresponding to this ExecutionContext. + + Returns a ResultProxy. + """ + + raise NotImplementedError() + + def handle_dbapi_exception(self, e): + """Receive a DBAPI exception which occured upon execute, result fetch, etc.""" + + raise NotImplementedError() + + def should_autocommit_text(self, statement): + """Parse the given textual statement and return True if it refers to a "committable" statement""" + + raise NotImplementedError() + + def last_inserted_params(self): + """Return a dictionary of the full parameter dictionary for the last compiled INSERT statement. + + Includes any ColumnDefaults or Sequences that were pre-executed. + """ + + raise NotImplementedError() + + def last_updated_params(self): + """Return a dictionary of the full parameter dictionary for the last compiled UPDATE statement. + + Includes any ColumnDefaults that were pre-executed. + """ + + raise NotImplementedError() + + def lastrow_has_defaults(self): + """Return True if the last INSERT or UPDATE row contained + inlined or database-side defaults. + """ + + raise NotImplementedError() + + def get_rowcount(self): + """Return the number of rows produced (by a SELECT query) + or affected (by an INSERT/UPDATE/DELETE statement). + + Note that this row count may not be properly implemented + in some dialects; this is indicated by the + ``supports_sane_rowcount`` and ``supports_sane_multi_rowcount`` + dialect attributes. + + """ + + raise NotImplementedError() + + +class Compiled(object): + """Represent a compiled SQL or DDL expression. + + The ``__str__`` method of the ``Compiled`` object should produce + the actual text of the statement. ``Compiled`` objects are + specific to their underlying database dialect, and also may + or may not be specific to the columns referenced within a + particular set of bind parameters. In no case should the + ``Compiled`` object be dependent on the actual values of those + bind parameters, even though it may reference those values as + defaults. + """ + + def __init__(self, dialect, statement, bind=None): + """Construct a new ``Compiled`` object. + + :param dialect: ``Dialect`` to compile against. + + :param statement: ``ClauseElement`` to be compiled. + + :param bind: Optional Engine or Connection to compile this statement against. + """ + + self.dialect = dialect + self.statement = statement + self.bind = bind + self.can_execute = statement.supports_execution + + def compile(self): + """Produce the internal string representation of this element.""" + + self.string = self.process(self.statement) + + @property + def sql_compiler(self): + """Return a Compiled that is capable of processing SQL expressions. + + If this compiler is one, it would likely just return 'self'. + + """ + + raise NotImplementedError() + + def process(self, obj, **kwargs): + return obj._compiler_dispatch(self, **kwargs) + + def __str__(self): + """Return the string text of the generated SQL or DDL.""" + + return self.string or '' + + def construct_params(self, params=None): + """Return the bind params for this compiled object. + + :param params: a dict of string/object pairs whos values will + override bind values compiled in to the + statement. + """ + + raise NotImplementedError() + + @property + def params(self): + """Return the bind params for this compiled object.""" + return self.construct_params() + + def execute(self, *multiparams, **params): + """Execute this compiled object.""" + + e = self.bind + if e is None: + raise exc.UnboundExecutionError("This Compiled object is not bound to any Engine or Connection.") + return e._execute_compiled(self, multiparams, params) + + def scalar(self, *multiparams, **params): + """Execute this compiled object and return the result's scalar value.""" + + return self.execute(*multiparams, **params).scalar() + + +class TypeCompiler(object): + """Produces DDL specification for TypeEngine objects.""" + + def __init__(self, dialect): + self.dialect = dialect + + def process(self, type_): + return type_._compiler_dispatch(self) + + +class Connectable(object): + """Interface for an object which supports execution of SQL constructs. + + The two implementations of ``Connectable`` are :class:`Connection` and + :class:`Engine`. + + Connectable must also implement the 'dialect' member which references a + :class:`Dialect` instance. + """ + + def contextual_connect(self): + """Return a Connection object which may be part of an ongoing context.""" + + raise NotImplementedError() + + def create(self, entity, **kwargs): + """Create a table or index given an appropriate schema object.""" + + raise NotImplementedError() + + def drop(self, entity, **kwargs): + """Drop a table or index given an appropriate schema object.""" + + raise NotImplementedError() + + def execute(self, object, *multiparams, **params): + raise NotImplementedError() + + def _execute_clauseelement(self, elem, multiparams=None, params=None): + raise NotImplementedError() + + +class Connection(Connectable): + """Provides high-level functionality for a wrapped DB-API connection. + + Provides execution support for string-based SQL statements as well + as ClauseElement, Compiled and DefaultGenerator objects. Provides + a begin method to return Transaction objects. + + The Connection object is **not** thread-safe. + + .. index:: + single: thread safety; Connection + """ + _execution_options = util.frozendict() + + def __init__(self, engine, connection=None, close_with_result=False, + _branch=False, _execution_options=None): + """Construct a new Connection. + + Connection objects are typically constructed by an + :class:`~sqlalchemy.engine.Engine`, see the ``connect()`` and + ``contextual_connect()`` methods of Engine. + """ + self.engine = engine + self.__connection = connection or engine.raw_connection() + self.__transaction = None + self.should_close_with_result = close_with_result + self.__savepoint_seq = 0 + self.__branch = _branch + self.__invalid = False + self._echo = self.engine._should_log_info() + if _execution_options: + self._execution_options = self._execution_options.union(_execution_options) + + def _branch(self): + """Return a new Connection which references this Connection's + engine and connection; but does not have close_with_result enabled, + and also whose close() method does nothing. + + This is used to execute "sub" statements within a single execution, + usually an INSERT statement. + """ + + return self.engine.Connection(self.engine, self.__connection, _branch=True) + + def execution_options(self, **opt): + """ Set non-SQL options for the connection which take effect during execution. + + The method returns a copy of this :class:`Connection` which references + the same underlying DBAPI connection, but also defines the given execution + options which will take effect for a call to :meth:`execute`. As the new + :class:`Connection` references the same underlying resource, it is probably + best to ensure that the copies would be discarded immediately, which + is implicit if used as in:: + + result = connection.execution_options(stream_results=True).execute(stmt) + + The options are the same as those accepted by + :meth:`sqlalchemy.sql.expression.Executable.execution_options`. + + """ + return self.engine.Connection( + self.engine, self.__connection, + _branch=self.__branch, _execution_options=opt) + + @property + def dialect(self): + "Dialect used by this Connection." + + return self.engine.dialect + + @property + def closed(self): + """Return True if this connection is closed.""" + + return not self.__invalid and '_Connection__connection' not in self.__dict__ + + @property + def invalidated(self): + """Return True if this connection was invalidated.""" + + return self.__invalid + + @property + def connection(self): + "The underlying DB-API connection managed by this Connection." + + try: + return self.__connection + except AttributeError: + if self.__invalid: + if self.__transaction is not None: + raise exc.InvalidRequestError("Can't reconnect until invalid transaction is rolled back") + self.__connection = self.engine.raw_connection() + self.__invalid = False + return self.__connection + raise exc.InvalidRequestError("This Connection is closed") + + @property + def info(self): + """A collection of per-DB-API connection instance properties.""" + + return self.connection.info + + def connect(self): + """Returns self. + + This ``Connectable`` interface method returns self, allowing + Connections to be used interchangably with Engines in most + situations that require a bind. + """ + + return self + + def contextual_connect(self, **kwargs): + """Returns self. + + This ``Connectable`` interface method returns self, allowing + Connections to be used interchangably with Engines in most + situations that require a bind. + """ + + return self + + def invalidate(self, exception=None): + """Invalidate the underlying DBAPI connection associated with this Connection. + + The underlying DB-API connection is literally closed (if + possible), and is discarded. Its source connection pool will + typically lazily create a new connection to replace it. + + Upon the next usage, this Connection will attempt to reconnect + to the pool with a new connection. + + Transactions in progress remain in an "opened" state (even though + the actual transaction is gone); these must be explicitly + rolled back before a reconnect on this Connection can proceed. This + is to prevent applications from accidentally continuing their transactional + operations in a non-transactional state. + """ + + if self.closed: + raise exc.InvalidRequestError("This Connection is closed") + + if self.__connection.is_valid: + self.__connection.invalidate(exception) + del self.__connection + self.__invalid = True + + def detach(self): + """Detach the underlying DB-API connection from its connection pool. + + This Connection instance will remain useable. When closed, + the DB-API connection will be literally closed and not + returned to its pool. The pool will typically lazily create a + new connection to replace the detached connection. + + This method can be used to insulate the rest of an application + from a modified state on a connection (such as a transaction + isolation level or similar). Also see + :class:`~sqlalchemy.interfaces.PoolListener` for a mechanism to modify + connection state when connections leave and return to their + connection pool. + """ + + self.__connection.detach() + + def begin(self): + """Begin a transaction and return a Transaction handle. + + Repeated calls to ``begin`` on the same Connection will create + a lightweight, emulated nested transaction. Only the + outermost transaction may ``commit``. Calls to ``commit`` on + inner transactions are ignored. Any transaction in the + hierarchy may ``rollback``, however. + """ + + if self.__transaction is None: + self.__transaction = RootTransaction(self) + return self.__transaction + else: + return Transaction(self, self.__transaction) + + def begin_nested(self): + """Begin a nested transaction and return a Transaction handle. + + Nested transactions require SAVEPOINT support in the + underlying database. Any transaction in the hierarchy may + ``commit`` and ``rollback``, however the outermost transaction + still controls the overall ``commit`` or ``rollback`` of the + transaction of a whole. + """ + + if self.__transaction is None: + self.__transaction = RootTransaction(self) + else: + self.__transaction = NestedTransaction(self, self.__transaction) + return self.__transaction + + def begin_twophase(self, xid=None): + """Begin a two-phase or XA transaction and return a Transaction handle. + + :param xid: the two phase transaction id. If not supplied, a random id + will be generated. + """ + + if self.__transaction is not None: + raise exc.InvalidRequestError( + "Cannot start a two phase transaction when a transaction " + "is already in progress.") + if xid is None: + xid = self.engine.dialect.create_xid(); + self.__transaction = TwoPhaseTransaction(self, xid) + return self.__transaction + + def recover_twophase(self): + return self.engine.dialect.do_recover_twophase(self) + + def rollback_prepared(self, xid, recover=False): + self.engine.dialect.do_rollback_twophase(self, xid, recover=recover) + + def commit_prepared(self, xid, recover=False): + self.engine.dialect.do_commit_twophase(self, xid, recover=recover) + + def in_transaction(self): + """Return True if a transaction is in progress.""" + + return self.__transaction is not None + + def _begin_impl(self): + if self._echo: + self.engine.logger.info("BEGIN") + try: + self.engine.dialect.do_begin(self.connection) + except Exception, e: + self._handle_dbapi_exception(e, None, None, None, None) + raise + + def _rollback_impl(self): + # use getattr() for is_valid to support exceptions raised in dialect initializer, + # where we do not yet have the pool wrappers plugged in + if not self.closed and not self.invalidated and \ + getattr(self.__connection, 'is_valid', False): + if self._echo: + self.engine.logger.info("ROLLBACK") + try: + self.engine.dialect.do_rollback(self.connection) + self.__transaction = None + except Exception, e: + self._handle_dbapi_exception(e, None, None, None, None) + raise + else: + self.__transaction = None + + def _commit_impl(self): + if self._echo: + self.engine.logger.info("COMMIT") + try: + self.engine.dialect.do_commit(self.connection) + self.__transaction = None + except Exception, e: + self._handle_dbapi_exception(e, None, None, None, None) + raise + + def _savepoint_impl(self, name=None): + if name is None: + self.__savepoint_seq += 1 + name = 'sa_savepoint_%s' % self.__savepoint_seq + if self.__connection.is_valid: + self.engine.dialect.do_savepoint(self, name) + return name + + def _rollback_to_savepoint_impl(self, name, context): + if self.__connection.is_valid: + self.engine.dialect.do_rollback_to_savepoint(self, name) + self.__transaction = context + + def _release_savepoint_impl(self, name, context): + if self.__connection.is_valid: + self.engine.dialect.do_release_savepoint(self, name) + self.__transaction = context + + def _begin_twophase_impl(self, xid): + if self.__connection.is_valid: + self.engine.dialect.do_begin_twophase(self, xid) + + def _prepare_twophase_impl(self, xid): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.engine.dialect.do_prepare_twophase(self, xid) + + def _rollback_twophase_impl(self, xid, is_prepared): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.engine.dialect.do_rollback_twophase(self, xid, is_prepared) + self.__transaction = None + + def _commit_twophase_impl(self, xid, is_prepared): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.engine.dialect.do_commit_twophase(self, xid, is_prepared) + self.__transaction = None + + def _autorollback(self): + if not self.in_transaction(): + self._rollback_impl() + + def close(self): + """Close this Connection.""" + + try: + conn = self.__connection + except AttributeError: + return + if not self.__branch: + conn.close() + self.__invalid = False + del self.__connection + self.__transaction = None + + def scalar(self, object, *multiparams, **params): + """Executes and returns the first column of the first row. + + The underlying result/cursor is closed after execution. + """ + + return self.execute(object, *multiparams, **params).scalar() + + def execute(self, object, *multiparams, **params): + """Executes and returns a ResultProxy.""" + + for c in type(object).__mro__: + if c in Connection.executors: + return Connection.executors[c](self, object, multiparams, params) + else: + raise exc.InvalidRequestError("Unexecutable object type: " + str(type(object))) + + def __distill_params(self, multiparams, params): + """Given arguments from the calling form *multiparams, **params, return a list + of bind parameter structures, usually a list of dictionaries. + + In the case of 'raw' execution which accepts positional parameters, + it may be a list of tuples or lists. + + """ + + if not multiparams: + if params: + return [params] + else: + return [] + elif len(multiparams) == 1: + zero = multiparams[0] + if isinstance(zero, (list, tuple)): + if not zero or hasattr(zero[0], '__iter__'): + return zero + else: + return [zero] + elif hasattr(zero, 'keys'): + return [zero] + else: + return [[zero]] + else: + if hasattr(multiparams[0], '__iter__'): + return multiparams + else: + return [multiparams] + + def _execute_function(self, func, multiparams, params): + return self._execute_clauseelement(func.select(), multiparams, params) + + def _execute_default(self, default, multiparams, params): + ctx = self.__create_execution_context() + ret = ctx._exec_default(default) + if self.should_close_with_result: + self.close() + return ret + + def _execute_ddl(self, ddl, params, multiparams): + context = self.__create_execution_context( + compiled_ddl=ddl.compile(dialect=self.dialect), + parameters=None + ) + return self.__execute_context(context) + + def _execute_clauseelement(self, elem, multiparams, params): + params = self.__distill_params(multiparams, params) + if params: + keys = params[0].keys() + else: + keys = [] + + context = self.__create_execution_context( + compiled_sql=elem.compile( + dialect=self.dialect, column_keys=keys, + inline=len(params) > 1), + parameters=params + ) + return self.__execute_context(context) + + def _execute_compiled(self, compiled, multiparams, params): + """Execute a sql.Compiled object.""" + + context = self.__create_execution_context( + compiled_sql=compiled, + parameters=self.__distill_params(multiparams, params) + ) + return self.__execute_context(context) + + def _execute_text(self, statement, multiparams, params): + parameters = self.__distill_params(multiparams, params) + context = self.__create_execution_context(statement=statement, parameters=parameters) + return self.__execute_context(context) + + def __execute_context(self, context): + if context.compiled: + context.pre_exec() + + if context.executemany: + self._cursor_executemany( + context.cursor, + context.statement, + context.parameters, context=context) + else: + self._cursor_execute( + context.cursor, + context.statement, + context.parameters[0], context=context) + + if context.compiled: + context.post_exec() + + if context.isinsert and not context.executemany: + context.post_insert() + + # create a resultproxy, get rowcount/implicit RETURNING + # rows, close cursor if no further results pending + r = context.get_result_proxy()._autoclose() + + if self.__transaction is None and context.should_autocommit: + self._commit_impl() + + if r.closed and self.should_close_with_result: + self.close() + + return r + + def _handle_dbapi_exception(self, e, statement, parameters, cursor, context): + if getattr(self, '_reentrant_error', False): + # Py3K + #raise exc.DBAPIError.instance(statement, parameters, e) from e + # Py2K + raise exc.DBAPIError.instance(statement, parameters, e), None, sys.exc_info()[2] + # end Py2K + self._reentrant_error = True + try: + if not isinstance(e, self.dialect.dbapi.Error): + return + + if context: + context.handle_dbapi_exception(e) + + is_disconnect = self.dialect.is_disconnect(e) + if is_disconnect: + self.invalidate(e) + self.engine.dispose() + else: + if cursor: + cursor.close() + self._autorollback() + if self.should_close_with_result: + self.close() + # Py3K + #raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) from e + # Py2K + raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect), None, sys.exc_info()[2] + # end Py2K + + finally: + del self._reentrant_error + + def __create_execution_context(self, **kwargs): + try: + dialect = self.engine.dialect + return dialect.execution_ctx_cls(dialect, connection=self, **kwargs) + except Exception, e: + self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None, None) + raise + + def _cursor_execute(self, cursor, statement, parameters, context=None): + if self._echo: + self.engine.logger.info(statement) + self.engine.logger.info("%r", parameters) + try: + self.dialect.do_execute(cursor, statement, parameters, context=context) + except Exception, e: + self._handle_dbapi_exception(e, statement, parameters, cursor, context) + raise + + def _cursor_executemany(self, cursor, statement, parameters, context=None): + if self._echo: + self.engine.logger.info(statement) + self.engine.logger.info("%r", parameters) + try: + self.dialect.do_executemany(cursor, statement, parameters, context=context) + except Exception, e: + self._handle_dbapi_exception(e, statement, parameters, cursor, context) + raise + + # poor man's multimethod/generic function thingy + executors = { + expression.FunctionElement: _execute_function, + expression.ClauseElement: _execute_clauseelement, + Compiled: _execute_compiled, + schema.SchemaItem: _execute_default, + schema.DDLElement: _execute_ddl, + basestring: _execute_text + } + + def create(self, entity, **kwargs): + """Create a Table or Index given an appropriate Schema object.""" + + return self.engine.create(entity, connection=self, **kwargs) + + def drop(self, entity, **kwargs): + """Drop a Table or Index given an appropriate Schema object.""" + + return self.engine.drop(entity, connection=self, **kwargs) + + def reflecttable(self, table, include_columns=None): + """Reflect the columns in the given string table name from the database.""" + + return self.engine.reflecttable(table, self, include_columns) + + def default_schema_name(self): + return self.engine.dialect.get_default_schema_name(self) + + def transaction(self, callable_, *args, **kwargs): + """Execute the given function within a transaction boundary. + + This is a shortcut for explicitly calling `begin()` and `commit()` + and optionally `rollback()` when exceptions are raised. The + given `*args` and `**kwargs` will be passed to the function. + + See also transaction() on engine. + + """ + + trans = self.begin() + try: + ret = self.run_callable(callable_, *args, **kwargs) + trans.commit() + return ret + except: + trans.rollback() + raise + + def run_callable(self, callable_, *args, **kwargs): + return callable_(self, *args, **kwargs) + + +class Transaction(object): + """Represent a Transaction in progress. + + The Transaction object is **not** threadsafe. + + .. index:: + single: thread safety; Transaction + """ + + def __init__(self, connection, parent): + self.connection = connection + self._parent = parent or self + self.is_active = True + + def close(self): + """Close this transaction. + + If this transaction is the base transaction in a begin/commit + nesting, the transaction will rollback(). Otherwise, the + method returns. + + This is used to cancel a Transaction without affecting the scope of + an enclosing transaction. + """ + if not self._parent.is_active: + return + if self._parent is self: + self.rollback() + + def rollback(self): + if not self._parent.is_active: + return + self._do_rollback() + self.is_active = False + + def _do_rollback(self): + self._parent.rollback() + + def commit(self): + if not self._parent.is_active: + raise exc.InvalidRequestError("This transaction is inactive") + self._do_commit() + self.is_active = False + + def _do_commit(self): + pass + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + if type is None and self.is_active: + self.commit() + else: + self.rollback() + + +class RootTransaction(Transaction): + def __init__(self, connection): + super(RootTransaction, self).__init__(connection, None) + self.connection._begin_impl() + + def _do_rollback(self): + if self.is_active: + self.connection._rollback_impl() + + def _do_commit(self): + if self.is_active: + self.connection._commit_impl() + + +class NestedTransaction(Transaction): + def __init__(self, connection, parent): + super(NestedTransaction, self).__init__(connection, parent) + self._savepoint = self.connection._savepoint_impl() + + def _do_rollback(self): + if self.is_active: + self.connection._rollback_to_savepoint_impl(self._savepoint, self._parent) + + def _do_commit(self): + if self.is_active: + self.connection._release_savepoint_impl(self._savepoint, self._parent) + + +class TwoPhaseTransaction(Transaction): + def __init__(self, connection, xid): + super(TwoPhaseTransaction, self).__init__(connection, None) + self._is_prepared = False + self.xid = xid + self.connection._begin_twophase_impl(self.xid) + + def prepare(self): + if not self._parent.is_active: + raise exc.InvalidRequestError("This transaction is inactive") + self.connection._prepare_twophase_impl(self.xid) + self._is_prepared = True + + def _do_rollback(self): + self.connection._rollback_twophase_impl(self.xid, self._is_prepared) + + def _do_commit(self): + self.connection._commit_twophase_impl(self.xid, self._is_prepared) + + +class Engine(Connectable, log.Identified): + """ + Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.base.Dialect` + together to provide a source of database connectivity and behavior. + + An :class:`Engine` object is instantiated publically using the :func:`~sqlalchemy.create_engine` + function. + + """ + + def __init__(self, pool, dialect, url, logging_name=None, echo=None, proxy=None): + self.pool = pool + self.url = url + self.dialect = dialect + if logging_name: + self.logging_name = logging_name + self.echo = echo + self.engine = self + self.logger = log.instance_logger(self, echoflag=echo) + if proxy: + self.Connection = _proxy_connection_cls(Connection, proxy) + else: + self.Connection = Connection + + @property + def name(self): + "String name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``." + + return self.dialect.name + + @property + def driver(self): + "Driver name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``." + + return self.dialect.driver + + echo = log.echo_property() + + def __repr__(self): + return 'Engine(%s)' % str(self.url) + + def dispose(self): + self.pool.dispose() + self.pool = self.pool.recreate() + + def create(self, entity, connection=None, **kwargs): + """Create a table or index within this engine's database connection given a schema.Table object.""" + + from sqlalchemy.engine import ddl + + self._run_visitor(ddl.SchemaGenerator, entity, connection=connection, **kwargs) + + def drop(self, entity, connection=None, **kwargs): + """Drop a table or index within this engine's database connection given a schema.Table object.""" + + from sqlalchemy.engine import ddl + + self._run_visitor(ddl.SchemaDropper, entity, connection=connection, **kwargs) + + def _execute_default(self, default): + connection = self.contextual_connect() + try: + return connection._execute_default(default, (), {}) + finally: + connection.close() + + @property + def func(self): + return expression._FunctionGenerator(bind=self) + + def text(self, text, *args, **kwargs): + """Return a sql.text() object for performing literal queries.""" + + return expression.text(text, bind=self, *args, **kwargs) + + def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): + if connection is None: + conn = self.contextual_connect(close_with_result=False) + else: + conn = connection + try: + visitorcallable(self.dialect, conn, **kwargs).traverse(element) + finally: + if connection is None: + conn.close() + + def transaction(self, callable_, *args, **kwargs): + """Execute the given function within a transaction boundary. + + This is a shortcut for explicitly calling `begin()` and `commit()` + and optionally `rollback()` when exceptions are raised. The + given `*args` and `**kwargs` will be passed to the function. + + The connection used is that of contextual_connect(). + + See also the similar method on Connection itself. + + """ + + conn = self.contextual_connect() + try: + return conn.transaction(callable_, *args, **kwargs) + finally: + conn.close() + + def run_callable(self, callable_, *args, **kwargs): + conn = self.contextual_connect() + try: + return conn.run_callable(callable_, *args, **kwargs) + finally: + conn.close() + + def execute(self, statement, *multiparams, **params): + connection = self.contextual_connect(close_with_result=True) + return connection.execute(statement, *multiparams, **params) + + def scalar(self, statement, *multiparams, **params): + return self.execute(statement, *multiparams, **params).scalar() + + def _execute_clauseelement(self, elem, multiparams=None, params=None): + connection = self.contextual_connect(close_with_result=True) + return connection._execute_clauseelement(elem, multiparams, params) + + def _execute_compiled(self, compiled, multiparams, params): + connection = self.contextual_connect(close_with_result=True) + return connection._execute_compiled(compiled, multiparams, params) + + def connect(self, **kwargs): + """Return a newly allocated Connection object.""" + + return self.Connection(self, **kwargs) + + def contextual_connect(self, close_with_result=False, **kwargs): + """Return a Connection object which may be newly allocated, or may be part of some ongoing context. + + This Connection is meant to be used by the various "auto-connecting" operations. + """ + + return self.Connection(self, self.pool.connect(), close_with_result=close_with_result, **kwargs) + + def table_names(self, schema=None, connection=None): + """Return a list of all table names available in the database. + + :param schema: Optional, retrieve names from a non-default schema. + + :param connection: Optional, use a specified connection. Default is the + ``contextual_connect`` for this ``Engine``. + """ + + if connection is None: + conn = self.contextual_connect() + else: + conn = connection + if not schema: + schema = self.dialect.default_schema_name + try: + return self.dialect.get_table_names(conn, schema) + finally: + if connection is None: + conn.close() + + def reflecttable(self, table, connection=None, include_columns=None): + """Given a Table object, reflects its columns and properties from the database.""" + + if connection is None: + conn = self.contextual_connect() + else: + conn = connection + try: + self.dialect.reflecttable(conn, table, include_columns) + finally: + if connection is None: + conn.close() + + def has_table(self, table_name, schema=None): + return self.run_callable(self.dialect.has_table, table_name, schema) + + def raw_connection(self): + """Return a DB-API connection.""" + + return self.pool.unique_connection() + + +def _proxy_connection_cls(cls, proxy): + class ProxyConnection(cls): + def execute(self, object, *multiparams, **params): + return proxy.execute(self, super(ProxyConnection, self).execute, + object, *multiparams, **params) + + def _execute_clauseelement(self, elem, multiparams=None, params=None): + return proxy.execute(self, super(ProxyConnection, self).execute, + elem, *(multiparams or []), **(params or {})) + + def _cursor_execute(self, cursor, statement, parameters, context=None): + return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, + cursor, statement, parameters, context, False) + + def _cursor_executemany(self, cursor, statement, parameters, context=None): + return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, + cursor, statement, parameters, context, True) + + def _begin_impl(self): + return proxy.begin(self, super(ProxyConnection, self)._begin_impl) + + def _rollback_impl(self): + return proxy.rollback(self, super(ProxyConnection, self)._rollback_impl) + + def _commit_impl(self): + return proxy.commit(self, super(ProxyConnection, self)._commit_impl) + + def _savepoint_impl(self, name=None): + return proxy.savepoint(self, super(ProxyConnection, self)._savepoint_impl, name=name) + + def _rollback_to_savepoint_impl(self, name, context): + return proxy.rollback_savepoint(self, + super(ProxyConnection, self)._rollback_to_savepoint_impl, + name, context) + + def _release_savepoint_impl(self, name, context): + return proxy.release_savepoint(self, + super(ProxyConnection, self)._release_savepoint_impl, + name, context) + + def _begin_twophase_impl(self, xid): + return proxy.begin_twophase(self, + super(ProxyConnection, self)._begin_twophase_impl, xid) + + def _prepare_twophase_impl(self, xid): + return proxy.prepare_twophase(self, + super(ProxyConnection, self)._prepare_twophase_impl, xid) + + def _rollback_twophase_impl(self, xid, is_prepared): + return proxy.rollback_twophase(self, + super(ProxyConnection, self)._rollback_twophase_impl, + xid, is_prepared) + + def _commit_twophase_impl(self, xid, is_prepared): + return proxy.commit_twophase(self, + super(ProxyConnection, self)._commit_twophase_impl, + xid, is_prepared) + + return ProxyConnection + +# This reconstructor is necessary so that pickles with the C extension or +# without use the same Binary format. +try: + # We need a different reconstructor on the C extension so that we can + # add extra checks that fields have correctly been initialized by + # __setstate__. + from sqlalchemy.cresultproxy import safe_rowproxy_reconstructor + + # The extra function embedding is needed so that the reconstructor function + # has the same signature whether or not the extension is present. + def rowproxy_reconstructor(cls, state): + return safe_rowproxy_reconstructor(cls, state) +except ImportError: + def rowproxy_reconstructor(cls, state): + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + +try: + from sqlalchemy.cresultproxy import BaseRowProxy +except ImportError: + class BaseRowProxy(object): + __slots__ = ('_parent', '_row', '_processors', '_keymap') + + def __init__(self, parent, row, processors, keymap): + """RowProxy objects are constructed by ResultProxy objects.""" + + self._parent = parent + self._row = row + self._processors = processors + self._keymap = keymap + + def __reduce__(self): + return (rowproxy_reconstructor, + (self.__class__, self.__getstate__())) + + def values(self): + """Return the values represented by this RowProxy as a list.""" + return list(self) + + def __iter__(self): + for processor, value in izip(self._processors, self._row): + if processor is None: + yield value + else: + yield processor(value) + + def __len__(self): + return len(self._row) + + def __getitem__(self, key): + try: + processor, index = self._keymap[key] + except KeyError: + processor, index = self._parent._key_fallback(key) + except TypeError: + if isinstance(key, slice): + l = [] + for processor, value in izip(self._processors[key], + self._row[key]): + if processor is None: + l.append(value) + else: + l.append(processor(value)) + return tuple(l) + else: + raise + if index is None: + raise exc.InvalidRequestError( + "Ambiguous column name '%s' in result set! " + "try 'use_labels' option on select statement." % key) + if processor is not None: + return processor(self._row[index]) + else: + return self._row[index] + + def __getattr__(self, name): + try: + # TODO: no test coverage here + return self[name] + except KeyError, e: + raise AttributeError(e.args[0]) + + +class RowProxy(BaseRowProxy): + """Proxy values from a single cursor row. + + Mostly follows "ordered dictionary" behavior, mapping result + values to the string-based column name, the integer position of + the result in the row, as well as Column instances which can be + mapped to the original Columns that produced this result set (for + results that correspond to constructed SQL expressions). + """ + __slots__ = () + + def __contains__(self, key): + return self._parent._has_key(self._row, key) + + def __getstate__(self): + return { + '_parent': self._parent, + '_row': tuple(self) + } + + def __setstate__(self, state): + self._parent = parent = state['_parent'] + self._row = state['_row'] + self._processors = parent._processors + self._keymap = parent._keymap + + __hash__ = None + + def __eq__(self, other): + return other is self or other == tuple(self) + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return repr(tuple(self)) + + def has_key(self, key): + """Return True if this RowProxy contains the given key.""" + + return self._parent._has_key(self._row, key) + + def items(self): + """Return a list of tuples, each tuple containing a key/value pair.""" + # TODO: no coverage here + return [(key, self[key]) for key in self.iterkeys()] + + def keys(self): + """Return the list of keys as strings represented by this RowProxy.""" + + return self._parent.keys + + def iterkeys(self): + return iter(self._parent.keys) + + def itervalues(self): + return iter(self) + + +class ResultMetaData(object): + """Handle cursor.description, applying additional info from an execution context.""" + + def __init__(self, parent, metadata): + self._processors = processors = [] + + # We do not strictly need to store the processor in the key mapping, + # though it is faster in the Python version (probably because of the + # saved attribute lookup self._processors) + self._keymap = keymap = {} + self.keys = [] + self._echo = parent._echo + context = parent.context + dialect = context.dialect + typemap = dialect.dbapi_type_map + + for i, (colname, coltype) in enumerate(m[0:2] for m in metadata): + if dialect.description_encoding: + colname = colname.decode(dialect.description_encoding) + + if '.' in colname: + # sqlite will in some circumstances prepend table name to + # colnames, so strip + origname = colname + colname = colname.split('.')[-1] + else: + origname = None + + if context.result_map: + try: + name, obj, type_ = context.result_map[colname.lower()] + except KeyError: + name, obj, type_ = \ + colname, None, typemap.get(coltype, types.NULLTYPE) + else: + name, obj, type_ = (colname, None, typemap.get(coltype, types.NULLTYPE)) + + processor = type_.dialect_impl(dialect).\ + result_processor(dialect, coltype) + + processors.append(processor) + rec = (processor, i) + + # indexes as keys. This is only needed for the Python version of + # RowProxy (the C version uses a faster path for integer indexes). + keymap[i] = rec + + # Column names as keys + if keymap.setdefault(name.lower(), rec) is not rec: + # We do not raise an exception directly because several + # columns colliding by name is not a problem as long as the + # user does not try to access them (ie use an index directly, + # or the more precise ColumnElement) + keymap[name.lower()] = (processor, None) + + # store the "origname" if we truncated (sqlite only) + if origname and \ + keymap.setdefault(origname.lower(), rec) is not rec: + keymap[origname.lower()] = (processor, None) + + if dialect.requires_name_normalize: + colname = dialect.normalize_name(colname) + + self.keys.append(colname) + if obj: + for o in obj: + keymap[o] = rec + + if self._echo: + self.logger = context.engine.logger + self.logger.debug( + "Col %r", tuple(x[0] for x in metadata)) + + def _key_fallback(self, key): + map = self._keymap + result = None + if isinstance(key, basestring): + result = map.get(key.lower()) + # fallback for targeting a ColumnElement to a textual expression + # this is a rare use case which only occurs when matching text() + # constructs to ColumnElements, and after a pickle/unpickle roundtrip + elif isinstance(key, expression.ColumnElement): + if key._label and key._label.lower() in map: + result = map[key._label.lower()] + elif hasattr(key, 'name') and key.name.lower() in map: + result = map[key.name.lower()] + if result is None: + raise exc.NoSuchColumnError( + "Could not locate column in row for column '%s'" % key) + else: + map[key] = result + return result + + def _has_key(self, row, key): + if key in self._keymap: + return True + else: + try: + self._key_fallback(key) + return True + except exc.NoSuchColumnError: + return False + + def __len__(self): + return len(self.keys) + + def __getstate__(self): + return { + '_pickled_keymap': dict( + (key, index) + for key, (processor, index) in self._keymap.iteritems() + if isinstance(key, (basestring, int)) + ), + 'keys': self.keys + } + + def __setstate__(self, state): + # the row has been processed at pickling time so we don't need any + # processor anymore + self._processors = [None for _ in xrange(len(state['keys']))] + self._keymap = keymap = {} + for key, index in state['_pickled_keymap'].iteritems(): + keymap[key] = (None, index) + self.keys = state['keys'] + self._echo = False + + +class ResultProxy(object): + """Wraps a DB-API cursor object to provide easier access to row columns. + + Individual columns may be accessed by their integer position, + case-insensitive column name, or by ``schema.Column`` + object. e.g.:: + + row = fetchone() + + col1 = row[0] # access via integer position + + col2 = row['col2'] # access via name + + col3 = row[mytable.c.mycol] # access via Column object. + + ``ResultProxy`` also handles post-processing of result column + data using ``TypeEngine`` objects, which are referenced from + the originating SQL statement that produced this result set. + + """ + + _process_row = RowProxy + out_parameters = None + _can_close_connection = False + + def __init__(self, context): + self.context = context + self.dialect = context.dialect + self.closed = False + self.cursor = context.cursor + self.connection = context.root_connection + self._echo = self.connection._echo and \ + context.engine._should_log_debug() + self._init_metadata() + + def _init_metadata(self): + metadata = self._cursor_description() + if metadata is None: + self._metadata = None + else: + self._metadata = ResultMetaData(self, metadata) + + def keys(self): + """Return the current set of string keys for rows.""" + if self._metadata: + return self._metadata.keys + else: + return [] + + @util.memoized_property + def rowcount(self): + """Return the 'rowcount' for this result. + + The 'rowcount' reports the number of rows affected + by an UPDATE or DELETE statement. It has *no* other + uses and is not intended to provide the number of rows + present from a SELECT. + + Note that this row count may not be properly implemented + in some dialects; this is indicated by + :meth:`~sqlalchemy.engine.base.ResultProxy.supports_sane_rowcount()` and + :meth:`~sqlalchemy.engine.base.ResultProxy.supports_sane_multi_rowcount()`. + + ``rowcount()`` also may not work at this time for a statement + that uses ``returning()``. + + """ + return self.context.rowcount + + @property + def lastrowid(self): + """return the 'lastrowid' accessor on the DBAPI cursor. + + This is a DBAPI specific method and is only functional + for those backends which support it, for statements + where it is appropriate. It's behavior is not + consistent across backends. + + Usage of this method is normally unnecessary; the + inserted_primary_key method provides a + tuple of primary key values for a newly inserted row, + regardless of database backend. + + """ + return self.cursor.lastrowid + + def _cursor_description(self): + """May be overridden by subclasses.""" + + return self.cursor.description + + def _autoclose(self): + """called by the Connection to autoclose cursors that have no pending results + beyond those used by an INSERT/UPDATE/DELETE with no explicit RETURNING clause. + + """ + if self.context.isinsert: + if self.context._is_implicit_returning: + self.context._fetch_implicit_returning(self) + self.close(_autoclose_connection=False) + elif not self.context._is_explicit_returning: + self.close(_autoclose_connection=False) + elif self._metadata is None: + # no results, get rowcount + # (which requires open cursor on some drivers + # such as kintersbasdb, mxodbc), + self.rowcount + self.close(_autoclose_connection=False) + + return self + + def close(self, _autoclose_connection=True): + """Close this ResultProxy. + + Closes the underlying DBAPI cursor corresponding to the execution. + + Note that any data cached within this ResultProxy is still available. + For some types of results, this may include buffered rows. + + If this ResultProxy was generated from an implicit execution, + the underlying Connection will also be closed (returns the + underlying DBAPI connection to the connection pool.) + + This method is called automatically when: + + * all result rows are exhausted using the fetchXXX() methods. + * cursor.description is None. + + """ + + if not self.closed: + self.closed = True + self.cursor.close() + if _autoclose_connection and \ + self.connection.should_close_with_result: + self.connection.close() + + def __iter__(self): + while True: + row = self.fetchone() + if row is None: + raise StopIteration + else: + yield row + + @util.memoized_property + def inserted_primary_key(self): + """Return the primary key for the row just inserted. + + This only applies to single row insert() constructs which + did not explicitly specify returning(). + + """ + if not self.context.isinsert: + raise exc.InvalidRequestError("Statement is not an insert() expression construct.") + elif self.context._is_explicit_returning: + raise exc.InvalidRequestError("Can't call inserted_primary_key when returning() is used.") + + return self.context._inserted_primary_key + + @util.deprecated("Use inserted_primary_key") + def last_inserted_ids(self): + """deprecated. use inserted_primary_key.""" + + return self.inserted_primary_key + + def last_updated_params(self): + """Return ``last_updated_params()`` from the underlying ExecutionContext. + + See ExecutionContext for details. + """ + + return self.context.last_updated_params() + + def last_inserted_params(self): + """Return ``last_inserted_params()`` from the underlying ExecutionContext. + + See ExecutionContext for details. + """ + + return self.context.last_inserted_params() + + def lastrow_has_defaults(self): + """Return ``lastrow_has_defaults()`` from the underlying ExecutionContext. + + See ExecutionContext for details. + """ + + return self.context.lastrow_has_defaults() + + def postfetch_cols(self): + """Return ``postfetch_cols()`` from the underlying ExecutionContext. + + See ExecutionContext for details. + """ + + return self.context.postfetch_cols + + def prefetch_cols(self): + return self.context.prefetch_cols + + def supports_sane_rowcount(self): + """Return ``supports_sane_rowcount`` from the dialect.""" + + return self.dialect.supports_sane_rowcount + + def supports_sane_multi_rowcount(self): + """Return ``supports_sane_multi_rowcount`` from the dialect.""" + + return self.dialect.supports_sane_multi_rowcount + + def _fetchone_impl(self): + return self.cursor.fetchone() + + def _fetchmany_impl(self, size=None): + return self.cursor.fetchmany(size) + + def _fetchall_impl(self): + return self.cursor.fetchall() + + def process_rows(self, rows): + process_row = self._process_row + metadata = self._metadata + keymap = metadata._keymap + processors = metadata._processors + if self._echo: + log = self.context.engine.logger.debug + l = [] + for row in rows: + log("Row %r", row) + l.append(process_row(metadata, row, processors, keymap)) + return l + else: + return [process_row(metadata, row, processors, keymap) + for row in rows] + + def fetchall(self): + """Fetch all rows, just like DB-API ``cursor.fetchall()``.""" + + try: + l = self.process_rows(self._fetchall_impl()) + self.close() + return l + except Exception, e: + self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) + raise + + def fetchmany(self, size=None): + """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``. + + If rows are present, the cursor remains open after this is called. + Else the cursor is automatically closed and an empty list is returned. + + """ + + try: + l = self.process_rows(self._fetchmany_impl(size)) + if len(l) == 0: + self.close() + return l + except Exception, e: + self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) + raise + + def fetchone(self): + """Fetch one row, just like DB-API ``cursor.fetchone()``. + + If a row is present, the cursor remains open after this is called. + Else the cursor is automatically closed and None is returned. + + """ + + try: + row = self._fetchone_impl() + if row is not None: + return self.process_rows([row])[0] + else: + self.close() + return None + except Exception, e: + self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) + raise + + def first(self): + """Fetch the first row and then close the result set unconditionally. + + Returns None if no row is present. + + """ + try: + row = self._fetchone_impl() + except Exception, e: + self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) + raise + + try: + if row is not None: + return self.process_rows([row])[0] + else: + return None + finally: + self.close() + + def scalar(self): + """Fetch the first column of the first row, and close the result set. + + Returns None if no row is present. + + """ + row = self.first() + if row is not None: + return row[0] + else: + return None + +class BufferedRowResultProxy(ResultProxy): + """A ResultProxy with row buffering behavior. + + ``ResultProxy`` that buffers the contents of a selection of rows + before ``fetchone()`` is called. This is to allow the results of + ``cursor.description`` to be available immediately, when + interfacing with a DB-API that requires rows to be consumed before + this information is available (currently psycopg2, when used with + server-side cursors). + + The pre-fetching behavior fetches only one row initially, and then + grows its buffer size by a fixed amount with each successive need + for additional rows up to a size of 100. + """ + + def _init_metadata(self): + self.__buffer_rows() + super(BufferedRowResultProxy, self)._init_metadata() + + # this is a "growth chart" for the buffering of rows. + # each successive __buffer_rows call will use the next + # value in the list for the buffer size until the max + # is reached + size_growth = { + 1 : 5, + 5 : 10, + 10 : 20, + 20 : 50, + 50 : 100 + } + + def __buffer_rows(self): + size = getattr(self, '_bufsize', 1) + self.__rowbuffer = self.cursor.fetchmany(size) + self._bufsize = self.size_growth.get(size, size) + + def _fetchone_impl(self): + if self.closed: + return None + if len(self.__rowbuffer) == 0: + self.__buffer_rows() + if len(self.__rowbuffer) == 0: + return None + return self.__rowbuffer.pop(0) + + def _fetchmany_impl(self, size=None): + result = [] + for x in range(0, size): + row = self._fetchone_impl() + if row is None: + break + result.append(row) + return result + + def _fetchall_impl(self): + ret = self.__rowbuffer + list(self.cursor.fetchall()) + self.__rowbuffer[:] = [] + return ret + +class FullyBufferedResultProxy(ResultProxy): + """A result proxy that buffers rows fully upon creation. + + Used for operations where a result is to be delivered + after the database conversation can not be continued, + such as MSSQL INSERT...OUTPUT after an autocommit. + + """ + def _init_metadata(self): + super(FullyBufferedResultProxy, self)._init_metadata() + self.__rowbuffer = self._buffer_rows() + + def _buffer_rows(self): + return self.cursor.fetchall() + + def _fetchone_impl(self): + if self.__rowbuffer: + return self.__rowbuffer.pop(0) + else: + return None + + def _fetchmany_impl(self, size=None): + result = [] + for x in range(0, size): + row = self._fetchone_impl() + if row is None: + break + result.append(row) + return result + + def _fetchall_impl(self): + ret = self.__rowbuffer + self.__rowbuffer = [] + return ret + +class BufferedColumnRow(RowProxy): + def __init__(self, parent, row, processors, keymap): + # preprocess row + row = list(row) + # this is a tad faster than using enumerate + index = 0 + for processor in parent._orig_processors: + if processor is not None: + row[index] = processor(row[index]) + index += 1 + row = tuple(row) + super(BufferedColumnRow, self).__init__(parent, row, + processors, keymap) + +class BufferedColumnResultProxy(ResultProxy): + """A ResultProxy with column buffering behavior. + + ``ResultProxy`` that loads all columns into memory each time + fetchone() is called. If fetchmany() or fetchall() are called, + the full grid of results is fetched. This is to operate with + databases where result rows contain "live" results that fall out + of scope unless explicitly fetched. Currently this includes + cx_Oracle LOB objects. + + """ + + _process_row = BufferedColumnRow + + def _init_metadata(self): + super(BufferedColumnResultProxy, self)._init_metadata() + metadata = self._metadata + # orig_processors will be used to preprocess each row when they are + # constructed. + metadata._orig_processors = metadata._processors + # replace the all type processors by None processors. + metadata._processors = [None for _ in xrange(len(metadata.keys))] + keymap = {} + for k, (func, index) in metadata._keymap.iteritems(): + keymap[k] = (None, index) + self._metadata._keymap = keymap + + def fetchall(self): + # can't call cursor.fetchall(), since rows must be + # fully processed before requesting more from the DBAPI. + l = [] + while True: + row = self.fetchone() + if row is None: + break + l.append(row) + return l + + def fetchmany(self, size=None): + # can't call cursor.fetchmany(), since rows must be + # fully processed before requesting more from the DBAPI. + if size is None: + return self.fetchall() + l = [] + for i in xrange(size): + row = self.fetchone() + if row is None: + break + l.append(row) + return l + +def connection_memoize(key): + """Decorator, memoize a function in a connection.info stash. + + Only applicable to functions which take no arguments other than a + connection. The memo will be stored in ``connection.info[key]``. + """ + + @util.decorator + def decorated(fn, self, connection): + connection = connection.connect() + try: + return connection.info[key] + except KeyError: + connection.info[key] = val = fn(self, connection) + return val + + return decorated diff --git a/sqlalchemy/engine/ddl.py b/sqlalchemy/engine/ddl.py new file mode 100644 index 0000000..ef10aa5 --- /dev/null +++ b/sqlalchemy/engine/ddl.py @@ -0,0 +1,128 @@ +# engine/ddl.py +# Copyright (C) 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Routines to handle CREATE/DROP workflow.""" + +from sqlalchemy import engine, schema +from sqlalchemy.sql import util as sql_util + + +class DDLBase(schema.SchemaVisitor): + def __init__(self, connection): + self.connection = connection + +class SchemaGenerator(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables and set(tables) or None + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def _can_create(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema) + + def visit_metadata(self, metadata): + if self.tables: + tables = self.tables + else: + tables = metadata.tables.values() + collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)] + + for listener in metadata.ddl_listeners['before-create']: + listener('before-create', metadata, self.connection, tables=collection) + + for table in collection: + self.traverse_single(table) + + for listener in metadata.ddl_listeners['after-create']: + listener('after-create', metadata, self.connection, tables=collection) + + def visit_table(self, table): + for listener in table.ddl_listeners['before-create']: + listener('before-create', table, self.connection) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.connection.execute(schema.CreateTable(table)) + + if hasattr(table, 'indexes'): + for index in table.indexes: + self.traverse_single(index) + + for listener in table.ddl_listeners['after-create']: + listener('after-create', table, self.connection) + + def visit_sequence(self, sequence): + if self.dialect.supports_sequences: + if ((not self.dialect.sequences_optional or + not sequence.optional) and + (not self.checkfirst or + not self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema))): + self.connection.execute(schema.CreateSequence(sequence)) + + def visit_index(self, index): + self.connection.execute(schema.CreateIndex(index)) + + +class SchemaDropper(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def visit_metadata(self, metadata): + if self.tables: + tables = self.tables + else: + tables = metadata.tables.values() + collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)] + + for listener in metadata.ddl_listeners['before-drop']: + listener('before-drop', metadata, self.connection, tables=collection) + + for table in collection: + self.traverse_single(table) + + for listener in metadata.ddl_listeners['after-drop']: + listener('after-drop', metadata, self.connection, tables=collection) + + def _can_drop(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema) + + def visit_index(self, index): + self.connection.execute(schema.DropIndex(index)) + + def visit_table(self, table): + for listener in table.ddl_listeners['before-drop']: + listener('before-drop', table, self.connection) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.connection.execute(schema.DropTable(table)) + + for listener in table.ddl_listeners['after-drop']: + listener('after-drop', table, self.connection) + + def visit_sequence(self, sequence): + if self.dialect.supports_sequences: + if ((not self.dialect.sequences_optional or + not sequence.optional) and + (not self.checkfirst or + self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema))): + self.connection.execute(schema.DropSequence(sequence)) diff --git a/sqlalchemy/engine/default.py b/sqlalchemy/engine/default.py new file mode 100644 index 0000000..6fb0a14 --- /dev/null +++ b/sqlalchemy/engine/default.py @@ -0,0 +1,700 @@ +# engine/default.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Default implementations of per-dialect sqlalchemy.engine classes. + +These are semi-private implementation classes which are only of importance +to database dialect authors; dialects will usually use the classes here +as the base class for their own corresponding classes. + +""" + +import re, random +from sqlalchemy.engine import base, reflection +from sqlalchemy.sql import compiler, expression +from sqlalchemy import exc, types as sqltypes, util + +AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', + re.I | re.UNICODE) + + +class DefaultDialect(base.Dialect): + """Default implementation of Dialect""" + + statement_compiler = compiler.SQLCompiler + ddl_compiler = compiler.DDLCompiler + type_compiler = compiler.GenericTypeCompiler + preparer = compiler.IdentifierPreparer + supports_alter = True + + # most DBAPIs happy with this for execute(). + # not cx_oracle. + execute_sequence_format = tuple + + supports_sequences = False + sequences_optional = False + preexecute_autoincrement_sequences = False + postfetch_lastrowid = True + implicit_returning = False + + supports_native_enum = False + supports_native_boolean = False + + # if the NUMERIC type + # returns decimal.Decimal. + # *not* the FLOAT type however. + supports_native_decimal = False + + # Py3K + #supports_unicode_statements = True + #supports_unicode_binds = True + # Py2K + supports_unicode_statements = False + supports_unicode_binds = False + returns_unicode_strings = False + # end Py2K + + name = 'default' + max_identifier_length = 9999 + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + dbapi_type_map = {} + colspecs = {} + default_paramstyle = 'named' + supports_default_values = False + supports_empty_insert = True + + server_version_info = None + + # indicates symbol names are + # UPPERCASEd if they are case insensitive + # within the database. + # if this is True, the methods normalize_name() + # and denormalize_name() must be provided. + requires_name_normalize = False + + reflection_options = () + + def __init__(self, convert_unicode=False, assert_unicode=False, + encoding='utf-8', paramstyle=None, dbapi=None, + implicit_returning=None, + label_length=None, **kwargs): + + if not getattr(self, 'ported_sqla_06', True): + util.warn( + "The %s dialect is not yet ported to SQLAlchemy 0.6" % self.name) + + self.convert_unicode = convert_unicode + if assert_unicode: + util.warn_deprecated("assert_unicode is deprecated. " + "SQLAlchemy emits a warning in all cases where it " + "would otherwise like to encode a Python unicode object " + "into a specific encoding but a plain bytestring is received. " + "This does *not* apply to DBAPIs that coerce Unicode natively." + ) + + self.encoding = encoding + self.positional = False + self._ischema = None + self.dbapi = dbapi + if paramstyle is not None: + self.paramstyle = paramstyle + elif self.dbapi is not None: + self.paramstyle = self.dbapi.paramstyle + else: + self.paramstyle = self.default_paramstyle + if implicit_returning is not None: + self.implicit_returning = implicit_returning + self.positional = self.paramstyle in ('qmark', 'format', 'numeric') + self.identifier_preparer = self.preparer(self) + self.type_compiler = self.type_compiler(self) + + if label_length and label_length > self.max_identifier_length: + raise exc.ArgumentError("Label length of %d is greater than this dialect's" + " maximum identifier length of %d" % + (label_length, self.max_identifier_length)) + self.label_length = label_length + + if not hasattr(self, 'description_encoding'): + self.description_encoding = getattr(self, 'description_encoding', encoding) + + @property + def dialect_description(self): + return self.name + "+" + self.driver + + def initialize(self, connection): + try: + self.server_version_info = self._get_server_version_info(connection) + except NotImplementedError: + self.server_version_info = None + try: + self.default_schema_name = self._get_default_schema_name(connection) + except NotImplementedError: + self.default_schema_name = None + + self.returns_unicode_strings = self._check_unicode_returns(connection) + + self.do_rollback(connection.connection) + + def on_connect(self): + """return a callable which sets up a newly created DBAPI connection. + + This is used to set dialect-wide per-connection options such as isolation + modes, unicode modes, etc. + + If a callable is returned, it will be assembled into a pool listener + that receives the direct DBAPI connection, with all wrappers removed. + + If None is returned, no listener will be generated. + + """ + return None + + def _check_unicode_returns(self, connection): + # Py2K + if self.supports_unicode_statements: + cast_to = unicode + else: + cast_to = str + # end Py2K + # Py3K + #cast_to = str + def check_unicode(type_): + cursor = connection.connection.cursor() + try: + cursor.execute( + cast_to( + expression.select( + [expression.cast( + expression.literal_column("'test unicode returns'"), type_) + ]).compile(dialect=self) + ) + ) + row = cursor.fetchone() + + return isinstance(row[0], unicode) + finally: + cursor.close() + + # detect plain VARCHAR + unicode_for_varchar = check_unicode(sqltypes.VARCHAR(60)) + + # detect if there's an NVARCHAR type with different behavior available + unicode_for_unicode = check_unicode(sqltypes.Unicode(60)) + + if unicode_for_unicode and not unicode_for_varchar: + return "conditional" + else: + return unicode_for_varchar + + def type_descriptor(self, typeobj): + """Provide a database-specific ``TypeEngine`` object, given + the generic object which comes from the types module. + + This method looks for a dictionary called + ``colspecs`` as a class or instance-level variable, + and passes on to ``types.adapt_type()``. + + """ + return sqltypes.adapt_type(typeobj, self.colspecs) + + def reflecttable(self, connection, table, include_columns): + insp = reflection.Inspector.from_engine(connection) + return insp.reflecttable(table, include_columns) + + def validate_identifier(self, ident): + if len(ident) > self.max_identifier_length: + raise exc.IdentifierError( + "Identifier '%s' exceeds maximum length of %d characters" % + (ident, self.max_identifier_length) + ) + + def connect(self, *cargs, **cparams): + return self.dbapi.connect(*cargs, **cparams) + + def create_connect_args(self, url): + opts = url.translate_connect_args() + opts.update(url.query) + return [[], opts] + + def do_begin(self, connection): + """Implementations might want to put logic here for turning + autocommit on/off, etc. + """ + + pass + + def do_rollback(self, connection): + """Implementations might want to put logic here for turning + autocommit on/off, etc. + """ + + connection.rollback() + + def do_commit(self, connection): + """Implementations might want to put logic here for turning + autocommit on/off, etc. + """ + + connection.commit() + + def create_xid(self): + """Create a random two-phase transaction ID. + + This id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). Its format is unspecified. + """ + + return "_sa_%032x" % random.randint(0, 2 ** 128) + + def do_savepoint(self, connection, name): + connection.execute(expression.SavepointClause(name)) + + def do_rollback_to_savepoint(self, connection, name): + connection.execute(expression.RollbackToSavepointClause(name)) + + def do_release_savepoint(self, connection, name): + connection.execute(expression.ReleaseSavepointClause(name)) + + def do_executemany(self, cursor, statement, parameters, context=None): + cursor.executemany(statement, parameters) + + def do_execute(self, cursor, statement, parameters, context=None): + cursor.execute(statement, parameters) + + def is_disconnect(self, e): + return False + + +class DefaultExecutionContext(base.ExecutionContext): + execution_options = util.frozendict() + isinsert = False + isupdate = False + isdelete = False + isddl = False + executemany = False + result_map = None + compiled = None + statement = None + + def __init__(self, + dialect, + connection, + compiled_sql=None, + compiled_ddl=None, + statement=None, + parameters=None): + + self.dialect = dialect + self._connection = self.root_connection = connection + self.engine = connection.engine + + if compiled_ddl is not None: + self.compiled = compiled = compiled_ddl + self.isddl = True + + if compiled.statement._execution_options: + self.execution_options = compiled.statement._execution_options + if connection._execution_options: + self.execution_options = self.execution_options.union( + connection._execution_options + ) + + if not dialect.supports_unicode_statements: + self.unicode_statement = unicode(compiled) + self.statement = self.unicode_statement.encode(self.dialect.encoding) + else: + self.statement = self.unicode_statement = unicode(compiled) + + self.cursor = self.create_cursor() + self.compiled_parameters = [] + self.parameters = [self._default_params] + + elif compiled_sql is not None: + self.compiled = compiled = compiled_sql + + if not compiled.can_execute: + raise exc.ArgumentError("Not an executable clause: %s" % compiled) + + if compiled.statement._execution_options: + self.execution_options = compiled.statement._execution_options + if connection._execution_options: + self.execution_options = self.execution_options.union( + connection._execution_options + ) + + # compiled clauseelement. process bind params, process table defaults, + # track collections used by ResultProxy to target and process results + + self.processors = dict( + (key, value) for key, value in + ( (compiled.bind_names[bindparam], + bindparam.bind_processor(self.dialect)) + for bindparam in compiled.bind_names ) + if value is not None) + + self.result_map = compiled.result_map + + if not dialect.supports_unicode_statements: + self.unicode_statement = unicode(compiled) + self.statement = self.unicode_statement.encode(self.dialect.encoding) + else: + self.statement = self.unicode_statement = unicode(compiled) + + self.isinsert = compiled.isinsert + self.isupdate = compiled.isupdate + self.isdelete = compiled.isdelete + + if not parameters: + self.compiled_parameters = [compiled.construct_params()] + else: + self.compiled_parameters = [compiled.construct_params(m, _group_number=grp) for + grp,m in enumerate(parameters)] + + self.executemany = len(parameters) > 1 + + self.cursor = self.create_cursor() + if self.isinsert or self.isupdate: + self.__process_defaults() + self.parameters = self.__convert_compiled_params(self.compiled_parameters) + + elif statement is not None: + # plain text statement + if connection._execution_options: + self.execution_options = self.execution_options.union(connection._execution_options) + self.parameters = self.__encode_param_keys(parameters) + self.executemany = len(parameters) > 1 + + if isinstance(statement, unicode) and not dialect.supports_unicode_statements: + self.unicode_statement = statement + self.statement = statement.encode(self.dialect.encoding) + else: + self.statement = self.unicode_statement = statement + + self.cursor = self.create_cursor() + else: + # no statement. used for standalone ColumnDefault execution. + if connection._execution_options: + self.execution_options = self.execution_options.union(connection._execution_options) + self.cursor = self.create_cursor() + + @util.memoized_property + def is_crud(self): + return self.isinsert or self.isupdate or self.isdelete + + @util.memoized_property + def should_autocommit(self): + autocommit = self.execution_options.get('autocommit', + not self.compiled and + self.statement and + expression.PARSE_AUTOCOMMIT + or False) + + if autocommit is expression.PARSE_AUTOCOMMIT: + return self.should_autocommit_text(self.unicode_statement) + else: + return autocommit + + @util.memoized_property + def _is_explicit_returning(self): + return self.compiled and \ + getattr(self.compiled.statement, '_returning', False) + + @util.memoized_property + def _is_implicit_returning(self): + return self.compiled and \ + bool(self.compiled.returning) and \ + not self.compiled.statement._returning + + @util.memoized_property + def _default_params(self): + if self.dialect.positional: + return self.dialect.execute_sequence_format() + else: + return {} + + def _execute_scalar(self, stmt): + """Execute a string statement on the current cursor, returning a scalar result. + + Used to fire off sequences, default phrases, and "select lastrowid" + types of statements individually + or in the context of a parent INSERT or UPDATE statement. + + """ + + conn = self._connection + if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements: + stmt = stmt.encode(self.dialect.encoding) + conn._cursor_execute(self.cursor, stmt, self._default_params) + return self.cursor.fetchone()[0] + + @property + def connection(self): + return self._connection._branch() + + def __encode_param_keys(self, params): + """Apply string encoding to the keys of dictionary-based bind parameters. + + This is only used executing textual, non-compiled SQL expressions. + + """ + + if not params: + return [self._default_params] + elif isinstance(params[0], self.dialect.execute_sequence_format): + return params + elif isinstance(params[0], dict): + if self.dialect.supports_unicode_statements: + return params + else: + def proc(d): + return dict((k.encode(self.dialect.encoding), d[k]) for k in d) + return [proc(d) for d in params] or [{}] + else: + return [self.dialect.execute_sequence_format(p) for p in params] + + + def __convert_compiled_params(self, compiled_parameters): + """Convert the dictionary of bind parameter values into a dict or list + to be sent to the DBAPI's execute() or executemany() method. + """ + + processors = self.processors + parameters = [] + if self.dialect.positional: + for compiled_params in compiled_parameters: + param = [] + for key in self.compiled.positiontup: + if key in processors: + param.append(processors[key](compiled_params[key])) + else: + param.append(compiled_params[key]) + parameters.append(self.dialect.execute_sequence_format(param)) + else: + encode = not self.dialect.supports_unicode_statements + for compiled_params in compiled_parameters: + param = {} + if encode: + encoding = self.dialect.encoding + for key in compiled_params: + if key in processors: + param[key.encode(encoding)] = processors[key](compiled_params[key]) + else: + param[key.encode(encoding)] = compiled_params[key] + else: + for key in compiled_params: + if key in processors: + param[key] = processors[key](compiled_params[key]) + else: + param[key] = compiled_params[key] + parameters.append(param) + return self.dialect.execute_sequence_format(parameters) + + def should_autocommit_text(self, statement): + return AUTOCOMMIT_REGEXP.match(statement) + + def create_cursor(self): + return self._connection.connection.cursor() + + def pre_exec(self): + pass + + def post_exec(self): + pass + + def get_lastrowid(self): + """return self.cursor.lastrowid, or equivalent, after an INSERT. + + This may involve calling special cursor functions, + issuing a new SELECT on the cursor (or a new one), + or returning a stored value that was + calculated within post_exec(). + + This function will only be called for dialects + which support "implicit" primary key generation, + keep preexecute_autoincrement_sequences set to False, + and when no explicit id value was bound to the + statement. + + The function is called once, directly after + post_exec() and before the transaction is committed + or ResultProxy is generated. If the post_exec() + method assigns a value to `self._lastrowid`, the + value is used in place of calling get_lastrowid(). + + Note that this method is *not* equivalent to the + ``lastrowid`` method on ``ResultProxy``, which is a + direct proxy to the DBAPI ``lastrowid`` accessor + in all cases. + + """ + + return self.cursor.lastrowid + + def handle_dbapi_exception(self, e): + pass + + def get_result_proxy(self): + return base.ResultProxy(self) + + @property + def rowcount(self): + return self.cursor.rowcount + + def supports_sane_rowcount(self): + return self.dialect.supports_sane_rowcount + + def supports_sane_multi_rowcount(self): + return self.dialect.supports_sane_multi_rowcount + + def post_insert(self): + if self.dialect.postfetch_lastrowid and \ + (not len(self._inserted_primary_key) or \ + None in self._inserted_primary_key): + + table = self.compiled.statement.table + lastrowid = self.get_lastrowid() + self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v + for c, v in zip(table.primary_key, self._inserted_primary_key) + ] + + def _fetch_implicit_returning(self, resultproxy): + table = self.compiled.statement.table + row = resultproxy.fetchone() + + self._inserted_primary_key = [v is not None and v or row[c] + for c, v in zip(table.primary_key, self._inserted_primary_key) + ] + + def last_inserted_params(self): + return self._last_inserted_params + + def last_updated_params(self): + return self._last_updated_params + + def lastrow_has_defaults(self): + return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols) + + def set_input_sizes(self, translate=None, exclude_types=None): + """Given a cursor and ClauseParameters, call the appropriate + style of ``setinputsizes()`` on the cursor, using DB-API types + from the bind parameter's ``TypeEngine`` objects. + """ + + if not hasattr(self.compiled, 'bind_names'): + return + + types = dict( + (self.compiled.bind_names[bindparam], bindparam.type) + for bindparam in self.compiled.bind_names) + + if self.dialect.positional: + inputsizes = [] + for key in self.compiled.positiontup: + typeengine = types[key] + dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if dbtype is not None and (not exclude_types or dbtype not in exclude_types): + inputsizes.append(dbtype) + try: + self.cursor.setinputsizes(*inputsizes) + except Exception, e: + self._connection._handle_dbapi_exception(e, None, None, None, self) + raise + else: + inputsizes = {} + for key in self.compiled.bind_names.values(): + typeengine = types[key] + dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if dbtype is not None and (not exclude_types or dbtype not in exclude_types): + if translate: + key = translate.get(key, key) + inputsizes[key.encode(self.dialect.encoding)] = dbtype + try: + self.cursor.setinputsizes(**inputsizes) + except Exception, e: + self._connection._handle_dbapi_exception(e, None, None, None, self) + raise + + def _exec_default(self, default): + if default.is_sequence: + return self.fire_sequence(default) + elif default.is_callable: + return default.arg(self) + elif default.is_clause_element: + # TODO: expensive branching here should be + # pulled into _exec_scalar() + conn = self.connection + c = expression.select([default.arg]).compile(bind=conn) + return conn._execute_compiled(c, (), {}).scalar() + else: + return default.arg + + def get_insert_default(self, column): + if column.default is None: + return None + else: + return self._exec_default(column.default) + + def get_update_default(self, column): + if column.onupdate is None: + return None + else: + return self._exec_default(column.onupdate) + + def __process_defaults(self): + """Generate default values for compiled insert/update statements, + and generate inserted_primary_key collection. + """ + + if self.executemany: + if len(self.compiled.prefetch): + scalar_defaults = {} + + # pre-determine scalar Python-side defaults + # to avoid many calls of get_insert_default()/get_update_default() + for c in self.compiled.prefetch: + if self.isinsert and c.default and c.default.is_scalar: + scalar_defaults[c] = c.default.arg + elif self.isupdate and c.onupdate and c.onupdate.is_scalar: + scalar_defaults[c] = c.onupdate.arg + + for param in self.compiled_parameters: + self.current_parameters = param + for c in self.compiled.prefetch: + if c in scalar_defaults: + val = scalar_defaults[c] + elif self.isinsert: + val = self.get_insert_default(c) + else: + val = self.get_update_default(c) + if val is not None: + param[c.key] = val + del self.current_parameters + + else: + self.current_parameters = compiled_parameters = self.compiled_parameters[0] + + for c in self.compiled.prefetch: + if self.isinsert: + val = self.get_insert_default(c) + else: + val = self.get_update_default(c) + + if val is not None: + compiled_parameters[c.key] = val + del self.current_parameters + + if self.isinsert: + self._inserted_primary_key = [compiled_parameters.get(c.key, None) + for c in self.compiled.statement.table.primary_key] + self._last_inserted_params = compiled_parameters + else: + self._last_updated_params = compiled_parameters + + self.postfetch_cols = self.compiled.postfetch + self.prefetch_cols = self.compiled.prefetch + +DefaultDialect.execution_ctx_cls = DefaultExecutionContext diff --git a/sqlalchemy/engine/reflection.py b/sqlalchemy/engine/reflection.py new file mode 100644 index 0000000..57f2205 --- /dev/null +++ b/sqlalchemy/engine/reflection.py @@ -0,0 +1,370 @@ +"""Provides an abstraction for obtaining database schema information. + +Usage Notes: + +Here are some general conventions when accessing the low level inspector +methods such as get_table_names, get_columns, etc. + +1. Inspector methods return lists of dicts in most cases for the following + reasons: + + * They're both standard types that can be serialized. + * Using a dict instead of a tuple allows easy expansion of attributes. + * Using a list for the outer structure maintains order and is easy to work + with (e.g. list comprehension [d['name'] for d in cols]). + +2. Records that contain a name, such as the column name in a column record + use the key 'name'. So for most return values, each record will have a + 'name' attribute.. +""" + +import sqlalchemy +from sqlalchemy import exc, sql +from sqlalchemy import util +from sqlalchemy.types import TypeEngine +from sqlalchemy import schema as sa_schema + + +@util.decorator +def cache(fn, self, con, *args, **kw): + info_cache = kw.get('info_cache', None) + if info_cache is None: + return fn(self, con, *args, **kw) + key = ( + fn.__name__, + tuple(a for a in args if isinstance(a, basestring)), + tuple((k, v) for k, v in kw.iteritems() if isinstance(v, (basestring, int, float))) + ) + ret = info_cache.get(key) + if ret is None: + ret = fn(self, con, *args, **kw) + info_cache[key] = ret + return ret + + +class Inspector(object): + """Performs database schema inspection. + + The Inspector acts as a proxy to the dialects' reflection methods and + provides higher level functions for accessing database schema information. + """ + + def __init__(self, conn): + """Initialize the instance. + + :param conn: a :class:`~sqlalchemy.engine.base.Connectable` + """ + + self.conn = conn + # set the engine + if hasattr(conn, 'engine'): + self.engine = conn.engine + else: + self.engine = conn + self.dialect = self.engine.dialect + self.info_cache = {} + + @classmethod + def from_engine(cls, engine): + if hasattr(engine.dialect, 'inspector'): + return engine.dialect.inspector(engine) + return Inspector(engine) + + @property + def default_schema_name(self): + return self.dialect.default_schema_name + + def get_schema_names(self): + """Return all schema names. + """ + + if hasattr(self.dialect, 'get_schema_names'): + return self.dialect.get_schema_names(self.conn, + info_cache=self.info_cache) + return [] + + def get_table_names(self, schema=None, order_by=None): + """Return all table names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + :param order_by: Optional, may be the string "foreign_key" to sort + the result on foreign key dependencies. + + This should probably not return view names or maybe it should return + them with an indicator t or v. + """ + + if hasattr(self.dialect, 'get_table_names'): + tnames = self.dialect.get_table_names(self.conn, + schema, + info_cache=self.info_cache) + else: + tnames = self.engine.table_names(schema) + if order_by == 'foreign_key': + ordered_tnames = tnames[:] + # Order based on foreign key dependencies. + for tname in tnames: + table_pos = tnames.index(tname) + fkeys = self.get_foreign_keys(tname, schema) + for fkey in fkeys: + rtable = fkey['referred_table'] + if rtable in ordered_tnames: + ref_pos = ordered_tnames.index(rtable) + # Make sure it's lower in the list than anything it + # references. + if table_pos > ref_pos: + ordered_tnames.pop(table_pos) # rtable moves up 1 + # insert just below rtable + ordered_tnames.index(ref_pos, tname) + tnames = ordered_tnames + return tnames + + def get_table_options(self, table_name, schema=None, **kw): + if hasattr(self.dialect, 'get_table_options'): + return self.dialect.get_table_options(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + return {} + + def get_view_names(self, schema=None): + """Return all view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + """ + + return self.dialect.get_view_names(self.conn, schema, + info_cache=self.info_cache) + + def get_view_definition(self, view_name, schema=None): + """Return definition for `view_name`. + + :param schema: Optional, retrieve names from a non-default schema. + """ + + return self.dialect.get_view_definition( + self.conn, view_name, schema, info_cache=self.info_cache) + + def get_columns(self, table_name, schema=None, **kw): + """Return information about columns in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + column information as a list of dicts with these keys: + + name + the column's name + + type + :class:`~sqlalchemy.types.TypeEngine` + + nullable + boolean + + default + the column's default value + + attrs + dict containing optional column attributes + """ + + col_defs = self.dialect.get_columns(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + for col_def in col_defs: + # make this easy and only return instances for coltype + coltype = col_def['type'] + if not isinstance(coltype, TypeEngine): + col_def['type'] = coltype() + return col_defs + + def get_primary_keys(self, table_name, schema=None, **kw): + """Return information about primary keys in `table_name`. + + Given a string `table_name`, and an optional string `schema`, return + primary key information as a list of column names. + """ + + pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + + return pkeys + + def get_foreign_keys(self, table_name, schema=None, **kw): + """Return information about foreign_keys in `table_name`. + + Given a string `table_name`, and an optional string `schema`, return + foreign key information as a list of dicts with these keys: + + constrained_columns + a list of column names that make up the foreign key + + referred_schema + the name of the referred schema + + referred_table + the name of the referred table + + referred_columns + a list of column names in the referred table that correspond to + constrained_columns + + \**kw + other options passed to the dialect's get_foreign_keys() method. + + """ + + fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + return fk_defs + + def get_indexes(self, table_name, schema=None, **kw): + """Return information about indexes in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + index information as a list of dicts with these keys: + + name + the index's name + + column_names + list of column names in order + + unique + boolean + + \**kw + other options passed to the dialect's get_indexes() method. + """ + + indexes = self.dialect.get_indexes(self.conn, table_name, + schema, + info_cache=self.info_cache, **kw) + return indexes + + def reflecttable(self, table, include_columns): + + dialect = self.conn.dialect + + # MySQL dialect does this. Applicable with other dialects? + if hasattr(dialect, '_connection_charset') \ + and hasattr(dialect, '_adjust_casing'): + charset = dialect._connection_charset + dialect._adjust_casing(table) + + # table attributes we might need. + reflection_options = dict( + (k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs) + + schema = table.schema + table_name = table.name + + # apply table options + tbl_opts = self.get_table_options(table_name, schema, **table.kwargs) + if tbl_opts: + table.kwargs.update(tbl_opts) + + # table.kwargs will need to be passed to each reflection method. Make + # sure keywords are strings. + tblkw = table.kwargs.copy() + for (k, v) in tblkw.items(): + del tblkw[k] + tblkw[str(k)] = v + + # Py2K + if isinstance(schema, str): + schema = schema.decode(dialect.encoding) + if isinstance(table_name, str): + table_name = table_name.decode(dialect.encoding) + # end Py2K + + # columns + found_table = False + for col_d in self.get_columns(table_name, schema, **tblkw): + found_table = True + name = col_d['name'] + if include_columns and name not in include_columns: + continue + + coltype = col_d['type'] + col_kw = { + 'nullable':col_d['nullable'], + } + if 'autoincrement' in col_d: + col_kw['autoincrement'] = col_d['autoincrement'] + if 'quote' in col_d: + col_kw['quote'] = col_d['quote'] + + colargs = [] + if col_d.get('default') is not None: + # the "default" value is assumed to be a literal SQL expression, + # so is wrapped in text() so that no quoting occurs on re-issuance. + colargs.append(sa_schema.DefaultClause(sql.text(col_d['default']))) + + if 'sequence' in col_d: + # TODO: mssql, maxdb and sybase are using this. + seq = col_d['sequence'] + sequence = sa_schema.Sequence(seq['name'], 1, 1) + if 'start' in seq: + sequence.start = seq['start'] + if 'increment' in seq: + sequence.increment = seq['increment'] + colargs.append(sequence) + + col = sa_schema.Column(name, coltype, *colargs, **col_kw) + table.append_column(col) + + if not found_table: + raise exc.NoSuchTableError(table.name) + + # Primary keys + primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[ + table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw) + if pk in table.c + ]) + + table.append_constraint(primary_key_constraint) + + # Foreign keys + fkeys = self.get_foreign_keys(table_name, schema, **tblkw) + for fkey_d in fkeys: + conname = fkey_d['name'] + constrained_columns = fkey_d['constrained_columns'] + referred_schema = fkey_d['referred_schema'] + referred_table = fkey_d['referred_table'] + referred_columns = fkey_d['referred_columns'] + refspec = [] + if referred_schema is not None: + sa_schema.Table(referred_table, table.metadata, + autoload=True, schema=referred_schema, + autoload_with=self.conn, + **reflection_options + ) + for column in referred_columns: + refspec.append(".".join( + [referred_schema, referred_table, column])) + else: + sa_schema.Table(referred_table, table.metadata, autoload=True, + autoload_with=self.conn, + **reflection_options + ) + for column in referred_columns: + refspec.append(".".join([referred_table, column])) + table.append_constraint( + sa_schema.ForeignKeyConstraint(constrained_columns, refspec, + conname, link_to_name=True)) + # Indexes + indexes = self.get_indexes(table_name, schema) + for index_d in indexes: + name = index_d['name'] + columns = index_d['column_names'] + unique = index_d['unique'] + flavor = index_d.get('type', 'unknown type') + if include_columns and \ + not set(columns).issubset(include_columns): + util.warn( + "Omitting %s KEY for (%s), key covers omitted columns." % + (flavor, ', '.join(columns))) + continue + sa_schema.Index(name, *[table.columns[c] for c in columns], + **dict(unique=unique)) diff --git a/sqlalchemy/engine/strategies.py b/sqlalchemy/engine/strategies.py new file mode 100644 index 0000000..7fc39b9 --- /dev/null +++ b/sqlalchemy/engine/strategies.py @@ -0,0 +1,227 @@ +"""Strategies for creating new instances of Engine types. + +These are semi-private implementation classes which provide the +underlying behavior for the "strategy" keyword argument available on +:func:`~sqlalchemy.engine.create_engine`. Current available options are +``plain``, ``threadlocal``, and ``mock``. + +New strategies can be added via new ``EngineStrategy`` classes. +""" + +from operator import attrgetter + +from sqlalchemy.engine import base, threadlocal, url +from sqlalchemy import util, exc +from sqlalchemy import pool as poollib + +strategies = {} + + +class EngineStrategy(object): + """An adaptor that processes input arguements and produces an Engine. + + Provides a ``create`` method that receives input arguments and + produces an instance of base.Engine or a subclass. + + """ + + def __init__(self): + strategies[self.name] = self + + def create(self, *args, **kwargs): + """Given arguments, returns a new Engine instance.""" + + raise NotImplementedError() + + +class DefaultEngineStrategy(EngineStrategy): + """Base class for built-in stratgies.""" + + pool_threadlocal = False + + def create(self, name_or_url, **kwargs): + # create url.URL object + u = url.make_url(name_or_url) + + dialect_cls = u.get_dialect() + + dialect_args = {} + # consume dialect arguments from kwargs + for k in util.get_cls_kwargs(dialect_cls): + if k in kwargs: + dialect_args[k] = kwargs.pop(k) + + dbapi = kwargs.pop('module', None) + if dbapi is None: + dbapi_args = {} + for k in util.get_func_kwargs(dialect_cls.dbapi): + if k in kwargs: + dbapi_args[k] = kwargs.pop(k) + dbapi = dialect_cls.dbapi(**dbapi_args) + + dialect_args['dbapi'] = dbapi + + # create dialect + dialect = dialect_cls(**dialect_args) + + # assemble connection arguments + (cargs, cparams) = dialect.create_connect_args(u) + cparams.update(kwargs.pop('connect_args', {})) + + # look for existing pool or create + pool = kwargs.pop('pool', None) + if pool is None: + def connect(): + try: + return dialect.connect(*cargs, **cparams) + except Exception, e: + # Py3K + #raise exc.DBAPIError.instance(None, None, e) from e + # Py2K + import sys + raise exc.DBAPIError.instance(None, None, e), None, sys.exc_info()[2] + # end Py2K + + creator = kwargs.pop('creator', connect) + + poolclass = (kwargs.pop('poolclass', None) or + getattr(dialect_cls, 'poolclass', poollib.QueuePool)) + pool_args = {} + + # consume pool arguments from kwargs, translating a few of + # the arguments + translate = {'logging_name': 'pool_logging_name', + 'echo': 'echo_pool', + 'timeout': 'pool_timeout', + 'recycle': 'pool_recycle', + 'use_threadlocal':'pool_threadlocal'} + for k in util.get_cls_kwargs(poolclass): + tk = translate.get(k, k) + if tk in kwargs: + pool_args[k] = kwargs.pop(tk) + pool_args.setdefault('use_threadlocal', self.pool_threadlocal) + pool = poolclass(creator, **pool_args) + else: + if isinstance(pool, poollib._DBProxy): + pool = pool.get_pool(*cargs, **cparams) + else: + pool = pool + + # create engine. + engineclass = self.engine_cls + engine_args = {} + for k in util.get_cls_kwargs(engineclass): + if k in kwargs: + engine_args[k] = kwargs.pop(k) + + _initialize = kwargs.pop('_initialize', True) + + # all kwargs should be consumed + if kwargs: + raise TypeError( + "Invalid argument(s) %s sent to create_engine(), " + "using configuration %s/%s/%s. Please check that the " + "keyword arguments are appropriate for this combination " + "of components." % (','.join("'%s'" % k for k in kwargs), + dialect.__class__.__name__, + pool.__class__.__name__, + engineclass.__name__)) + + engine = engineclass(pool, dialect, u, **engine_args) + + if _initialize: + do_on_connect = dialect.on_connect() + if do_on_connect: + def on_connect(conn, rec): + conn = getattr(conn, '_sqla_unwrap', conn) + if conn is None: + return + do_on_connect(conn) + + pool.add_listener({'first_connect': on_connect, 'connect':on_connect}) + + def first_connect(conn, rec): + c = base.Connection(engine, connection=conn) + dialect.initialize(c) + pool.add_listener({'first_connect':first_connect}) + + return engine + + +class PlainEngineStrategy(DefaultEngineStrategy): + """Strategy for configuring a regular Engine.""" + + name = 'plain' + engine_cls = base.Engine + +PlainEngineStrategy() + + +class ThreadLocalEngineStrategy(DefaultEngineStrategy): + """Strategy for configuring an Engine with thredlocal behavior.""" + + name = 'threadlocal' + pool_threadlocal = True + engine_cls = threadlocal.TLEngine + +ThreadLocalEngineStrategy() + + +class MockEngineStrategy(EngineStrategy): + """Strategy for configuring an Engine-like object with mocked execution. + + Produces a single mock Connectable object which dispatches + statement execution to a passed-in function. + + """ + + name = 'mock' + + def create(self, name_or_url, executor, **kwargs): + # create url.URL object + u = url.make_url(name_or_url) + + dialect_cls = u.get_dialect() + + dialect_args = {} + # consume dialect arguments from kwargs + for k in util.get_cls_kwargs(dialect_cls): + if k in kwargs: + dialect_args[k] = kwargs.pop(k) + + # create dialect + dialect = dialect_cls(**dialect_args) + + return MockEngineStrategy.MockConnection(dialect, executor) + + class MockConnection(base.Connectable): + def __init__(self, dialect, execute): + self._dialect = dialect + self.execute = execute + + engine = property(lambda s: s) + dialect = property(attrgetter('_dialect')) + name = property(lambda s: s._dialect.name) + + def contextual_connect(self, **kwargs): + return self + + def compiler(self, statement, parameters, **kwargs): + return self._dialect.compiler( + statement, parameters, engine=self, **kwargs) + + def create(self, entity, **kwargs): + kwargs['checkfirst'] = False + from sqlalchemy.engine import ddl + + ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity) + + def drop(self, entity, **kwargs): + kwargs['checkfirst'] = False + from sqlalchemy.engine import ddl + ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity) + + def execute(self, object, *multiparams, **params): + raise NotImplementedError() + +MockEngineStrategy() diff --git a/sqlalchemy/engine/threadlocal.py b/sqlalchemy/engine/threadlocal.py new file mode 100644 index 0000000..001caee --- /dev/null +++ b/sqlalchemy/engine/threadlocal.py @@ -0,0 +1,103 @@ +"""Provides a thread-local transactional wrapper around the root Engine class. + +The ``threadlocal`` module is invoked when using the ``strategy="threadlocal"`` flag +with :func:`~sqlalchemy.engine.create_engine`. This module is semi-private and is +invoked automatically when the threadlocal engine strategy is used. +""" + +from sqlalchemy import util +from sqlalchemy.engine import base +import weakref + +class TLConnection(base.Connection): + def __init__(self, *arg, **kw): + super(TLConnection, self).__init__(*arg, **kw) + self.__opencount = 0 + + def _increment_connect(self): + self.__opencount += 1 + return self + + def close(self): + if self.__opencount == 1: + base.Connection.close(self) + self.__opencount -= 1 + + def _force_close(self): + self.__opencount = 0 + base.Connection.close(self) + + +class TLEngine(base.Engine): + """An Engine that includes support for thread-local managed transactions.""" + + + def __init__(self, *args, **kwargs): + super(TLEngine, self).__init__(*args, **kwargs) + self._connections = util.threading.local() + proxy = kwargs.get('proxy') + if proxy: + self.TLConnection = base._proxy_connection_cls(TLConnection, proxy) + else: + self.TLConnection = TLConnection + + def contextual_connect(self, **kw): + if not hasattr(self._connections, 'conn'): + connection = None + else: + connection = self._connections.conn() + + if connection is None or connection.closed: + # guards against pool-level reapers, if desired. + # or not connection.connection.is_valid: + connection = self.TLConnection(self, self.pool.connect(), **kw) + self._connections.conn = conn = weakref.ref(connection) + + return connection._increment_connect() + + def begin_twophase(self, xid=None): + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid)) + + def begin_nested(self): + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append(self.contextual_connect().begin_nested()) + + def begin(self): + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append(self.contextual_connect().begin()) + + def prepare(self): + self._connections.trans[-1].prepare() + + def commit(self): + trans = self._connections.trans.pop(-1) + trans.commit() + + def rollback(self): + trans = self._connections.trans.pop(-1) + trans.rollback() + + def dispose(self): + self._connections = util.threading.local() + super(TLEngine, self).dispose() + + @property + def closed(self): + return not hasattr(self._connections, 'conn') or \ + self._connections.conn() is None or \ + self._connections.conn().closed + + def close(self): + if not self.closed: + self.contextual_connect().close() + connection = self._connections.conn() + connection._force_close() + del self._connections.conn + self._connections.trans = [] + + def __repr__(self): + return 'TLEngine(%s)' % str(self.url) diff --git a/sqlalchemy/engine/url.py b/sqlalchemy/engine/url.py new file mode 100644 index 0000000..5d658d7 --- /dev/null +++ b/sqlalchemy/engine/url.py @@ -0,0 +1,214 @@ +"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates +information about a database connection specification. + +The URL object is created automatically when :func:`~sqlalchemy.engine.create_engine` is called +with a string argument; alternatively, the URL is a public-facing construct which can +be used directly and is also accepted directly by ``create_engine()``. +""" + +import re, cgi, sys, urllib +from sqlalchemy import exc + + +class URL(object): + """ + Represent the components of a URL used to connect to a database. + + This object is suitable to be passed directly to a + ``create_engine()`` call. The fields of the URL are parsed from a + string by the ``module-level make_url()`` function. the string + format of the URL is an RFC-1738-style string. + + All initialization parameters are available as public attributes. + + :param drivername: the name of the database backend. + This name will correspond to a module in sqlalchemy/databases + or a third party plug-in. + + :param username: The user name. + + :param password: database password. + + :param host: The name of the host. + + :param port: The port number. + + :param database: The database name. + + :param query: A dictionary of options to be passed to the + dialect and/or the DBAPI upon connect. + + """ + + def __init__(self, drivername, username=None, password=None, + host=None, port=None, database=None, query=None): + self.drivername = drivername + self.username = username + self.password = password + self.host = host + if port is not None: + self.port = int(port) + else: + self.port = None + self.database = database + self.query = query or {} + + def __str__(self): + s = self.drivername + "://" + if self.username is not None: + s += self.username + if self.password is not None: + s += ':' + urllib.quote_plus(self.password) + s += "@" + if self.host is not None: + s += self.host + if self.port is not None: + s += ':' + str(self.port) + if self.database is not None: + s += '/' + self.database + if self.query: + keys = self.query.keys() + keys.sort() + s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys) + return s + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return \ + isinstance(other, URL) and \ + self.drivername == other.drivername and \ + self.username == other.username and \ + self.password == other.password and \ + self.host == other.host and \ + self.database == other.database and \ + self.query == other.query + + def get_dialect(self): + """Return the SQLAlchemy database dialect class corresponding + to this URL's driver name. + """ + + try: + if '+' in self.drivername: + dialect, driver = self.drivername.split('+') + else: + dialect, driver = self.drivername, 'base' + + module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects + module = getattr(module, dialect) + module = getattr(module, driver) + + return module.dialect + except ImportError: + module = self._load_entry_point() + if module is not None: + return module + else: + raise + + def _load_entry_point(self): + """attempt to load this url's dialect from entry points, or return None + if pkg_resources is not installed or there is no matching entry point. + + Raise ImportError if the actual load fails. + + """ + try: + import pkg_resources + except ImportError: + return None + + for res in pkg_resources.iter_entry_points('sqlalchemy.dialects'): + if res.name == self.drivername: + return res.load() + else: + return None + + def translate_connect_args(self, names=[], **kw): + """Translate url attributes into a dictionary of connection arguments. + + Returns attributes of this url (`host`, `database`, `username`, + `password`, `port`) as a plain dictionary. The attribute names are + used as the keys by default. Unset or false attributes are omitted + from the final dictionary. + + :param \**kw: Optional, alternate key names for url attributes. + + :param names: Deprecated. Same purpose as the keyword-based alternate names, + but correlates the name to the original positionally. + """ + + translated = {} + attribute_names = ['host', 'database', 'username', 'password', 'port'] + for sname in attribute_names: + if names: + name = names.pop(0) + elif sname in kw: + name = kw[sname] + else: + name = sname + if name is not None and getattr(self, sname, False): + translated[name] = getattr(self, sname) + return translated + +def make_url(name_or_url): + """Given a string or unicode instance, produce a new URL instance. + + The given string is parsed according to the RFC 1738 spec. If an + existing URL object is passed, just returns the object. + """ + + if isinstance(name_or_url, basestring): + return _parse_rfc1738_args(name_or_url) + else: + return name_or_url + +def _parse_rfc1738_args(name): + pattern = re.compile(r''' + (?P[\w\+]+):// + (?: + (?P[^:/]*) + (?::(?P[^/]*))? + @)? + (?: + (?P[^/:]*) + (?::(?P[^/]*))? + )? + (?:/(?P.*))? + ''' + , re.X) + + m = pattern.match(name) + if m is not None: + components = m.groupdict() + if components['database'] is not None: + tokens = components['database'].split('?', 2) + components['database'] = tokens[0] + query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None + # Py2K + if query is not None: + query = dict((k.encode('ascii'), query[k]) for k in query) + # end Py2K + else: + query = None + components['query'] = query + + if components['password'] is not None: + components['password'] = urllib.unquote_plus(components['password']) + + name = components.pop('name') + return URL(name, **components) + else: + raise exc.ArgumentError( + "Could not parse rfc1738 URL from string '%s'" % name) + +def _parse_keyvalue_args(name): + m = re.match( r'(\w+)://(.*)', name) + if m is not None: + (name, args) = m.group(1, 2) + opts = dict( cgi.parse_qsl( args ) ) + return URL(name, *opts) + else: + return None diff --git a/sqlalchemy/exc.py b/sqlalchemy/exc.py new file mode 100644 index 0000000..31826f4 --- /dev/null +++ b/sqlalchemy/exc.py @@ -0,0 +1,191 @@ +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Exceptions used with SQLAlchemy. + +The base exception class is SQLAlchemyError. Exceptions which are raised as a +result of DBAPI exceptions are all subclasses of +:class:`~sqlalchemy.exc.DBAPIError`. + +""" + + +class SQLAlchemyError(Exception): + """Generic error class.""" + + +class ArgumentError(SQLAlchemyError): + """Raised when an invalid or conflicting function argument is supplied. + + This error generally corresponds to construction time state errors. + + """ + + +class CircularDependencyError(SQLAlchemyError): + """Raised by topological sorts when a circular dependency is detected""" + + +class CompileError(SQLAlchemyError): + """Raised when an error occurs during SQL compilation""" + +class IdentifierError(SQLAlchemyError): + """Raised when a schema name is beyond the max character limit""" + +# Moved to orm.exc; compatability definition installed by orm import until 0.6 +ConcurrentModificationError = None + +class DisconnectionError(SQLAlchemyError): + """A disconnect is detected on a raw DB-API connection. + + This error is raised and consumed internally by a connection pool. It can + be raised by a ``PoolListener`` so that the host pool forces a disconnect. + + """ + + +# Moved to orm.exc; compatability definition installed by orm import until 0.6 +FlushError = None + +class TimeoutError(SQLAlchemyError): + """Raised when a connection pool times out on getting a connection.""" + + +class InvalidRequestError(SQLAlchemyError): + """SQLAlchemy was asked to do something it can't do. + + This error generally corresponds to runtime state errors. + + """ + +class NoSuchColumnError(KeyError, InvalidRequestError): + """A nonexistent column is requested from a ``RowProxy``.""" + +class NoReferenceError(InvalidRequestError): + """Raised by ``ForeignKey`` to indicate a reference cannot be resolved.""" + +class NoReferencedTableError(NoReferenceError): + """Raised by ``ForeignKey`` when the referred ``Table`` cannot be located.""" + +class NoReferencedColumnError(NoReferenceError): + """Raised by ``ForeignKey`` when the referred ``Column`` cannot be located.""" + +class NoSuchTableError(InvalidRequestError): + """Table does not exist or is not visible to a connection.""" + + +class UnboundExecutionError(InvalidRequestError): + """SQL was attempted without a database connection to execute it on.""" + + +# Moved to orm.exc; compatability definition installed by orm import until 0.6 +UnmappedColumnError = None + +class DBAPIError(SQLAlchemyError): + """Raised when the execution of a database operation fails. + + ``DBAPIError`` wraps exceptions raised by the DB-API underlying the + database operation. Driver-specific implementations of the standard + DB-API exception types are wrapped by matching sub-types of SQLAlchemy's + ``DBAPIError`` when possible. DB-API's ``Error`` type maps to + ``DBAPIError`` in SQLAlchemy, otherwise the names are identical. Note + that there is no guarantee that different DB-API implementations will + raise the same exception type for any given error condition. + + If the error-raising operation occured in the execution of a SQL + statement, that statement and its parameters will be available on + the exception object in the ``statement`` and ``params`` attributes. + + The wrapped exception object is available in the ``orig`` attribute. + Its type and properties are DB-API implementation specific. + + """ + + @classmethod + def instance(cls, statement, params, orig, connection_invalidated=False): + # Don't ever wrap these, just return them directly as if + # DBAPIError didn't exist. + if isinstance(orig, (KeyboardInterrupt, SystemExit)): + return orig + + if orig is not None: + name, glob = orig.__class__.__name__, globals() + if name in glob and issubclass(glob[name], DBAPIError): + cls = glob[name] + + return cls(statement, params, orig, connection_invalidated) + + def __init__(self, statement, params, orig, connection_invalidated=False): + try: + text = str(orig) + except (KeyboardInterrupt, SystemExit): + raise + except Exception, e: + text = 'Error in str() of DB-API-generated exception: ' + str(e) + SQLAlchemyError.__init__( + self, '(%s) %s' % (orig.__class__.__name__, text)) + self.statement = statement + self.params = params + self.orig = orig + self.connection_invalidated = connection_invalidated + + def __str__(self): + if isinstance(self.params, (list, tuple)) and len(self.params) > 10 and isinstance(self.params[0], (list, dict, tuple)): + return ' '.join((SQLAlchemyError.__str__(self), + repr(self.statement), + repr(self.params[:2]), + '... and a total of %i bound parameter sets' % len(self.params))) + return ' '.join((SQLAlchemyError.__str__(self), + repr(self.statement), repr(self.params))) + + +# As of 0.4, SQLError is now DBAPIError. +# SQLError alias will be removed in 0.6. +SQLError = DBAPIError + +class InterfaceError(DBAPIError): + """Wraps a DB-API InterfaceError.""" + + +class DatabaseError(DBAPIError): + """Wraps a DB-API DatabaseError.""" + + +class DataError(DatabaseError): + """Wraps a DB-API DataError.""" + + +class OperationalError(DatabaseError): + """Wraps a DB-API OperationalError.""" + + +class IntegrityError(DatabaseError): + """Wraps a DB-API IntegrityError.""" + + +class InternalError(DatabaseError): + """Wraps a DB-API InternalError.""" + + +class ProgrammingError(DatabaseError): + """Wraps a DB-API ProgrammingError.""" + + +class NotSupportedError(DatabaseError): + """Wraps a DB-API NotSupportedError.""" + + +# Warnings + +class SADeprecationWarning(DeprecationWarning): + """Issued once per usage of a deprecated API.""" + + +class SAPendingDeprecationWarning(PendingDeprecationWarning): + """Issued once per usage of a deprecated API.""" + + +class SAWarning(RuntimeWarning): + """Issued at runtime.""" diff --git a/sqlalchemy/ext/__init__.py b/sqlalchemy/ext/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/sqlalchemy/ext/__init__.py @@ -0,0 +1 @@ + diff --git a/sqlalchemy/ext/associationproxy.py b/sqlalchemy/ext/associationproxy.py new file mode 100644 index 0000000..c7437d7 --- /dev/null +++ b/sqlalchemy/ext/associationproxy.py @@ -0,0 +1,878 @@ +"""Contain the ``AssociationProxy`` class. + +The ``AssociationProxy`` is a Python property object which provides +transparent proxied access to the endpoint of an association object. + +See the example ``examples/association/proxied_association.py``. + +""" +import itertools +import operator +import weakref +from sqlalchemy import exceptions +from sqlalchemy import orm +from sqlalchemy import util +from sqlalchemy.orm import collections +from sqlalchemy.sql import not_ + + +def association_proxy(target_collection, attr, **kw): + """Return a Python property implementing a view of *attr* over a collection. + + Implements a read/write view over an instance's *target_collection*, + extracting *attr* from each member of the collection. The property acts + somewhat like this list comprehension:: + + [getattr(member, *attr*) + for member in getattr(instance, *target_collection*)] + + Unlike the list comprehension, the collection returned by the property is + always in sync with *target_collection*, and mutations made to either + collection will be reflected in both. + + Implements a Python property representing a relationship as a collection of + simpler values. The proxied property will mimic the collection type of + the target (list, dict or set), or, in the case of a one to one relationship, + a simple scalar value. + + :param target_collection: Name of the relationship attribute we'll proxy to, + usually created with :func:`~sqlalchemy.orm.relationship`. + + :param attr: Attribute on the associated instances we'll proxy for. + + For example, given a target collection of [obj1, obj2], a list created + by this proxy property would look like [getattr(obj1, *attr*), + getattr(obj2, *attr*)] + + If the relationship is one-to-one or otherwise uselist=False, then simply: + getattr(obj, *attr*) + + :param creator: optional. + + When new items are added to this proxied collection, new instances of + the class collected by the target collection will be created. For list + and set collections, the target class constructor will be called with + the 'value' for the new instance. For dict types, two arguments are + passed: key and value. + + If you want to construct instances differently, supply a *creator* + function that takes arguments as above and returns instances. + + For scalar relationships, creator() will be called if the target is None. + If the target is present, set operations are proxied to setattr() on the + associated object. + + If you have an associated object with multiple attributes, you may set + up multiple association proxies mapping to different attributes. See + the unit tests for examples, and for examples of how creator() functions + can be used to construct the scalar relationship on-demand in this + situation. + + :param \*\*kw: Passes along any other keyword arguments to + :class:`AssociationProxy`. + + """ + return AssociationProxy(target_collection, attr, **kw) + + +class AssociationProxy(object): + """A descriptor that presents a read/write view of an object attribute.""" + + def __init__(self, target_collection, attr, creator=None, + getset_factory=None, proxy_factory=None, proxy_bulk_set=None): + """Arguments are: + + target_collection + Name of the collection we'll proxy to, usually created with + 'relationship()' in a mapper setup. + + attr + Attribute on the collected instances we'll proxy for. For example, + given a target collection of [obj1, obj2], a list created by this + proxy property would look like [getattr(obj1, attr), getattr(obj2, + attr)] + + creator + Optional. When new items are added to this proxied collection, new + instances of the class collected by the target collection will be + created. For list and set collections, the target class constructor + will be called with the 'value' for the new instance. For dict + types, two arguments are passed: key and value. + + If you want to construct instances differently, supply a 'creator' + function that takes arguments as above and returns instances. + + getset_factory + Optional. Proxied attribute access is automatically handled by + routines that get and set values based on the `attr` argument for + this proxy. + + If you would like to customize this behavior, you may supply a + `getset_factory` callable that produces a tuple of `getter` and + `setter` functions. The factory is called with two arguments, the + abstract type of the underlying collection and this proxy instance. + + proxy_factory + Optional. The type of collection to emulate is determined by + sniffing the target collection. If your collection type can't be + determined by duck typing or you'd like to use a different + collection implementation, you may supply a factory function to + produce those collections. Only applicable to non-scalar relationships. + + proxy_bulk_set + Optional, use with proxy_factory. See the _set() method for + details. + + """ + self.target_collection = target_collection + self.value_attr = attr + self.creator = creator + self.getset_factory = getset_factory + self.proxy_factory = proxy_factory + self.proxy_bulk_set = proxy_bulk_set + + self.scalar = None + self.owning_class = None + self.key = '_%s_%s_%s' % ( + type(self).__name__, target_collection, id(self)) + self.collection_class = None + + def _get_property(self): + return (orm.class_mapper(self.owning_class). + get_property(self.target_collection)) + + @property + def target_class(self): + """The class the proxy is attached to.""" + return self._get_property().mapper.class_ + + def _target_is_scalar(self): + return not self._get_property().uselist + + def __get__(self, obj, class_): + if self.owning_class is None: + self.owning_class = class_ and class_ or type(obj) + if obj is None: + return self + elif self.scalar is None: + self.scalar = self._target_is_scalar() + if self.scalar: + self._initialize_scalar_accessors() + + if self.scalar: + return self._scalar_get(getattr(obj, self.target_collection)) + else: + try: + # If the owning instance is reborn (orm session resurrect, + # etc.), refresh the proxy cache. + creator_id, proxy = getattr(obj, self.key) + if id(obj) == creator_id: + return proxy + except AttributeError: + pass + proxy = self._new(_lazy_collection(obj, self.target_collection)) + setattr(obj, self.key, (id(obj), proxy)) + return proxy + + def __set__(self, obj, values): + if self.owning_class is None: + self.owning_class = type(obj) + if self.scalar is None: + self.scalar = self._target_is_scalar() + if self.scalar: + self._initialize_scalar_accessors() + + if self.scalar: + creator = self.creator and self.creator or self.target_class + target = getattr(obj, self.target_collection) + if target is None: + setattr(obj, self.target_collection, creator(values)) + else: + self._scalar_set(target, values) + else: + proxy = self.__get__(obj, None) + if proxy is not values: + proxy.clear() + self._set(proxy, values) + + def __delete__(self, obj): + if self.owning_class is None: + self.owning_class = type(obj) + delattr(obj, self.key) + + def _initialize_scalar_accessors(self): + if self.getset_factory: + get, set = self.getset_factory(None, self) + else: + get, set = self._default_getset(None) + self._scalar_get, self._scalar_set = get, set + + def _default_getset(self, collection_class): + attr = self.value_attr + getter = operator.attrgetter(attr) + if collection_class is dict: + setter = lambda o, k, v: setattr(o, attr, v) + else: + setter = lambda o, v: setattr(o, attr, v) + return getter, setter + + def _new(self, lazy_collection): + creator = self.creator and self.creator or self.target_class + self.collection_class = util.duck_type_collection(lazy_collection()) + + if self.proxy_factory: + return self.proxy_factory(lazy_collection, creator, self.value_attr, self) + + if self.getset_factory: + getter, setter = self.getset_factory(self.collection_class, self) + else: + getter, setter = self._default_getset(self.collection_class) + + if self.collection_class is list: + return _AssociationList(lazy_collection, creator, getter, setter, self) + elif self.collection_class is dict: + return _AssociationDict(lazy_collection, creator, getter, setter, self) + elif self.collection_class is set: + return _AssociationSet(lazy_collection, creator, getter, setter, self) + else: + raise exceptions.ArgumentError( + 'could not guess which interface to use for ' + 'collection_class "%s" backing "%s"; specify a ' + 'proxy_factory and proxy_bulk_set manually' % + (self.collection_class.__name__, self.target_collection)) + + def _inflate(self, proxy): + creator = self.creator and self.creator or self.target_class + + if self.getset_factory: + getter, setter = self.getset_factory(self.collection_class, self) + else: + getter, setter = self._default_getset(self.collection_class) + + proxy.creator = creator + proxy.getter = getter + proxy.setter = setter + + def _set(self, proxy, values): + if self.proxy_bulk_set: + self.proxy_bulk_set(proxy, values) + elif self.collection_class is list: + proxy.extend(values) + elif self.collection_class is dict: + proxy.update(values) + elif self.collection_class is set: + proxy.update(values) + else: + raise exceptions.ArgumentError( + 'no proxy_bulk_set supplied for custom ' + 'collection_class implementation') + + @property + def _comparator(self): + return self._get_property().comparator + + def any(self, criterion=None, **kwargs): + return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) + + def has(self, criterion=None, **kwargs): + return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) + + def contains(self, obj): + return self._comparator.any(**{self.value_attr: obj}) + + def __eq__(self, obj): + return self._comparator.has(**{self.value_attr: obj}) + + def __ne__(self, obj): + return not_(self.__eq__(obj)) + + +class _lazy_collection(object): + def __init__(self, obj, target): + self.ref = weakref.ref(obj) + self.target = target + + def __call__(self): + obj = self.ref() + if obj is None: + raise exceptions.InvalidRequestError( + "stale association proxy, parent object has gone out of " + "scope") + return getattr(obj, self.target) + + def __getstate__(self): + return {'obj':self.ref(), 'target':self.target} + + def __setstate__(self, state): + self.ref = weakref.ref(state['obj']) + self.target = state['target'] + +class _AssociationCollection(object): + def __init__(self, lazy_collection, creator, getter, setter, parent): + """Constructs an _AssociationCollection. + + This will always be a subclass of either _AssociationList, + _AssociationSet, or _AssociationDict. + + lazy_collection + A callable returning a list-based collection of entities (usually an + object attribute managed by a SQLAlchemy relationship()) + + creator + A function that creates new target entities. Given one parameter: + value. This assertion is assumed:: + + obj = creator(somevalue) + assert getter(obj) == somevalue + + getter + A function. Given an associated object, return the 'value'. + + setter + A function. Given an associated object and a value, store that + value on the object. + + """ + self.lazy_collection = lazy_collection + self.creator = creator + self.getter = getter + self.setter = setter + self.parent = parent + + col = property(lambda self: self.lazy_collection()) + + def __len__(self): + return len(self.col) + + def __nonzero__(self): + return bool(self.col) + + def __getstate__(self): + return {'parent':self.parent, 'lazy_collection':self.lazy_collection} + + def __setstate__(self, state): + self.parent = state['parent'] + self.lazy_collection = state['lazy_collection'] + self.parent._inflate(self) + +class _AssociationList(_AssociationCollection): + """Generic, converting, list-to-list proxy.""" + + def _create(self, value): + return self.creator(value) + + def _get(self, object): + return self.getter(object) + + def _set(self, object, value): + return self.setter(object, value) + + def __getitem__(self, index): + return self._get(self.col[index]) + + def __setitem__(self, index, value): + if not isinstance(index, slice): + self._set(self.col[index], value) + else: + if index.stop is None: + stop = len(self) + elif index.stop < 0: + stop = len(self) + index.stop + else: + stop = index.stop + step = index.step or 1 + + rng = range(index.start or 0, stop, step) + if step == 1: + for i in rng: + del self[index.start] + i = index.start + for item in value: + self.insert(i, item) + i += 1 + else: + if len(value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" % (len(value), + len(rng))) + for i, item in zip(rng, value): + self._set(self.col[i], item) + + def __delitem__(self, index): + del self.col[index] + + def __contains__(self, value): + for member in self.col: + # testlib.pragma exempt:__eq__ + if self._get(member) == value: + return True + return False + + def __getslice__(self, start, end): + return [self._get(member) for member in self.col[start:end]] + + def __setslice__(self, start, end, values): + members = [self._create(v) for v in values] + self.col[start:end] = members + + def __delslice__(self, start, end): + del self.col[start:end] + + def __iter__(self): + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or + just use the underlying collection directly from its property + on the parent. + """ + + for member in self.col: + yield self._get(member) + raise StopIteration + + def append(self, value): + item = self._create(value) + self.col.append(item) + + def count(self, value): + return sum([1 for _ in + itertools.ifilter(lambda v: v == value, iter(self))]) + + def extend(self, values): + for v in values: + self.append(v) + + def insert(self, index, value): + self.col[index:index] = [self._create(value)] + + def pop(self, index=-1): + return self.getter(self.col.pop(index)) + + def remove(self, value): + for i, val in enumerate(self): + if val == value: + del self.col[i] + return + raise ValueError("value not in list") + + def reverse(self): + """Not supported, use reversed(mylist)""" + + raise NotImplementedError + + def sort(self): + """Not supported, use sorted(mylist)""" + + raise NotImplementedError + + def clear(self): + del self.col[0:len(self.col)] + + def __eq__(self, other): + return list(self) == other + + def __ne__(self, other): + return list(self) != other + + def __lt__(self, other): + return list(self) < other + + def __le__(self, other): + return list(self) <= other + + def __gt__(self, other): + return list(self) > other + + def __ge__(self, other): + return list(self) >= other + + def __cmp__(self, other): + return cmp(list(self), other) + + def __add__(self, iterable): + try: + other = list(iterable) + except TypeError: + return NotImplemented + return list(self) + other + + def __radd__(self, iterable): + try: + other = list(iterable) + except TypeError: + return NotImplemented + return other + list(self) + + def __mul__(self, n): + if not isinstance(n, int): + return NotImplemented + return list(self) * n + __rmul__ = __mul__ + + def __iadd__(self, iterable): + self.extend(iterable) + return self + + def __imul__(self, n): + # unlike a regular list *=, proxied __imul__ will generate unique + # backing objects for each copy. *= on proxied lists is a bit of + # a stretch anyhow, and this interpretation of the __imul__ contract + # is more plausibly useful than copying the backing objects. + if not isinstance(n, int): + return NotImplemented + if n == 0: + self.clear() + elif n > 1: + self.extend(list(self) * (n - 1)) + return self + + def copy(self): + return list(self) + + def __repr__(self): + return repr(list(self)) + + def __hash__(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) + + for func_name, func in locals().items(): + if (util.callable(func) and func.func_name == func_name and + not func.__doc__ and hasattr(list, func_name)): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + + +_NotProvided = util.symbol('_NotProvided') +class _AssociationDict(_AssociationCollection): + """Generic, converting, dict-to-dict proxy.""" + + def _create(self, key, value): + return self.creator(key, value) + + def _get(self, object): + return self.getter(object) + + def _set(self, object, key, value): + return self.setter(object, key, value) + + def __getitem__(self, key): + return self._get(self.col[key]) + + def __setitem__(self, key, value): + if key in self.col: + self._set(self.col[key], key, value) + else: + self.col[key] = self._create(key, value) + + def __delitem__(self, key): + del self.col[key] + + def __contains__(self, key): + # testlib.pragma exempt:__hash__ + return key in self.col + + def has_key(self, key): + # testlib.pragma exempt:__hash__ + return key in self.col + + def __iter__(self): + return self.col.iterkeys() + + def clear(self): + self.col.clear() + + def __eq__(self, other): + return dict(self) == other + + def __ne__(self, other): + return dict(self) != other + + def __lt__(self, other): + return dict(self) < other + + def __le__(self, other): + return dict(self) <= other + + def __gt__(self, other): + return dict(self) > other + + def __ge__(self, other): + return dict(self) >= other + + def __cmp__(self, other): + return cmp(dict(self), other) + + def __repr__(self): + return repr(dict(self.items())) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key, default=None): + if key not in self.col: + self.col[key] = self._create(key, default) + return default + else: + return self[key] + + def keys(self): + return self.col.keys() + + def iterkeys(self): + return self.col.iterkeys() + + def values(self): + return [ self._get(member) for member in self.col.values() ] + + def itervalues(self): + for key in self.col: + yield self._get(self.col[key]) + raise StopIteration + + def items(self): + return [(k, self._get(self.col[k])) for k in self] + + def iteritems(self): + for key in self.col: + yield (key, self._get(self.col[key])) + raise StopIteration + + def pop(self, key, default=_NotProvided): + if default is _NotProvided: + member = self.col.pop(key) + else: + member = self.col.pop(key, default) + return self._get(member) + + def popitem(self): + item = self.col.popitem() + return (item[0], self._get(item[1])) + + def update(self, *a, **kw): + if len(a) > 1: + raise TypeError('update expected at most 1 arguments, got %i' % + len(a)) + elif len(a) == 1: + seq_or_map = a[0] + for item in seq_or_map: + if isinstance(item, tuple): + self[item[0]] = item[1] + else: + self[item] = seq_or_map[item] + + for key, value in kw: + self[key] = value + + def copy(self): + return dict(self.items()) + + def __hash__(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) + + for func_name, func in locals().items(): + if (util.callable(func) and func.func_name == func_name and + not func.__doc__ and hasattr(dict, func_name)): + func.__doc__ = getattr(dict, func_name).__doc__ + del func_name, func + + +class _AssociationSet(_AssociationCollection): + """Generic, converting, set-to-set proxy.""" + + def _create(self, value): + return self.creator(value) + + def _get(self, object): + return self.getter(object) + + def _set(self, object, value): + return self.setter(object, value) + + def __len__(self): + return len(self.col) + + def __nonzero__(self): + if self.col: + return True + else: + return False + + def __contains__(self, value): + for member in self.col: + # testlib.pragma exempt:__eq__ + if self._get(member) == value: + return True + return False + + def __iter__(self): + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or just use + the underlying collection directly from its property on the parent. + + """ + for member in self.col: + yield self._get(member) + raise StopIteration + + def add(self, value): + if value not in self: + self.col.add(self._create(value)) + + # for discard and remove, choosing a more expensive check strategy rather + # than call self.creator() + def discard(self, value): + for member in self.col: + if self._get(member) == value: + self.col.discard(member) + break + + def remove(self, value): + for member in self.col: + if self._get(member) == value: + self.col.discard(member) + return + raise KeyError(value) + + def pop(self): + if not self.col: + raise KeyError('pop from an empty set') + member = self.col.pop() + return self._get(member) + + def update(self, other): + for value in other: + self.add(value) + + def __ior__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + for value in other: + self.add(value) + return self + + def _set(self): + return set(iter(self)) + + def union(self, other): + return set(self).union(other) + + __or__ = union + + def difference(self, other): + return set(self).difference(other) + + __sub__ = difference + + def difference_update(self, other): + for value in other: + self.discard(value) + + def __isub__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + for value in other: + self.discard(value) + return self + + def intersection(self, other): + return set(self).intersection(other) + + __and__ = intersection + + def intersection_update(self, other): + want, have = self.intersection(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __iand__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + want, have = self.intersection(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + return self + + def symmetric_difference(self, other): + return set(self).symmetric_difference(other) + + __xor__ = symmetric_difference + + def symmetric_difference_update(self, other): + want, have = self.symmetric_difference(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __ixor__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + want, have = self.symmetric_difference(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + return self + + def issubset(self, other): + return set(self).issubset(other) + + def issuperset(self, other): + return set(self).issuperset(other) + + def clear(self): + self.col.clear() + + def copy(self): + return set(self) + + def __eq__(self, other): + return set(self) == other + + def __ne__(self, other): + return set(self) != other + + def __lt__(self, other): + return set(self) < other + + def __le__(self, other): + return set(self) <= other + + def __gt__(self, other): + return set(self) > other + + def __ge__(self, other): + return set(self) >= other + + def __repr__(self): + return repr(set(self)) + + def __hash__(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) + + for func_name, func in locals().items(): + if (util.callable(func) and func.func_name == func_name and + not func.__doc__ and hasattr(set, func_name)): + func.__doc__ = getattr(set, func_name).__doc__ + del func_name, func diff --git a/sqlalchemy/ext/compiler.py b/sqlalchemy/ext/compiler.py new file mode 100644 index 0000000..3226b0e --- /dev/null +++ b/sqlalchemy/ext/compiler.py @@ -0,0 +1,194 @@ +"""Provides an API for creation of custom ClauseElements and compilers. + +Synopsis +======== + +Usage involves the creation of one or more :class:`~sqlalchemy.sql.expression.ClauseElement` +subclasses and one or more callables defining its compilation:: + + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.sql.expression import ColumnClause + + class MyColumn(ColumnClause): + pass + + @compiles(MyColumn) + def compile_mycolumn(element, compiler, **kw): + return "[%s]" % element.name + +Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`, +the base expression element for named column objects. The ``compiles`` +decorator registers itself with the ``MyColumn`` class so that it is invoked +when the object is compiled to a string:: + + from sqlalchemy import select + + s = select([MyColumn('x'), MyColumn('y')]) + print str(s) + +Produces:: + + SELECT [x], [y] + +Dialect-specific compilation rules +================================== + +Compilers can also be made dialect-specific. The appropriate compiler will be +invoked for the dialect in use:: + + from sqlalchemy.schema import DDLElement + + class AlterColumn(DDLElement): + + def __init__(self, column, cmd): + self.column = column + self.cmd = cmd + + @compiles(AlterColumn) + def visit_alter_column(element, compiler, **kw): + return "ALTER COLUMN %s ..." % element.column.name + + @compiles(AlterColumn, 'postgresql') + def visit_alter_column(element, compiler, **kw): + return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, element.column.name) + +The second ``visit_alter_table`` will be invoked when any ``postgresql`` dialect is used. + +Compiling sub-elements of a custom expression construct +======================================================= + +The ``compiler`` argument is the :class:`~sqlalchemy.engine.base.Compiled` +object in use. This object can be inspected for any information about the +in-progress compilation, including ``compiler.dialect``, +``compiler.statement`` etc. The :class:`~sqlalchemy.sql.compiler.SQLCompiler` +and :class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()`` +method which can be used for compilation of embedded attributes:: + + from sqlalchemy.sql.expression import Executable, ClauseElement + + class InsertFromSelect(Executable, ClauseElement): + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True), + compiler.process(element.select) + ) + + insert = InsertFromSelect(t1, select([t1]).where(t1.c.x>5)) + print insert + +Produces:: + + "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)" + +Cross Compiling between SQL and DDL compilers +--------------------------------------------- + +SQL and DDL constructs are each compiled using different base compilers - ``SQLCompiler`` +and ``DDLCompiler``. A common need is to access the compilation rules of SQL expressions +from within a DDL expression. The ``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as below where we generate a CHECK +constraint that embeds a SQL expression:: + + @compiles(MyConstraint) + def compile_my_constraint(constraint, ddlcompiler, **kw): + return "CONSTRAINT %s CHECK (%s)" % ( + constraint.name, + ddlcompiler.sql_compiler.process(constraint.expression) + ) + +Changing the default compilation of existing constructs +======================================================= + +The compiler extension applies just as well to the existing constructs. When overriding +the compilation of a built in SQL construct, the @compiles decorator is invoked upon +the appropriate class (be sure to use the class, i.e. ``Insert`` or ``Select``, instead of the creation function such as ``insert()`` or ``select()``). + +Within the new compilation function, to get at the "original" compilation routine, +use the appropriate visit_XXX method - this because compiler.process() will call upon the +overriding routine and cause an endless loop. Such as, to add "prefix" to all insert statements:: + + from sqlalchemy.sql.expression import Insert + + @compiles(Insert) + def prefix_inserts(insert, compiler, **kw): + return compiler.visit_insert(insert.prefix_with("some prefix"), **kw) + +The above compiler will prefix all INSERT statements with "some prefix" when compiled. + +Subclassing Guidelines +====================== + +A big part of using the compiler extension is subclassing SQLAlchemy expression constructs. To make this easier, the expression and schema packages feature a set of "bases" intended for common tasks. A synopsis is as follows: + +* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root + expression class. Any SQL expression can be derived from this base, and is + probably the best choice for longer constructs such as specialized INSERT + statements. + +* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all + "column-like" elements. Anything that you'd place in the "columns" clause of + a SELECT statement (as well as order by and group by) can derive from this - + the object will automatically have Python "comparison" behavior. + + :class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a + ``type`` member which is expression's return type. This can be established + at the instance level in the constructor, or at the class level if its + generally constant:: + + class timestamp(ColumnElement): + type = TIMESTAMP() + +* :class:`~sqlalchemy.sql.expression.FunctionElement` - This is a hybrid of a + ``ColumnElement`` and a "from clause" like object, and represents a SQL + function or stored procedure type of call. Since most databases support + statements along the line of "SELECT FROM " + ``FunctionElement`` adds in the ability to be used in the FROM clause of a + ``select()`` construct. + +* :class:`~sqlalchemy.schema.DDLElement` - The root of all DDL expressions, + like CREATE TABLE, ALTER TABLE, etc. Compilation of ``DDLElement`` + subclasses is issued by a ``DDLCompiler`` instead of a ``SQLCompiler``. + ``DDLElement`` also features ``Table`` and ``MetaData`` event hooks via the + ``execute_at()`` method, allowing the construct to be invoked during CREATE + TABLE and DROP TABLE sequences. + +* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which should be + used with any expression class that represents a "standalone" SQL statement that + can be passed directly to an ``execute()`` method. It is already implicit + within ``DDLElement`` and ``FunctionElement``. + +""" + +def compiles(class_, *specs): + def decorate(fn): + existing = getattr(class_, '_compiler_dispatcher', None) + if not existing: + existing = _dispatcher() + + # TODO: why is the lambda needed ? + setattr(class_, '_compiler_dispatch', lambda *arg, **kw: existing(*arg, **kw)) + setattr(class_, '_compiler_dispatcher', existing) + + if specs: + for s in specs: + existing.specs[s] = fn + else: + existing.specs['default'] = fn + return fn + return decorate + +class _dispatcher(object): + def __init__(self): + self.specs = {} + + def __call__(self, element, compiler, **kw): + # TODO: yes, this could also switch off of DBAPI in use. + fn = self.specs.get(compiler.dialect.name, None) + if not fn: + fn = self.specs['default'] + return fn(element, compiler, **kw) + diff --git a/sqlalchemy/ext/declarative.py b/sqlalchemy/ext/declarative.py new file mode 100644 index 0000000..1f4658b --- /dev/null +++ b/sqlalchemy/ext/declarative.py @@ -0,0 +1,940 @@ +""" +Synopsis +======== + +SQLAlchemy object-relational configuration involves the use of +:class:`~sqlalchemy.schema.Table`, :func:`~sqlalchemy.orm.mapper`, and +class objects to define the three areas of configuration. +:mod:`~sqlalchemy.ext.declarative` allows all three types of +configuration to be expressed declaratively on an individual +mapped class. Regular SQLAlchemy schema elements and ORM constructs +are used in most cases. + +As a simple example:: + + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class SomeClass(Base): + __tablename__ = 'some_table' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + +Above, the :func:`declarative_base` callable returns a new base class from which +all mapped classes should inherit. When the class definition is completed, a +new :class:`~sqlalchemy.schema.Table` and +:class:`~sqlalchemy.orm.mapper` will have been generated, accessible +via the ``__table__`` and ``__mapper__`` attributes on the ``SomeClass`` class. + +Defining Attributes +=================== + +In the above example, the :class:`~sqlalchemy.schema.Column` objects are +automatically named with the name of the attribute to which they are +assigned. + +They can also be explicitly named, and that name does not have to be +the same as name assigned on the class. +The column will be assigned to the :class:`~sqlalchemy.schema.Table` using the +given name, and mapped to the class using the attribute name:: + + class SomeClass(Base): + __tablename__ = 'some_table' + id = Column("some_table_id", Integer, primary_key=True) + name = Column("name", String(50)) + +Attributes may be added to the class after its construction, and they will be +added to the underlying :class:`~sqlalchemy.schema.Table` and +:func:`~sqlalchemy.orm.mapper()` definitions as appropriate:: + + SomeClass.data = Column('data', Unicode) + SomeClass.related = relationship(RelatedInfo) + +Classes which are mapped explicitly using +:func:`~sqlalchemy.orm.mapper()` can interact freely with declarative +classes. + +It is recommended, though not required, that all tables +share the same underlying :class:`~sqlalchemy.schema.MetaData` object, +so that string-configured :class:`~sqlalchemy.schema.ForeignKey` +references can be resolved without issue. + +Association of Metadata and Engine +================================== + +The :func:`declarative_base` base class contains a +:class:`~sqlalchemy.schema.MetaData` object where newly +defined :class:`~sqlalchemy.schema.Table` objects are collected. This +is accessed via the :class:`~sqlalchemy.schema.MetaData` class level +accessor, so to create tables we can say:: + + engine = create_engine('sqlite://') + Base.metadata.create_all(engine) + +The :class:`~sqlalchemy.engine.base.Engine` created above may also be +directly associated with the declarative base class using the ``bind`` +keyword argument, where it will be associated with the underlying +:class:`~sqlalchemy.schema.MetaData` object and allow SQL operations +involving that metadata and its tables to make use of that engine +automatically:: + + Base = declarative_base(bind=create_engine('sqlite://')) + +Alternatively, by way of the normal +:class:`~sqlalchemy.schema.MetaData` behaviour, the ``bind`` attribute +of the class level accessor can be assigned at any time as follows:: + + Base.metadata.bind = create_engine('sqlite://') + +The :func:`declarative_base` can also receive a pre-created +:class:`~sqlalchemy.schema.MetaData` object, which allows a +declarative setup to be associated with an already +existing traditional collection of :class:`~sqlalchemy.schema.Table` +objects:: + + mymetadata = MetaData() + Base = declarative_base(metadata=mymetadata) + +Configuring Relationships +========================= + +Relationships to other classes are done in the usual way, with the added +feature that the class specified to :func:`~sqlalchemy.orm.relationship` +may be a string name (note that :func:`~sqlalchemy.orm.relationship` is +only available as of SQLAlchemy 0.6beta2, and in all prior versions is known +as :func:`~sqlalchemy.orm.relation`, +including 0.5 and 0.4). The "class registry" associated with ``Base`` +is used at mapper compilation time to resolve the name into the actual +class object, which is expected to have been defined once the mapper +configuration is used:: + + class User(Base): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + addresses = relationship("Address", backref="user") + + class Address(Base): + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + +Column constructs, since they are just that, are immediately usable, +as below where we define a primary join condition on the ``Address`` +class using them:: + + class Address(Base): + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + user = relationship(User, primaryjoin=user_id == User.id) + +In addition to the main argument for :func:`~sqlalchemy.orm.relationship`, +other arguments which depend upon the columns present on an as-yet +undefined class may also be specified as strings. These strings are +evaluated as Python expressions. The full namespace available within +this evaluation includes all classes mapped for this declarative base, +as well as the contents of the ``sqlalchemy`` package, including +expression functions like :func:`~sqlalchemy.sql.expression.desc` and +:attr:`~sqlalchemy.sql.expression.func`:: + + class User(Base): + # .... + addresses = relationship("Address", + order_by="desc(Address.email)", + primaryjoin="Address.user_id==User.id") + +As an alternative to string-based attributes, attributes may also be +defined after all classes have been created. Just add them to the target +class after the fact:: + + User.addresses = relationship(Address, + primaryjoin=Address.user_id==User.id) + +Configuring Many-to-Many Relationships +====================================== + +There's nothing special about many-to-many with declarative. The +``secondary`` argument to :func:`~sqlalchemy.orm.relationship` still +requires a :class:`~sqlalchemy.schema.Table` object, not a declarative +class. The :class:`~sqlalchemy.schema.Table` should share the same +:class:`~sqlalchemy.schema.MetaData` object used by the declarative +base:: + + keywords = Table( + 'keywords', Base.metadata, + Column('author_id', Integer, ForeignKey('authors.id')), + Column('keyword_id', Integer, ForeignKey('keywords.id')) + ) + + class Author(Base): + __tablename__ = 'authors' + id = Column(Integer, primary_key=True) + keywords = relationship("Keyword", secondary=keywords) + +You should generally **not** map a class and also specify its table in +a many-to-many relationship, since the ORM may issue duplicate INSERT and +DELETE statements. + + +Defining Synonyms +================= + +Synonyms are introduced in :ref:`synonyms`. To define a getter/setter +which proxies to an underlying attribute, use +:func:`~sqlalchemy.orm.synonym` with the ``descriptor`` argument:: + + class MyClass(Base): + __tablename__ = 'sometable' + + _attr = Column('attr', String) + + def _get_attr(self): + return self._some_attr + def _set_attr(self, attr): + self._some_attr = attr + attr = synonym('_attr', descriptor=property(_get_attr, _set_attr)) + +The above synonym is then usable as an instance attribute as well as a +class-level expression construct:: + + x = MyClass() + x.attr = "some value" + session.query(MyClass).filter(MyClass.attr == 'some other value').all() + +For simple getters, the :func:`synonym_for` decorator can be used in +conjunction with ``@property``:: + + class MyClass(Base): + __tablename__ = 'sometable' + + _attr = Column('attr', String) + + @synonym_for('_attr') + @property + def attr(self): + return self._some_attr + +Similarly, :func:`comparable_using` is a front end for the +:func:`~sqlalchemy.orm.comparable_property` ORM function:: + + class MyClass(Base): + __tablename__ = 'sometable' + + name = Column('name', String) + + @comparable_using(MyUpperCaseComparator) + @property + def uc_name(self): + return self.name.upper() + +Table Configuration +=================== + +Table arguments other than the name, metadata, and mapped Column +arguments are specified using the ``__table_args__`` class attribute. +This attribute accommodates both positional as well as keyword +arguments that are normally sent to the +:class:`~sqlalchemy.schema.Table` constructor. +The attribute can be specified in one of two forms. One is as a +dictionary:: + + class MyClass(Base): + __tablename__ = 'sometable' + __table_args__ = {'mysql_engine':'InnoDB'} + +The other, a tuple of the form +``(arg1, arg2, ..., {kwarg1:value, ...})``, which allows positional +arguments to be specified as well (usually constraints):: + + class MyClass(Base): + __tablename__ = 'sometable' + __table_args__ = ( + ForeignKeyConstraint(['id'], ['remote_table.id']), + UniqueConstraint('foo'), + {'autoload':True} + ) + +Note that the keyword parameters dictionary is required in the tuple +form even if empty. + +As an alternative to ``__tablename__``, a direct +:class:`~sqlalchemy.schema.Table` construct may be used. The +:class:`~sqlalchemy.schema.Column` objects, which in this case require +their names, will be added to the mapping just like a regular mapping +to a table:: + + class MyClass(Base): + __table__ = Table('my_table', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)) + ) + +Mapper Configuration +==================== + +Configuration of mappers is done with the +:func:`~sqlalchemy.orm.mapper` function and all the possible mapper +configuration parameters can be found in the documentation for that +function. + +:func:`~sqlalchemy.orm.mapper` is still used by declaratively mapped +classes and keyword parameters to the function can be passed by +placing them in the ``__mapper_args__`` class variable:: + + class Widget(Base): + __tablename__ = 'widgets' + id = Column(Integer, primary_key=True) + + __mapper_args__ = {'extension': MyWidgetExtension()} + +Inheritance Configuration +========================= + +Declarative supports all three forms of inheritance as intuitively +as possible. The ``inherits`` mapper keyword argument is not needed +as declarative will determine this from the class itself. The various +"polymorphic" keyword arguments are specified using ``__mapper_args__``. + +Joined Table Inheritance +~~~~~~~~~~~~~~~~~~~~~~~~ + +Joined table inheritance is defined as a subclass that defines its own +table:: + + class Person(Base): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + __tablename__ = 'engineers' + __mapper_args__ = {'polymorphic_identity': 'engineer'} + id = Column(Integer, ForeignKey('people.id'), primary_key=True) + primary_language = Column(String(50)) + +Note that above, the ``Engineer.id`` attribute, since it shares the +same attribute name as the ``Person.id`` attribute, will in fact +represent the ``people.id`` and ``engineers.id`` columns together, and +will render inside a query as ``"people.id"``. +To provide the ``Engineer`` class with an attribute that represents +only the ``engineers.id`` column, give it a different attribute name:: + + class Engineer(Person): + __tablename__ = 'engineers' + __mapper_args__ = {'polymorphic_identity': 'engineer'} + engineer_id = Column('id', Integer, ForeignKey('people.id'), primary_key=True) + primary_language = Column(String(50)) + +Single Table Inheritance +~~~~~~~~~~~~~~~~~~~~~~~~ + +Single table inheritance is defined as a subclass that does not have +its own table; you just leave out the ``__table__`` and ``__tablename__`` +attributes:: + + class Person(Base): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + __mapper_args__ = {'polymorphic_identity': 'engineer'} + primary_language = Column(String(50)) + +When the above mappers are configured, the ``Person`` class is mapped +to the ``people`` table *before* the ``primary_language`` column is +defined, and this column will not be included in its own mapping. +When ``Engineer`` then defines the ``primary_language`` column, the +column is added to the ``people`` table so that it is included in the +mapping for ``Engineer`` and is also part of the table's full set of +columns. Columns which are not mapped to ``Person`` are also excluded +from any other single or joined inheriting classes using the +``exclude_properties`` mapper argument. Below, ``Manager`` will have +all the attributes of ``Person`` and ``Manager`` but *not* the +``primary_language`` attribute of ``Engineer``:: + + class Manager(Person): + __mapper_args__ = {'polymorphic_identity': 'manager'} + golf_swing = Column(String(50)) + +The attribute exclusion logic is provided by the +``exclude_properties`` mapper argument, and declarative's default +behavior can be disabled by passing an explicit ``exclude_properties`` +collection (empty or otherwise) to the ``__mapper_args__``. + +Concrete Table Inheritance +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Concrete is defined as a subclass which has its own table and sets the +``concrete`` keyword argument to ``True``:: + + class Person(Base): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + class Engineer(Person): + __tablename__ = 'engineers' + __mapper_args__ = {'concrete':True} + id = Column(Integer, primary_key=True) + primary_language = Column(String(50)) + name = Column(String(50)) + +Usage of an abstract base class is a little less straightforward as it +requires usage of :func:`~sqlalchemy.orm.util.polymorphic_union`:: + + engineers = Table('engineers', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('primary_language', String(50)) + ) + managers = Table('managers', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('golf_swing', String(50)) + ) + + punion = polymorphic_union({ + 'engineer':engineers, + 'manager':managers + }, 'type', 'punion') + + class Person(Base): + __table__ = punion + __mapper_args__ = {'polymorphic_on':punion.c.type} + + class Engineer(Person): + __table__ = engineers + __mapper_args__ = {'polymorphic_identity':'engineer', 'concrete':True} + + class Manager(Person): + __table__ = managers + __mapper_args__ = {'polymorphic_identity':'manager', 'concrete':True} + + +Mix-in Classes +============== + +A common need when using :mod:`~sqlalchemy.ext.declarative` is to +share some functionality, often a set of columns, across many +classes. The normal python idiom would be to put this common code into +a base class and have all the other classes subclass this class. + +When using :mod:`~sqlalchemy.ext.declarative`, this need is met by +using a "mix-in class". A mix-in class is one that isn't mapped to a +table and doesn't subclass the declarative :class:`Base`. For example:: + + class MyMixin(object): + + __table_args__ = {'mysql_engine':'InnoDB'} + __mapper_args__=dict(always_refresh=True) + id = Column(Integer, primary_key=True) + + def foo(self): + return 'bar'+str(self.id) + + class MyModel(Base,MyMixin): + __tablename__='test' + name = Column(String(1000), nullable=False, index=True) + +As the above example shows, ``__table_args__`` and ``__mapper_args__`` +can both be abstracted out into a mix-in if you use common values for +these across many classes. + +However, particularly in the case of ``__table_args__``, you may want +to combine some parameters from several mix-ins with those you wish to +define on the class iteself. To help with this, a +:func:`~sqlalchemy.util.classproperty` decorator is provided that lets +you implement a class property with a function. For example:: + + from sqlalchemy.util import classproperty + + class MySQLSettings: + __table_args__ = {'mysql_engine':'InnoDB'} + + class MyOtherMixin: + __table_args__ = {'info':'foo'} + + class MyModel(Base,MySQLSettings,MyOtherMixin): + __tablename__='my_model' + + @classproperty + def __table_args__(self): + args = dict() + args.update(MySQLSettings.__table_args__) + args.update(MyOtherMixin.__table_args__) + return args + + id = Column(Integer, primary_key=True) + +Class Constructor +================= + +As a convenience feature, the :func:`declarative_base` sets a default +constructor on classes which takes keyword arguments, and assigns them +to the named attributes:: + + e = Engineer(primary_language='python') + +Sessions +======== + +Note that ``declarative`` does nothing special with sessions, and is +only intended as an easier way to configure mappers and +:class:`~sqlalchemy.schema.Table` objects. A typical application +setup using :func:`~sqlalchemy.orm.scoped_session` might look like:: + + engine = create_engine('postgresql://scott:tiger@localhost/test') + Session = scoped_session(sessionmaker(autocommit=False, + autoflush=False, + bind=engine)) + Base = declarative_base() + +Mapped instances then make usage of +:class:`~sqlalchemy.orm.session.Session` in the usual way. + +""" + +from sqlalchemy.schema import Table, Column, MetaData +from sqlalchemy.orm import synonym as _orm_synonym, mapper, comparable_property, class_mapper +from sqlalchemy.orm.interfaces import MapperProperty +from sqlalchemy.orm.properties import RelationshipProperty, ColumnProperty +from sqlalchemy.orm.util import _is_mapped_class +from sqlalchemy import util, exceptions +from sqlalchemy.sql import util as sql_util + + +__all__ = 'declarative_base', 'synonym_for', 'comparable_using', 'instrument_declarative' + +def instrument_declarative(cls, registry, metadata): + """Given a class, configure the class declaratively, + using the given registry, which can be any dictionary, and + MetaData object. + + """ + if '_decl_class_registry' in cls.__dict__: + raise exceptions.InvalidRequestError( + "Class %r already has been " + "instrumented declaratively" % cls) + cls._decl_class_registry = registry + cls.metadata = metadata + _as_declarative(cls, cls.__name__, cls.__dict__) + +def _as_declarative(cls, classname, dict_): + + # dict_ will be a dictproxy, which we can't write to, and we need to! + dict_ = dict(dict_) + + column_copies = dict() + unmapped_mixins = False + for base in cls.__bases__: + names = dir(base) + if not _is_mapped_class(base): + unmapped_mixins = True + for name in names: + obj = getattr(base,name, None) + if isinstance(obj, Column): + if obj.foreign_keys: + raise exceptions.InvalidRequestError( + "Columns with foreign keys to other columns " + "are not allowed on declarative mixins at this time." + ) + dict_[name]=column_copies[obj]=obj.copy() + elif isinstance(obj, RelationshipProperty): + raise exceptions.InvalidRequestError( + "relationships are not allowed on " + "declarative mixins at this time.") + + # doing it this way enables these attributes to be descriptors + get_mapper_args = '__mapper_args__' in dict_ + get_table_args = '__table_args__' in dict_ + if unmapped_mixins: + get_mapper_args = get_mapper_args or getattr(cls,'__mapper_args__',None) + get_table_args = get_table_args or getattr(cls,'__table_args__',None) + tablename = getattr(cls,'__tablename__',None) + if tablename: + # subtle: if tablename is a descriptor here, we actually + # put the wrong value in, but it serves as a marker to get + # the right value value... + dict_['__tablename__']=tablename + + # now that we know whether or not to get these, get them from the class + # if we should, enabling them to be decorators + mapper_args = get_mapper_args and cls.__mapper_args__ or {} + table_args = get_table_args and cls.__table_args__ or None + + # make sure that column copies are used rather than the original columns + # from any mixins + for k, v in mapper_args.iteritems(): + mapper_args[k] = column_copies.get(v,v) + + cls._decl_class_registry[classname] = cls + our_stuff = util.OrderedDict() + for k in dict_: + value = dict_[k] + if (isinstance(value, tuple) and len(value) == 1 and + isinstance(value[0], (Column, MapperProperty))): + util.warn("Ignoring declarative-like tuple value of attribute " + "%s: possibly a copy-and-paste error with a comma " + "left at the end of the line?" % k) + continue + if not isinstance(value, (Column, MapperProperty)): + continue + prop = _deferred_relationship(cls, value) + our_stuff[k] = prop + + # set up attributes in the order they were created + our_stuff.sort(key=lambda key: our_stuff[key]._creation_order) + + # extract columns from the class dict + cols = [] + for key, c in our_stuff.iteritems(): + if isinstance(c, ColumnProperty): + for col in c.columns: + if isinstance(col, Column) and col.table is None: + _undefer_column_name(key, col) + cols.append(col) + elif isinstance(c, Column): + _undefer_column_name(key, c) + cols.append(c) + # if the column is the same name as the key, + # remove it from the explicit properties dict. + # the normal rules for assigning column-based properties + # will take over, including precedence of columns + # in multi-column ColumnProperties. + if key == c.key: + del our_stuff[key] + + table = None + if '__table__' not in dict_: + if '__tablename__' in dict_: + # see above: if __tablename__ is a descriptor, this + # means we get the right value used! + tablename = cls.__tablename__ + + if isinstance(table_args, dict): + args, table_kw = (), table_args + elif isinstance(table_args, tuple): + args = table_args[0:-1] + table_kw = table_args[-1] + if len(table_args) < 2 or not isinstance(table_kw, dict): + raise exceptions.ArgumentError( + "Tuple form of __table_args__ is " + "(arg1, arg2, arg3, ..., {'kw1':val1, 'kw2':val2, ...})" + ) + else: + args, table_kw = (), {} + + autoload = dict_.get('__autoload__') + if autoload: + table_kw['autoload'] = True + + cls.__table__ = table = Table(tablename, cls.metadata, + *(tuple(cols) + tuple(args)), **table_kw) + else: + table = cls.__table__ + if cols: + for c in cols: + if not table.c.contains_column(c): + raise exceptions.ArgumentError( + "Can't add additional column %r when specifying __table__" % key + ) + + if 'inherits' not in mapper_args: + for c in cls.__bases__: + if _is_mapped_class(c): + mapper_args['inherits'] = cls._decl_class_registry.get(c.__name__, None) + break + + if hasattr(cls, '__mapper_cls__'): + mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__) + else: + mapper_cls = mapper + + if table is None and 'inherits' not in mapper_args: + raise exceptions.InvalidRequestError( + "Class %r does not have a __table__ or __tablename__ " + "specified and does not inherit from an existing table-mapped class." % cls + ) + + elif 'inherits' in mapper_args and not mapper_args.get('concrete', False): + inherited_mapper = class_mapper(mapper_args['inherits'], compile=False) + inherited_table = inherited_mapper.local_table + if 'inherit_condition' not in mapper_args and table is not None: + # figure out the inherit condition with relaxed rules + # about nonexistent tables, to allow for ForeignKeys to + # not-yet-defined tables (since we know for sure that our + # parent table is defined within the same MetaData) + mapper_args['inherit_condition'] = sql_util.join_condition( + mapper_args['inherits'].__table__, table, + ignore_nonexistent_tables=True) + + if table is None: + # single table inheritance. + # ensure no table args + if table_args is not None: + raise exceptions.ArgumentError( + "Can't place __table_args__ on an inherited class with no table." + ) + + # add any columns declared here to the inherited table. + for c in cols: + if c.primary_key: + raise exceptions.ArgumentError( + "Can't place primary key columns on an inherited class with no table." + ) + if c.name in inherited_table.c: + raise exceptions.ArgumentError( + "Column '%s' on class %s conflicts with existing column '%s'" % + (c, cls, inherited_table.c[c.name]) + ) + inherited_table.append_column(c) + + # single or joined inheritance + # exclude any cols on the inherited table which are not mapped on the + # parent class, to avoid + # mapping columns specific to sibling/nephew classes + inherited_mapper = class_mapper(mapper_args['inherits'], compile=False) + inherited_table = inherited_mapper.local_table + + if 'exclude_properties' not in mapper_args: + mapper_args['exclude_properties'] = exclude_properties = \ + set([c.key for c in inherited_table.c + if c not in inherited_mapper._columntoproperty]) + exclude_properties.difference_update([c.key for c in cols]) + + cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args) + +class DeclarativeMeta(type): + def __init__(cls, classname, bases, dict_): + if '_decl_class_registry' in cls.__dict__: + return type.__init__(cls, classname, bases, dict_) + + _as_declarative(cls, classname, cls.__dict__) + return type.__init__(cls, classname, bases, dict_) + + def __setattr__(cls, key, value): + if '__mapper__' in cls.__dict__: + if isinstance(value, Column): + _undefer_column_name(key, value) + cls.__table__.append_column(value) + cls.__mapper__.add_property(key, value) + elif isinstance(value, ColumnProperty): + for col in value.columns: + if isinstance(col, Column) and col.table is None: + _undefer_column_name(key, col) + cls.__table__.append_column(col) + cls.__mapper__.add_property(key, value) + elif isinstance(value, MapperProperty): + cls.__mapper__.add_property(key, _deferred_relationship(cls, value)) + else: + type.__setattr__(cls, key, value) + else: + type.__setattr__(cls, key, value) + + +class _GetColumns(object): + def __init__(self, cls): + self.cls = cls + def __getattr__(self, key): + + mapper = class_mapper(self.cls, compile=False) + if mapper: + prop = mapper.get_property(key) + if not isinstance(prop, ColumnProperty): + raise exceptions.InvalidRequestError( + "Property %r is not an instance of" + " ColumnProperty (i.e. does not correspond" + " directly to a Column)." % key) + return getattr(self.cls, key) + + +def _deferred_relationship(cls, prop): + def resolve_arg(arg): + import sqlalchemy + + def access_cls(key): + if key in cls._decl_class_registry: + return _GetColumns(cls._decl_class_registry[key]) + elif key in cls.metadata.tables: + return cls.metadata.tables[key] + else: + return sqlalchemy.__dict__[key] + + d = util.PopulateDict(access_cls) + def return_cls(): + try: + x = eval(arg, globals(), d) + + if isinstance(x, _GetColumns): + return x.cls + else: + return x + except NameError, n: + raise exceptions.InvalidRequestError( + "When compiling mapper %s, expression %r failed to locate a name (%r). " + "If this is a class name, consider adding this relationship() to the %r " + "class after both dependent classes have been defined." % ( + prop.parent, arg, n.args[0], cls)) + return return_cls + + if isinstance(prop, RelationshipProperty): + for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin', + 'secondary', '_foreign_keys', 'remote_side'): + v = getattr(prop, attr) + if isinstance(v, basestring): + setattr(prop, attr, resolve_arg(v)) + + if prop.backref and isinstance(prop.backref, tuple): + key, kwargs = prop.backref + for attr in ('primaryjoin', 'secondaryjoin', 'secondary', + 'foreign_keys', 'remote_side', 'order_by'): + if attr in kwargs and isinstance(kwargs[attr], basestring): + kwargs[attr] = resolve_arg(kwargs[attr]) + + + return prop + +def synonym_for(name, map_column=False): + """Decorator, make a Python @property a query synonym for a column. + + A decorator version of :func:`~sqlalchemy.orm.synonym`. The function being + decorated is the 'descriptor', otherwise passes its arguments through + to synonym():: + + @synonym_for('col') + @property + def prop(self): + return 'special sauce' + + The regular ``synonym()`` is also usable directly in a declarative setting + and may be convenient for read/write properties:: + + prop = synonym('col', descriptor=property(_read_prop, _write_prop)) + + """ + def decorate(fn): + return _orm_synonym(name, map_column=map_column, descriptor=fn) + return decorate + +def comparable_using(comparator_factory): + """Decorator, allow a Python @property to be used in query criteria. + + This is a decorator front end to + :func:`~sqlalchemy.orm.comparable_property` that passes + through the comparator_factory and the function being decorated:: + + @comparable_using(MyComparatorType) + @property + def prop(self): + return 'special sauce' + + The regular ``comparable_property()`` is also usable directly in a + declarative setting and may be convenient for read/write properties:: + + prop = comparable_property(MyComparatorType) + + """ + def decorate(fn): + return comparable_property(comparator_factory, fn) + return decorate + +def _declarative_constructor(self, **kwargs): + """A simple constructor that allows initialization from kwargs. + + Sets attributes on the constructed instance using the names and + values in ``kwargs``. + + Only keys that are present as + attributes of the instance's class are allowed. These could be, + for example, any mapped columns or relationships. + """ + for k in kwargs: + if not hasattr(type(self), k): + raise TypeError( + "%r is an invalid keyword argument for %s" % + (k, type(self).__name__)) + setattr(self, k, kwargs[k]) +_declarative_constructor.__name__ = '__init__' + +def declarative_base(bind=None, metadata=None, mapper=None, cls=object, + name='Base', constructor=_declarative_constructor, + metaclass=DeclarativeMeta): + """Construct a base class for declarative class definitions. + + The new base class will be given a metaclass that produces + appropriate :class:`~sqlalchemy.schema.Table` objects and makes + the appropriate :func:`~sqlalchemy.orm.mapper` calls based on the + information provided declaratively in the class and any subclasses + of the class. + + :param bind: An optional + :class:`~sqlalchemy.engine.base.Connectable`, will be assigned + the ``bind`` attribute on the :class:`~sqlalchemy.MetaData` + instance. + + + :param metadata: + An optional :class:`~sqlalchemy.MetaData` instance. All + :class:`~sqlalchemy.schema.Table` objects implicitly declared by + subclasses of the base will share this MetaData. A MetaData instance + will be created if none is provided. The + :class:`~sqlalchemy.MetaData` instance will be available via the + `metadata` attribute of the generated declarative base class. + + :param mapper: + An optional callable, defaults to :func:`~sqlalchemy.orm.mapper`. Will be + used to map subclasses to their Tables. + + :param cls: + Defaults to :class:`object`. A type to use as the base for the generated + declarative base class. May be a class or tuple of classes. + + :param name: + Defaults to ``Base``. The display name for the generated + class. Customizing this is not required, but can improve clarity in + tracebacks and debugging. + + :param constructor: + Defaults to + :func:`~sqlalchemy.ext.declarative._declarative_constructor`, an + __init__ implementation that assigns \**kwargs for declared + fields and relationships to an instance. If ``None`` is supplied, + no __init__ will be provided and construction will fall back to + cls.__init__ by way of the normal Python semantics. + + :param metaclass: + Defaults to :class:`DeclarativeMeta`. A metaclass or __metaclass__ + compatible callable to use as the meta type of the generated + declarative base class. + + """ + lcl_metadata = metadata or MetaData() + if bind: + lcl_metadata.bind = bind + + bases = not isinstance(cls, tuple) and (cls,) or cls + class_dict = dict(_decl_class_registry=dict(), + metadata=lcl_metadata) + + if constructor: + class_dict['__init__'] = constructor + if mapper: + class_dict['__mapper_cls__'] = mapper + + return metaclass(name, bases, class_dict) + +def _undefer_column_name(key, column): + if column.key is None: + column.key = key + if column.name is None: + column.name = key diff --git a/sqlalchemy/ext/horizontal_shard.py b/sqlalchemy/ext/horizontal_shard.py new file mode 100644 index 0000000..78e3f59 --- /dev/null +++ b/sqlalchemy/ext/horizontal_shard.py @@ -0,0 +1,125 @@ +# horizontal_shard.py +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Horizontal sharding support. + +Defines a rudimental 'horizontal sharding' system which allows a Session to +distribute queries and persistence operations across multiple databases. + +For a usage example, see the :ref:`examples_sharding` example included in +the source distrbution. + +""" + +import sqlalchemy.exceptions as sa_exc +from sqlalchemy import util +from sqlalchemy.orm.session import Session +from sqlalchemy.orm.query import Query + +__all__ = ['ShardedSession', 'ShardedQuery'] + + +class ShardedSession(Session): + def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs): + """Construct a ShardedSession. + + :param shard_chooser: A callable which, passed a Mapper, a mapped instance, and possibly a + SQL clause, returns a shard ID. This id may be based off of the + attributes present within the object, or on some round-robin + scheme. If the scheme is based on a selection, it should set + whatever state on the instance to mark it in the future as + participating in that shard. + + :param id_chooser: A callable, passed a query and a tuple of identity values, which + should return a list of shard ids where the ID might reside. The + databases will be queried in the order of this listing. + + :param query_chooser: For a given Query, returns the list of shard_ids where the query + should be issued. Results from all shards returned will be combined + together into a single listing. + + :param shards: A dictionary of string shard names to :class:`~sqlalchemy.engine.base.Engine` + objects. + + """ + super(ShardedSession, self).__init__(**kwargs) + self.shard_chooser = shard_chooser + self.id_chooser = id_chooser + self.query_chooser = query_chooser + self.__binds = {} + self._mapper_flush_opts = {'connection_callable':self.connection} + self._query_cls = ShardedQuery + if shards is not None: + for k in shards: + self.bind_shard(k, shards[k]) + + def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): + if shard_id is None: + shard_id = self.shard_chooser(mapper, instance) + + if self.transaction is not None: + return self.transaction.connection(mapper, shard_id=shard_id) + else: + return self.get_bind(mapper, + shard_id=shard_id, + instance=instance).contextual_connect(**kwargs) + + def get_bind(self, mapper, shard_id=None, instance=None, clause=None, **kw): + if shard_id is None: + shard_id = self.shard_chooser(mapper, instance, clause=clause) + return self.__binds[shard_id] + + def bind_shard(self, shard_id, bind): + self.__binds[shard_id] = bind + +class ShardedQuery(Query): + def __init__(self, *args, **kwargs): + super(ShardedQuery, self).__init__(*args, **kwargs) + self.id_chooser = self.session.id_chooser + self.query_chooser = self.session.query_chooser + self._shard_id = None + + def set_shard(self, shard_id): + """return a new query, limited to a single shard ID. + + all subsequent operations with the returned query will + be against the single shard regardless of other state. + """ + + q = self._clone() + q._shard_id = shard_id + return q + + def _execute_and_instances(self, context): + if self._shard_id is not None: + result = self.session.connection( + mapper=self._mapper_zero(), + shard_id=self._shard_id).execute(context.statement, self._params) + return self.instances(result, context) + else: + partial = [] + for shard_id in self.query_chooser(self): + result = self.session.connection( + mapper=self._mapper_zero(), + shard_id=shard_id).execute(context.statement, self._params) + partial = partial + list(self.instances(result, context)) + + # if some kind of in memory 'sorting' + # were done, this is where it would happen + return iter(partial) + + def get(self, ident, **kwargs): + if self._shard_id is not None: + return super(ShardedQuery, self).get(ident) + else: + ident = util.to_list(ident) + for shard_id in self.id_chooser(self, ident): + o = self.set_shard(shard_id).get(ident, **kwargs) + if o is not None: + return o + else: + return None + diff --git a/sqlalchemy/ext/orderinglist.py b/sqlalchemy/ext/orderinglist.py new file mode 100644 index 0000000..0d2c3ae --- /dev/null +++ b/sqlalchemy/ext/orderinglist.py @@ -0,0 +1,315 @@ +"""A custom list that manages index/position information for its children. + +:author: Jason Kirtland + +``orderinglist`` is a helper for mutable ordered relationships. It will intercept +list operations performed on a relationship collection and automatically +synchronize changes in list position with an attribute on the related objects. +(See :ref:`advdatamapping_entitycollections` for more information on the general pattern.) + +Example: Two tables that store slides in a presentation. Each slide +has a number of bullet points, displayed in order by the 'position' +column on the bullets table. These bullets can be inserted and re-ordered +by your end users, and you need to update the 'position' column of all +affected rows when changes are made. + +.. sourcecode:: python+sql + + slides_table = Table('Slides', metadata, + Column('id', Integer, primary_key=True), + Column('name', String)) + + bullets_table = Table('Bullets', metadata, + Column('id', Integer, primary_key=True), + Column('slide_id', Integer, ForeignKey('Slides.id')), + Column('position', Integer), + Column('text', String)) + + class Slide(object): + pass + class Bullet(object): + pass + + mapper(Slide, slides_table, properties={ + 'bullets': relationship(Bullet, order_by=[bullets_table.c.position]) + }) + mapper(Bullet, bullets_table) + +The standard relationship mapping will produce a list-like attribute on each Slide +containing all related Bullets, but coping with changes in ordering is totally +your responsibility. If you insert a Bullet into that list, there is no +magic- it won't have a position attribute unless you assign it it one, and +you'll need to manually renumber all the subsequent Bullets in the list to +accommodate the insert. + +An ``orderinglist`` can automate this and manage the 'position' attribute on all +related bullets for you. + +.. sourcecode:: python+sql + + mapper(Slide, slides_table, properties={ + 'bullets': relationship(Bullet, + collection_class=ordering_list('position'), + order_by=[bullets_table.c.position]) + }) + mapper(Bullet, bullets_table) + + s = Slide() + s.bullets.append(Bullet()) + s.bullets.append(Bullet()) + s.bullets[1].position + >>> 1 + s.bullets.insert(1, Bullet()) + s.bullets[2].position + >>> 2 + +Use the ``ordering_list`` function to set up the ``collection_class`` on relationships +(as in the mapper example above). This implementation depends on the list +starting in the proper order, so be SURE to put an order_by on your relationship. + +.. warning:: ``ordering_list`` only provides limited functionality when a primary + key column or unique column is the target of the sort. Since changing the order of + entries often means that two rows must trade values, this is not possible when + the value is constrained by a primary key or unique constraint, since one of the rows + would temporarily have to point to a third available value so that the other row + could take its old value. ``ordering_list`` doesn't do any of this for you, + nor does SQLAlchemy itself. + +``ordering_list`` takes the name of the related object's ordering attribute as +an argument. By default, the zero-based integer index of the object's +position in the ``ordering_list`` is synchronized with the ordering attribute: +index 0 will get position 0, index 1 position 1, etc. To start numbering at 1 +or some other integer, provide ``count_from=1``. + +Ordering values are not limited to incrementing integers. Almost any scheme +can implemented by supplying a custom ``ordering_func`` that maps a Python list +index to any value you require. + + + + +""" +from sqlalchemy.orm.collections import collection +from sqlalchemy import util + +__all__ = [ 'ordering_list' ] + + +def ordering_list(attr, count_from=None, **kw): + """Prepares an OrderingList factory for use in mapper definitions. + + Returns an object suitable for use as an argument to a Mapper relationship's + ``collection_class`` option. Arguments are: + + attr + Name of the mapped attribute to use for storage and retrieval of + ordering information + + count_from (optional) + Set up an integer-based ordering, starting at ``count_from``. For + example, ``ordering_list('pos', count_from=1)`` would create a 1-based + list in SQL, storing the value in the 'pos' column. Ignored if + ``ordering_func`` is supplied. + + Passes along any keyword arguments to ``OrderingList`` constructor. + """ + + kw = _unsugar_count_from(count_from=count_from, **kw) + return lambda: OrderingList(attr, **kw) + +# Ordering utility functions +def count_from_0(index, collection): + """Numbering function: consecutive integers starting at 0.""" + + return index + +def count_from_1(index, collection): + """Numbering function: consecutive integers starting at 1.""" + + return index + 1 + +def count_from_n_factory(start): + """Numbering function: consecutive integers starting at arbitrary start.""" + + def f(index, collection): + return index + start + try: + f.__name__ = 'count_from_%i' % start + except TypeError: + pass + return f + +def _unsugar_count_from(**kw): + """Builds counting functions from keywrod arguments. + + Keyword argument filter, prepares a simple ``ordering_func`` from a + ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. + """ + + count_from = kw.pop('count_from', None) + if kw.get('ordering_func', None) is None and count_from is not None: + if count_from == 0: + kw['ordering_func'] = count_from_0 + elif count_from == 1: + kw['ordering_func'] = count_from_1 + else: + kw['ordering_func'] = count_from_n_factory(count_from) + return kw + +class OrderingList(list): + """A custom list that manages position information for its children. + + See the module and __init__ documentation for more details. The + ``ordering_list`` factory function is used to configure ``OrderingList`` + collections in ``mapper`` relationship definitions. + + """ + + def __init__(self, ordering_attr=None, ordering_func=None, + reorder_on_append=False): + """A custom list that manages position information for its children. + + ``OrderingList`` is a ``collection_class`` list implementation that + syncs position in a Python list with a position attribute on the + mapped objects. + + This implementation relies on the list starting in the proper order, + so be **sure** to put an ``order_by`` on your relationship. + + ordering_attr + Name of the attribute that stores the object's order in the + relationship. + + ordering_func + Optional. A function that maps the position in the Python list to a + value to store in the ``ordering_attr``. Values returned are + usually (but need not be!) integers. + + An ``ordering_func`` is called with two positional parameters: the + index of the element in the list, and the list itself. + + If omitted, Python list indexes are used for the attribute values. + Two basic pre-built numbering functions are provided in this module: + ``count_from_0`` and ``count_from_1``. For more exotic examples + like stepped numbering, alphabetical and Fibonacci numbering, see + the unit tests. + + reorder_on_append + Default False. When appending an object with an existing (non-None) + ordering value, that value will be left untouched unless + ``reorder_on_append`` is true. This is an optimization to avoid a + variety of dangerous unexpected database writes. + + SQLAlchemy will add instances to the list via append() when your + object loads. If for some reason the result set from the database + skips a step in the ordering (say, row '1' is missing but you get + '2', '3', and '4'), reorder_on_append=True would immediately + renumber the items to '1', '2', '3'. If you have multiple sessions + making changes, any of whom happen to load this collection even in + passing, all of the sessions would try to "clean up" the numbering + in their commits, possibly causing all but one to fail with a + concurrent modification error. Spooky action at a distance. + + Recommend leaving this with the default of False, and just call + ``reorder()`` if you're doing ``append()`` operations with + previously ordered instances or when doing some housekeeping after + manual sql operations. + + """ + self.ordering_attr = ordering_attr + if ordering_func is None: + ordering_func = count_from_0 + self.ordering_func = ordering_func + self.reorder_on_append = reorder_on_append + + # More complex serialization schemes (multi column, e.g.) are possible by + # subclassing and reimplementing these two methods. + def _get_order_value(self, entity): + return getattr(entity, self.ordering_attr) + + def _set_order_value(self, entity, value): + setattr(entity, self.ordering_attr, value) + + def reorder(self): + """Synchronize ordering for the entire collection. + + Sweeps through the list and ensures that each object has accurate + ordering information set. + + """ + for index, entity in enumerate(self): + self._order_entity(index, entity, True) + + # As of 0.5, _reorder is no longer semi-private + _reorder = reorder + + def _order_entity(self, index, entity, reorder=True): + have = self._get_order_value(entity) + + # Don't disturb existing ordering if reorder is False + if have is not None and not reorder: + return + + should_be = self.ordering_func(index, self) + if have != should_be: + self._set_order_value(entity, should_be) + + def append(self, entity): + super(OrderingList, self).append(entity) + self._order_entity(len(self) - 1, entity, self.reorder_on_append) + + def _raw_append(self, entity): + """Append without any ordering behavior.""" + + super(OrderingList, self).append(entity) + _raw_append = collection.adds(1)(_raw_append) + + def insert(self, index, entity): + super(OrderingList, self).insert(index, entity) + self._reorder() + + def remove(self, entity): + super(OrderingList, self).remove(entity) + self._reorder() + + def pop(self, index=-1): + entity = super(OrderingList, self).pop(index) + self._reorder() + return entity + + def __setitem__(self, index, entity): + if isinstance(index, slice): + step = index.step or 1 + start = index.start or 0 + if start < 0: + start += len(self) + stop = index.stop or len(self) + if stop < 0: + stop += len(self) + + for i in xrange(start, stop, step): + self.__setitem__(i, entity[i]) + else: + self._order_entity(index, entity, True) + super(OrderingList, self).__setitem__(index, entity) + + def __delitem__(self, index): + super(OrderingList, self).__delitem__(index) + self._reorder() + + # Py2K + def __setslice__(self, start, end, values): + super(OrderingList, self).__setslice__(start, end, values) + self._reorder() + + def __delslice__(self, start, end): + super(OrderingList, self).__delslice__(start, end) + self._reorder() + # end Py2K + + for func_name, func in locals().items(): + if (util.callable(func) and func.func_name == func_name and + not func.__doc__ and hasattr(list, func_name)): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + diff --git a/sqlalchemy/ext/serializer.py b/sqlalchemy/ext/serializer.py new file mode 100644 index 0000000..354f28c --- /dev/null +++ b/sqlalchemy/ext/serializer.py @@ -0,0 +1,155 @@ +"""Serializer/Deserializer objects for usage with SQLAlchemy query structures, +allowing "contextual" deserialization. + +Any SQLAlchemy query structure, either based on sqlalchemy.sql.* +or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session +etc. which are referenced by the structure are not persisted in serialized +form, but are instead re-associated with the query structure +when it is deserialized. + +Usage is nearly the same as that of the standard Python pickle module:: + + from sqlalchemy.ext.serializer import loads, dumps + metadata = MetaData(bind=some_engine) + Session = scoped_session(sessionmaker()) + + # ... define mappers + + query = Session.query(MyClass).filter(MyClass.somedata=='foo').order_by(MyClass.sortkey) + + # pickle the query + serialized = dumps(query) + + # unpickle. Pass in metadata + scoped_session + query2 = loads(serialized, metadata, Session) + + print query2.all() + +Similar restrictions as when using raw pickle apply; mapped classes must be +themselves be pickleable, meaning they are importable from a module-level +namespace. + +The serializer module is only appropriate for query structures. It is not +needed for: + +* instances of user-defined classes. These contain no references to engines, + sessions or expression constructs in the typical case and can be serialized directly. + +* Table metadata that is to be loaded entirely from the serialized structure (i.e. is + not already declared in the application). Regular pickle.loads()/dumps() can + be used to fully dump any ``MetaData`` object, typically one which was reflected + from an existing database at some previous point in time. The serializer module + is specifically for the opposite case, where the Table metadata is already present + in memory. + +""" + +from sqlalchemy.orm import class_mapper, Query +from sqlalchemy.orm.session import Session +from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.orm.attributes import QueryableAttribute +from sqlalchemy import Table, Column +from sqlalchemy.engine import Engine +from sqlalchemy.util import pickle +import re +import base64 +# Py3K +#from io import BytesIO as byte_buffer +# Py2K +from cStringIO import StringIO as byte_buffer +# end Py2K + +# Py3K +#def b64encode(x): +# return base64.b64encode(x).decode('ascii') +#def b64decode(x): +# return base64.b64decode(x.encode('ascii')) +# Py2K +b64encode = base64.b64encode +b64decode = base64.b64decode +# end Py2K + +__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads'] + + + +def Serializer(*args, **kw): + pickler = pickle.Pickler(*args, **kw) + + def persistent_id(obj): + #print "serializing:", repr(obj) + if isinstance(obj, QueryableAttribute): + cls = obj.impl.class_ + key = obj.impl.key + id = "attribute:" + key + ":" + b64encode(pickle.dumps(cls)) + elif isinstance(obj, Mapper) and not obj.non_primary: + id = "mapper:" + b64encode(pickle.dumps(obj.class_)) + elif isinstance(obj, Table): + id = "table:" + str(obj) + elif isinstance(obj, Column) and isinstance(obj.table, Table): + id = "column:" + str(obj.table) + ":" + obj.key + elif isinstance(obj, Session): + id = "session:" + elif isinstance(obj, Engine): + id = "engine:" + else: + return None + return id + + pickler.persistent_id = persistent_id + return pickler + +our_ids = re.compile(r'(mapper|table|column|session|attribute|engine):(.*)') + +def Deserializer(file, metadata=None, scoped_session=None, engine=None): + unpickler = pickle.Unpickler(file) + + def get_engine(): + if engine: + return engine + elif scoped_session and scoped_session().bind: + return scoped_session().bind + elif metadata and metadata.bind: + return metadata.bind + else: + return None + + def persistent_load(id): + m = our_ids.match(id) + if not m: + return None + else: + type_, args = m.group(1, 2) + if type_ == 'attribute': + key, clsarg = args.split(":") + cls = pickle.loads(b64decode(clsarg)) + return getattr(cls, key) + elif type_ == "mapper": + cls = pickle.loads(b64decode(args)) + return class_mapper(cls) + elif type_ == "table": + return metadata.tables[args] + elif type_ == "column": + table, colname = args.split(':') + return metadata.tables[table].c[colname] + elif type_ == "session": + return scoped_session() + elif type_ == "engine": + return get_engine() + else: + raise Exception("Unknown token: %s" % type_) + unpickler.persistent_load = persistent_load + return unpickler + +def dumps(obj, protocol=0): + buf = byte_buffer() + pickler = Serializer(buf, protocol) + pickler.dump(obj) + return buf.getvalue() + +def loads(data, metadata=None, scoped_session=None, engine=None): + buf = byte_buffer(data) + unpickler = Deserializer(buf, metadata, scoped_session, engine) + return unpickler.load() + + diff --git a/sqlalchemy/ext/sqlsoup.py b/sqlalchemy/ext/sqlsoup.py new file mode 100644 index 0000000..4d5f4b7 --- /dev/null +++ b/sqlalchemy/ext/sqlsoup.py @@ -0,0 +1,551 @@ +""" +Introduction +============ + +SqlSoup provides a convenient way to access existing database tables without +having to declare table or mapper classes ahead of time. It is built on top of the SQLAlchemy ORM and provides a super-minimalistic interface to an existing database. + +Suppose we have a database with users, books, and loans tables +(corresponding to the PyWebOff dataset, if you're curious). + +Creating a SqlSoup gateway is just like creating an SQLAlchemy +engine:: + + >>> from sqlalchemy.ext.sqlsoup import SqlSoup + >>> db = SqlSoup('sqlite:///:memory:') + +or, you can re-use an existing engine:: + + >>> db = SqlSoup(engine) + +You can optionally specify a schema within the database for your +SqlSoup:: + + >>> db.schema = myschemaname + +Loading objects +=============== + +Loading objects is as easy as this:: + + >>> users = db.users.all() + >>> users.sort() + >>> users + [MappedUsers(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0), MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)] + +Of course, letting the database do the sort is better:: + + >>> db.users.order_by(db.users.name).all() + [MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1), MappedUsers(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0)] + +Field access is intuitive:: + + >>> users[0].email + u'student@example.edu' + +Of course, you don't want to load all users very often. Let's add a +WHERE clause. Let's also switch the order_by to DESC while we're at +it:: + + >>> from sqlalchemy import or_, and_, desc + >>> where = or_(db.users.name=='Bhargan Basepair', db.users.email=='student@example.edu') + >>> db.users.filter(where).order_by(desc(db.users.name)).all() + [MappedUsers(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0), MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)] + +You can also use .first() (to retrieve only the first object from a query) or +.one() (like .first when you expect exactly one user -- it will raise an +exception if more were returned):: + + >>> db.users.filter(db.users.name=='Bhargan Basepair').one() + MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1) + +Since name is the primary key, this is equivalent to + + >>> db.users.get('Bhargan Basepair') + MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1) + +This is also equivalent to + + >>> db.users.filter_by(name='Bhargan Basepair').one() + MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1) + +filter_by is like filter, but takes kwargs instead of full clause expressions. +This makes it more concise for simple queries like this, but you can't do +complex queries like the or\_ above or non-equality based comparisons this way. + +Full query documentation +------------------------ + +Get, filter, filter_by, order_by, limit, and the rest of the +query methods are explained in detail in :ref:`ormtutorial_querying`. + +Modifying objects +================= + +Modifying objects is intuitive:: + + >>> user = _ + >>> user.email = 'basepair+nospam@example.edu' + >>> db.commit() + +(SqlSoup leverages the sophisticated SQLAlchemy unit-of-work code, so +multiple updates to a single object will be turned into a single +``UPDATE`` statement when you commit.) + +To finish covering the basics, let's insert a new loan, then delete +it:: + + >>> book_id = db.books.filter_by(title='Regional Variation in Moss').first().id + >>> db.loans.insert(book_id=book_id, user_name=user.name) + MappedLoans(book_id=2,user_name=u'Bhargan Basepair',loan_date=None) + + >>> loan = db.loans.filter_by(book_id=2, user_name='Bhargan Basepair').one() + >>> db.delete(loan) + >>> db.commit() + +You can also delete rows that have not been loaded as objects. Let's +do our insert/delete cycle once more, this time using the loans +table's delete method. (For SQLAlchemy experts: note that no flush() +call is required since this delete acts at the SQL level, not at the +Mapper level.) The same where-clause construction rules apply here as +to the select methods. + +:: + + >>> db.loans.insert(book_id=book_id, user_name=user.name) + MappedLoans(book_id=2,user_name=u'Bhargan Basepair',loan_date=None) + >>> db.loans.delete(db.loans.book_id==2) + +You can similarly update multiple rows at once. This will change the +book_id to 1 in all loans whose book_id is 2:: + + >>> db.loans.update(db.loans.book_id==2, book_id=1) + >>> db.loans.filter_by(book_id=1).all() + [MappedLoans(book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))] + + +Joins +===== + +Occasionally, you will want to pull out a lot of data from related +tables all at once. In this situation, it is far more efficient to +have the database perform the necessary join. (Here we do not have *a +lot of data* but hopefully the concept is still clear.) SQLAlchemy is +smart enough to recognize that loans has a foreign key to users, and +uses that as the join condition automatically. + +:: + + >>> join1 = db.join(db.users, db.loans, isouter=True) + >>> join1.filter_by(name='Joe Student').all() + [MappedJoin(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0,book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))] + +If you're unfortunate enough to be using MySQL with the default MyISAM +storage engine, you'll have to specify the join condition manually, +since MyISAM does not store foreign keys. Here's the same join again, +with the join condition explicitly specified:: + + >>> db.join(db.users, db.loans, db.users.name==db.loans.user_name, isouter=True) + + +You can compose arbitrarily complex joins by combining Join objects +with tables or other joins. Here we combine our first join with the +books table:: + + >>> join2 = db.join(join1, db.books) + >>> join2.all() + [MappedJoin(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0,book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0),id=1,title=u'Mustards I Have Known',published_year=u'1989',authors=u'Jones')] + +If you join tables that have an identical column name, wrap your join +with `with_labels`, to disambiguate columns with their table name +(.c is short for .columns):: + + >>> db.with_labels(join1).c.keys() + [u'users_name', u'users_email', u'users_password', u'users_classname', u'users_admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date'] + +You can also join directly to a labeled object:: + + >>> labeled_loans = db.with_labels(db.loans) + >>> db.join(db.users, labeled_loans, isouter=True).c.keys() + [u'name', u'email', u'password', u'classname', u'admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date'] + + +Relationships +============= + +You can define relationships on SqlSoup classes: + + >>> db.users.relate('loans', db.loans) + +These can then be used like a normal SA property: + + >>> db.users.get('Joe Student').loans + [MappedLoans(book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))] + + >>> db.users.filter(~db.users.loans.any()).all() + [MappedUsers(name=u'Bhargan Basepair',email='basepair+nospam@example.edu',password=u'basepair',classname=None,admin=1)] + + +relate can take any options that the relationship function accepts in normal mapper definition: + + >>> del db._cache['users'] + >>> db.users.relate('loans', db.loans, order_by=db.loans.loan_date, cascade='all, delete-orphan') + +Advanced Use +============ + +Sessions, Transations and Application Integration +------------------------------------------------- + +**Note:** please read and understand this section thoroughly before using SqlSoup in any web application. + +SqlSoup uses a ScopedSession to provide thread-local sessions. You +can get a reference to the current one like this:: + + >>> session = db.session + +The default session is available at the module level in SQLSoup, via:: + + >>> from sqlalchemy.ext.sqlsoup import Session + +The configuration of this session is ``autoflush=True``, ``autocommit=False``. +This means when you work with the SqlSoup object, you need to call ``db.commit()`` +in order to have changes persisted. You may also call ``db.rollback()`` to +roll things back. + +Since the SqlSoup object's Session automatically enters into a transaction as soon +as it's used, it is *essential* that you call ``commit()`` or ``rollback()`` +on it when the work within a thread completes. This means all the guidelines +for web application integration at :ref:`session_lifespan` must be followed. + +The SqlSoup object can have any session or scoped session configured onto it. +This is of key importance when integrating with existing code or frameworks +such as Pylons. If your application already has a ``Session`` configured, +pass it to your SqlSoup object:: + + >>> from myapplication import Session + >>> db = SqlSoup(session=Session) + +If the ``Session`` is configured with ``autocommit=True``, use ``flush()`` +instead of ``commit()`` to persist changes - in this case, the ``Session`` +closes out its transaction immediately and no external management is needed. ``rollback()`` is also not available. Configuring a new SQLSoup object in "autocommit" mode looks like:: + + >>> from sqlalchemy.orm import scoped_session, sessionmaker + >>> db = SqlSoup('sqlite://', session=scoped_session(sessionmaker(autoflush=False, expire_on_commit=False, autocommit=True))) + + +Mapping arbitrary Selectables +----------------------------- + +SqlSoup can map any SQLAlchemy ``Selectable`` with the map +method. Let's map a ``Select`` object that uses an aggregate function; +we'll use the SQLAlchemy ``Table`` that SqlSoup introspected as the +basis. (Since we're not mapping to a simple table or join, we need to +tell SQLAlchemy how to find the *primary key* which just needs to be +unique within the select, and not necessarily correspond to a *real* +PK in the database.) + +:: + + >>> from sqlalchemy import select, func + >>> b = db.books._table + >>> s = select([b.c.published_year, func.count('*').label('n')], from_obj=[b], group_by=[b.c.published_year]) + >>> s = s.alias('years_with_count') + >>> years_with_count = db.map(s, primary_key=[s.c.published_year]) + >>> years_with_count.filter_by(published_year='1989').all() + [MappedBooks(published_year=u'1989',n=1)] + +Obviously if we just wanted to get a list of counts associated with +book years once, raw SQL is going to be less work. The advantage of +mapping a Select is reusability, both standalone and in Joins. (And if +you go to full SQLAlchemy, you can perform mappings like this directly +to your object models.) + +An easy way to save mapped selectables like this is to just hang them on +your db object:: + + >>> db.years_with_count = years_with_count + +Python is flexible like that! + + +Raw SQL +------- + +SqlSoup works fine with SQLAlchemy's text construct, described in :ref:`sqlexpression_text`. +You can also execute textual SQL directly using the `execute()` method, +which corresponds to the `execute()` method on the underlying `Session`. +Expressions here are expressed like ``text()`` constructs, using named parameters +with colons:: + + >>> rp = db.execute('select name, email from users where name like :name order by name', name='%Bhargan%') + >>> for name, email in rp.fetchall(): print name, email + Bhargan Basepair basepair+nospam@example.edu + +Or you can get at the current transaction's connection using `connection()`. This is the +raw connection object which can accept any sort of SQL expression or raw SQL string passed to the database:: + + >>> conn = db.connection() + >>> conn.execute("'select name, email from users where name like ? order by name'", '%Bhargan%') + + +Dynamic table names +------------------- + +You can load a table whose name is specified at runtime with the entity() method: + + >>> tablename = 'loans' + >>> db.entity(tablename) == db.loans + True + +entity() also takes an optional schema argument. If none is specified, the +default schema is used. + + +""" + +from sqlalchemy import Table, MetaData, join +from sqlalchemy import schema, sql +from sqlalchemy.engine.base import Engine +from sqlalchemy.orm import scoped_session, sessionmaker, mapper, \ + class_mapper, relationship, session,\ + object_session +from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE +from sqlalchemy.exceptions import SQLAlchemyError, InvalidRequestError, ArgumentError +from sqlalchemy.sql import expression + + +__all__ = ['PKNotFoundError', 'SqlSoup'] + +Session = scoped_session(sessionmaker(autoflush=True, autocommit=False)) + +class AutoAdd(MapperExtension): + def __init__(self, scoped_session): + self.scoped_session = scoped_session + + def instrument_class(self, mapper, class_): + class_.__init__ = self._default__init__(mapper) + + def _default__init__(ext, mapper): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + return __init__ + + def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): + session = self.scoped_session() + session._save_without_cascade(instance) + return EXT_CONTINUE + + def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): + sess = object_session(instance) + if sess: + sess.expunge(instance) + return EXT_CONTINUE + +class PKNotFoundError(SQLAlchemyError): + pass + +def _ddl_error(cls): + msg = 'SQLSoup can only modify mapped Tables (found: %s)' \ + % cls._table.__class__.__name__ + raise InvalidRequestError(msg) + +# metaclass is necessary to expose class methods with getattr, e.g. +# we want to pass db.users.select through to users._mapper.select +class SelectableClassType(type): + def insert(cls, **kwargs): + _ddl_error(cls) + + def __clause_element__(cls): + return cls._table + + def __getattr__(cls, attr): + if attr == '_query': + # called during mapper init + raise AttributeError() + return getattr(cls._query, attr) + +class TableClassType(SelectableClassType): + def insert(cls, **kwargs): + o = cls() + o.__dict__.update(kwargs) + return o + + def relate(cls, propname, *args, **kwargs): + class_mapper(cls)._configure_property(propname, relationship(*args, **kwargs)) + +def _is_outer_join(selectable): + if not isinstance(selectable, sql.Join): + return False + if selectable.isouter: + return True + return _is_outer_join(selectable.left) or _is_outer_join(selectable.right) + +def _selectable_name(selectable): + if isinstance(selectable, sql.Alias): + return _selectable_name(selectable.element) + elif isinstance(selectable, sql.Select): + return ''.join(_selectable_name(s) for s in selectable.froms) + elif isinstance(selectable, schema.Table): + return selectable.name.capitalize() + else: + x = selectable.__class__.__name__ + if x[0] == '_': + x = x[1:] + return x + +def _class_for_table(session, engine, selectable, **mapper_kwargs): + selectable = expression._clause_element_as_expr(selectable) + mapname = 'Mapped' + _selectable_name(selectable) + # Py2K + if isinstance(mapname, unicode): + engine_encoding = engine.dialect.encoding + mapname = mapname.encode(engine_encoding) + # end Py2K + + if isinstance(selectable, Table): + klass = TableClassType(mapname, (object,), {}) + else: + klass = SelectableClassType(mapname, (object,), {}) + + def _compare(self, o): + L = list(self.__class__.c.keys()) + L.sort() + t1 = [getattr(self, k) for k in L] + try: + t2 = [getattr(o, k) for k in L] + except AttributeError: + raise TypeError('unable to compare with %s' % o.__class__) + return t1, t2 + + # python2/python3 compatible system of + # __cmp__ - __lt__ + __eq__ + + def __lt__(self, o): + t1, t2 = _compare(self, o) + return t1 < t2 + + def __eq__(self, o): + t1, t2 = _compare(self, o) + return t1 == t2 + + def __repr__(self): + L = ["%s=%r" % (key, getattr(self, key, '')) + for key in self.__class__.c.keys()] + return '%s(%s)' % (self.__class__.__name__, ','.join(L)) + + for m in ['__eq__', '__repr__', '__lt__']: + setattr(klass, m, eval(m)) + klass._table = selectable + klass.c = expression.ColumnCollection() + mappr = mapper(klass, + selectable, + extension=AutoAdd(session), + **mapper_kwargs) + + for k in mappr.iterate_properties: + klass.c[k.key] = k.columns[0] + + klass._query = session.query_property() + return klass + +class SqlSoup(object): + def __init__(self, engine_or_metadata, **kw): + """Initialize a new ``SqlSoup``. + + `args` may either be an ``SQLEngine`` or a set of arguments + suitable for passing to ``create_engine``. + """ + + self.session = kw.pop('session', Session) + + if isinstance(engine_or_metadata, MetaData): + self._metadata = engine_or_metadata + elif isinstance(engine_or_metadata, (basestring, Engine)): + self._metadata = MetaData(engine_or_metadata) + else: + raise ArgumentError("invalid engine or metadata argument %r" % engine_or_metadata) + + self._cache = {} + self.schema = None + + @property + def engine(self): + return self._metadata.bind + + bind = engine + + def delete(self, *args, **kwargs): + self.session.delete(*args, **kwargs) + + def execute(self, stmt, **params): + return self.session.execute(sql.text(stmt, bind=self.bind), **params) + + @property + def _underlying_session(self): + if isinstance(self.session, session.Session): + return self.session + else: + return self.session() + + def connection(self): + return self._underlying_session._connection_for_bind(self.bind) + + def flush(self): + self.session.flush() + + def rollback(self): + self.session.rollback() + + def commit(self): + self.session.commit() + + def clear(self): + self.session.expunge_all() + + def expunge(self, *args, **kw): + self.session.expunge(*args, **kw) + + def expunge_all(self): + self.session.expunge_all() + + def map(self, selectable, **kwargs): + try: + t = self._cache[selectable] + except KeyError: + t = _class_for_table(self.session, self.engine, selectable, **kwargs) + self._cache[selectable] = t + return t + + def with_labels(self, item): + # TODO give meaningful aliases + return self.map( + expression._clause_element_as_expr(item). + select(use_labels=True). + alias('foo')) + + def join(self, *args, **kwargs): + j = join(*args, **kwargs) + return self.map(j) + + def entity(self, attr, schema=None): + try: + t = self._cache[attr] + except KeyError, ke: + table = Table(attr, self._metadata, autoload=True, autoload_with=self.bind, schema=schema or self.schema) + if not table.primary_key.columns: + raise PKNotFoundError('table %r does not have a primary key defined [columns: %s]' % (attr, ','.join(table.c.keys()))) + if table.columns: + t = _class_for_table(self.session, self.engine, table) + else: + t = None + self._cache[attr] = t + return t + + def __getattr__(self, attr): + return self.entity(attr) + + def __repr__(self): + return 'SqlSoup(%r)' % self._metadata + diff --git a/sqlalchemy/interfaces.py b/sqlalchemy/interfaces.py new file mode 100644 index 0000000..c2a267d --- /dev/null +++ b/sqlalchemy/interfaces.py @@ -0,0 +1,205 @@ +# interfaces.py +# Copyright (C) 2007 Jason Kirtland jek@discorporate.us +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Interfaces and abstract types.""" + + +class PoolListener(object): + """Hooks into the lifecycle of connections in a ``Pool``. + + Usage:: + + class MyListener(PoolListener): + def connect(self, dbapi_con, con_record): + '''perform connect operations''' + # etc. + + # create a new pool with a listener + p = QueuePool(..., listeners=[MyListener()]) + + # add a listener after the fact + p.add_listener(MyListener()) + + # usage with create_engine() + e = create_engine("url://", listeners=[MyListener()]) + + All of the standard connection :class:`~sqlalchemy.pool.Pool` types can + accept event listeners for key connection lifecycle events: + creation, pool check-out and check-in. There are no events fired + when a connection closes. + + For any given DB-API connection, there will be one ``connect`` + event, `n` number of ``checkout`` events, and either `n` or `n - 1` + ``checkin`` events. (If a ``Connection`` is detached from its + pool via the ``detach()`` method, it won't be checked back in.) + + These are low-level events for low-level objects: raw Python + DB-API connections, without the conveniences of the SQLAlchemy + ``Connection`` wrapper, ``Dialect`` services or ``ClauseElement`` + execution. If you execute SQL through the connection, explicitly + closing all cursors and other resources is recommended. + + Events also receive a ``_ConnectionRecord``, a long-lived internal + ``Pool`` object that basically represents a "slot" in the + connection pool. ``_ConnectionRecord`` objects have one public + attribute of note: ``info``, a dictionary whose contents are + scoped to the lifetime of the DB-API connection managed by the + record. You can use this shared storage area however you like. + + There is no need to subclass ``PoolListener`` to handle events. + Any class that implements one or more of these methods can be used + as a pool listener. The ``Pool`` will inspect the methods + provided by a listener object and add the listener to one or more + internal event queues based on its capabilities. In terms of + efficiency and function call overhead, you're much better off only + providing implementations for the hooks you'll be using. + + """ + + def connect(self, dbapi_con, con_record): + """Called once for each new DB-API connection or Pool's ``creator()``. + + dbapi_con + A newly connected raw DB-API connection (not a SQLAlchemy + ``Connection`` wrapper). + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + """ + + def first_connect(self, dbapi_con, con_record): + """Called exactly once for the first DB-API connection. + + dbapi_con + A newly connected raw DB-API connection (not a SQLAlchemy + ``Connection`` wrapper). + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + """ + + def checkout(self, dbapi_con, con_record, con_proxy): + """Called when a connection is retrieved from the Pool. + + dbapi_con + A raw DB-API connection + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + con_proxy + The ``_ConnectionFairy`` which manages the connection for the span of + the current checkout. + + If you raise an ``exc.DisconnectionError``, the current + connection will be disposed and a fresh connection retrieved. + Processing of all checkout listeners will abort and restart + using the new connection. + """ + + def checkin(self, dbapi_con, con_record): + """Called when a connection returns to the pool. + + Note that the connection may be closed, and may be None if the + connection has been invalidated. ``checkin`` will not be called + for detached connections. (They do not return to the pool.) + + dbapi_con + A raw DB-API connection + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + """ + +class ConnectionProxy(object): + """Allows interception of statement execution by Connections. + + Either or both of the ``execute()`` and ``cursor_execute()`` + may be implemented to intercept compiled statement and + cursor level executions, e.g.:: + + class MyProxy(ConnectionProxy): + def execute(self, conn, execute, clauseelement, *multiparams, **params): + print "compiled statement:", clauseelement + return execute(clauseelement, *multiparams, **params) + + def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): + print "raw statement:", statement + return execute(cursor, statement, parameters, context) + + The ``execute`` argument is a function that will fulfill the default + execution behavior for the operation. The signature illustrated + in the example should be used. + + The proxy is installed into an :class:`~sqlalchemy.engine.Engine` via + the ``proxy`` argument:: + + e = create_engine('someurl://', proxy=MyProxy()) + + """ + def execute(self, conn, execute, clauseelement, *multiparams, **params): + """Intercept high level execute() events.""" + + return execute(clauseelement, *multiparams, **params) + + def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): + """Intercept low-level cursor execute() events.""" + + return execute(cursor, statement, parameters, context) + + def begin(self, conn, begin): + """Intercept begin() events.""" + + return begin() + + def rollback(self, conn, rollback): + """Intercept rollback() events.""" + + return rollback() + + def commit(self, conn, commit): + """Intercept commit() events.""" + + return commit() + + def savepoint(self, conn, savepoint, name=None): + """Intercept savepoint() events.""" + + return savepoint(name=name) + + def rollback_savepoint(self, conn, rollback_savepoint, name, context): + """Intercept rollback_savepoint() events.""" + + return rollback_savepoint(name, context) + + def release_savepoint(self, conn, release_savepoint, name, context): + """Intercept release_savepoint() events.""" + + return release_savepoint(name, context) + + def begin_twophase(self, conn, begin_twophase, xid): + """Intercept begin_twophase() events.""" + + return begin_twophase(xid) + + def prepare_twophase(self, conn, prepare_twophase, xid): + """Intercept prepare_twophase() events.""" + + return prepare_twophase(xid) + + def rollback_twophase(self, conn, rollback_twophase, xid, is_prepared): + """Intercept rollback_twophase() events.""" + + return rollback_twophase(xid, is_prepared) + + def commit_twophase(self, conn, commit_twophase, xid, is_prepared): + """Intercept commit_twophase() events.""" + + return commit_twophase(xid, is_prepared) + diff --git a/sqlalchemy/log.py b/sqlalchemy/log.py new file mode 100644 index 0000000..49c779f --- /dev/null +++ b/sqlalchemy/log.py @@ -0,0 +1,119 @@ +# log.py - adapt python logging module to SQLAlchemy +# Copyright (C) 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Logging control and utilities. + +Control of logging for SA can be performed from the regular python logging +module. The regular dotted module namespace is used, starting at +'sqlalchemy'. For class-level logging, the class name is appended. + +The "echo" keyword parameter which is available on SQLA ``Engine`` +and ``Pool`` objects corresponds to a logger specific to that +instance only. + +E.g.:: + + engine.echo = True + +is equivalent to:: + + import logging + logger = logging.getLogger('sqlalchemy.engine.Engine.%s' % hex(id(engine))) + logger.setLevel(logging.DEBUG) + +""" + +import logging +import sys +from sqlalchemy import util + +rootlogger = logging.getLogger('sqlalchemy') +if rootlogger.level == logging.NOTSET: + rootlogger.setLevel(logging.WARN) + +default_enabled = False +def default_logging(name): + global default_enabled + if logging.getLogger(name).getEffectiveLevel() < logging.WARN: + default_enabled = True + if not default_enabled: + default_enabled = True + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter( + '%(asctime)s %(levelname)s %(name)s %(message)s')) + rootlogger.addHandler(handler) + +_logged_classes = set() +def class_logger(cls, enable=False): + logger = logging.getLogger(cls.__module__ + "." + cls.__name__) + if enable == 'debug': + logger.setLevel(logging.DEBUG) + elif enable == 'info': + logger.setLevel(logging.INFO) + cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG) + cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO) + cls.logger = logger + _logged_classes.add(cls) + + +class Identified(object): + @util.memoized_property + def logging_name(self): + # limit the number of loggers by chopping off the hex(id). + # some novice users unfortunately create an unlimited number + # of Engines in their applications which would otherwise + # cause the app to run out of memory. + return "0x...%s" % hex(id(self))[-4:] + + +def instance_logger(instance, echoflag=None): + """create a logger for an instance that implements :class:`Identified`. + + Warning: this is an expensive call which also results in a permanent + increase in memory overhead for each call. Use only for + low-volume, long-time-spanning objects. + + """ + + name = "%s.%s.%s" % (instance.__class__.__module__, + instance.__class__.__name__, instance.logging_name) + + if echoflag is not None: + l = logging.getLogger(name) + if echoflag == 'debug': + default_logging(name) + l.setLevel(logging.DEBUG) + elif echoflag is True: + default_logging(name) + l.setLevel(logging.INFO) + elif echoflag is False: + l.setLevel(logging.WARN) + else: + l = logging.getLogger(name) + instance._should_log_debug = lambda: l.isEnabledFor(logging.DEBUG) + instance._should_log_info = lambda: l.isEnabledFor(logging.INFO) + return l + +class echo_property(object): + __doc__ = """\ + When ``True``, enable log output for this element. + + This has the effect of setting the Python logging level for the namespace + of this element's class and object reference. A value of boolean ``True`` + indicates that the loglevel ``logging.INFO`` will be set for the logger, + whereas the string value ``debug`` will set the loglevel to + ``logging.DEBUG``. + """ + + def __get__(self, instance, owner): + if instance is None: + return self + else: + return instance._should_log_debug() and 'debug' or \ + (instance._should_log_info() and True or False) + + def __set__(self, instance, value): + instance_logger(instance, echoflag=value) diff --git a/sqlalchemy/orm/__init__.py b/sqlalchemy/orm/__init__.py new file mode 100644 index 0000000..206c8d0 --- /dev/null +++ b/sqlalchemy/orm/__init__.py @@ -0,0 +1,1176 @@ +# sqlalchemy/orm/__init__.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Functional constructs for ORM configuration. + +See the SQLAlchemy object relational tutorial and mapper configuration +documentation for an overview of how this module is used. + +""" + +from sqlalchemy.orm import exc +from sqlalchemy.orm.mapper import ( + Mapper, + _mapper_registry, + class_mapper, + ) +from sqlalchemy.orm.interfaces import ( + EXT_CONTINUE, + EXT_STOP, + ExtensionOption, + InstrumentationManager, + MapperExtension, + PropComparator, + SessionExtension, + AttributeExtension, + ) +from sqlalchemy.orm.util import ( + AliasedClass as aliased, + Validator, + join, + object_mapper, + outerjoin, + polymorphic_union, + with_parent, + ) +from sqlalchemy.orm.properties import ( + ColumnProperty, + ComparableProperty, + CompositeProperty, + RelationshipProperty, + PropertyLoader, + SynonymProperty, + ) +from sqlalchemy.orm import mapper as mapperlib +from sqlalchemy.orm.mapper import reconstructor, validates +from sqlalchemy.orm import strategies +from sqlalchemy.orm.query import AliasOption, Query +from sqlalchemy.sql import util as sql_util +from sqlalchemy.orm.session import Session as _Session +from sqlalchemy.orm.session import object_session, sessionmaker, make_transient +from sqlalchemy.orm.scoping import ScopedSession +from sqlalchemy import util as sa_util + + +__all__ = ( + 'EXT_CONTINUE', + 'EXT_STOP', + 'InstrumentationManager', + 'MapperExtension', + 'AttributeExtension', + 'Validator', + 'PropComparator', + 'Query', + 'aliased', + 'backref', + 'class_mapper', + 'clear_mappers', + 'column_property', + 'comparable_property', + 'compile_mappers', + 'composite', + 'contains_alias', + 'contains_eager', + 'create_session', + 'defer', + 'deferred', + 'dynamic_loader', + 'eagerload', + 'eagerload_all', + 'extension', + 'join', + 'joinedload', + 'joinedload_all', + 'lazyload', + 'mapper', + 'make_transient', + 'noload', + 'object_mapper', + 'object_session', + 'outerjoin', + 'polymorphic_union', + 'reconstructor', + 'relationship', + 'relation', + 'scoped_session', + 'sessionmaker', + 'subqueryload', + 'subqueryload_all', + 'synonym', + 'undefer', + 'undefer_group', + 'validates' + ) + + +def scoped_session(session_factory, scopefunc=None): + """Provides thread-local management of Sessions. + + This is a front-end function to + :class:`~sqlalchemy.orm.scoping.ScopedSession`. + + :param session_factory: a callable function that produces + :class:`Session` instances, such as :func:`sessionmaker` or + :func:`create_session`. + + :param scopefunc: optional, TODO + + :returns: an :class:`~sqlalchemy.orm.scoping.ScopedSession` instance + + Usage:: + + Session = scoped_session(sessionmaker(autoflush=True)) + + To instantiate a Session object which is part of the scoped context, + instantiate normally:: + + session = Session() + + Most session methods are available as classmethods from the scoped + session:: + + Session.commit() + Session.close() + + To map classes so that new instances are saved in the current Session + automatically, as well as to provide session-aware class attributes such + as "query", use the `mapper` classmethod from the scoped session:: + + mapper = Session.mapper + mapper(Class, table, ...) + + """ + return ScopedSession(session_factory, scopefunc=scopefunc) + +def create_session(bind=None, **kwargs): + """Create a new :class:`~sqlalchemy.orm.session.Session`. + + :param bind: optional, a single Connectable to use for all + database access in the created + :class:`~sqlalchemy.orm.session.Session`. + + :param \*\*kwargs: optional, passed through to the + :class:`Session` constructor. + + :returns: an :class:`~sqlalchemy.orm.session.Session` instance + + The defaults of create_session() are the opposite of that of + :func:`sessionmaker`; ``autoflush`` and ``expire_on_commit`` are + False, ``autocommit`` is True. In this sense the session acts + more like the "classic" SQLAlchemy 0.3 session with these. + + Usage:: + + >>> from sqlalchemy.orm import create_session + >>> session = create_session() + + It is recommended to use :func:`sessionmaker` instead of + create_session(). + + """ + kwargs.setdefault('autoflush', False) + kwargs.setdefault('autocommit', True) + kwargs.setdefault('expire_on_commit', False) + return _Session(bind=bind, **kwargs) + +def relationship(argument, secondary=None, **kwargs): + """Provide a relationship of a primary Mapper to a secondary Mapper. + + .. note:: This function is known as :func:`relation` in all versions + of SQLAlchemy prior to version 0.6beta2, including the 0.5 and 0.4 series. + :func:`~sqlalchemy.orm.relationship()` is only available starting with + SQLAlchemy 0.6beta2. The :func:`relation` name will remain available for + the foreseeable future in order to enable cross-compatibility. + + This corresponds to a parent-child or associative table relationship. The + constructed class is an instance of :class:`RelationshipProperty`. + + A typical :func:`relationship`:: + + mapper(Parent, properties={ + 'children': relationship(Children) + }) + + :param argument: + a class or :class:`Mapper` instance, representing the target of + the relationship. + + :param secondary: + for a many-to-many relationship, specifies the intermediary + table. The *secondary* keyword argument should generally only + be used for a table that is not otherwise expressed in any class + mapping. In particular, using the Association Object Pattern is + generally mutually exclusive with the use of the *secondary* + keyword argument. + + :param backref: + indicates the string name of a property to be placed on the related + mapper's class that will handle this relationship in the other + direction. The other property will be created automatically + when the mappers are configured. Can also be passed as a + :func:`backref` object to control the configuration of the + new relationship. + + :param back_populates: + Takes a string name and has the same meaning as ``backref``, + except the complementing property is **not** created automatically, + and instead must be configured explicitly on the other mapper. The + complementing property should also indicate ``back_populates`` + to this relationship to ensure proper functioning. + + :param cascade: + a comma-separated list of cascade rules which determines how + Session operations should be "cascaded" from parent to child. + This defaults to ``False``, which means the default cascade + should be used. The default value is ``"save-update, merge"``. + + Available cascades are: + + * ``save-update`` - cascade the :meth:`~sqlalchemy.orm.session.Session.add` + operation. This cascade applies both to future and + past calls to :meth:`~sqlalchemy.orm.session.Session.add`, + meaning new items added to a collection or scalar relationship + get placed into the same session as that of the parent, and + also applies to items which have been removed from this + relationship but are still part of unflushed history. + + * ``merge`` - cascade the :meth:`~sqlalchemy.orm.session.Session.merge` + operation + + * ``expunge`` - cascade the :meth:`~sqlalchemy.orm.session.Session.expunge` + operation + + * ``delete`` - cascade the :meth:`~sqlalchemy.orm.session.Session.delete` + operation + + * ``delete-orphan`` - if an item of the child's type with no + parent is detected, mark it for deletion. Note that this + option prevents a pending item of the child's class from being + persisted without a parent present. + + * ``refresh-expire`` - cascade the :meth:`~sqlalchemy.orm.session.Session.expire` + and :meth:`~sqlalchemy.orm.session.Session.refresh` operations + + * ``all`` - shorthand for "save-update,merge, refresh-expire, + expunge, delete" + + :param collection_class: + a class or callable that returns a new list-holding object. will + be used in place of a plain list for storing elements. + + :param comparator_factory: + a class which extends :class:`RelationshipProperty.Comparator` which + provides custom SQL clause generation for comparison operations. + + :param extension: + an :class:`AttributeExtension` instance, or list of extensions, + which will be prepended to the list of attribute listeners for + the resulting descriptor placed on the class. These listeners + will receive append and set events before the operation + proceeds, and may be used to halt (via exception throw) or + change the value used in the operation. + + :param foreign_keys: + a list of columns which are to be used as "foreign key" columns. + this parameter should be used in conjunction with explicit + ``primaryjoin`` and ``secondaryjoin`` (if needed) arguments, and + the columns within the ``foreign_keys`` list should be present + within those join conditions. Normally, ``relationship()`` will + inspect the columns within the join conditions to determine + which columns are the "foreign key" columns, based on + information in the ``Table`` metadata. Use this argument when no + ForeignKey's are present in the join condition, or to override + the table-defined foreign keys. + + :param innerjoin=False: + when ``True``, joined eager loads will use an inner join to join + against related tables instead of an outer join. The purpose + of this option is strictly one of performance, as inner joins + generally perform better than outer joins. This flag can + be set to ``True`` when the relationship references an object + via many-to-one using local foreign keys that are not nullable, + or when the reference is one-to-one or a collection that is + guaranteed to have one or at least one entry. + + :param join_depth: + when non-``None``, an integer value indicating how many levels + deep "eager" loaders should join on a self-referring or cyclical + relationship. The number counts how many times the same Mapper + shall be present in the loading condition along a particular join + branch. When left at its default of ``None``, eager loaders + will stop chaining when they encounter a the same target mapper + which is already higher up in the chain. This option applies + both to joined- and subquery- eager loaders. + + :param lazy=('select'|'joined'|'subquery'|'noload'|'dynamic'): specifies + how the related items should be loaded. Values include: + + * 'select' - items should be loaded lazily when the property is first + accessed. + + * 'joined' - items should be loaded "eagerly" in the same query as + that of the parent, using a JOIN or LEFT OUTER JOIN. + + * 'subquery' - items should be loaded "eagerly" within the same + query as that of the parent, using a second SQL statement + which issues a JOIN to a subquery of the original + statement. + + * 'noload' - no loading should occur at any time. This is to + support "write-only" attributes, or attributes which are + populated in some manner specific to the application. + + * 'dynamic' - the attribute will return a pre-configured + :class:`~sqlalchemy.orm.query.Query` object for all read + operations, onto which further filtering operations can be + applied before iterating the results. The dynamic + collection supports a limited set of mutation operations, + allowing ``append()`` and ``remove()``. Changes to the + collection will not be visible until flushed + to the database, where it is then refetched upon iteration. + + * True - a synonym for 'select' + + * False - a synonyn for 'joined' + + * None - a synonym for 'noload' + + :param order_by: + indicates the ordering that should be applied when loading these + items. + + :param passive_deletes=False: + Indicates loading behavior during delete operations. + + A value of True indicates that unloaded child items should not + be loaded during a delete operation on the parent. Normally, + when a parent item is deleted, all child items are loaded so + that they can either be marked as deleted, or have their + foreign key to the parent set to NULL. Marking this flag as + True usually implies an ON DELETE rule is in + place which will handle updating/deleting child rows on the + database side. + + Additionally, setting the flag to the string value 'all' will + disable the "nulling out" of the child foreign keys, when there + is no delete or delete-orphan cascade enabled. This is + typically used when a triggering or error raise scenario is in + place on the database side. Note that the foreign key + attributes on in-session child objects will not be changed + after a flush occurs so this is a very special use-case + setting. + + :param passive_updates=True: + Indicates loading and INSERT/UPDATE/DELETE behavior when the + source of a foreign key value changes (i.e. an "on update" + cascade), which are typically the primary key columns of the + source row. + + When True, it is assumed that ON UPDATE CASCADE is configured on + the foreign key in the database, and that the database will + handle propagation of an UPDATE from a source column to + dependent rows. Note that with databases which enforce + referential integrity (i.e. PostgreSQL, MySQL with InnoDB tables), + ON UPDATE CASCADE is required for this operation. The + relationship() will update the value of the attribute on related + items which are locally present in the session during a flush. + + When False, it is assumed that the database does not enforce + referential integrity and will not be issuing its own CASCADE + operation for an update. The relationship() will issue the + appropriate UPDATE statements to the database in response to the + change of a referenced key, and items locally present in the + session during a flush will also be refreshed. + + This flag should probably be set to False if primary key changes + are expected and the database in use doesn't support CASCADE + (i.e. SQLite, MySQL MyISAM tables). + + Also see the passive_updates flag on ``mapper()``. + + A future SQLAlchemy release will provide a "detect" feature for + this flag. + + :param post_update: + this indicates that the relationship should be handled by a + second UPDATE statement after an INSERT or before a + DELETE. Currently, it also will issue an UPDATE after the + instance was UPDATEd as well, although this technically should + be improved. This flag is used to handle saving bi-directional + dependencies between two individual rows (i.e. each row + references the other), where it would otherwise be impossible to + INSERT or DELETE both rows fully since one row exists before the + other. Use this flag when a particular mapping arrangement will + incur two rows that are dependent on each other, such as a table + that has a one-to-many relationship to a set of child rows, and + also has a column that references a single child row within that + list (i.e. both tables contain a foreign key to each other). If + a ``flush()`` operation returns an error that a "cyclical + dependency" was detected, this is a cue that you might want to + use ``post_update`` to "break" the cycle. + + :param primaryjoin: + a ColumnElement (i.e. WHERE criterion) that will be used as the primary + join of this child object against the parent object, or in a + many-to-many relationship the join of the primary object to the + association table. By default, this value is computed based on the + foreign key relationships of the parent and child tables (or association + table). + + :param remote_side: + used for self-referential relationships, indicates the column or + list of columns that form the "remote side" of the relationship. + + :param secondaryjoin: + a ColumnElement (i.e. WHERE criterion) that will be used as the join of + an association table to the child object. By default, this value is + computed based on the foreign key relationships of the association and + child tables. + + :param single_parent=(True|False): + when True, installs a validator which will prevent objects + from being associated with more than one parent at a time. + This is used for many-to-one or many-to-many relationships that + should be treated either as one-to-one or one-to-many. Its + usage is optional unless delete-orphan cascade is also + set on this relationship(), in which case its required (new in 0.5.2). + + :param uselist=(True|False): + a boolean that indicates if this property should be loaded as a + list or a scalar. In most cases, this value is determined + automatically by ``relationship()``, based on the type and direction + of the relationship - one to many forms a list, many to one + forms a scalar, many to many is a list. If a scalar is desired + where normally a list would be present, such as a bi-directional + one-to-one relationship, set uselist to False. + + :param viewonly=False: + when set to True, the relationship is used only for loading objects + within the relationship, and has no effect on the unit-of-work + flush process. Relationships with viewonly can specify any kind of + join conditions to provide additional views of related objects + onto a parent object. Note that the functionality of a viewonly + relationship has its limits - complicated join conditions may + not compile into eager or lazy loaders properly. If this is the + case, use an alternative method. + + """ + return RelationshipProperty(argument, secondary=secondary, **kwargs) + +def relation(*arg, **kw): + """A synonym for :func:`relationship`.""" + + return relationship(*arg, **kw) + +def dynamic_loader(argument, secondary=None, primaryjoin=None, + secondaryjoin=None, foreign_keys=None, backref=None, + post_update=False, cascade=False, remote_side=None, + enable_typechecks=True, passive_deletes=False, + order_by=None, comparator_factory=None, query_class=None): + """Construct a dynamically-loading mapper property. + + This property is similar to :func:`relationship`, except read + operations return an active :class:`Query` object which reads from + the database when accessed. Items may be appended to the + attribute via ``append()``, or removed via ``remove()``; changes + will be persisted to the database during a :meth:`Sesion.flush`. + However, no other Python list or collection mutation operations + are available. + + A subset of arguments available to :func:`relationship` are available + here. + + :param argument: + a class or :class:`Mapper` instance, representing the target of + the relationship. + + :param secondary: + for a many-to-many relationship, specifies the intermediary + table. The *secondary* keyword argument should generally only + be used for a table that is not otherwise expressed in any class + mapping. In particular, using the Association Object Pattern is + generally mutually exclusive with the use of the *secondary* + keyword argument. + + :param query_class: + Optional, a custom Query subclass to be used as the basis for + dynamic collection. + + """ + from sqlalchemy.orm.dynamic import DynaLoader + + return RelationshipProperty( + argument, secondary=secondary, primaryjoin=primaryjoin, + secondaryjoin=secondaryjoin, foreign_keys=foreign_keys, backref=backref, + post_update=post_update, cascade=cascade, remote_side=remote_side, + enable_typechecks=enable_typechecks, passive_deletes=passive_deletes, + order_by=order_by, comparator_factory=comparator_factory, + strategy_class=DynaLoader, query_class=query_class) + +def column_property(*args, **kwargs): + """Provide a column-level property for use with a Mapper. + + Column-based properties can normally be applied to the mapper's + ``properties`` dictionary using the ``schema.Column`` element directly. + Use this function when the given column is not directly present within the + mapper's selectable; examples include SQL expressions, functions, and + scalar SELECT queries. + + Columns that aren't present in the mapper's selectable won't be persisted + by the mapper and are effectively "read-only" attributes. + + \*cols + list of Column objects to be mapped. + + comparator_factory + a class which extends ``sqlalchemy.orm.properties.ColumnProperty.Comparator`` + which provides custom SQL clause generation for comparison operations. + + group + a group name for this property when marked as deferred. + + deferred + when True, the column property is "deferred", meaning that + it does not load immediately, and is instead loaded when the + attribute is first accessed on an instance. See also + :func:`~sqlalchemy.orm.deferred`. + + extension + an :class:`~sqlalchemy.orm.interfaces.AttributeExtension` instance, + or list of extensions, which will be prepended to the list of + attribute listeners for the resulting descriptor placed on the class. + These listeners will receive append and set events before the + operation proceeds, and may be used to halt (via exception throw) + or change the value used in the operation. + + """ + + return ColumnProperty(*args, **kwargs) + +def composite(class_, *cols, **kwargs): + """Return a composite column-based property for use with a Mapper. + + This is very much like a column-based property except the given class is + used to represent "composite" values composed of one or more columns. + + The class must implement a constructor with positional arguments matching + the order of columns supplied here, as well as a __composite_values__() + method which returns values in the same order. + + A simple example is representing separate two columns in a table as a + single, first-class "Point" object:: + + class Point(object): + def __init__(self, x, y): + self.x = x + self.y = y + def __composite_values__(self): + return self.x, self.y + def __eq__(self, other): + return other is not None and self.x == other.x and self.y == other.y + + # and then in the mapping: + ... composite(Point, mytable.c.x, mytable.c.y) ... + + The composite object may have its attributes populated based on the names + of the mapped columns. To override the way internal state is set, + additionally implement ``__set_composite_values__``:: + + class Point(object): + def __init__(self, x, y): + self.some_x = x + self.some_y = y + def __composite_values__(self): + return self.some_x, self.some_y + def __set_composite_values__(self, x, y): + self.some_x = x + self.some_y = y + def __eq__(self, other): + return other is not None and self.some_x == other.x and self.some_y == other.y + + Arguments are: + + class\_ + The "composite type" class. + + \*cols + List of Column objects to be mapped. + + group + A group name for this property when marked as deferred. + + deferred + When True, the column property is "deferred", meaning that it does not + load immediately, and is instead loaded when the attribute is first + accessed on an instance. See also :func:`~sqlalchemy.orm.deferred`. + + comparator_factory + a class which extends ``sqlalchemy.orm.properties.CompositeProperty.Comparator`` + which provides custom SQL clause generation for comparison operations. + + extension + an :class:`~sqlalchemy.orm.interfaces.AttributeExtension` instance, + or list of extensions, which will be prepended to the list of + attribute listeners for the resulting descriptor placed on the class. + These listeners will receive append and set events before the + operation proceeds, and may be used to halt (via exception throw) + or change the value used in the operation. + + """ + return CompositeProperty(class_, *cols, **kwargs) + + +def backref(name, **kwargs): + """Create a back reference with explicit arguments, which are the same + arguments one can send to ``relationship()``. + + Used with the `backref` keyword argument to ``relationship()`` in + place of a string argument. + + """ + return (name, kwargs) + +def deferred(*columns, **kwargs): + """Return a ``DeferredColumnProperty``, which indicates this + object attributes should only be loaded from its corresponding + table column when first accessed. + + Used with the `properties` dictionary sent to ``mapper()``. + + """ + return ColumnProperty(deferred=True, *columns, **kwargs) + +def mapper(class_, local_table=None, *args, **params): + """Return a new :class:`~sqlalchemy.orm.Mapper` object. + + :param class\_: The class to be mapped. + + :param local_table: The table to which the class is mapped, or None if this mapper + inherits from another mapper using concrete table inheritance. + + :param always_refresh: If True, all query operations for this mapped class will overwrite all + data within object instances that already exist within the session, + erasing any in-memory changes with whatever information was loaded + from the database. Usage of this flag is highly discouraged; as an + alternative, see the method `populate_existing()` on + :class:`~sqlalchemy.orm.query.Query`. + + :param allow_null_pks: This flag is deprecated - this is stated as allow_partial_pks + which defaults to True. + + :param allow_partial_pks: Defaults to True. Indicates that a composite primary key with + some NULL values should be considered as possibly existing + within the database. This affects whether a mapper will assign + an incoming row to an existing identity, as well as if + session.merge() will check the database first for a particular + primary key value. A "partial primary key" can occur if one + has mapped to an OUTER JOIN, for example. + + :param batch: Indicates that save operations of multiple entities can be batched + together for efficiency. setting to False indicates that an instance + will be fully saved before saving the next instance, which includes + inserting/updating all table rows corresponding to the entity as well + as calling all ``MapperExtension`` methods corresponding to the save + operation. + + :param column_prefix: A string which will be prepended to the `key` name of all Columns when + creating column-based properties from the given Table. Does not + affect explicitly specified column-based properties + + :param concrete: If True, indicates this mapper should use concrete table inheritance + with its parent mapper. + + :param exclude_properties: A list of properties not to map. Columns present in the mapped table + and present in this list will not be automatically converted into + properties. Note that neither this option nor include_properties will + allow an end-run around Python inheritance. If mapped class ``B`` + inherits from mapped class ``A``, no combination of includes or + excludes will allow ``B`` to have fewer properties than its + superclass, ``A``. + + + :param extension: A :class:`~sqlalchemy.orm.interfaces.MapperExtension` instance or list of + :class:`~sqlalchemy.orm.interfaces.MapperExtension` instances which will be applied to all + operations by this :class:`~sqlalchemy.orm.mapper.Mapper`. + + :param include_properties: An inclusive list of properties to map. Columns present in the mapped + table but not present in this list will not be automatically converted + into properties. + + :param inherits: Another :class:`~sqlalchemy.orm.Mapper` for which + this :class:`~sqlalchemy.orm.Mapper` will have an inheritance + relationship with. + + + :param inherit_condition: For joined table inheritance, a SQL expression (constructed + ``ClauseElement``) which will define how the two tables are joined; + defaults to a natural join between the two tables. + + :param inherit_foreign_keys: When inherit_condition is used and the condition contains no + ForeignKey columns, specify the "foreign" columns of the join + condition in this list. else leave as None. + + :param non_primary: Construct a ``Mapper`` that will define only the selection of + instances, not their persistence. Any number of non_primary mappers + may be created for a particular class. + + :param order_by: A single ``Column`` or list of ``Columns`` for which + selection operations should use as the default ordering for entities. + Defaults to the OID/ROWID of the table if any, or the first primary + key column of the table. + + :param passive_updates: Indicates UPDATE behavior of foreign keys when a primary key changes + on a joined-table inheritance or other joined table mapping. + + When True, it is assumed that ON UPDATE CASCADE is configured on + the foreign key in the database, and that the database will + handle propagation of an UPDATE from a source column to + dependent rows. Note that with databases which enforce + referential integrity (i.e. PostgreSQL, MySQL with InnoDB tables), + ON UPDATE CASCADE is required for this operation. The + relationship() will update the value of the attribute on related + items which are locally present in the session during a flush. + + When False, it is assumed that the database does not enforce + referential integrity and will not be issuing its own CASCADE + operation for an update. The relationship() will issue the + appropriate UPDATE statements to the database in response to the + change of a referenced key, and items locally present in the + session during a flush will also be refreshed. + + This flag should probably be set to False if primary key changes + are expected and the database in use doesn't support CASCADE + (i.e. SQLite, MySQL MyISAM tables). + + Also see the passive_updates flag on :func:`relationship()`. + + A future SQLAlchemy release will provide a "detect" feature for + this flag. + + :param polymorphic_on: Used with mappers in an inheritance relationship, a ``Column`` which + will identify the class/mapper combination to be used with a + particular row. Requires the ``polymorphic_identity`` value to be set + for all mappers in the inheritance hierarchy. The column specified by + ``polymorphic_on`` is usually a column that resides directly within + the base mapper's mapped table; alternatively, it may be a column that + is only present within the portion of the + ``with_polymorphic`` argument. + + :param polymorphic_identity: A value which will be stored in the Column denoted by polymorphic_on, + corresponding to the *class identity* of this mapper. + + :param properties: A dictionary mapping the string names of object attributes to + ``MapperProperty`` instances, which define the persistence behavior of + that attribute. Note that the columns in the mapped table are + automatically converted into ``ColumnProperty`` instances based on the + `key` property of each ``Column`` (although they can be overridden + using this dictionary). + + :param primary_key: A list of ``Column`` objects which define the *primary key* to be used + against this mapper's selectable unit. This is normally simply the + primary key of the `local_table`, but can be overridden here. + + :param version_id_col: A ``Column`` which must have an integer type that will be used to keep + a running *version id* of mapped entities in the database. this is + used during save operations to ensure that no other thread or process + has updated the instance during the lifetime of the entity, else a + ``ConcurrentModificationError`` exception is thrown. + + :param version_id_generator: A callable which defines the algorithm used to generate new version + ids. Defaults to an integer generator. Can be replaced with one that + generates timestamps, uuids, etc. e.g.:: + + import uuid + + mapper(Cls, table, + version_id_col=table.c.version_uuid, + version_id_generator=lambda version:uuid.uuid4().hex + ) + + The callable receives the current version identifier as its + single argument. + + :param with_polymorphic: A tuple in the form ``(, )`` indicating the + default style of "polymorphic" loading, that is, which tables are + queried at once. is any single or list of mappers and/or + classes indicating the inherited classes that should be loaded at + once. The special value ``'*'`` may be used to indicate all descending + classes should be loaded immediately. The second tuple argument + indicates a selectable that will be used to query for + multiple classes. Normally, it is left as None, in which case this + mapper will form an outer join from the base mapper's table to that of + all desired sub-mappers. When specified, it provides the selectable + to be used for polymorphic loading. When with_polymorphic includes + mappers which load from a "concrete" inheriting table, the + argument is required, since it usually requires more + complex UNION queries. + + + """ + return Mapper(class_, local_table, *args, **params) + +def synonym(name, map_column=False, descriptor=None, comparator_factory=None): + """Set up `name` as a synonym to another mapped property. + + Used with the ``properties`` dictionary sent to :func:`~sqlalchemy.orm.mapper`. + + Any existing attributes on the class which map the key name sent + to the ``properties`` dictionary will be used by the synonym to provide + instance-attribute behavior (that is, any Python property object, provided + by the ``property`` builtin or providing a ``__get__()``, ``__set__()`` + and ``__del__()`` method). If no name exists for the key, the + ``synonym()`` creates a default getter/setter object automatically and + applies it to the class. + + `name` refers to the name of the existing mapped property, which can be + any other ``MapperProperty`` including column-based properties and + relationships. + + If `map_column` is ``True``, an additional ``ColumnProperty`` is created + on the mapper automatically, using the synonym's name as the keyname of + the property, and the keyname of this ``synonym()`` as the name of the + column to map. For example, if a table has a column named ``status``:: + + class MyClass(object): + def _get_status(self): + return self._status + def _set_status(self, value): + self._status = value + status = property(_get_status, _set_status) + + mapper(MyClass, sometable, properties={ + "status":synonym("_status", map_column=True) + }) + + The column named ``status`` will be mapped to the attribute named + ``_status``, and the ``status`` attribute on ``MyClass`` will be used to + proxy access to the column-based attribute. + + """ + return SynonymProperty(name, map_column=map_column, descriptor=descriptor, comparator_factory=comparator_factory) + +def comparable_property(comparator_factory, descriptor=None): + """Provide query semantics for an unmanaged attribute. + + Allows a regular Python @property (descriptor) to be used in Queries and + SQL constructs like a managed attribute. comparable_property wraps a + descriptor with a proxy that directs operator overrides such as == + (__eq__) to the supplied comparator but proxies everything else through to + the original descriptor:: + + class MyClass(object): + @property + def myprop(self): + return 'foo' + + class MyComparator(sqlalchemy.orm.interfaces.PropComparator): + def __eq__(self, other): + .... + + mapper(MyClass, mytable, properties=dict( + 'myprop': comparable_property(MyComparator))) + + Used with the ``properties`` dictionary sent to :func:`~sqlalchemy.orm.mapper`. + + comparator_factory + A PropComparator subclass or factory that defines operator behavior + for this property. + + descriptor + Optional when used in a ``properties={}`` declaration. The Python + descriptor or property to layer comparison behavior on top of. + + The like-named descriptor will be automatically retreived from the + mapped class if left blank in a ``properties`` declaration. + + """ + return ComparableProperty(comparator_factory, descriptor) + +def compile_mappers(): + """Compile all mappers that have been defined. + + This is equivalent to calling ``compile()`` on any individual mapper. + + """ + for m in list(_mapper_registry): + m.compile() + +def clear_mappers(): + """Remove all mappers that have been created thus far. + + The mapped classes will return to their initial "unmapped" state and can + be re-mapped with new mappers. + + """ + mapperlib._COMPILE_MUTEX.acquire() + try: + while _mapper_registry: + try: + # can't even reliably call list(weakdict) in jython + mapper, b = _mapper_registry.popitem() + mapper.dispose() + except KeyError: + pass + finally: + mapperlib._COMPILE_MUTEX.release() + +def extension(ext): + """Return a ``MapperOption`` that will insert the given + ``MapperExtension`` to the beginning of the list of extensions + that will be called in the context of the ``Query``. + + Used with :meth:`~sqlalchemy.orm.query.Query.options`. + + """ + return ExtensionOption(ext) + +@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') +def joinedload(*keys, **kw): + """Return a ``MapperOption`` that will convert the property of the given + name into an joined eager load. + + .. note:: This function is known as :func:`eagerload` in all versions + of SQLAlchemy prior to version 0.6beta3, including the 0.5 and 0.4 series. + :func:`eagerload` will remain available for + the foreseeable future in order to enable cross-compatibility. + + Used with :meth:`~sqlalchemy.orm.query.Query.options`. + + examples:: + + # joined-load the "orders" colleciton on "User" + query(User).options(joinedload(User.orders)) + + # joined-load the "keywords" collection on each "Item", + # but not the "items" collection on "Order" - those + # remain lazily loaded. + query(Order).options(joinedload(Order.items, Item.keywords)) + + # to joined-load across both, use joinedload_all() + query(Order).options(joinedload_all(Order.items, Item.keywords)) + + :func:`joinedload` also accepts a keyword argument `innerjoin=True` which + indicates using an inner join instead of an outer:: + + query(Order).options(joinedload(Order.user, innerjoin=True)) + + Note that the join created by :func:`joinedload` is aliased such that + no other aspects of the query will affect what it loads. To use joined eager + loading with a join that is constructed manually using :meth:`~sqlalchemy.orm.query.Query.join` + or :func:`~sqlalchemy.orm.join`, see :func:`contains_eager`. + + See also: :func:`subqueryload`, :func:`lazyload` + + """ + innerjoin = kw.pop('innerjoin', None) + if innerjoin is not None: + return ( + strategies.EagerLazyOption(keys, lazy='joined'), + strategies.EagerJoinOption(keys, innerjoin) + ) + else: + return strategies.EagerLazyOption(keys, lazy='joined') + +@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') +def joinedload_all(*keys, **kw): + """Return a ``MapperOption`` that will convert all properties along the + given dot-separated path into an joined eager load. + + .. note:: This function is known as :func:`eagerload_all` in all versions + of SQLAlchemy prior to version 0.6beta3, including the 0.5 and 0.4 series. + :func:`eagerload_all` will remain available for + the foreseeable future in order to enable cross-compatibility. + + Used with :meth:`~sqlalchemy.orm.query.Query.options`. + + For example:: + + query.options(joinedload_all('orders.items.keywords'))... + + will set all of 'orders', 'orders.items', and 'orders.items.keywords' to + load in one joined eager load. + + Individual descriptors are accepted as arguments as well:: + + query.options(joinedload_all(User.orders, Order.items, Item.keywords)) + + The keyword arguments accept a flag `innerjoin=True|False` which will + override the value of the `innerjoin` flag specified on the relationship(). + + See also: :func:`subqueryload_all`, :func:`lazyload` + + """ + innerjoin = kw.pop('innerjoin', None) + if innerjoin is not None: + return ( + strategies.EagerLazyOption(keys, lazy='joined', chained=True), + strategies.EagerJoinOption(keys, innerjoin, chained=True) + ) + else: + return strategies.EagerLazyOption(keys, lazy='joined', chained=True) + +def eagerload(*args, **kwargs): + """A synonym for :func:`joinedload()`.""" + return joinedload(*args, **kwargs) + +def eagerload_all(*args, **kwargs): + """A synonym for :func:`joinedload_all()`""" + return joinedload_all(*args, **kwargs) + +def subqueryload(*keys): + """Return a ``MapperOption`` that will convert the property + of the given name into an subquery eager load. + + .. note:: This function is new as of SQLAlchemy version 0.6beta3. + + Used with :meth:`~sqlalchemy.orm.query.Query.options`. + + examples:: + + # subquery-load the "orders" colleciton on "User" + query(User).options(subqueryload(User.orders)) + + # subquery-load the "keywords" collection on each "Item", + # but not the "items" collection on "Order" - those + # remain lazily loaded. + query(Order).options(subqueryload(Order.items, Item.keywords)) + + # to subquery-load across both, use subqueryload_all() + query(Order).options(subqueryload_all(Order.items, Item.keywords)) + + See also: :func:`joinedload`, :func:`lazyload` + + """ + return strategies.EagerLazyOption(keys, lazy="subquery") + +def subqueryload_all(*keys): + """Return a ``MapperOption`` that will convert all properties along the + given dot-separated path into a subquery eager load. + + .. note:: This function is new as of SQLAlchemy version 0.6beta3. + + Used with :meth:`~sqlalchemy.orm.query.Query.options`. + + For example:: + + query.options(subqueryload_all('orders.items.keywords'))... + + will set all of 'orders', 'orders.items', and 'orders.items.keywords' to + load in one subquery eager load. + + Individual descriptors are accepted as arguments as well:: + + query.options(subqueryload_all(User.orders, Order.items, Item.keywords)) + + See also: :func:`joinedload_all`, :func:`lazyload` + + """ + return strategies.EagerLazyOption(keys, lazy="subquery", chained=True) + +@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') +def lazyload(*keys): + """Return a ``MapperOption`` that will convert the property of the given + name into a lazy load. + + Used with :meth:`~sqlalchemy.orm.query.Query.options`. + + See also: :func:`eagerload`, :func:`subqueryload` + + """ + return strategies.EagerLazyOption(keys, lazy=True) + +def noload(*keys): + """Return a ``MapperOption`` that will convert the property of the + given name into a non-load. + + Used with :meth:`~sqlalchemy.orm.query.Query.options`. + + See also: :func:`lazyload`, :func:`eagerload`, :func:`subqueryload` + + """ + return strategies.EagerLazyOption(keys, lazy=None) + +def contains_alias(alias): + """Return a ``MapperOption`` that will indicate to the query that + the main table has been aliased. + + `alias` is the string name or ``Alias`` object representing the + alias. + + """ + return AliasOption(alias) + +@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') +def contains_eager(*keys, **kwargs): + """Return a ``MapperOption`` that will indicate to the query that + the given attribute should be eagerly loaded from columns currently + in the query. + + Used with :meth:`~sqlalchemy.orm.query.Query.options`. + + The option is used in conjunction with an explicit join that loads + the desired rows, i.e.:: + + sess.query(Order).\\ + join(Order.user).\\ + options(contains_eager(Order.user)) + + The above query would join from the ``Order`` entity to its related + ``User`` entity, and the returned ``Order`` objects would have the + ``Order.user`` attribute pre-populated. + + :func:`contains_eager` also accepts an `alias` argument, which + is the string name of an alias, an :func:`~sqlalchemy.sql.expression.alias` + construct, or an :func:`~sqlalchemy.orm.aliased` construct. Use this + when the eagerly-loaded rows are to come from an aliased table:: + + user_alias = aliased(User) + sess.query(Order).\\ + join((user_alias, Order.user)).\\ + options(contains_eager(Order.user, alias=user_alias)) + + See also :func:`eagerload` for the "automatic" version of this + functionality. + + """ + alias = kwargs.pop('alias', None) + if kwargs: + raise exceptions.ArgumentError("Invalid kwargs for contains_eager: %r" % kwargs.keys()) + + return ( + strategies.EagerLazyOption(keys, lazy='joined', propagate_to_loaders=False), + strategies.LoadEagerFromAliasOption(keys, alias=alias) + ) + +@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') +def defer(*keys): + """Return a ``MapperOption`` that will convert the column property of the + given name into a deferred load. + + Used with :meth:`~sqlalchemy.orm.query.Query.options`. + + """ + return strategies.DeferredOption(keys, defer=True) + +@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') +def undefer(*keys): + """Return a ``MapperOption`` that will convert the column property of the + given name into a non-deferred (regular column) load. + + Used with :meth:`~sqlalchemy.orm.query.Query.options`. + + """ + return strategies.DeferredOption(keys, defer=False) + +def undefer_group(name): + """Return a ``MapperOption`` that will convert the given group of deferred + column properties into a non-deferred (regular column) load. + + Used with :meth:`~sqlalchemy.orm.query.Query.options`. + + """ + return strategies.UndeferGroupOption(name) diff --git a/sqlalchemy/orm/attributes.py b/sqlalchemy/orm/attributes.py new file mode 100644 index 0000000..887d9a9 --- /dev/null +++ b/sqlalchemy/orm/attributes.py @@ -0,0 +1,1708 @@ +# attributes.py - manages object attributes +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Defines SQLAlchemy's system of class instrumentation.. + +This module is usually not directly visible to user applications, but +defines a large part of the ORM's interactivity. + +SQLA's instrumentation system is completely customizable, in which +case an understanding of the general mechanics of this module is helpful. +An example of full customization is in /examples/custom_attributes. + +""" + +import operator +from operator import attrgetter, itemgetter +import types +import weakref + +from sqlalchemy import util +from sqlalchemy.orm import interfaces, collections, exc +import sqlalchemy.exceptions as sa_exc + +# lazy imports +_entity_info = None +identity_equal = None +state = None + +PASSIVE_NO_RESULT = util.symbol('PASSIVE_NO_RESULT') +ATTR_WAS_SET = util.symbol('ATTR_WAS_SET') +NO_VALUE = util.symbol('NO_VALUE') +NEVER_SET = util.symbol('NEVER_SET') + +# "passive" get settings +# TODO: the True/False values need to be factored out +# of the rest of ORM code +# don't fire off any callables, and don't initialize the attribute to +# an empty value +PASSIVE_NO_INITIALIZE = True #util.symbol('PASSIVE_NO_INITIALIZE') + +# don't fire off any callables, but if no callables present +# then initialize to an empty value/collection +# this is used by backrefs. +PASSIVE_NO_FETCH = util.symbol('PASSIVE_NO_FETCH') + +# fire callables/initialize as needed +PASSIVE_OFF = False #util.symbol('PASSIVE_OFF') + +INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__' +"""Attribute, elects custom instrumentation when present on a mapped class. + +Allows a class to specify a slightly or wildly different technique for +tracking changes made to mapped attributes and collections. + +Only one instrumentation implementation is allowed in a given object +inheritance hierarchy. + +The value of this attribute must be a callable and will be passed a class +object. The callable must return one of: + + - An instance of an interfaces.InstrumentationManager or subclass + - An object implementing all or some of InstrumentationManager (TODO) + - A dictionary of callables, implementing all or some of the above (TODO) + - An instance of a ClassManager or subclass + +interfaces.InstrumentationManager is public API and will remain stable +between releases. ClassManager is not public and no guarantees are made +about stability. Caveat emptor. + +This attribute is consulted by the default SQLAlchemy instrumentation +resolution code. If custom finders are installed in the global +instrumentation_finders list, they may or may not choose to honor this +attribute. + +""" + +instrumentation_finders = [] +"""An extensible sequence of instrumentation implementation finding callables. + +Finders callables will be passed a class object. If None is returned, the +next finder in the sequence is consulted. Otherwise the return must be an +instrumentation factory that follows the same guidelines as +INSTRUMENTATION_MANAGER. + +By default, the only finder is find_native_user_instrumentation_hook, which +searches for INSTRUMENTATION_MANAGER. If all finders return None, standard +ClassManager instrumentation is used. + +""" + +class QueryableAttribute(interfaces.PropComparator): + + def __init__(self, key, impl=None, comparator=None, parententity=None): + """Construct an InstrumentedAttribute. + + comparator + a sql.Comparator to which class-level compare/math events will be sent + """ + self.key = key + self.impl = impl + self.comparator = comparator + self.parententity = parententity + + def get_history(self, instance, **kwargs): + return self.impl.get_history(instance_state(instance), instance_dict(instance), **kwargs) + + def __selectable__(self): + # TODO: conditionally attach this method based on clause_element ? + return self + + def __clause_element__(self): + return self.comparator.__clause_element__() + + def label(self, name): + return self.__clause_element__().label(name) + + def operate(self, op, *other, **kwargs): + return op(self.comparator, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + return op(other, self.comparator, **kwargs) + + def hasparent(self, state, optimistic=False): + return self.impl.hasparent(state, optimistic=optimistic) + + def __getattr__(self, key): + try: + return getattr(self.comparator, key) + except AttributeError: + raise AttributeError( + 'Neither %r object nor %r object has an attribute %r' % ( + type(self).__name__, + type(self.comparator).__name__, + key) + ) + + def __str__(self): + return repr(self.parententity) + "." + self.property.key + + @property + def property(self): + return self.comparator.property + + +class InstrumentedAttribute(QueryableAttribute): + """Public-facing descriptor, placed in the mapped class dictionary.""" + + def __set__(self, instance, value): + self.impl.set(instance_state(instance), instance_dict(instance), value, None) + + def __delete__(self, instance): + self.impl.delete(instance_state(instance), instance_dict(instance)) + + def __get__(self, instance, owner): + if instance is None: + return self + return self.impl.get(instance_state(instance), instance_dict(instance)) + +class _ProxyImpl(object): + accepts_scalar_loader = False + expire_missing = True + + def __init__(self, key): + self.key = key + +def proxied_attribute_factory(descriptor): + """Create an InstrumentedAttribute / user descriptor hybrid. + + Returns a new InstrumentedAttribute type that delegates descriptor + behavior and getattr() to the given descriptor. + """ + + class Proxy(InstrumentedAttribute): + """A combination of InsturmentedAttribute and a regular descriptor.""" + + def __init__(self, key, descriptor, comparator, parententity): + self.key = key + # maintain ProxiedAttribute.user_prop compatability. + self.descriptor = self.user_prop = descriptor + self._comparator = comparator + self._parententity = parententity + self.impl = _ProxyImpl(key) + + @util.memoized_property + def comparator(self): + if util.callable(self._comparator): + self._comparator = self._comparator() + return self._comparator + + def __get__(self, instance, owner): + """Delegate __get__ to the original descriptor.""" + if instance is None: + descriptor.__get__(instance, owner) + return self + return descriptor.__get__(instance, owner) + + def __set__(self, instance, value): + """Delegate __set__ to the original descriptor.""" + return descriptor.__set__(instance, value) + + def __delete__(self, instance): + """Delegate __delete__ to the original descriptor.""" + return descriptor.__delete__(instance) + + def __getattr__(self, attribute): + """Delegate __getattr__ to the original descriptor and/or comparator.""" + + try: + return getattr(descriptor, attribute) + except AttributeError: + try: + return getattr(self._comparator, attribute) + except AttributeError: + raise AttributeError( + 'Neither %r object nor %r object has an attribute %r' % ( + type(descriptor).__name__, + type(self._comparator).__name__, + attribute) + ) + + Proxy.__name__ = type(descriptor).__name__ + 'Proxy' + + util.monkeypatch_proxied_specials(Proxy, type(descriptor), + name='descriptor', + from_instance=descriptor) + return Proxy + +class AttributeImpl(object): + """internal implementation for instrumented attributes.""" + + def __init__(self, class_, key, + callable_, trackparent=False, extension=None, + compare_function=None, active_history=False, + parent_token=None, expire_missing=True, + **kwargs): + """Construct an AttributeImpl. + + \class_ + associated class + + key + string name of the attribute + + \callable_ + optional function which generates a callable based on a parent + instance, which produces the "default" values for a scalar or + collection attribute when it's first accessed, if not present + already. + + trackparent + if True, attempt to track if an instance has a parent attached + to it via this attribute. + + extension + a single or list of AttributeExtension object(s) which will + receive set/delete/append/remove/etc. events. + + compare_function + a function that compares two values which are normally + assignable to this attribute. + + active_history + indicates that get_history() should always return the "old" value, + even if it means executing a lazy callable upon attribute change. + + parent_token + Usually references the MapperProperty, used as a key for + the hasparent() function to identify an "owning" attribute. + Allows multiple AttributeImpls to all match a single + owner attribute. + + expire_missing + if False, don't add an "expiry" callable to this attribute + during state.expire_attributes(None), if no value is present + for this key. + + """ + self.class_ = class_ + self.key = key + self.callable_ = callable_ + self.trackparent = trackparent + self.parent_token = parent_token or self + if compare_function is None: + self.is_equal = operator.eq + else: + self.is_equal = compare_function + self.extensions = util.to_list(extension or []) + for e in self.extensions: + if e.active_history: + active_history = True + break + self.active_history = active_history + self.expire_missing = expire_missing + + def hasparent(self, state, optimistic=False): + """Return the boolean value of a `hasparent` flag attached to + the given state. + + The `optimistic` flag determines what the default return value + should be if no `hasparent` flag can be located. + + As this function is used to determine if an instance is an + *orphan*, instances that were loaded from storage should be + assumed to not be orphans, until a True/False value for this + flag is set. + + An instance attribute that is loaded by a callable function + will also not have a `hasparent` flag. + + """ + return state.parents.get(id(self.parent_token), optimistic) + + def sethasparent(self, state, value): + """Set a boolean flag on the given item corresponding to + whether or not it is attached to a parent object via the + attribute represented by this ``InstrumentedAttribute``. + + """ + state.parents[id(self.parent_token)] = value + + def set_callable(self, state, callable_): + """Set a callable function for this attribute on the given object. + + This callable will be executed when the attribute is next + accessed, and is assumed to construct part of the instances + previously stored state. When its value or values are loaded, + they will be established as part of the instance's *committed + state*. While *trackparent* information will be assembled for + these instances, attribute-level event handlers will not be + fired. + + The callable overrides the class level callable set in the + ``InstrumentedAttribute`` constructor. + + """ + state.callables[self.key] = callable_ + + def get_history(self, state, dict_, passive=PASSIVE_OFF): + raise NotImplementedError() + + def _get_callable(self, state): + if self.key in state.callables: + return state.callables[self.key] + elif self.callable_ is not None: + return self.callable_(state) + else: + return None + + def initialize(self, state, dict_): + """Initialize the given state's attribute with an empty value.""" + + dict_[self.key] = None + return None + + def get(self, state, dict_, passive=PASSIVE_OFF): + """Retrieve a value from the given object. + + If a callable is assembled on this object's attribute, and + passive is False, the callable will be executed and the + resulting value will be set as the new value for this attribute. + """ + + try: + return dict_[self.key] + except KeyError: + # if no history, check for lazy callables, etc. + if state.committed_state.get(self.key, NEVER_SET) is NEVER_SET: + if passive is PASSIVE_NO_INITIALIZE: + return PASSIVE_NO_RESULT + + callable_ = self._get_callable(state) + if callable_ is not None: + #if passive is not PASSIVE_OFF: + # return PASSIVE_NO_RESULT + value = callable_(passive=passive) + if value is PASSIVE_NO_RESULT: + return value + elif value is not ATTR_WAS_SET: + return self.set_committed_value(state, dict_, value) + else: + if self.key not in dict_: + return self.get(state, dict_, passive=passive) + return dict_[self.key] + + # Return a new, empty value + return self.initialize(state, dict_) + + def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, value, initiator, passive=passive) + + def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, None, initiator, passive=passive) + + def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + raise NotImplementedError() + + def get_committed_value(self, state, dict_, passive=PASSIVE_OFF): + """return the unchanged value of this attribute""" + + if self.key in state.committed_state: + if state.committed_state[self.key] is NO_VALUE: + return None + else: + return state.committed_state.get(self.key) + else: + return self.get(state, dict_, passive=passive) + + def set_committed_value(self, state, dict_, value): + """set an attribute value on the given instance and 'commit' it.""" + + state.commit(dict_, [self.key]) + + state.callables.pop(self.key, None) + state.dict[self.key] = value + + return value + +class ScalarAttributeImpl(AttributeImpl): + """represents a scalar value-holding InstrumentedAttribute.""" + + accepts_scalar_loader = True + uses_objects = False + + def delete(self, state, dict_): + + # TODO: catch key errors, convert to attributeerror? + if self.active_history: + old = self.get(state, dict_) + else: + old = dict_.get(self.key, NO_VALUE) + + if self.extensions: + self.fire_remove_event(state, dict_, old, None) + state.modified_event(dict_, self, False, old) + del dict_[self.key] + + def get_history(self, state, dict_, passive=PASSIVE_OFF): + return History.from_attribute( + self, state, dict_.get(self.key, NO_VALUE)) + + def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + if initiator is self: + return + + if self.active_history: + old = self.get(state, dict_) + else: + old = dict_.get(self.key, NO_VALUE) + + if self.extensions: + value = self.fire_replace_event(state, dict_, value, old, initiator) + state.modified_event(dict_, self, False, old) + dict_[self.key] = value + + def fire_replace_event(self, state, dict_, value, previous, initiator): + for ext in self.extensions: + value = ext.set(state, value, previous, initiator or self) + return value + + def fire_remove_event(self, state, dict_, value, initiator): + for ext in self.extensions: + ext.remove(state, value, initiator or self) + + @property + def type(self): + self.property.columns[0].type + + +class MutableScalarAttributeImpl(ScalarAttributeImpl): + """represents a scalar value-holding InstrumentedAttribute, which can detect + changes within the value itself. + """ + + uses_objects = False + + def __init__(self, class_, key, callable_, + class_manager, copy_function=None, + compare_function=None, **kwargs): + super(ScalarAttributeImpl, self).__init__( + class_, + key, + callable_, + compare_function=compare_function, + **kwargs) + class_manager.mutable_attributes.add(key) + if copy_function is None: + raise sa_exc.ArgumentError( + "MutableScalarAttributeImpl requires a copy function") + self.copy = copy_function + + def get_history(self, state, dict_, passive=PASSIVE_OFF): + if not dict_: + v = state.committed_state.get(self.key, NO_VALUE) + else: + v = dict_.get(self.key, NO_VALUE) + + return History.from_attribute( + self, state, v) + + def check_mutable_modified(self, state, dict_): + added, \ + unchanged, \ + deleted = self.get_history(state, dict_, passive=PASSIVE_NO_INITIALIZE) + return bool(added or deleted) + + def get(self, state, dict_, passive=PASSIVE_OFF): + if self.key not in state.mutable_dict: + ret = ScalarAttributeImpl.get(self, state, dict_, passive=passive) + if ret is not PASSIVE_NO_RESULT: + state.mutable_dict[self.key] = ret + return ret + else: + return state.mutable_dict[self.key] + + def delete(self, state, dict_): + ScalarAttributeImpl.delete(self, state, dict_) + state.mutable_dict.pop(self.key) + + def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + if initiator is self: + return + + if self.extensions: + old = self.get(state, dict_) + value = self.fire_replace_event(state, dict_, value, old, initiator) + + state.modified_event(dict_, self, True, NEVER_SET) + dict_[self.key] = value + state.mutable_dict[self.key] = value + + +class ScalarObjectAttributeImpl(ScalarAttributeImpl): + """represents a scalar-holding InstrumentedAttribute, + where the target object is also instrumented. + + Adds events to delete/set operations. + + """ + + accepts_scalar_loader = False + uses_objects = True + + def __init__(self, class_, key, callable_, + trackparent=False, extension=None, copy_function=None, + compare_function=None, **kwargs): + super(ScalarObjectAttributeImpl, self).__init__( + class_, + key, + callable_, + trackparent=trackparent, + extension=extension, + compare_function=compare_function, + **kwargs) + if compare_function is None: + self.is_equal = identity_equal + + def delete(self, state, dict_): + old = self.get(state, dict_) + self.fire_remove_event(state, dict_, old, self) + del dict_[self.key] + + def get_history(self, state, dict_, passive=PASSIVE_OFF): + if self.key in dict_: + return History.from_attribute(self, state, dict_[self.key]) + else: + current = self.get(state, dict_, passive=passive) + if current is PASSIVE_NO_RESULT: + return HISTORY_BLANK + else: + return History.from_attribute(self, state, current) + + def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + """Set a value on the given InstanceState. + + `initiator` is the ``InstrumentedAttribute`` that initiated the + ``set()`` operation and is used to control the depth of a circular + setter operation. + + """ + if initiator is self: + return + + if self.active_history: + old = self.get(state, dict_) + else: + old = self.get(state, dict_, passive=PASSIVE_NO_FETCH) + + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value + + def fire_remove_event(self, state, dict_, value, initiator): + if self.trackparent and value is not None: + self.sethasparent(instance_state(value), False) + + for ext in self.extensions: + ext.remove(state, value, initiator or self) + + state.modified_event(dict_, self, False, value) + + def fire_replace_event(self, state, dict_, value, previous, initiator): + if self.trackparent: + if (previous is not value and + previous is not None and + previous is not PASSIVE_NO_RESULT): + self.sethasparent(instance_state(previous), False) + + for ext in self.extensions: + value = ext.set(state, value, previous, initiator or self) + + state.modified_event(dict_, self, False, previous) + + if self.trackparent: + if value is not None: + self.sethasparent(instance_state(value), True) + + return value + + +class CollectionAttributeImpl(AttributeImpl): + """A collection-holding attribute that instruments changes in membership. + + Only handles collections of instrumented objects. + + InstrumentedCollectionAttribute holds an arbitrary, user-specified + container object (defaulting to a list) and brokers access to the + CollectionAdapter, a "view" onto that object that presents consistent + bag semantics to the orm layer independent of the user data implementation. + + """ + accepts_scalar_loader = False + uses_objects = True + + def __init__(self, class_, key, callable_, + typecallable=None, trackparent=False, extension=None, + copy_function=None, compare_function=None, **kwargs): + super(CollectionAttributeImpl, self).__init__( + class_, + key, + callable_, + trackparent=trackparent, + extension=extension, + compare_function=compare_function, + **kwargs) + + if copy_function is None: + copy_function = self.__copy + self.copy = copy_function + self.collection_factory = typecallable + + def __copy(self, item): + return [y for y in list(collections.collection_adapter(item))] + + def get_history(self, state, dict_, passive=PASSIVE_OFF): + current = self.get(state, dict_, passive=passive) + if current is PASSIVE_NO_RESULT: + return HISTORY_BLANK + else: + return History.from_attribute(self, state, current) + + def fire_append_event(self, state, dict_, value, initiator): + for ext in self.extensions: + value = ext.append(state, value, initiator or self) + + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + + if self.trackparent and value is not None: + self.sethasparent(instance_state(value), True) + + return value + + def fire_pre_remove_event(self, state, dict_, initiator): + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + + def fire_remove_event(self, state, dict_, value, initiator): + if self.trackparent and value is not None: + self.sethasparent(instance_state(value), False) + + for ext in self.extensions: + ext.remove(state, value, initiator or self) + + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + + def delete(self, state, dict_): + if self.key not in dict_: + return + + state.modified_event(dict_, self, True, NEVER_SET) + + collection = self.get_collection(state, state.dict) + collection.clear_with_event() + # TODO: catch key errors, convert to attributeerror? + del dict_[self.key] + + def initialize(self, state, dict_): + """Initialize this attribute with an empty collection.""" + + _, user_data = self._initialize_collection(state) + dict_[self.key] = user_data + return user_data + + def _initialize_collection(self, state): + return state.manager.initialize_collection( + self.key, state, self.collection_factory) + + def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + if initiator is self: + return + + collection = self.get_collection(state, dict_, passive=passive) + if collection is PASSIVE_NO_RESULT: + value = self.fire_append_event(state, dict_, value, initiator) + assert self.key not in dict_, "Collection was loaded during event handling." + state.get_pending(self.key).append(value) + else: + collection.append_with_event(value, initiator) + + def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + if initiator is self: + return + + collection = self.get_collection(state, state.dict, passive=passive) + if collection is PASSIVE_NO_RESULT: + self.fire_remove_event(state, dict_, value, initiator) + assert self.key not in dict_, "Collection was loaded during event handling." + state.get_pending(self.key).remove(value) + else: + collection.remove_with_event(value, initiator) + + def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + """Set a value on the given object. + + `initiator` is the ``InstrumentedAttribute`` that initiated the + ``set()`` operation and is used to control the depth of a circular + setter operation. + """ + + if initiator is self: + return + + self._set_iterable( + state, dict_, value, + lambda adapter, i: adapter.adapt_like_to_iterable(i)) + + def _set_iterable(self, state, dict_, iterable, adapter=None): + """Set a collection value from an iterable of state-bearers. + + ``adapter`` is an optional callable invoked with a CollectionAdapter + and the iterable. Should return an iterable of state-bearing + instances suitable for appending via a CollectionAdapter. Can be used + for, e.g., adapting an incoming dictionary into an iterator of values + rather than keys. + + """ + # pulling a new collection first so that an adaptation exception does + # not trigger a lazy load of the old collection. + new_collection, user_data = self._initialize_collection(state) + if adapter: + new_values = list(adapter(new_collection, iterable)) + else: + new_values = list(iterable) + + old = self.get(state, dict_) + + # ignore re-assignment of the current collection, as happens + # implicitly with in-place operators (foo.collection |= other) + if old is iterable: + return + + state.modified_event(dict_, self, True, old) + + old_collection = self.get_collection(state, dict_, old) + + dict_[self.key] = user_data + + collections.bulk_replace(new_values, old_collection, new_collection) + old_collection.unlink(old) + + + def set_committed_value(self, state, dict_, value): + """Set an attribute value on the given instance and 'commit' it.""" + + collection, user_data = self._initialize_collection(state) + + if value: + for item in value: + collection.append_without_event(item) + + state.callables.pop(self.key, None) + state.dict[self.key] = user_data + + state.commit(dict_, [self.key]) + + if self.key in state.pending: + + # pending items exist. issue a modified event, + # add/remove new items. + state.modified_event(dict_, self, True, user_data) + + pending = state.pending.pop(self.key) + added = pending.added_items + removed = pending.deleted_items + for item in added: + collection.append_without_event(item) + for item in removed: + collection.remove_without_event(item) + + return user_data + + def get_collection(self, state, dict_, user_data=None, passive=PASSIVE_OFF): + """Retrieve the CollectionAdapter associated with the given state. + + Creates a new CollectionAdapter if one does not exist. + + """ + if user_data is None: + user_data = self.get(state, dict_, passive=passive) + if user_data is PASSIVE_NO_RESULT: + return user_data + + return getattr(user_data, '_sa_adapter') + +class GenericBackrefExtension(interfaces.AttributeExtension): + """An extension which synchronizes a two-way relationship. + + A typical two-way relationship is a parent object containing a list of + child objects, where each child object references the parent. The other + are two objects which contain scalar references to each other. + + """ + + active_history = False + + def __init__(self, key): + self.key = key + + def set(self, state, child, oldchild, initiator): + if oldchild is child: + return child + + if oldchild is not None and oldchild is not PASSIVE_NO_RESULT: + # With lazy=None, there's no guarantee that the full collection is + # present when updating via a backref. + old_state, old_dict = instance_state(oldchild), instance_dict(oldchild) + impl = old_state.get_impl(self.key) + try: + impl.remove(old_state, + old_dict, + state.obj(), + initiator, passive=PASSIVE_NO_FETCH) + except (ValueError, KeyError, IndexError): + pass + + if child is not None: + child_state, child_dict = instance_state(child), instance_dict(child) + child_state.get_impl(self.key).append( + child_state, + child_dict, + state.obj(), + initiator, passive=PASSIVE_NO_FETCH) + return child + + def append(self, state, child, initiator): + child_state, child_dict = instance_state(child), instance_dict(child) + child_state.get_impl(self.key).append( + child_state, + child_dict, + state.obj(), + initiator, passive=PASSIVE_NO_FETCH) + return child + + def remove(self, state, child, initiator): + if child is not None: + child_state, child_dict = instance_state(child), instance_dict(child) + child_state.get_impl(self.key).remove( + child_state, + child_dict, + state.obj(), + initiator, passive=PASSIVE_NO_FETCH) + + +class Events(object): + def __init__(self): + self.original_init = object.__init__ + # Initialize to tuples instead of lists to minimize the memory + # footprint + self.on_init = () + self.on_init_failure = () + self.on_load = () + self.on_resurrect = () + + def run(self, event, *args): + for fn in getattr(self, event): + fn(*args) + + def add_listener(self, event, listener): + # not thread safe... problem? mb: nope + bucket = getattr(self, event) + if bucket == (): + setattr(self, event, [listener]) + else: + bucket.append(listener) + + def remove_listener(self, event, listener): + bucket = getattr(self, event) + bucket.remove(listener) + + +class ClassManager(dict): + """tracks state information at the class level.""" + + MANAGER_ATTR = '_sa_class_manager' + STATE_ATTR = '_sa_instance_state' + + event_registry_factory = Events + deferred_scalar_loader = None + + def __init__(self, class_): + self.class_ = class_ + self.factory = None # where we came from, for inheritance bookkeeping + self.info = {} + self.mapper = None + self.new_init = None + self.mutable_attributes = set() + self.local_attrs = {} + self.originals = {} + for base in class_.__mro__[-2:0:-1]: # reverse, skipping 1st and last + if not isinstance(base, type): + continue + cls_state = manager_of_class(base) + if cls_state: + self.update(cls_state) + self.events = self.event_registry_factory() + self.manage() + self._instrument_init() + + def _configure_create_arguments(self, + _source=None, + deferred_scalar_loader=None): + """Accept extra **kw arguments passed to create_manager_for_cls. + + The current contract of ClassManager and other managers is that they + take a single "cls" argument in their constructor (as per + test/orm/instrumentation.py InstrumentationCollisionTest). This + is to provide consistency with the current API of "class manager" + callables and such which may return various ClassManager and + ClassManager-like instances. So create_manager_for_cls sends + in ClassManager-specific arguments via this method once the + non-proxied ClassManager is available. + + """ + if _source: + deferred_scalar_loader = _source.deferred_scalar_loader + + if deferred_scalar_loader: + self.deferred_scalar_loader = deferred_scalar_loader + + def _subclass_manager(self, cls): + """Create a new ClassManager for a subclass of this ClassManager's class. + + This is called automatically when attributes are instrumented so that + the attributes can be propagated to subclasses against their own + class-local manager, without the need for mappers etc. to have already + pre-configured managers for the full class hierarchy. Mappers + can post-configure the auto-generated ClassManager when needed. + + """ + manager = manager_of_class(cls) + if manager is None: + manager = _create_manager_for_cls(cls, _source=self) + return manager + + def _instrument_init(self): + # TODO: self.class_.__init__ is often the already-instrumented + # __init__ from an instrumented superclass. We still need to make + # our own wrapper, but it would + # be nice to wrap the original __init__ and not our existing wrapper + # of such, since this adds method overhead. + self.events.original_init = self.class_.__init__ + self.new_init = _generate_init(self.class_, self) + self.install_member('__init__', self.new_init) + + def _uninstrument_init(self): + if self.new_init: + self.uninstall_member('__init__') + self.new_init = None + + def _create_instance_state(self, instance): + if self.mutable_attributes: + return state.MutableAttrInstanceState(instance, self) + else: + return state.InstanceState(instance, self) + + def manage(self): + """Mark this instance as the manager for its class.""" + + setattr(self.class_, self.MANAGER_ATTR, self) + + def dispose(self): + """Dissasociate this manager from its class.""" + + delattr(self.class_, self.MANAGER_ATTR) + + def manager_getter(self): + return attrgetter(self.MANAGER_ATTR) + + def instrument_attribute(self, key, inst, propagated=False): + if propagated: + if key in self.local_attrs: + return # don't override local attr with inherited attr + else: + self.local_attrs[key] = inst + self.install_descriptor(key, inst) + self[key] = inst + + for cls in self.class_.__subclasses__(): + manager = self._subclass_manager(cls) + manager.instrument_attribute(key, inst, True) + + def post_configure_attribute(self, key): + pass + + def uninstrument_attribute(self, key, propagated=False): + if key not in self: + return + if propagated: + if key in self.local_attrs: + return # don't get rid of local attr + else: + del self.local_attrs[key] + self.uninstall_descriptor(key) + del self[key] + if key in self.mutable_attributes: + self.mutable_attributes.remove(key) + for cls in self.class_.__subclasses__(): + manager = self._subclass_manager(cls) + manager.uninstrument_attribute(key, True) + + def unregister(self): + """remove all instrumentation established by this ClassManager.""" + + self._uninstrument_init() + + self.mapper = self.events = None + self.info.clear() + + for key in list(self): + if key in self.local_attrs: + self.uninstrument_attribute(key) + + def install_descriptor(self, key, inst): + if key in (self.STATE_ATTR, self.MANAGER_ATTR): + raise KeyError("%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % key) + setattr(self.class_, key, inst) + + def uninstall_descriptor(self, key): + delattr(self.class_, key) + + def install_member(self, key, implementation): + if key in (self.STATE_ATTR, self.MANAGER_ATTR): + raise KeyError("%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % key) + self.originals.setdefault(key, getattr(self.class_, key, None)) + setattr(self.class_, key, implementation) + + def uninstall_member(self, key): + original = self.originals.pop(key, None) + if original is not None: + setattr(self.class_, key, original) + + def instrument_collection_class(self, key, collection_class): + return collections.prepare_instrumentation(collection_class) + + def initialize_collection(self, key, state, factory): + user_data = factory() + adapter = collections.CollectionAdapter( + self.get_impl(key), state, user_data) + return adapter, user_data + + def is_instrumented(self, key, search=False): + if search: + return key in self + else: + return key in self.local_attrs + + def get_impl(self, key): + return self[key].impl + + @property + def attributes(self): + return self.itervalues() + + ## InstanceState management + + def new_instance(self, state=None): + instance = self.class_.__new__(self.class_) + setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance)) + return instance + + def setup_instance(self, instance, state=None): + setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance)) + + def teardown_instance(self, instance): + delattr(instance, self.STATE_ATTR) + + def _new_state_if_none(self, instance): + """Install a default InstanceState if none is present. + + A private convenience method used by the __init__ decorator. + + """ + if hasattr(instance, self.STATE_ATTR): + return False + else: + state = self._create_instance_state(instance) + setattr(instance, self.STATE_ATTR, state) + return state + + def state_getter(self): + """Return a (instance) -> InstanceState callable. + + "state getter" callables should raise either KeyError or + AttributeError if no InstanceState could be found for the + instance. + """ + + return attrgetter(self.STATE_ATTR) + + def dict_getter(self): + return attrgetter('__dict__') + + def has_state(self, instance): + return hasattr(instance, self.STATE_ATTR) + + def has_parent(self, state, key, optimistic=False): + """TODO""" + return self.get_impl(key).hasparent(state, optimistic=optimistic) + + def __nonzero__(self): + """All ClassManagers are non-zero regardless of attribute state.""" + return True + + def __repr__(self): + return '<%s of %r at %x>' % ( + self.__class__.__name__, self.class_, id(self)) + +class _ClassInstrumentationAdapter(ClassManager): + """Adapts a user-defined InstrumentationManager to a ClassManager.""" + + def __init__(self, class_, override, **kw): + self._adapted = override + self._get_state = self._adapted.state_getter(class_) + self._get_dict = self._adapted.dict_getter(class_) + + ClassManager.__init__(self, class_, **kw) + + def manage(self): + self._adapted.manage(self.class_, self) + + def dispose(self): + self._adapted.dispose(self.class_) + + def manager_getter(self): + return self._adapted.manager_getter(self.class_) + + def instrument_attribute(self, key, inst, propagated=False): + ClassManager.instrument_attribute(self, key, inst, propagated) + if not propagated: + self._adapted.instrument_attribute(self.class_, key, inst) + + def post_configure_attribute(self, key): + self._adapted.post_configure_attribute(self.class_, key, self[key]) + + def install_descriptor(self, key, inst): + self._adapted.install_descriptor(self.class_, key, inst) + + def uninstall_descriptor(self, key): + self._adapted.uninstall_descriptor(self.class_, key) + + def install_member(self, key, implementation): + self._adapted.install_member(self.class_, key, implementation) + + def uninstall_member(self, key): + self._adapted.uninstall_member(self.class_, key) + + def instrument_collection_class(self, key, collection_class): + return self._adapted.instrument_collection_class( + self.class_, key, collection_class) + + def initialize_collection(self, key, state, factory): + delegate = getattr(self._adapted, 'initialize_collection', None) + if delegate: + return delegate(key, state, factory) + else: + return ClassManager.initialize_collection(self, key, state, factory) + + def new_instance(self, state=None): + instance = self.class_.__new__(self.class_) + self.setup_instance(instance, state) + return instance + + def _new_state_if_none(self, instance): + """Install a default InstanceState if none is present. + + A private convenience method used by the __init__ decorator. + """ + if self.has_state(instance): + return False + else: + return self.setup_instance(instance) + + def setup_instance(self, instance, state=None): + self._adapted.initialize_instance_dict(self.class_, instance) + + if state is None: + state = self._create_instance_state(instance) + + # the given instance is assumed to have no state + self._adapted.install_state(self.class_, instance, state) + return state + + def teardown_instance(self, instance): + self._adapted.remove_state(self.class_, instance) + + def has_state(self, instance): + try: + state = self._get_state(instance) + except exc.NO_STATE: + return False + else: + return True + + def state_getter(self): + return self._get_state + + def dict_getter(self): + return self._get_dict + +class History(tuple): + """A 3-tuple of added, unchanged and deleted values. + + Each tuple member is an iterable sequence. + + """ + + __slots__ = () + + added = property(itemgetter(0)) + unchanged = property(itemgetter(1)) + deleted = property(itemgetter(2)) + + def __new__(cls, added, unchanged, deleted): + return tuple.__new__(cls, (added, unchanged, deleted)) + + def __nonzero__(self): + return self != HISTORY_BLANK + + def sum(self): + return (self.added or []) +\ + (self.unchanged or []) +\ + (self.deleted or []) + + def non_deleted(self): + return (self.added or []) +\ + (self.unchanged or []) + + def non_added(self): + return (self.unchanged or []) +\ + (self.deleted or []) + + def has_changes(self): + return bool(self.added or self.deleted) + + def as_state(self): + return History( + [(c is not None and c is not PASSIVE_NO_RESULT) + and instance_state(c) or None + for c in self.added], + [(c is not None and c is not PASSIVE_NO_RESULT) + and instance_state(c) or None + for c in self.unchanged], + [(c is not None and c is not PASSIVE_NO_RESULT) + and instance_state(c) or None + for c in self.deleted], + ) + + @classmethod + def from_attribute(cls, attribute, state, current): + original = state.committed_state.get(attribute.key, NEVER_SET) + + if hasattr(attribute, 'get_collection'): + current = attribute.get_collection(state, state.dict, current) + if original is NO_VALUE: + return cls(list(current), (), ()) + elif original is NEVER_SET: + return cls((), list(current), ()) + else: + current_set = util.IdentitySet(current) + original_set = util.IdentitySet(original) + + # ensure duplicates are maintained + return cls( + [x for x in current if x not in original_set], + [x for x in current if x in original_set], + [x for x in original if x not in current_set] + ) + else: + if current is NO_VALUE: + if (original is not None and + original is not NEVER_SET and + original is not NO_VALUE): + deleted = [original] + else: + deleted = () + return cls((), (), deleted) + elif original is NO_VALUE: + return cls([current], (), ()) + elif (original is NEVER_SET or + attribute.is_equal(current, original) is True): + # dont let ClauseElement expressions here trip things up + return cls((), [current], ()) + else: + if original is not None: + deleted = [original] + else: + deleted = () + return cls([current], (), deleted) + +HISTORY_BLANK = History(None, None, None) + +def get_history(obj, key, **kwargs): + """Return a History record for the given object and attribute key. + + obj is an instrumented object instance. An InstanceState + is accepted directly for backwards compatibility but + this usage is deprecated. + + """ + return get_state_history(instance_state(obj), key, **kwargs) + +def get_state_history(state, key, **kwargs): + return state.get_history(key, **kwargs) + +def has_parent(cls, obj, key, optimistic=False): + """TODO""" + manager = manager_of_class(cls) + state = instance_state(obj) + return manager.has_parent(state, key, optimistic) + +def register_class(class_, **kw): + """Register class instrumentation. + + Returns the existing or newly created class manager. + """ + + manager = manager_of_class(class_) + if manager is None: + manager = _create_manager_for_cls(class_, **kw) + return manager + +def unregister_class(class_): + """Unregister class instrumentation.""" + + instrumentation_registry.unregister(class_) + +def register_attribute(class_, key, **kw): + + proxy_property = kw.pop('proxy_property', None) + + comparator = kw.pop('comparator', None) + parententity = kw.pop('parententity', None) + register_descriptor(class_, key, proxy_property, comparator, parententity) + if not proxy_property: + register_attribute_impl(class_, key, **kw) + +def register_attribute_impl(class_, key, + uselist=False, callable_=None, + useobject=False, mutable_scalars=False, + impl_class=None, **kw): + + manager = manager_of_class(class_) + if uselist: + factory = kw.pop('typecallable', None) + typecallable = manager.instrument_collection_class( + key, factory or list) + else: + typecallable = kw.pop('typecallable', None) + + if impl_class: + impl = impl_class(class_, key, typecallable, **kw) + elif uselist: + impl = CollectionAttributeImpl(class_, key, callable_, + typecallable=typecallable, **kw) + elif useobject: + impl = ScalarObjectAttributeImpl(class_, key, callable_, **kw) + elif mutable_scalars: + impl = MutableScalarAttributeImpl(class_, key, callable_, + class_manager=manager, **kw) + else: + impl = ScalarAttributeImpl(class_, key, callable_, **kw) + + manager[key].impl = impl + + manager.post_configure_attribute(key) + +def register_descriptor(class_, key, proxy_property=None, comparator=None, parententity=None, property_=None): + manager = manager_of_class(class_) + + if proxy_property: + proxy_type = proxied_attribute_factory(proxy_property) + descriptor = proxy_type(key, proxy_property, comparator, parententity) + else: + descriptor = InstrumentedAttribute(key, comparator=comparator, parententity=parententity) + + manager.instrument_attribute(key, descriptor) + +def unregister_attribute(class_, key): + manager_of_class(class_).uninstrument_attribute(key) + +def init_collection(obj, key): + """Initialize a collection attribute and return the collection adapter. + + This function is used to provide direct access to collection internals + for a previously unloaded attribute. e.g.:: + + collection_adapter = init_collection(someobject, 'elements') + for elem in values: + collection_adapter.append_without_event(elem) + + For an easier way to do the above, see :func:`~sqlalchemy.orm.attributes.set_committed_value`. + + obj is an instrumented object instance. An InstanceState + is accepted directly for backwards compatibility but + this usage is deprecated. + + """ + state = instance_state(obj) + dict_ = state.dict + return init_state_collection(state, dict_, key) + +def init_state_collection(state, dict_, key): + """Initialize a collection attribute and return the collection adapter.""" + + attr = state.get_impl(key) + user_data = attr.initialize(state, dict_) + return attr.get_collection(state, dict_, user_data) + +def set_committed_value(instance, key, value): + """Set the value of an attribute with no history events. + + Cancels any previous history present. The value should be + a scalar value for scalar-holding attributes, or + an iterable for any collection-holding attribute. + + This is the same underlying method used when a lazy loader + fires off and loads additional data from the database. + In particular, this method can be used by application code + which has loaded additional attributes or collections through + separate queries, which can then be attached to an instance + as though it were part of its original loaded state. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).set_committed_value(state, dict_, value) + +def set_attribute(instance, key, value): + """Set the value of an attribute, firing history events. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to establish attribute state as understood + by SQLAlchemy. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).set(state, dict_, value, None) + +def get_attribute(instance, key): + """Get the value of an attribute, firing any callables required. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to make usage of attribute state as understood + by SQLAlchemy. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + return state.get_impl(key).get(state, dict_) + +def del_attribute(instance, key): + """Delete the value of an attribute, firing history events. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to establish attribute state as understood + by SQLAlchemy. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).delete(state, dict_) + +def is_instrumented(instance, key): + """Return True if the given attribute on the given instance is instrumented + by the attributes package. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + + """ + return manager_of_class(instance.__class__).is_instrumented(key, search=True) + +class InstrumentationRegistry(object): + """Private instrumentation registration singleton. + + All classes are routed through this registry + when first instrumented, however the InstrumentationRegistry + is not actually needed unless custom ClassManagers are in use. + + """ + + _manager_finders = weakref.WeakKeyDictionary() + _state_finders = util.WeakIdentityMapping() + _dict_finders = util.WeakIdentityMapping() + _extended = False + + def create_manager_for_cls(self, class_, **kw): + assert class_ is not None + assert manager_of_class(class_) is None + + for finder in instrumentation_finders: + factory = finder(class_) + if factory is not None: + break + else: + factory = ClassManager + + existing_factories = self._collect_management_factories_for(class_).\ + difference([factory]) + if existing_factories: + raise TypeError( + "multiple instrumentation implementations specified " + "in %s inheritance hierarchy: %r" % ( + class_.__name__, list(existing_factories))) + + manager = factory(class_) + if not isinstance(manager, ClassManager): + manager = _ClassInstrumentationAdapter(class_, manager) + + if factory != ClassManager and not self._extended: + # somebody invoked a custom ClassManager. + # reinstall global "getter" functions with the more + # expensive ones. + self._extended = True + _install_lookup_strategy(self) + + manager._configure_create_arguments(**kw) + + manager.factory = factory + self._manager_finders[class_] = manager.manager_getter() + self._state_finders[class_] = manager.state_getter() + self._dict_finders[class_] = manager.dict_getter() + return manager + + def _collect_management_factories_for(self, cls): + """Return a collection of factories in play or specified for a hierarchy. + + Traverses the entire inheritance graph of a cls and returns a collection + of instrumentation factories for those classes. Factories are extracted + from active ClassManagers, if available, otherwise + instrumentation_finders is consulted. + + """ + hierarchy = util.class_hierarchy(cls) + factories = set() + for member in hierarchy: + manager = manager_of_class(member) + if manager is not None: + factories.add(manager.factory) + else: + for finder in instrumentation_finders: + factory = finder(member) + if factory is not None: + break + else: + factory = None + factories.add(factory) + factories.discard(None) + return factories + + def manager_of_class(self, cls): + # this is only called when alternate instrumentation has been established + if cls is None: + return None + try: + finder = self._manager_finders[cls] + except KeyError: + return None + else: + return finder(cls) + + def state_of(self, instance): + # this is only called when alternate instrumentation has been established + if instance is None: + raise AttributeError("None has no persistent state.") + try: + return self._state_finders[instance.__class__](instance) + except KeyError: + raise AttributeError("%r is not instrumented" % instance.__class__) + + def dict_of(self, instance): + # this is only called when alternate instrumentation has been established + if instance is None: + raise AttributeError("None has no persistent state.") + try: + return self._dict_finders[instance.__class__](instance) + except KeyError: + raise AttributeError("%r is not instrumented" % instance.__class__) + + def unregister(self, class_): + if class_ in self._manager_finders: + manager = self.manager_of_class(class_) + manager.unregister() + manager.dispose() + del self._manager_finders[class_] + del self._state_finders[class_] + del self._dict_finders[class_] + if ClassManager.MANAGER_ATTR in class_.__dict__: + delattr(class_, ClassManager.MANAGER_ATTR) + +instrumentation_registry = InstrumentationRegistry() + +def _install_lookup_strategy(implementation): + """Replace global class/object management functions + with either faster or more comprehensive implementations, + based on whether or not extended class instrumentation + has been detected. + + This function is called only by InstrumentationRegistry() + and unit tests specific to this behavior. + + """ + global instance_state, instance_dict, manager_of_class + if implementation is util.symbol('native'): + instance_state = attrgetter(ClassManager.STATE_ATTR) + instance_dict = attrgetter("__dict__") + def manager_of_class(cls): + return cls.__dict__.get(ClassManager.MANAGER_ATTR, None) + else: + instance_state = instrumentation_registry.state_of + instance_dict = instrumentation_registry.dict_of + manager_of_class = instrumentation_registry.manager_of_class + +_create_manager_for_cls = instrumentation_registry.create_manager_for_cls + +# Install default "lookup" strategies. These are basically +# very fast attrgetters for key attributes. +# When a custom ClassManager is installed, more expensive per-class +# strategies are copied over these. +_install_lookup_strategy(util.symbol('native')) + +def find_native_user_instrumentation_hook(cls): + """Find user-specified instrumentation management for a class.""" + return getattr(cls, INSTRUMENTATION_MANAGER, None) +instrumentation_finders.append(find_native_user_instrumentation_hook) + +def _generate_init(class_, class_manager): + """Build an __init__ decorator that triggers ClassManager events.""" + + # TODO: we should use the ClassManager's notion of the + # original '__init__' method, once ClassManager is fixed + # to always reference that. + original__init__ = class_.__init__ + assert original__init__ + + # Go through some effort here and don't change the user's __init__ + # calling signature. + # FIXME: need to juggle local names to avoid constructor argument + # clashes. + func_body = """\ +def __init__(%(apply_pos)s): + new_state = class_manager._new_state_if_none(%(self_arg)s) + if new_state: + return new_state.initialize_instance(%(apply_kw)s) + else: + return original__init__(%(apply_kw)s) +""" + func_vars = util.format_argspec_init(original__init__, grouped=False) + func_text = func_body % func_vars + + # Py3K + #func_defaults = getattr(original__init__, '__defaults__', None) + # Py2K + func = getattr(original__init__, 'im_func', original__init__) + func_defaults = getattr(func, 'func_defaults', None) + # end Py2K + + env = locals().copy() + exec func_text in env + __init__ = env['__init__'] + __init__.__doc__ = original__init__.__doc__ + if func_defaults: + __init__.func_defaults = func_defaults + return __init__ diff --git a/sqlalchemy/orm/collections.py b/sqlalchemy/orm/collections.py new file mode 100644 index 0000000..616f251 --- /dev/null +++ b/sqlalchemy/orm/collections.py @@ -0,0 +1,1438 @@ +"""Support for collections of mapped entities. + +The collections package supplies the machinery used to inform the ORM of +collection membership changes. An instrumentation via decoration approach is +used, allowing arbitrary types (including built-ins) to be used as entity +collections without requiring inheritance from a base class. + +Instrumentation decoration relays membership change events to the +``InstrumentedCollectionAttribute`` that is currently managing the collection. +The decorators observe function call arguments and return values, tracking +entities entering or leaving the collection. Two decorator approaches are +provided. One is a bundle of generic decorators that map function arguments +and return values to events:: + + from sqlalchemy.orm.collections import collection + class MyClass(object): + # ... + + @collection.adds(1) + def store(self, item): + self.data.append(item) + + @collection.removes_return() + def pop(self): + return self.data.pop() + + +The second approach is a bundle of targeted decorators that wrap appropriate +append and remove notifiers around the mutation methods present in the +standard Python ``list``, ``set`` and ``dict`` interfaces. These could be +specified in terms of generic decorator recipes, but are instead hand-tooled +for increased efficiency. The targeted decorators occasionally implement +adapter-like behavior, such as mapping bulk-set methods (``extend``, +``update``, ``__setslice__``, etc.) into the series of atomic mutation events +that the ORM requires. + +The targeted decorators are used internally for automatic instrumentation of +entity collection classes. Every collection class goes through a +transformation process roughly like so: + +1. If the class is a built-in, substitute a trivial sub-class +2. Is this class already instrumented? +3. Add in generic decorators +4. Sniff out the collection interface through duck-typing +5. Add targeted decoration to any undecorated interface method + +This process modifies the class at runtime, decorating methods and adding some +bookkeeping properties. This isn't possible (or desirable) for built-in +classes like ``list``, so trivial sub-classes are substituted to hold +decoration:: + + class InstrumentedList(list): + pass + +Collection classes can be specified in ``relationship(collection_class=)`` as +types or a function that returns an instance. Collection classes are +inspected and instrumented during the mapper compilation phase. The +collection_class callable will be executed once to produce a specimen +instance, and the type of that specimen will be instrumented. Functions that +return built-in types like ``lists`` will be adapted to produce instrumented +instances. + +When extending a known type like ``list``, additional decorations are not +generally not needed. Odds are, the extension method will delegate to a +method that's already instrumented. For example:: + + class QueueIsh(list): + def push(self, item): + self.append(item) + def shift(self): + return self.pop(0) + +There's no need to decorate these methods. ``append`` and ``pop`` are already +instrumented as part of the ``list`` interface. Decorating them would fire +duplicate events, which should be avoided. + +The targeted decoration tries not to rely on other methods in the underlying +collection class, but some are unavoidable. Many depend on 'read' methods +being present to properly instrument a 'write', for example, ``__setitem__`` +needs ``__getitem__``. "Bulk" methods like ``update`` and ``extend`` may also +reimplemented in terms of atomic appends and removes, so the ``extend`` +decoration will actually perform many ``append`` operations and not call the +underlying method at all. + +Tight control over bulk operation and the firing of events is also possible by +implementing the instrumentation internally in your methods. The basic +instrumentation package works under the general assumption that collection +mutation will not raise unusual exceptions. If you want to closely +orchestrate append and remove events with exception management, internal +instrumentation may be the answer. Within your method, +``collection_adapter(self)`` will retrieve an object that you can use for +explicit control over triggering append and remove events. + +The owning object and InstrumentedCollectionAttribute are also reachable +through the adapter, allowing for some very sophisticated behavior. + +""" + +import copy +import inspect +import operator +import sys +import weakref + +import sqlalchemy.exceptions as sa_exc +from sqlalchemy.sql import expression +from sqlalchemy import schema, util + + +__all__ = ['collection', 'collection_adapter', + 'mapped_collection', 'column_mapped_collection', + 'attribute_mapped_collection'] + +__instrumentation_mutex = util.threading.Lock() + + +def column_mapped_collection(mapping_spec): + """A dictionary-based collection type with column-based keying. + + Returns a MappedCollection factory with a keying function generated + from mapping_spec, which may be a Column or a sequence of Columns. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + from sqlalchemy.orm.util import _state_mapper + from sqlalchemy.orm.attributes import instance_state + + cols = [expression._no_literals(q) for q in util.to_list(mapping_spec)] + if len(cols) == 1: + def keyfunc(value): + state = instance_state(value) + m = _state_mapper(state) + return m._get_state_attr_by_column(state, cols[0]) + else: + mapping_spec = tuple(cols) + def keyfunc(value): + state = instance_state(value) + m = _state_mapper(state) + return tuple(m._get_state_attr_by_column(state, c) + for c in mapping_spec) + return lambda: MappedCollection(keyfunc) + +def attribute_mapped_collection(attr_name): + """A dictionary-based collection type with attribute-based keying. + + Returns a MappedCollection factory with a keying based on the + 'attr_name' attribute of entities in the collection. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + return lambda: MappedCollection(operator.attrgetter(attr_name)) + + +def mapped_collection(keyfunc): + """A dictionary-based collection type with arbitrary keying. + + Returns a MappedCollection factory with a keying function generated + from keyfunc, a callable that takes an entity and returns a key value. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + return lambda: MappedCollection(keyfunc) + +class collection(object): + """Decorators for entity collection classes. + + The decorators fall into two groups: annotations and interception recipes. + + The annotating decorators (appender, remover, iterator, + internally_instrumented, on_link) indicate the method's purpose and take no + arguments. They are not written with parens:: + + @collection.appender + def append(self, append): ... + + The recipe decorators all require parens, even those that take no + arguments:: + + @collection.adds('entity'): + def insert(self, position, entity): ... + + @collection.removes_return() + def popitem(self): ... + + Decorators can be specified in long-hand for Python 2.3, or with + the class-level dict attribute '__instrumentation__'- see the source + for details. + + """ + # Bundled as a class solely for ease of use: packaging, doc strings, + # importability. + + @staticmethod + def appender(fn): + """Tag the method as the collection appender. + + The appender method is called with one positional argument: the value + to append. The method will be automatically decorated with 'adds(1)' + if not already decorated:: + + @collection.appender + def add(self, append): ... + + # or, equivalently + @collection.appender + @collection.adds(1) + def add(self, append): ... + + # for mapping type, an 'append' may kick out a previous value + # that occupies that slot. consider d['a'] = 'foo'- any previous + # value in d['a'] is discarded. + @collection.appender + @collection.replaces(1) + def add(self, entity): + key = some_key_func(entity) + previous = None + if key in self: + previous = self[key] + self[key] = entity + return previous + + If the value to append is not allowed in the collection, you may + raise an exception. Something to remember is that the appender + will be called for each object mapped by a database query. If the + database contains rows that violate your collection semantics, you + will need to get creative to fix the problem, as access via the + collection will not work. + + If the appender method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + + """ + setattr(fn, '_sa_instrument_role', 'appender') + return fn + + @staticmethod + def remover(fn): + """Tag the method as the collection remover. + + The remover method is called with one positional argument: the value + to remove. The method will be automatically decorated with + 'removes_return()' if not already decorated:: + + @collection.remover + def zap(self, entity): ... + + # or, equivalently + @collection.remover + @collection.removes_return() + def zap(self, ): ... + + If the value to remove is not present in the collection, you may + raise an exception or return None to ignore the error. + + If the remove method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + + """ + setattr(fn, '_sa_instrument_role', 'remover') + return fn + + @staticmethod + def iterator(fn): + """Tag the method as the collection remover. + + The iterator method is called with no arguments. It is expected to + return an iterator over all collection members:: + + @collection.iterator + def __iter__(self): ... + + """ + setattr(fn, '_sa_instrument_role', 'iterator') + return fn + + @staticmethod + def internally_instrumented(fn): + """Tag the method as instrumented. + + This tag will prevent any decoration from being applied to the method. + Use this if you are orchestrating your own calls to collection_adapter + in one of the basic SQLAlchemy interface methods, or to prevent + an automatic ABC method decoration from wrapping your implementation:: + + # normally an 'extend' method on a list-like class would be + # automatically intercepted and re-implemented in terms of + # SQLAlchemy events and append(). your implementation will + # never be called, unless: + @collection.internally_instrumented + def extend(self, items): ... + + """ + setattr(fn, '_sa_instrumented', True) + return fn + + @staticmethod + def on_link(fn): + """Tag the method as a the "linked to attribute" event handler. + + This optional event handler will be called when the collection class + is linked to or unlinked from the InstrumentedAttribute. It is + invoked immediately after the '_sa_adapter' property is set on + the instance. A single argument is passed: the collection adapter + that has been linked, or None if unlinking. + + """ + setattr(fn, '_sa_instrument_role', 'on_link') + return fn + + @staticmethod + def converter(fn): + """Tag the method as the collection converter. + + This optional method will be called when a collection is being + replaced entirely, as in:: + + myobj.acollection = [newvalue1, newvalue2] + + The converter method will receive the object being assigned and should + return an iterable of values suitable for use by the ``appender`` + method. A converter must not assign values or mutate the collection, + it's sole job is to adapt the value the user provides into an iterable + of values for the ORM's use. + + The default converter implementation will use duck-typing to do the + conversion. A dict-like collection will be convert into an iterable + of dictionary values, and other types will simply be iterated. + + @collection.converter + def convert(self, other): ... + + If the duck-typing of the object does not match the type of this + collection, a TypeError is raised. + + Supply an implementation of this method if you want to expand the + range of possible types that can be assigned in bulk or perform + validation on the values about to be assigned. + + """ + setattr(fn, '_sa_instrument_role', 'converter') + return fn + + @staticmethod + def adds(arg): + """Mark the method as adding an entity to the collection. + + Adds "add to collection" handling to the method. The decorator + argument indicates which method argument holds the SQLAlchemy-relevant + value. Arguments can be specified positionally (i.e. integer) or by + name:: + + @collection.adds(1) + def push(self, item): ... + + @collection.adds('entity') + def do_stuff(self, thing, entity=None): ... + + """ + def decorator(fn): + setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) + return fn + return decorator + + @staticmethod + def replaces(arg): + """Mark the method as replacing an entity in the collection. + + Adds "add to collection" and "remove from collection" handling to + the method. The decorator argument indicates which method argument + holds the SQLAlchemy-relevant value to be added, and return value, if + any will be considered the value to remove. + + Arguments can be specified positionally (i.e. integer) or by name:: + + @collection.replaces(2) + def __setitem__(self, index, item): ... + + """ + def decorator(fn): + setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) + setattr(fn, '_sa_instrument_after', 'fire_remove_event') + return fn + return decorator + + @staticmethod + def removes(arg): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The decorator + argument indicates which method argument holds the SQLAlchemy-relevant + value to be removed. Arguments can be specified positionally (i.e. + integer) or by name:: + + @collection.removes(1) + def zap(self, item): ... + + For methods where the value to remove is not known at call-time, use + collection.removes_return. + + """ + def decorator(fn): + setattr(fn, '_sa_instrument_before', ('fire_remove_event', arg)) + return fn + return decorator + + @staticmethod + def removes_return(): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The return value + of the method, if any, is considered the value to remove. The method + arguments are not inspected:: + + @collection.removes_return() + def pop(self): ... + + For methods where the value to remove is known at call-time, use + collection.remove. + + """ + def decorator(fn): + setattr(fn, '_sa_instrument_after', 'fire_remove_event') + return fn + return decorator + + +# public instrumentation interface for 'internally instrumented' +# implementations +def collection_adapter(collection): + """Fetch the CollectionAdapter for a collection.""" + return getattr(collection, '_sa_adapter', None) + +def collection_iter(collection): + """Iterate over an object supporting the @iterator or __iter__ protocols. + + If the collection is an ORM collection, it need not be attached to an + object to be iterable. + + """ + try: + return getattr(collection, '_sa_iterator', + getattr(collection, '__iter__'))() + except AttributeError: + raise TypeError("'%s' object is not iterable" % + type(collection).__name__) + + +class CollectionAdapter(object): + """Bridges between the ORM and arbitrary Python collections. + + Proxies base-level collection operations (append, remove, iterate) + to the underlying Python collection, and emits add/remove events for + entities entering or leaving the collection. + + The ORM uses an CollectionAdapter exclusively for interaction with + entity collections. + + """ + def __init__(self, attr, owner_state, data): + self.attr = attr + # TODO: figure out what this being a weakref buys us + self._data = weakref.ref(data) + self.owner_state = owner_state + self.link_to_self(data) + + data = property(lambda s: s._data(), + doc="The entity collection being adapted.") + + def link_to_self(self, data): + """Link a collection to this adapter, and fire a link event.""" + setattr(data, '_sa_adapter', self) + if hasattr(data, '_sa_on_link'): + getattr(data, '_sa_on_link')(self) + + def unlink(self, data): + """Unlink a collection from any adapter, and fire a link event.""" + setattr(data, '_sa_adapter', None) + if hasattr(data, '_sa_on_link'): + getattr(data, '_sa_on_link')(None) + + def adapt_like_to_iterable(self, obj): + """Converts collection-compatible objects to an iterable of values. + + Can be passed any type of object, and if the underlying collection + determines that it can be adapted into a stream of values it can + use, returns an iterable of values suitable for append()ing. + + This method may raise TypeError or any other suitable exception + if adaptation fails. + + If a converter implementation is not supplied on the collection, + a default duck-typing-based implementation is used. + + """ + converter = getattr(self._data(), '_sa_converter', None) + if converter is not None: + return converter(obj) + + setting_type = util.duck_type_collection(obj) + receiving_type = util.duck_type_collection(self._data()) + + if obj is None or setting_type != receiving_type: + given = obj is None and 'None' or obj.__class__.__name__ + if receiving_type is None: + wanted = self._data().__class__.__name__ + else: + wanted = receiving_type.__name__ + + raise TypeError( + "Incompatible collection type: %s is not %s-like" % ( + given, wanted)) + + # If the object is an adapted collection, return the (iterable) + # adapter. + if getattr(obj, '_sa_adapter', None) is not None: + return getattr(obj, '_sa_adapter') + elif setting_type == dict: + # Py3K + #return obj.values() + # Py2K + return getattr(obj, 'itervalues', getattr(obj, 'values'))() + # end Py2K + else: + return iter(obj) + + def append_with_event(self, item, initiator=None): + """Add an entity to the collection, firing mutation events.""" + getattr(self._data(), '_sa_appender')(item, _sa_initiator=initiator) + + def append_without_event(self, item): + """Add or restore an entity to the collection, firing no events.""" + getattr(self._data(), '_sa_appender')(item, _sa_initiator=False) + + def remove_with_event(self, item, initiator=None): + """Remove an entity from the collection, firing mutation events.""" + getattr(self._data(), '_sa_remover')(item, _sa_initiator=initiator) + + def remove_without_event(self, item): + """Remove an entity from the collection, firing no events.""" + getattr(self._data(), '_sa_remover')(item, _sa_initiator=False) + + def clear_with_event(self, initiator=None): + """Empty the collection, firing a mutation event for each entity.""" + for item in list(self): + self.remove_with_event(item, initiator) + + def clear_without_event(self): + """Empty the collection, firing no events.""" + for item in list(self): + self.remove_without_event(item) + + def __iter__(self): + """Iterate over entities in the collection.""" + + # Py3K requires iter() here + return iter(getattr(self._data(), '_sa_iterator')()) + + def __len__(self): + """Count entities in the collection.""" + return len(list(getattr(self._data(), '_sa_iterator')())) + + def __nonzero__(self): + return True + + def fire_append_event(self, item, initiator=None): + """Notify that a entity has entered the collection. + + Initiator is the InstrumentedAttribute that initiated the membership + mutation, and should be left as None unless you are passing along + an initiator value from a chained operation. + + """ + if initiator is not False and item is not None: + return self.attr.fire_append_event(self.owner_state, self.owner_state.dict, item, initiator) + else: + return item + + def fire_remove_event(self, item, initiator=None): + """Notify that a entity has been removed from the collection. + + Initiator is the InstrumentedAttribute that initiated the membership + mutation, and should be left as None unless you are passing along + an initiator value from a chained operation. + + """ + if initiator is not False and item is not None: + self.attr.fire_remove_event(self.owner_state, self.owner_state.dict, item, initiator) + + def fire_pre_remove_event(self, initiator=None): + """Notify that an entity is about to be removed from the collection. + + Only called if the entity cannot be removed after calling + fire_remove_event(). + + """ + self.attr.fire_pre_remove_event(self.owner_state, self.owner_state.dict, initiator=initiator) + + def __getstate__(self): + return {'key': self.attr.key, + 'owner_state': self.owner_state, + 'data': self.data} + + def __setstate__(self, d): + self.attr = getattr(d['owner_state'].obj().__class__, d['key']).impl + self.owner_state = d['owner_state'] + self._data = weakref.ref(d['data']) + + +def bulk_replace(values, existing_adapter, new_adapter): + """Load a new collection, firing events based on prior like membership. + + Appends instances in ``values`` onto the ``new_adapter``. Events will be + fired for any instance not present in the ``existing_adapter``. Any + instances in ``existing_adapter`` not present in ``values`` will have + remove events fired upon them. + + values + An iterable of collection member instances + + existing_adapter + A CollectionAdapter of instances to be replaced + + new_adapter + An empty CollectionAdapter to load with ``values`` + + + """ + if not isinstance(values, list): + values = list(values) + + idset = util.IdentitySet + constants = idset(existing_adapter or ()).intersection(values or ()) + additions = idset(values or ()).difference(constants) + removals = idset(existing_adapter or ()).difference(constants) + + for member in values or (): + if member in additions: + new_adapter.append_with_event(member) + elif member in constants: + new_adapter.append_without_event(member) + + if existing_adapter: + for member in removals: + existing_adapter.remove_with_event(member) + +def prepare_instrumentation(factory): + """Prepare a callable for future use as a collection class factory. + + Given a collection class factory (either a type or no-arg callable), + return another factory that will produce compatible instances when + called. + + This function is responsible for converting collection_class=list + into the run-time behavior of collection_class=InstrumentedList. + + """ + # Convert a builtin to 'Instrumented*' + if factory in __canned_instrumentation: + factory = __canned_instrumentation[factory] + + # Create a specimen + cls = type(factory()) + + # Did factory callable return a builtin? + if cls in __canned_instrumentation: + # Wrap it so that it returns our 'Instrumented*' + factory = __converting_factory(factory) + cls = factory() + + # Instrument the class if needed. + if __instrumentation_mutex.acquire(): + try: + if getattr(cls, '_sa_instrumented', None) != id(cls): + _instrument_class(cls) + finally: + __instrumentation_mutex.release() + + return factory + +def __converting_factory(original_factory): + """Convert the type returned by collection factories on the fly. + + Given a collection factory that returns a builtin type (e.g. a list), + return a wrapped function that converts that type to one of our + instrumented types. + + """ + def wrapper(): + collection = original_factory() + type_ = type(collection) + if type_ in __canned_instrumentation: + # return an instrumented type initialized from the factory's + # collection + return __canned_instrumentation[type_](collection) + else: + raise sa_exc.InvalidRequestError( + "Collection class factories must produce instances of a " + "single class.") + try: + # often flawed but better than nothing + wrapper.__name__ = "%sWrapper" % original_factory.__name__ + wrapper.__doc__ = original_factory.__doc__ + except: + pass + return wrapper + +def _instrument_class(cls): + """Modify methods in a class and install instrumentation.""" + + # TODO: more formally document this as a decoratorless/Python 2.3 + # option for specifying instrumentation. (likely doc'd here in code only, + # not in online docs.) Useful for C types too. + # + # __instrumentation__ = { + # 'rolename': 'methodname', # ... + # 'methods': { + # 'methodname': ('fire_{append,remove}_event', argspec, + # 'fire_{append,remove}_event'), + # 'append': ('fire_append_event', 1, None), + # '__setitem__': ('fire_append_event', 1, 'fire_remove_event'), + # 'pop': (None, None, 'fire_remove_event'), + # } + # } + + # In the normal call flow, a request for any of the 3 basic collection + # types is transformed into one of our trivial subclasses + # (e.g. InstrumentedList). Catch anything else that sneaks in here... + if cls.__module__ == '__builtin__': + raise sa_exc.ArgumentError( + "Can not instrument a built-in type. Use a " + "subclass, even a trivial one.") + + collection_type = util.duck_type_collection(cls) + if collection_type in __interfaces: + roles = __interfaces[collection_type].copy() + decorators = roles.pop('_decorators', {}) + else: + roles, decorators = {}, {} + + if hasattr(cls, '__instrumentation__'): + roles.update(copy.deepcopy(getattr(cls, '__instrumentation__'))) + + methods = roles.pop('methods', {}) + + for name in dir(cls): + method = getattr(cls, name, None) + if not util.callable(method): + continue + + # note role declarations + if hasattr(method, '_sa_instrument_role'): + role = method._sa_instrument_role + assert role in ('appender', 'remover', 'iterator', + 'on_link', 'converter') + roles[role] = name + + # transfer instrumentation requests from decorated function + # to the combined queue + before, after = None, None + if hasattr(method, '_sa_instrument_before'): + op, argument = method._sa_instrument_before + assert op in ('fire_append_event', 'fire_remove_event') + before = op, argument + if hasattr(method, '_sa_instrument_after'): + op = method._sa_instrument_after + assert op in ('fire_append_event', 'fire_remove_event') + after = op + if before: + methods[name] = before[0], before[1], after + elif after: + methods[name] = None, None, after + + # apply ABC auto-decoration to methods that need it + for method, decorator in decorators.items(): + fn = getattr(cls, method, None) + if (fn and method not in methods and + not hasattr(fn, '_sa_instrumented')): + setattr(cls, method, decorator(fn)) + + # ensure all roles are present, and apply implicit instrumentation if + # needed + if 'appender' not in roles or not hasattr(cls, roles['appender']): + raise sa_exc.ArgumentError( + "Type %s must elect an appender method to be " + "a collection class" % cls.__name__) + elif (roles['appender'] not in methods and + not hasattr(getattr(cls, roles['appender']), '_sa_instrumented')): + methods[roles['appender']] = ('fire_append_event', 1, None) + + if 'remover' not in roles or not hasattr(cls, roles['remover']): + raise sa_exc.ArgumentError( + "Type %s must elect a remover method to be " + "a collection class" % cls.__name__) + elif (roles['remover'] not in methods and + not hasattr(getattr(cls, roles['remover']), '_sa_instrumented')): + methods[roles['remover']] = ('fire_remove_event', 1, None) + + if 'iterator' not in roles or not hasattr(cls, roles['iterator']): + raise sa_exc.ArgumentError( + "Type %s must elect an iterator method to be " + "a collection class" % cls.__name__) + + # apply ad-hoc instrumentation from decorators, class-level defaults + # and implicit role declarations + for method, (before, argument, after) in methods.items(): + setattr(cls, method, + _instrument_membership_mutator(getattr(cls, method), + before, argument, after)) + # intern the role map + for role, method in roles.items(): + setattr(cls, '_sa_%s' % role, getattr(cls, method)) + + setattr(cls, '_sa_instrumented', id(cls)) + +def _instrument_membership_mutator(method, before, argument, after): + """Route method args and/or return value through the collection adapter.""" + # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))' + if before: + fn_args = list(util.flatten_iterator(inspect.getargspec(method)[0])) + if type(argument) is int: + pos_arg = argument + named_arg = len(fn_args) > argument and fn_args[argument] or None + else: + if argument in fn_args: + pos_arg = fn_args.index(argument) + else: + pos_arg = None + named_arg = argument + del fn_args + + def wrapper(*args, **kw): + if before: + if pos_arg is None: + if named_arg not in kw: + raise sa_exc.ArgumentError( + "Missing argument %s" % argument) + value = kw[named_arg] + else: + if len(args) > pos_arg: + value = args[pos_arg] + elif named_arg in kw: + value = kw[named_arg] + else: + raise sa_exc.ArgumentError( + "Missing argument %s" % argument) + + initiator = kw.pop('_sa_initiator', None) + if initiator is False: + executor = None + else: + executor = getattr(args[0], '_sa_adapter', None) + + if before and executor: + getattr(executor, before)(value, initiator) + + if not after or not executor: + return method(*args, **kw) + else: + res = method(*args, **kw) + if res is not None: + getattr(executor, after)(res, initiator) + return res + try: + wrapper._sa_instrumented = True + wrapper.__name__ = method.__name__ + wrapper.__doc__ = method.__doc__ + except: + pass + return wrapper + +def __set(collection, item, _sa_initiator=None): + """Run set events, may eventually be inlined into decorators.""" + + if _sa_initiator is not False and item is not None: + executor = getattr(collection, '_sa_adapter', None) + if executor: + item = getattr(executor, 'fire_append_event')(item, _sa_initiator) + return item + +def __del(collection, item, _sa_initiator=None): + """Run del events, may eventually be inlined into decorators.""" + if _sa_initiator is not False and item is not None: + executor = getattr(collection, '_sa_adapter', None) + if executor: + getattr(executor, 'fire_remove_event')(item, _sa_initiator) + +def __before_delete(collection, _sa_initiator=None): + """Special method to run 'commit existing value' methods""" + executor = getattr(collection, '_sa_adapter', None) + if executor: + getattr(executor, 'fire_pre_remove_event')(_sa_initiator) + +def _list_decorators(): + """Tailored instrumentation wrappers for any list-like class.""" + + def _tidy(fn): + setattr(fn, '_sa_instrumented', True) + fn.__doc__ = getattr(getattr(list, fn.__name__), '__doc__') + + def append(fn): + def append(self, item, _sa_initiator=None): + item = __set(self, item, _sa_initiator) + fn(self, item) + _tidy(append) + return append + + def remove(fn): + def remove(self, value, _sa_initiator=None): + __before_delete(self, _sa_initiator) + # testlib.pragma exempt:__eq__ + fn(self, value) + __del(self, value, _sa_initiator) + _tidy(remove) + return remove + + def insert(fn): + def insert(self, index, value): + value = __set(self, value) + fn(self, index, value) + _tidy(insert) + return insert + + def __setitem__(fn): + def __setitem__(self, index, value): + if not isinstance(index, slice): + existing = self[index] + if existing is not None: + __del(self, existing) + value = __set(self, value) + fn(self, index, value) + else: + # slice assignment requires __delitem__, insert, __len__ + step = index.step or 1 + start = index.start or 0 + if start < 0: + start += len(self) + stop = index.stop or len(self) + if stop < 0: + stop += len(self) + + if step == 1: + for i in xrange(start, stop, step): + if len(self) > start: + del self[start] + + for i, item in enumerate(value): + self.insert(i + start, item) + else: + rng = range(start, stop, step) + if len(value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" % (len(value), + len(rng))) + for i, item in zip(rng, value): + self.__setitem__(i, item) + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, index): + if not isinstance(index, slice): + item = self[index] + __del(self, item) + fn(self, index) + else: + # slice deletion requires __getslice__ and a slice-groking + # __getitem__ for stepped deletion + # note: not breaking this into atomic dels + for item in self[index]: + __del(self, item) + fn(self, index) + _tidy(__delitem__) + return __delitem__ + + # Py2K + def __setslice__(fn): + def __setslice__(self, start, end, values): + for value in self[start:end]: + __del(self, value) + values = [__set(self, value) for value in values] + fn(self, start, end, values) + _tidy(__setslice__) + return __setslice__ + + def __delslice__(fn): + def __delslice__(self, start, end): + for value in self[start:end]: + __del(self, value) + fn(self, start, end) + _tidy(__delslice__) + return __delslice__ + # end Py2K + + def extend(fn): + def extend(self, iterable): + for value in iterable: + self.append(value) + _tidy(extend) + return extend + + def __iadd__(fn): + def __iadd__(self, iterable): + # list.__iadd__ takes any iterable and seems to let TypeError raise + # as-is instead of returning NotImplemented + for value in iterable: + self.append(value) + return self + _tidy(__iadd__) + return __iadd__ + + def pop(fn): + def pop(self, index=-1): + __before_delete(self) + item = fn(self, index) + __del(self, item) + return item + _tidy(pop) + return pop + + # __imul__ : not wrapping this. all members of the collection are already + # present, so no need to fire appends... wrapping it with an explicit + # decorator is still possible, so events on *= can be had if they're + # desired. hard to imagine a use case for __imul__, though. + + l = locals().copy() + l.pop('_tidy') + return l + +def _dict_decorators(): + """Tailored instrumentation wrappers for any dict-like mapping class.""" + + def _tidy(fn): + setattr(fn, '_sa_instrumented', True) + fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__') + + Unspecified = util.symbol('Unspecified') + + def __setitem__(fn): + def __setitem__(self, key, value, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator) + value = __set(self, value, _sa_initiator) + fn(self, key, value) + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, key, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator) + fn(self, key) + _tidy(__delitem__) + return __delitem__ + + def clear(fn): + def clear(self): + for key in self: + __del(self, self[key]) + fn(self) + _tidy(clear) + return clear + + def pop(fn): + def pop(self, key, default=Unspecified): + if key in self: + __del(self, self[key]) + if default is Unspecified: + return fn(self, key) + else: + return fn(self, key, default) + _tidy(pop) + return pop + + def popitem(fn): + def popitem(self): + __before_delete(self) + item = fn(self) + __del(self, item[1]) + return item + _tidy(popitem) + return popitem + + def setdefault(fn): + def setdefault(self, key, default=None): + if key not in self: + self.__setitem__(key, default) + return default + else: + return self.__getitem__(key) + _tidy(setdefault) + return setdefault + + if sys.version_info < (2, 4): + def update(fn): + def update(self, other): + for key in other.keys(): + if key not in self or self[key] is not other[key]: + self[key] = other[key] + _tidy(update) + return update + else: + def update(fn): + def update(self, __other=Unspecified, **kw): + if __other is not Unspecified: + if hasattr(__other, 'keys'): + for key in __other.keys(): + if (key not in self or + self[key] is not __other[key]): + self[key] = __other[key] + else: + for key, value in __other: + if key not in self or self[key] is not value: + self[key] = value + for key in kw: + if key not in self or self[key] is not kw[key]: + self[key] = kw[key] + _tidy(update) + return update + + l = locals().copy() + l.pop('_tidy') + l.pop('Unspecified') + return l + +if util.py3k: + _set_binop_bases = (set, frozenset) +else: + import sets + _set_binop_bases = (set, frozenset, sets.BaseSet) + +def _set_binops_check_strict(self, obj): + """Allow only set, frozenset and self.__class__-derived objects in binops.""" + return isinstance(obj, _set_binop_bases + (self.__class__,)) + +def _set_binops_check_loose(self, obj): + """Allow anything set-like to participate in set binops.""" + return (isinstance(obj, _set_binop_bases + (self.__class__,)) or + util.duck_type_collection(obj) == set) + + +def _set_decorators(): + """Tailored instrumentation wrappers for any set-like class.""" + + def _tidy(fn): + setattr(fn, '_sa_instrumented', True) + fn.__doc__ = getattr(getattr(set, fn.__name__), '__doc__') + + Unspecified = util.symbol('Unspecified') + + def add(fn): + def add(self, value, _sa_initiator=None): + if value not in self: + value = __set(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ + fn(self, value) + _tidy(add) + return add + + if sys.version_info < (2, 4): + def discard(fn): + def discard(self, value, _sa_initiator=None): + if value in self: + self.remove(value, _sa_initiator) + _tidy(discard) + return discard + else: + def discard(fn): + def discard(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ + if value in self: + __del(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ + fn(self, value) + _tidy(discard) + return discard + + def remove(fn): + def remove(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ + if value in self: + __del(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ + fn(self, value) + _tidy(remove) + return remove + + def pop(fn): + def pop(self): + __before_delete(self) + item = fn(self) + __del(self, item) + return item + _tidy(pop) + return pop + + def clear(fn): + def clear(self): + for item in list(self): + self.remove(item) + _tidy(clear) + return clear + + def update(fn): + def update(self, value): + for item in value: + self.add(item) + _tidy(update) + return update + + def __ior__(fn): + def __ior__(self, value): + if not _set_binops_check_strict(self, value): + return NotImplemented + for item in value: + self.add(item) + return self + _tidy(__ior__) + return __ior__ + + def difference_update(fn): + def difference_update(self, value): + for item in value: + self.discard(item) + _tidy(difference_update) + return difference_update + + def __isub__(fn): + def __isub__(self, value): + if not _set_binops_check_strict(self, value): + return NotImplemented + for item in value: + self.discard(item) + return self + _tidy(__isub__) + return __isub__ + + def intersection_update(fn): + def intersection_update(self, other): + want, have = self.intersection(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + _tidy(intersection_update) + return intersection_update + + def __iand__(fn): + def __iand__(self, other): + if not _set_binops_check_strict(self, other): + return NotImplemented + want, have = self.intersection(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + return self + _tidy(__iand__) + return __iand__ + + def symmetric_difference_update(fn): + def symmetric_difference_update(self, other): + want, have = self.symmetric_difference(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + _tidy(symmetric_difference_update) + return symmetric_difference_update + + def __ixor__(fn): + def __ixor__(self, other): + if not _set_binops_check_strict(self, other): + return NotImplemented + want, have = self.symmetric_difference(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + return self + _tidy(__ixor__) + return __ixor__ + + l = locals().copy() + l.pop('_tidy') + l.pop('Unspecified') + return l + + +class InstrumentedList(list): + """An instrumented version of the built-in list.""" + + __instrumentation__ = { + 'appender': 'append', + 'remover': 'remove', + 'iterator': '__iter__', } + +class InstrumentedSet(set): + """An instrumented version of the built-in set.""" + + __instrumentation__ = { + 'appender': 'add', + 'remover': 'remove', + 'iterator': '__iter__', } + +class InstrumentedDict(dict): + """An instrumented version of the built-in dict.""" + + # Py3K + #__instrumentation__ = { + # 'iterator': 'values', } + # Py2K + __instrumentation__ = { + 'iterator': 'itervalues', } + # end Py2K + +__canned_instrumentation = { + list: InstrumentedList, + set: InstrumentedSet, + dict: InstrumentedDict, + } + +__interfaces = { + list: {'appender': 'append', + 'remover': 'remove', + 'iterator': '__iter__', + '_decorators': _list_decorators(), }, + set: {'appender': 'add', + 'remover': 'remove', + 'iterator': '__iter__', + '_decorators': _set_decorators(), }, + # decorators are required for dicts and object collections. + # Py3K + #dict: {'iterator': 'values', + # '_decorators': _dict_decorators(), }, + # Py2K + dict: {'iterator': 'itervalues', + '_decorators': _dict_decorators(), }, + # end Py2K + # < 0.4 compatible naming, deprecated- use decorators instead. + None: {} + } + +class MappedCollection(dict): + """A basic dictionary-based collection class. + + Extends dict with the minimal bag semantics that collection classes require. + ``set`` and ``remove`` are implemented in terms of a keying function: any + callable that takes an object and returns an object for use as a dictionary + key. + + """ + + def __init__(self, keyfunc): + """Create a new collection with keying provided by keyfunc. + + keyfunc may be any callable any callable that takes an object and + returns an object for use as a dictionary key. + + The keyfunc will be called every time the ORM needs to add a member by + value-only (such as when loading instances from the database) or + remove a member. The usual cautions about dictionary keying apply- + ``keyfunc(object)`` should return the same output for the life of the + collection. Keying based on mutable properties can result in + unreachable instances "lost" in the collection. + + """ + self.keyfunc = keyfunc + + def set(self, value, _sa_initiator=None): + """Add an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + self.__setitem__(key, value, _sa_initiator) + set = collection.internally_instrumented(set) + set = collection.appender(set) + + def remove(self, value, _sa_initiator=None): + """Remove an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + # Let self[key] raise if key is not in this collection + # testlib.pragma exempt:__ne__ + if self[key] != value: + raise sa_exc.InvalidRequestError( + "Can not remove '%s': collection holds '%s' for key '%s'. " + "Possible cause: is the MappedCollection key function " + "based on mutable properties or properties that only obtain " + "values after flush?" % + (value, self[key], key)) + self.__delitem__(key, _sa_initiator) + remove = collection.internally_instrumented(remove) + remove = collection.remover(remove) + + def _convert(self, dictlike): + """Validate and convert a dict-like object into values for set()ing. + + This is called behind the scenes when a MappedCollection is replaced + entirely by another collection, as in:: + + myobj.mappedcollection = {'a':obj1, 'b': obj2} # ... + + Raises a TypeError if the key in any (key, value) pair in the dictlike + object does not match the key that this collection's keyfunc would + have assigned for that value. + + """ + for incoming_key, value in util.dictlike_iteritems(dictlike): + new_key = self.keyfunc(value) + if incoming_key != new_key: + raise TypeError( + "Found incompatible key %r for value %r; this collection's " + "keying function requires a key of %r for this value." % ( + incoming_key, value, new_key)) + yield value + _convert = collection.converter(_convert) diff --git a/sqlalchemy/orm/dependency.py b/sqlalchemy/orm/dependency.py new file mode 100644 index 0000000..cbbfb08 --- /dev/null +++ b/sqlalchemy/orm/dependency.py @@ -0,0 +1,575 @@ +# orm/dependency.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Relationship dependencies. + +Bridges the ``PropertyLoader`` (i.e. a ``relationship()``) and the +``UOWTransaction`` together to allow processing of relationship()-based +dependencies at flush time. + +""" + +from sqlalchemy import sql, util +import sqlalchemy.exceptions as sa_exc +from sqlalchemy.orm import attributes, exc, sync +from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY + + +def create_dependency_processor(prop): + types = { + ONETOMANY : OneToManyDP, + MANYTOONE: ManyToOneDP, + MANYTOMANY : ManyToManyDP, + } + return types[prop.direction](prop) + +class DependencyProcessor(object): + has_dependencies = True + + def __init__(self, prop): + self.prop = prop + self.cascade = prop.cascade + self.mapper = prop.mapper + self.parent = prop.parent + self.secondary = prop.secondary + self.direction = prop.direction + self.post_update = prop.post_update + self.passive_deletes = prop.passive_deletes + self.passive_updates = prop.passive_updates + self.enable_typechecks = prop.enable_typechecks + self.key = prop.key + self.dependency_marker = MapperStub(self.parent, self.mapper, self.key) + if not self.prop.synchronize_pairs: + raise sa_exc.ArgumentError("Can't build a DependencyProcessor for relationship %s. " + "No target attributes to populate between parent and child are present" % self.prop) + + def _get_instrumented_attribute(self): + """Return the ``InstrumentedAttribute`` handled by this + ``DependencyProecssor``. + + """ + return self.parent.class_manager.get_impl(self.key) + + def hasparent(self, state): + """return True if the given object instance has a parent, + according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``. + + """ + # TODO: use correct API for this + return self._get_instrumented_attribute().hasparent(state) + + def register_dependencies(self, uowcommit): + """Tell a ``UOWTransaction`` what mappers are dependent on + which, with regards to the two or three mappers handled by + this ``DependencyProcessor``. + + """ + + raise NotImplementedError() + + def register_processors(self, uowcommit): + """Tell a ``UOWTransaction`` about this object as a processor, + which will be executed after that mapper's objects have been + saved or before they've been deleted. The process operation + manages attributes and dependent operations between two mappers. + + """ + raise NotImplementedError() + + def whose_dependent_on_who(self, state1, state2): + """Given an object pair assuming `obj2` is a child of `obj1`, + return a tuple with the dependent object second, or None if + there is no dependency. + + """ + if state1 is state2: + return None + elif self.direction == ONETOMANY: + return (state1, state2) + else: + return (state2, state1) + + def process_dependencies(self, task, deplist, uowcommit, delete = False): + """This method is called during a flush operation to + synchronize data between a parent and child object. + + It is called within the context of the various mappers and + sometimes individual objects sorted according to their + insert/update/delete order (topological sort). + + """ + raise NotImplementedError() + + def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): + """Used before the flushes' topological sort to traverse + through related objects and ensure every instance which will + require save/update/delete is properly added to the + UOWTransaction. + + """ + raise NotImplementedError() + + def _verify_canload(self, state): + if state is not None and \ + not self.mapper._canload(state, allow_subtypes=not self.enable_typechecks): + if self.mapper._canload(state, allow_subtypes=True): + raise exc.FlushError( + "Attempting to flush an item of type %s on collection '%s', " + "which is not the expected type %s. Configure mapper '%s' to " + "load this subtype polymorphically, or set " + "enable_typechecks=False to allow subtypes. " + "Mismatched typeloading may cause bi-directional relationships " + "(backrefs) to not function properly." % + (state.class_, self.prop, self.mapper.class_, self.mapper)) + else: + raise exc.FlushError( + "Attempting to flush an item of type %s on collection '%s', " + "whose mapper does not inherit from that of %s." % + (state.class_, self.prop, self.mapper.class_)) + + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + """Called during a flush to synchronize primary key identifier + values between a parent/child object, as well as to an + associationrow in the case of many-to-many. + + """ + raise NotImplementedError() + + def _check_reverse_action(self, uowcommit, parent, child, action): + """Determine if an action has been performed by the 'reverse' property of this property. + + this is used to ensure that only one side of a bidirectional relationship + issues a certain operation for a parent/child pair. + + """ + for r in self.prop._reverse_property: + if not r.viewonly and (r._dependency_processor, action, parent, child) in uowcommit.attributes: + return True + return False + + def _performed_action(self, uowcommit, parent, child, action): + """Establish that an action has been performed for a certain parent/child pair. + + Used only for actions that are sensitive to bidirectional double-action, + i.e. manytomany, post_update. + + """ + uowcommit.attributes[(self, action, parent, child)] = True + + def _conditional_post_update(self, state, uowcommit, related): + """Execute a post_update call. + + For relationships that contain the post_update flag, an additional + ``UPDATE`` statement may be associated after an ``INSERT`` or + before a ``DELETE`` in order to resolve circular row + dependencies. + + This method will check for the post_update flag being set on a + particular relationship, and given a target object and list of + one or more related objects, and execute the ``UPDATE`` if the + given related object list contains ``INSERT``s or ``DELETE``s. + + """ + if state is not None and self.post_update: + for x in related: + if x is not None and not self._check_reverse_action(uowcommit, x, state, "postupdate"): + uowcommit.register_object(state, postupdate=True, post_update_cols=[r for l, r in self.prop.synchronize_pairs]) + self._performed_action(uowcommit, x, state, "postupdate") + break + + def _pks_changed(self, uowcommit, state): + raise NotImplementedError() + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.prop) + +class OneToManyDP(DependencyProcessor): + def register_dependencies(self, uowcommit): + if self.post_update: + uowcommit.register_dependency(self.mapper, self.dependency_marker) + uowcommit.register_dependency(self.parent, self.dependency_marker) + else: + uowcommit.register_dependency(self.parent, self.mapper) + + def register_processors(self, uowcommit): + if self.post_update: + uowcommit.register_processor(self.dependency_marker, self, self.parent) + else: + uowcommit.register_processor(self.parent, self, self.parent) + + def process_dependencies(self, task, deplist, uowcommit, delete = False): + if delete: + # head object is being deleted, and we manage its list of child objects + # the child objects have to have their foreign key to the parent set to NULL + # this phase can be called safely for any cascade but is unnecessary if delete cascade + # is on. + if self.post_update or not self.passive_deletes == 'all': + for state in deplist: + history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) + if history: + for child in history.deleted: + if child is not None and self.hasparent(child) is False: + self._synchronize(state, child, None, True, uowcommit) + self._conditional_post_update(child, uowcommit, [state]) + if self.post_update or not self.cascade.delete: + for child in history.unchanged: + if child is not None: + self._synchronize(state, child, None, True, uowcommit) + self._conditional_post_update(child, uowcommit, [state]) + else: + for state in deplist: + history = uowcommit.get_attribute_history(state, self.key, passive=True) + if history: + for child in history.added: + self._synchronize(state, child, None, False, uowcommit) + if child is not None: + self._conditional_post_update(child, uowcommit, [state]) + + for child in history.deleted: + if not self.cascade.delete_orphan and not self.hasparent(child): + self._synchronize(state, child, None, True, uowcommit) + + if self._pks_changed(uowcommit, state): + for child in history.unchanged: + self._synchronize(state, child, None, False, uowcommit) + + def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): + if delete: + # head object is being deleted, and we manage its list of child objects + # the child objects have to have their foreign key to the parent set to NULL + if not self.post_update: + should_null_fks = not self.cascade.delete and not self.passive_deletes == 'all' + for state in deplist: + history = uowcommit.get_attribute_history( + state, self.key, passive=self.passive_deletes) + if history: + for child in history.deleted: + if child is not None and self.hasparent(child) is False: + if self.cascade.delete_orphan: + uowcommit.register_object(child, isdelete=True) + else: + uowcommit.register_object(child) + if should_null_fks: + for child in history.unchanged: + if child is not None: + uowcommit.register_object(child) + else: + for state in deplist: + history = uowcommit.get_attribute_history(state, self.key, passive=True) + if history: + for child in history.added: + if child is not None: + uowcommit.register_object(child) + for child in history.deleted: + if not self.cascade.delete_orphan: + uowcommit.register_object(child, isdelete=False) + elif self.hasparent(child) is False: + uowcommit.register_object(child, isdelete=True) + for c, m in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object( + attributes.instance_state(c), + isdelete=True) + if self._pks_changed(uowcommit, state): + if not history: + history = uowcommit.get_attribute_history( + state, self.key, passive=self.passive_updates) + if history: + for child in history.unchanged: + if child is not None: + uowcommit.register_object(child) + + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + source = state + dest = child + if dest is None or (not self.post_update and uowcommit.is_deleted(dest)): + return + self._verify_canload(child) + if clearkeys: + sync.clear(dest, self.mapper, self.prop.synchronize_pairs) + else: + sync.populate(source, self.parent, dest, self.mapper, + self.prop.synchronize_pairs, uowcommit, + self.passive_updates) + + def _pks_changed(self, uowcommit, state): + return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs) + +class DetectKeySwitch(DependencyProcessor): + """a special DP that works for many-to-one relationships, fires off for + child items who have changed their referenced key.""" + + has_dependencies = False + + def register_dependencies(self, uowcommit): + pass + + def register_processors(self, uowcommit): + uowcommit.register_processor(self.parent, self, self.mapper) + + def preprocess_dependencies(self, task, deplist, uowcommit, delete=False): + # for non-passive updates, register in the preprocess stage + # so that mapper save_obj() gets a hold of changes + if not delete and not self.passive_updates: + self._process_key_switches(deplist, uowcommit) + + def process_dependencies(self, task, deplist, uowcommit, delete=False): + # for passive updates, register objects in the process stage + # so that we avoid ManyToOneDP's registering the object without + # the listonly flag in its own preprocess stage (results in UPDATE) + # statements being emitted + if not delete and self.passive_updates: + self._process_key_switches(deplist, uowcommit) + + def _process_key_switches(self, deplist, uowcommit): + switchers = set(s for s in deplist if self._pks_changed(uowcommit, s)) + if switchers: + # yes, we're doing a linear search right now through the UOW. only + # takes effect when primary key values have actually changed. + # a possible optimization might be to enhance the "hasparents" capability of + # attributes to actually store all parent references, but this introduces + # more complicated attribute accounting. + for s in [elem for elem in uowcommit.session.identity_map.all_states() + if issubclass(elem.class_, self.parent.class_) and + self.key in elem.dict and + elem.dict[self.key] is not None and + attributes.instance_state(elem.dict[self.key]) in switchers + ]: + uowcommit.register_object(s) + sync.populate( + attributes.instance_state(s.dict[self.key]), + self.mapper, s, self.parent, self.prop.synchronize_pairs, + uowcommit, self.passive_updates) + + def _pks_changed(self, uowcommit, state): + return sync.source_modified(uowcommit, state, self.mapper, self.prop.synchronize_pairs) + +class ManyToOneDP(DependencyProcessor): + def __init__(self, prop): + DependencyProcessor.__init__(self, prop) + self.mapper._dependency_processors.append(DetectKeySwitch(prop)) + + def register_dependencies(self, uowcommit): + if self.post_update: + uowcommit.register_dependency(self.mapper, self.dependency_marker) + uowcommit.register_dependency(self.parent, self.dependency_marker) + else: + uowcommit.register_dependency(self.mapper, self.parent) + + def register_processors(self, uowcommit): + if self.post_update: + uowcommit.register_processor(self.dependency_marker, self, self.parent) + else: + uowcommit.register_processor(self.mapper, self, self.parent) + + def process_dependencies(self, task, deplist, uowcommit, delete=False): + if delete: + if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all': + # post_update means we have to update our row to not reference the child object + # before we can DELETE the row + for state in deplist: + self._synchronize(state, None, None, True, uowcommit) + history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) + if history: + self._conditional_post_update(state, uowcommit, history.sum()) + else: + for state in deplist: + history = uowcommit.get_attribute_history(state, self.key, passive=True) + if history: + for child in history.added: + self._synchronize(state, child, None, False, uowcommit) + self._conditional_post_update(state, uowcommit, history.sum()) + + def preprocess_dependencies(self, task, deplist, uowcommit, delete=False): + if self.post_update: + return + if delete: + if self.cascade.delete or self.cascade.delete_orphan: + for state in deplist: + history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) + if history: + if self.cascade.delete_orphan: + todelete = history.sum() + else: + todelete = history.non_deleted() + for child in todelete: + if child is None: + continue + uowcommit.register_object(child, isdelete=True) + for c, m in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object( + attributes.instance_state(c), isdelete=True) + else: + for state in deplist: + uowcommit.register_object(state) + if self.cascade.delete_orphan: + history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) + if history: + for child in history.deleted: + if self.hasparent(child) is False: + uowcommit.register_object(child, isdelete=True) + for c, m in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object( + attributes.instance_state(c), + isdelete=True) + + + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + if state is None or (not self.post_update and uowcommit.is_deleted(state)): + return + + if clearkeys or child is None: + sync.clear(state, self.parent, self.prop.synchronize_pairs) + else: + self._verify_canload(child) + sync.populate(child, self.mapper, state, + self.parent, self.prop.synchronize_pairs, uowcommit, + self.passive_updates + ) + +class ManyToManyDP(DependencyProcessor): + def register_dependencies(self, uowcommit): + # many-to-many. create a "Stub" mapper to represent the + # "middle table" in the relationship. This stub mapper doesnt save + # or delete any objects, but just marks a dependency on the two + # related mappers. its dependency processor then populates the + # association table. + + uowcommit.register_dependency(self.parent, self.dependency_marker) + uowcommit.register_dependency(self.mapper, self.dependency_marker) + + def register_processors(self, uowcommit): + uowcommit.register_processor(self.dependency_marker, self, self.parent) + + def process_dependencies(self, task, deplist, uowcommit, delete = False): + connection = uowcommit.transaction.connection(self.mapper) + secondary_delete = [] + secondary_insert = [] + secondary_update = [] + + if delete: + for state in deplist: + history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) + if history: + for child in history.non_added(): + if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"): + continue + associationrow = {} + self._synchronize(state, child, associationrow, False, uowcommit) + secondary_delete.append(associationrow) + self._performed_action(uowcommit, state, child, "manytomany") + else: + for state in deplist: + history = uowcommit.get_attribute_history(state, self.key) + if history: + for child in history.added: + if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"): + continue + associationrow = {} + self._synchronize(state, child, associationrow, False, uowcommit) + self._performed_action(uowcommit, state, child, "manytomany") + secondary_insert.append(associationrow) + for child in history.deleted: + if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"): + continue + associationrow = {} + self._synchronize(state, child, associationrow, False, uowcommit) + self._performed_action(uowcommit, state, child, "manytomany") + secondary_delete.append(associationrow) + + if not self.passive_updates and self._pks_changed(uowcommit, state): + if not history: + history = uowcommit.get_attribute_history(state, self.key, passive=False) + + for child in history.unchanged: + associationrow = {} + sync.update(state, self.parent, associationrow, "old_", self.prop.synchronize_pairs) + sync.update(child, self.mapper, associationrow, "old_", self.prop.secondary_synchronize_pairs) + + #self.syncrules.update(associationrow, state, child, "old_") + secondary_update.append(associationrow) + + if secondary_delete: + statement = self.secondary.delete(sql.and_(*[ + c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow + ])) + result = connection.execute(statement, secondary_delete) + if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_delete): + raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of " + "secondary table rows deleted from table '%s': %d" % + (result.rowcount, self.secondary.description, len(secondary_delete))) + + if secondary_update: + statement = self.secondary.update(sql.and_(*[ + c == sql.bindparam("old_" + c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow + ])) + result = connection.execute(statement, secondary_update) + if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update): + raise exc.ConcurrentModificationError("Updated rowcount %d does not match number of " + "secondary table rows updated from table '%s': %d" % + (result.rowcount, self.secondary.description, len(secondary_update))) + + if secondary_insert: + statement = self.secondary.insert() + connection.execute(statement, secondary_insert) + + def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): + if not delete: + for state in deplist: + history = uowcommit.get_attribute_history(state, self.key, passive=True) + if history: + for child in history.deleted: + if self.cascade.delete_orphan and self.hasparent(child) is False: + uowcommit.register_object(child, isdelete=True) + for c, m in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object( + attributes.instance_state(c), isdelete=True) + + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + if associationrow is None: + return + self._verify_canload(child) + + sync.populate_dict(state, self.parent, associationrow, + self.prop.synchronize_pairs) + sync.populate_dict(child, self.mapper, associationrow, + self.prop.secondary_synchronize_pairs) + + def _pks_changed(self, uowcommit, state): + return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs) + +class MapperStub(object): + """Represent a many-to-many dependency within a flush + context. + + The UOWTransaction corresponds dependencies to mappers. + MapperStub takes the place of the "association table" + so that a depedendency can be corresponded to it. + + """ + + def __init__(self, parent, mapper, key): + self.mapper = mapper + self.base_mapper = self + self.class_ = mapper.class_ + self._inheriting_mappers = [] + + def polymorphic_iterator(self): + return iter((self,)) + + def _register_dependencies(self, uowcommit): + pass + + def _register_procesors(self, uowcommit): + pass + + def _save_obj(self, *args, **kwargs): + pass + + def _delete_obj(self, *args, **kwargs): + pass + + def primary_mapper(self): + return self diff --git a/sqlalchemy/orm/dynamic.py b/sqlalchemy/orm/dynamic.py new file mode 100644 index 0000000..d796040 --- /dev/null +++ b/sqlalchemy/orm/dynamic.py @@ -0,0 +1,293 @@ +# dynamic.py +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Dynamic collection API. + +Dynamic collections act like Query() objects for read operations and support +basic add/delete mutation. + +""" + +from sqlalchemy import log, util +from sqlalchemy import exc as sa_exc +from sqlalchemy.orm import exc as sa_exc +from sqlalchemy.sql import operators +from sqlalchemy.orm import ( + attributes, object_session, util as mapperutil, strategies, object_mapper + ) +from sqlalchemy.orm.query import Query +from sqlalchemy.orm.util import _state_has_identity, has_identity +from sqlalchemy.orm import attributes, collections + +class DynaLoader(strategies.AbstractRelationshipLoader): + def init_class_attribute(self, mapper): + self.is_class_level = True + + strategies._register_attribute(self, + mapper, + useobject=True, + impl_class=DynamicAttributeImpl, + target_mapper=self.parent_property.mapper, + order_by=self.parent_property.order_by, + query_class=self.parent_property.query_class + ) + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + return (None, None) + +log.class_logger(DynaLoader) + +class DynamicAttributeImpl(attributes.AttributeImpl): + uses_objects = True + accepts_scalar_loader = False + + def __init__(self, class_, key, typecallable, + target_mapper, order_by, query_class=None, **kwargs): + super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs) + self.target_mapper = target_mapper + self.order_by = order_by + if not query_class: + self.query_class = AppenderQuery + elif AppenderMixin in query_class.mro(): + self.query_class = query_class + else: + self.query_class = mixin_user_query(query_class) + + def get(self, state, dict_, passive=False): + if passive: + return self._get_collection_history(state, passive=True).added_items + else: + return self.query_class(self, state) + + def get_collection(self, state, dict_, user_data=None, passive=True): + if passive: + return self._get_collection_history(state, passive=passive).added_items + else: + history = self._get_collection_history(state, passive=passive) + return history.added_items + history.unchanged_items + + def fire_append_event(self, state, dict_, value, initiator): + collection_history = self._modified_event(state, dict_) + collection_history.added_items.append(value) + + for ext in self.extensions: + ext.append(state, value, initiator or self) + + if self.trackparent and value is not None: + self.sethasparent(attributes.instance_state(value), True) + + def fire_remove_event(self, state, dict_, value, initiator): + collection_history = self._modified_event(state, dict_) + collection_history.deleted_items.append(value) + + if self.trackparent and value is not None: + self.sethasparent(attributes.instance_state(value), False) + + for ext in self.extensions: + ext.remove(state, value, initiator or self) + + def _modified_event(self, state, dict_): + + if self.key not in state.committed_state: + state.committed_state[self.key] = CollectionHistory(self, state) + + state.modified_event(dict_, + self, + False, + attributes.NEVER_SET, + passive=attributes.PASSIVE_NO_INITIALIZE) + + # this is a hack to allow the _base.ComparableEntity fixture + # to work + dict_[self.key] = True + return state.committed_state[self.key] + + def set(self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF): + if initiator is self: + return + + self._set_iterable(state, dict_, value) + + def _set_iterable(self, state, dict_, iterable, adapter=None): + + collection_history = self._modified_event(state, dict_) + new_values = list(iterable) + + if _state_has_identity(state): + old_collection = list(self.get(state, dict_)) + else: + old_collection = [] + + collections.bulk_replace(new_values, DynCollectionAdapter(self, state, old_collection), DynCollectionAdapter(self, state, new_values)) + + def delete(self, *args, **kwargs): + raise NotImplementedError() + + def get_history(self, state, dict_, passive=False): + c = self._get_collection_history(state, passive) + return attributes.History(c.added_items, c.unchanged_items, c.deleted_items) + + def _get_collection_history(self, state, passive=False): + if self.key in state.committed_state: + c = state.committed_state[self.key] + else: + c = CollectionHistory(self, state) + + if not passive: + return CollectionHistory(self, state, apply_to=c) + else: + return c + + def append(self, state, dict_, value, initiator, passive=False): + if initiator is not self: + self.fire_append_event(state, dict_, value, initiator) + + def remove(self, state, dict_, value, initiator, passive=False): + if initiator is not self: + self.fire_remove_event(state, dict_, value, initiator) + +class DynCollectionAdapter(object): + """the dynamic analogue to orm.collections.CollectionAdapter""" + + def __init__(self, attr, owner_state, data): + self.attr = attr + self.state = owner_state + self.data = data + + def __iter__(self): + return iter(self.data) + + def append_with_event(self, item, initiator=None): + self.attr.append(self.state, self.state.dict, item, initiator) + + def remove_with_event(self, item, initiator=None): + self.attr.remove(self.state, self.state.dict, item, initiator) + + def append_without_event(self, item): + pass + + def remove_without_event(self, item): + pass + +class AppenderMixin(object): + query_class = None + + def __init__(self, attr, state): + Query.__init__(self, attr.target_mapper, None) + self.instance = instance = state.obj() + self.attr = attr + + mapper = object_mapper(instance) + prop = mapper.get_property(self.attr.key, resolve_synonyms=True) + self._criterion = prop.compare( + operators.eq, + instance, + value_is_parent=True, + alias_secondary=False) + + if self.attr.order_by: + self._order_by = self.attr.order_by + + def __session(self): + sess = object_session(self.instance) + if sess is not None and self.autoflush and sess.autoflush and self.instance in sess: + sess.flush() + if not has_identity(self.instance): + return None + else: + return sess + + def session(self): + return self.__session() + session = property(session, lambda s, x:None) + + def __iter__(self): + sess = self.__session() + if sess is None: + return iter(self.attr._get_collection_history( + attributes.instance_state(self.instance), + passive=True).added_items) + else: + return iter(self._clone(sess)) + + def __getitem__(self, index): + sess = self.__session() + if sess is None: + return self.attr._get_collection_history( + attributes.instance_state(self.instance), + passive=True).added_items.__getitem__(index) + else: + return self._clone(sess).__getitem__(index) + + def count(self): + sess = self.__session() + if sess is None: + return len(self.attr._get_collection_history( + attributes.instance_state(self.instance), + passive=True).added_items) + else: + return self._clone(sess).count() + + def _clone(self, sess=None): + # note we're returning an entirely new Query class instance + # here without any assignment capabilities; the class of this + # query is determined by the session. + instance = self.instance + if sess is None: + sess = object_session(instance) + if sess is None: + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session, and no " + "contextual session is established; lazy load operation " + "of attribute '%s' cannot proceed" % ( + mapperutil.instance_str(instance), self.attr.key)) + + if self.query_class: + query = self.query_class(self.attr.target_mapper, session=sess) + else: + query = sess.query(self.attr.target_mapper) + + query._criterion = self._criterion + query._order_by = self._order_by + + return query + + def append(self, item): + self.attr.append( + attributes.instance_state(self.instance), + attributes.instance_dict(self.instance), item, None) + + def remove(self, item): + self.attr.remove( + attributes.instance_state(self.instance), + attributes.instance_dict(self.instance), item, None) + + +class AppenderQuery(AppenderMixin, Query): + """A dynamic query that supports basic collection storage operations.""" + + +def mixin_user_query(cls): + """Return a new class with AppenderQuery functionality layered over.""" + name = 'Appender' + cls.__name__ + return type(name, (AppenderMixin, cls), {'query_class': cls}) + +class CollectionHistory(object): + """Overrides AttributeHistory to receive append/remove events directly.""" + + def __init__(self, attr, state, apply_to=None): + if apply_to: + deleted = util.IdentitySet(apply_to.deleted_items) + added = apply_to.added_items + coll = AppenderQuery(attr, state).autoflush(False) + self.unchanged_items = [o for o in util.IdentitySet(coll) if o not in deleted] + self.added_items = apply_to.added_items + self.deleted_items = apply_to.deleted_items + else: + self.deleted_items = [] + self.added_items = [] + self.unchanged_items = [] + diff --git a/sqlalchemy/orm/evaluator.py b/sqlalchemy/orm/evaluator.py new file mode 100644 index 0000000..3ee7078 --- /dev/null +++ b/sqlalchemy/orm/evaluator.py @@ -0,0 +1,104 @@ +import operator +from sqlalchemy.sql import operators, functions +from sqlalchemy.sql import expression as sql + + +class UnevaluatableError(Exception): + pass + +_straight_ops = set(getattr(operators, op) + for op in ('add', 'mul', 'sub', + # Py2K + 'div', + # end Py2K + 'mod', 'truediv', + 'lt', 'le', 'ne', 'gt', 'ge', 'eq')) + + +_notimplemented_ops = set(getattr(operators, op) + for op in ('like_op', 'notlike_op', 'ilike_op', + 'notilike_op', 'between_op', 'in_op', + 'notin_op', 'endswith_op', 'concat_op')) + +class EvaluatorCompiler(object): + def process(self, clause): + meth = getattr(self, "visit_%s" % clause.__visit_name__, None) + if not meth: + raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__) + return meth(clause) + + def visit_grouping(self, clause): + return self.process(clause.element) + + def visit_null(self, clause): + return lambda obj: None + + def visit_column(self, clause): + if 'parentmapper' in clause._annotations: + key = clause._annotations['parentmapper']._get_col_to_prop(clause).key + else: + key = clause.key + get_corresponding_attr = operator.attrgetter(key) + return lambda obj: get_corresponding_attr(obj) + + def visit_clauselist(self, clause): + evaluators = map(self.process, clause.clauses) + if clause.operator is operators.or_: + def evaluate(obj): + has_null = False + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value: + return True + has_null = has_null or value is None + if has_null: + return None + return False + elif clause.operator is operators.and_: + def evaluate(obj): + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if not value: + if value is None: + return None + return False + return True + else: + raise UnevaluatableError("Cannot evaluate clauselist with operator %s" % clause.operator) + + return evaluate + + def visit_binary(self, clause): + eval_left,eval_right = map(self.process, [clause.left, clause.right]) + operator = clause.operator + if operator is operators.is_: + def evaluate(obj): + return eval_left(obj) == eval_right(obj) + elif operator is operators.isnot: + def evaluate(obj): + return eval_left(obj) != eval_right(obj) + elif operator in _straight_ops: + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is None or right_val is None: + return None + return operator(eval_left(obj), eval_right(obj)) + else: + raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator)) + return evaluate + + def visit_unary(self, clause): + eval_inner = self.process(clause.element) + if clause.operator is operators.inv: + def evaluate(obj): + value = eval_inner(obj) + if value is None: + return None + return not value + return evaluate + raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator)) + + def visit_bindparam(self, clause): + val = clause.value + return lambda obj: val diff --git a/sqlalchemy/orm/exc.py b/sqlalchemy/orm/exc.py new file mode 100644 index 0000000..431acc1 --- /dev/null +++ b/sqlalchemy/orm/exc.py @@ -0,0 +1,98 @@ +# exc.py - ORM exceptions +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""SQLAlchemy ORM exceptions.""" + +import sqlalchemy as sa + + +NO_STATE = (AttributeError, KeyError) +"""Exception types that may be raised by instrumentation implementations.""" + +class ConcurrentModificationError(sa.exc.SQLAlchemyError): + """Rows have been modified outside of the unit of work.""" + + +class FlushError(sa.exc.SQLAlchemyError): + """A invalid condition was detected during flush().""" + + +class UnmappedError(sa.exc.InvalidRequestError): + """TODO""" + +class DetachedInstanceError(sa.exc.SQLAlchemyError): + """An attempt to access unloaded attributes on a mapped instance that is detached.""" + +class UnmappedInstanceError(UnmappedError): + """An mapping operation was requested for an unknown instance.""" + + def __init__(self, obj, msg=None): + if not msg: + try: + mapper = sa.orm.class_mapper(type(obj)) + name = _safe_cls_name(type(obj)) + msg = ("Class %r is mapped, but this instance lacks " + "instrumentation. This occurs when the instance is created " + "before sqlalchemy.orm.mapper(%s) was called." % (name, name)) + except UnmappedClassError: + msg = _default_unmapped(type(obj)) + if isinstance(obj, type): + msg += ( + '; was a class (%s) supplied where an instance was ' + 'required?' % _safe_cls_name(obj)) + UnmappedError.__init__(self, msg) + + +class UnmappedClassError(UnmappedError): + """An mapping operation was requested for an unknown class.""" + + def __init__(self, cls, msg=None): + if not msg: + msg = _default_unmapped(cls) + UnmappedError.__init__(self, msg) + + +class ObjectDeletedError(sa.exc.InvalidRequestError): + """An refresh() operation failed to re-retrieve an object's row.""" + + +class UnmappedColumnError(sa.exc.InvalidRequestError): + """Mapping operation was requested on an unknown column.""" + + +class NoResultFound(sa.exc.InvalidRequestError): + """A database result was required but none was found.""" + + +class MultipleResultsFound(sa.exc.InvalidRequestError): + """A single database result was required but more than one were found.""" + + +# Legacy compat until 0.6. +sa.exc.ConcurrentModificationError = ConcurrentModificationError +sa.exc.FlushError = FlushError +sa.exc.UnmappedColumnError + +def _safe_cls_name(cls): + try: + cls_name = '.'.join((cls.__module__, cls.__name__)) + except AttributeError: + cls_name = getattr(cls, '__name__', None) + if cls_name is None: + cls_name = repr(cls) + return cls_name + +def _default_unmapped(cls): + try: + mappers = sa.orm.attributes.manager_of_class(cls).mappers + except NO_STATE: + mappers = {} + except TypeError: + mappers = {} + name = _safe_cls_name(cls) + + if not mappers: + return "Class '%s' is not mapped" % name diff --git a/sqlalchemy/orm/identity.py b/sqlalchemy/orm/identity.py new file mode 100644 index 0000000..4650b06 --- /dev/null +++ b/sqlalchemy/orm/identity.py @@ -0,0 +1,251 @@ +# identity.py +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import weakref + +from sqlalchemy import util as base_util +from sqlalchemy.orm import attributes + + +class IdentityMap(dict): + def __init__(self): + self._mutable_attrs = set() + self._modified = set() + self._wr = weakref.ref(self) + + def replace(self, state): + raise NotImplementedError() + + def add(self, state): + raise NotImplementedError() + + def remove(self, state): + raise NotImplementedError() + + def update(self, dict): + raise NotImplementedError("IdentityMap uses add() to insert data") + + def clear(self): + raise NotImplementedError("IdentityMap uses remove() to remove data") + + def _manage_incoming_state(self, state): + state._instance_dict = self._wr + + if state.modified: + self._modified.add(state) + if state.manager.mutable_attributes: + self._mutable_attrs.add(state) + + def _manage_removed_state(self, state): + del state._instance_dict + self._mutable_attrs.discard(state) + self._modified.discard(state) + + def _dirty_states(self): + return self._modified.union(s for s in self._mutable_attrs.copy() + if s.modified) + + def check_modified(self): + """return True if any InstanceStates present have been marked as 'modified'.""" + + if self._modified: + return True + else: + for state in self._mutable_attrs.copy(): + if state.modified: + return True + return False + + def has_key(self, key): + return key in self + + def popitem(self): + raise NotImplementedError("IdentityMap uses remove() to remove data") + + def pop(self, key, *args): + raise NotImplementedError("IdentityMap uses remove() to remove data") + + def setdefault(self, key, default=None): + raise NotImplementedError("IdentityMap uses add() to insert data") + + def copy(self): + raise NotImplementedError() + + def __setitem__(self, key, value): + raise NotImplementedError("IdentityMap uses add() to insert data") + + def __delitem__(self, key): + raise NotImplementedError("IdentityMap uses remove() to remove data") + +class WeakInstanceDict(IdentityMap): + + def __getitem__(self, key): + state = dict.__getitem__(self, key) + o = state.obj() + if o is None: + o = state._is_really_none() + if o is None: + raise KeyError, key + return o + + def __contains__(self, key): + try: + if dict.__contains__(self, key): + state = dict.__getitem__(self, key) + o = state.obj() + if o is None: + o = state._is_really_none() + else: + return False + except KeyError: + return False + else: + return o is not None + + def contains_state(self, state): + return dict.get(self, state.key) is state + + def replace(self, state): + if dict.__contains__(self, state.key): + existing = dict.__getitem__(self, state.key) + if existing is not state: + self._manage_removed_state(existing) + else: + return + + dict.__setitem__(self, state.key, state) + self._manage_incoming_state(state) + + def add(self, state): + if state.key in self: + if dict.__getitem__(self, state.key) is not state: + raise AssertionError("A conflicting state is already " + "present in the identity map for key %r" + % (state.key, )) + else: + dict.__setitem__(self, state.key, state) + self._manage_incoming_state(state) + + def remove_key(self, key): + state = dict.__getitem__(self, key) + self.remove(state) + + def remove(self, state): + if dict.pop(self, state.key) is not state: + raise AssertionError("State %s is not present in this identity map" % state) + self._manage_removed_state(state) + + def discard(self, state): + if self.contains_state(state): + dict.__delitem__(self, state.key) + self._manage_removed_state(state) + + def get(self, key, default=None): + state = dict.get(self, key, default) + if state is default: + return default + o = state.obj() + if o is None: + o = state._is_really_none() + if o is None: + return default + return o + + # Py2K + def items(self): + return list(self.iteritems()) + + def iteritems(self): + for state in dict.itervalues(self): + # end Py2K + # Py3K + #def items(self): + # for state in dict.values(self): + value = state.obj() + if value is not None: + yield state.key, value + + # Py2K + def values(self): + return list(self.itervalues()) + + def itervalues(self): + for state in dict.itervalues(self): + # end Py2K + # Py3K + #def values(self): + # for state in dict.values(self): + instance = state.obj() + if instance is not None: + yield instance + + def all_states(self): + # Py3K + # return list(dict.values(self)) + + # Py2K + return dict.values(self) + # end Py2K + + def prune(self): + return 0 + +class StrongInstanceDict(IdentityMap): + def all_states(self): + return [attributes.instance_state(o) for o in self.itervalues()] + + def contains_state(self, state): + return state.key in self and attributes.instance_state(self[state.key]) is state + + def replace(self, state): + if dict.__contains__(self, state.key): + existing = dict.__getitem__(self, state.key) + existing = attributes.instance_state(existing) + if existing is not state: + self._manage_removed_state(existing) + else: + return + + dict.__setitem__(self, state.key, state.obj()) + self._manage_incoming_state(state) + + def add(self, state): + if state.key in self: + if attributes.instance_state(dict.__getitem__(self, state.key)) is not state: + raise AssertionError("A conflicting state is already present in the identity map for key %r" % (state.key, )) + else: + dict.__setitem__(self, state.key, state.obj()) + self._manage_incoming_state(state) + + def remove(self, state): + if attributes.instance_state(dict.pop(self, state.key)) is not state: + raise AssertionError("State %s is not present in this identity map" % state) + self._manage_removed_state(state) + + def discard(self, state): + if self.contains_state(state): + dict.__delitem__(self, state.key) + self._manage_removed_state(state) + + def remove_key(self, key): + state = attributes.instance_state(dict.__getitem__(self, key)) + self.remove(state) + + def prune(self): + """prune unreferenced, non-dirty states.""" + + ref_count = len(self) + dirty = [s.obj() for s in self.all_states() if s.modified] + + # work around http://bugs.python.org/issue6149 + keepers = weakref.WeakValueDictionary() + keepers.update(self) + + dict.clear(self) + dict.update(self, keepers) + self.modified = bool(dirty) + return ref_count - len(self) + diff --git a/sqlalchemy/orm/interfaces.py b/sqlalchemy/orm/interfaces.py new file mode 100644 index 0000000..7fbb086 --- /dev/null +++ b/sqlalchemy/orm/interfaces.py @@ -0,0 +1,1098 @@ +# interfaces.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +Semi-private module containing various base classes used throughout the ORM. + +Defines the extension classes :class:`MapperExtension`, +:class:`SessionExtension`, and :class:`AttributeExtension` as +well as other user-subclassable extension objects. + +""" + +from itertools import chain + +import sqlalchemy.exceptions as sa_exc +from sqlalchemy import log, util +from sqlalchemy.sql import expression + +class_mapper = None +collections = None + +__all__ = ( + 'AttributeExtension', + 'EXT_CONTINUE', + 'EXT_STOP', + 'ExtensionOption', + 'InstrumentationManager', + 'LoaderStrategy', + 'MapperExtension', + 'MapperOption', + 'MapperProperty', + 'PropComparator', + 'PropertyOption', + 'SessionExtension', + 'StrategizedOption', + 'StrategizedProperty', + 'build_path', + ) + +EXT_CONTINUE = util.symbol('EXT_CONTINUE') +EXT_STOP = util.symbol('EXT_STOP') + +ONETOMANY = util.symbol('ONETOMANY') +MANYTOONE = util.symbol('MANYTOONE') +MANYTOMANY = util.symbol('MANYTOMANY') + +class MapperExtension(object): + """Base implementation for customizing ``Mapper`` behavior. + + New extension classes subclass ``MapperExtension`` and are specified + using the ``extension`` mapper() argument, which is a single + ``MapperExtension`` or a list of such. A single mapper + can maintain a chain of ``MapperExtension`` objects. When a + particular mapping event occurs, the corresponding method + on each ``MapperExtension`` is invoked serially, and each method + has the ability to halt the chain from proceeding further. + + Each ``MapperExtension`` method returns the symbol + EXT_CONTINUE by default. This symbol generally means "move + to the next ``MapperExtension`` for processing". For methods + that return objects like translated rows or new object + instances, EXT_CONTINUE means the result of the method + should be ignored. In some cases it's required for a + default mapper activity to be performed, such as adding a + new instance to a result list. + + The symbol EXT_STOP has significance within a chain + of ``MapperExtension`` objects that the chain will be stopped + when this symbol is returned. Like EXT_CONTINUE, it also + has additional significance in some cases that a default + mapper activity will not be performed. + + """ + def instrument_class(self, mapper, class_): + """Receive a class when the mapper is first constructed, and has + applied instrumentation to the mapped class. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + return EXT_CONTINUE + + def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): + """Receive an instance when it's constructor is called. + + This method is only called during a userland construction of + an object. It is not called when an object is loaded from the + database. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + return EXT_CONTINUE + + def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): + """Receive an instance when it's constructor has been called, + and raised an exception. + + This method is only called during a userland construction of + an object. It is not called when an object is loaded from the + database. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + return EXT_CONTINUE + + def translate_row(self, mapper, context, row): + """Perform pre-processing on the given result row and return a + new row instance. + + This is called when the mapper first receives a row, before + the object identity or the instance itself has been derived + from that row. The given row may or may not be a + ``RowProxy`` object - it will always be a dictionary-like + object which contains mapped columns as keys. The + returned object should also be a dictionary-like object + which recognizes mapped columns as keys. + + If the ultimate return value is EXT_CONTINUE, the row + is not translated. + + """ + return EXT_CONTINUE + + def create_instance(self, mapper, selectcontext, row, class_): + """Receive a row when a new object instance is about to be + created from that row. + + The method can choose to create the instance itself, or it can return + EXT_CONTINUE to indicate normal object creation should take place. + + mapper + The mapper doing the operation + + selectcontext + The QueryContext generated from the Query. + + row + The result row from the database + + class\_ + The class we are mapping. + + return value + A new object instance, or EXT_CONTINUE + + """ + return EXT_CONTINUE + + def append_result(self, mapper, selectcontext, row, instance, result, **flags): + """Receive an object instance before that instance is appended + to a result list. + + If this method returns EXT_CONTINUE, result appending will proceed + normally. if this method returns any other value or None, + result appending will not proceed for this instance, giving + this extension an opportunity to do the appending itself, if + desired. + + mapper + The mapper doing the operation. + + selectcontext + The QueryContext generated from the Query. + + row + The result row from the database. + + instance + The object instance to be appended to the result. + + result + List to which results are being appended. + + \**flags + extra information about the row, same as criterion in + ``create_row_processor()`` method of :class:`~sqlalchemy.orm.interfaces.MapperProperty` + """ + + return EXT_CONTINUE + + def populate_instance(self, mapper, selectcontext, row, instance, **flags): + """Receive an instance before that instance has + its attributes populated. + + This usually corresponds to a newly loaded instance but may + also correspond to an already-loaded instance which has + unloaded attributes to be populated. The method may be called + many times for a single instance, as multiple result rows are + used to populate eagerly loaded collections. + + If this method returns EXT_CONTINUE, instance population will + proceed normally. If any other value or None is returned, + instance population will not proceed, giving this extension an + opportunity to populate the instance itself, if desired. + + As of 0.5, most usages of this hook are obsolete. For a + generic "object has been newly created from a row" hook, use + ``reconstruct_instance()``, or the ``@orm.reconstructor`` + decorator. + + """ + return EXT_CONTINUE + + def reconstruct_instance(self, mapper, instance): + """Receive an object instance after it has been created via + ``__new__``, and after initial attribute population has + occurred. + + This typically occurs when the instance is created based on + incoming result rows, and is only called once for that + instance's lifetime. + + Note that during a result-row load, this method is called upon + the first row received for this instance. Note that some + attributes and collections may or may not be loaded or even + initialized, depending on what's present in the result rows. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + return EXT_CONTINUE + + def before_insert(self, mapper, connection, instance): + """Receive an object instance before that instance is inserted + into its table. + + This is a good place to set up primary key values and such + that aren't handled otherwise. + + Column-based attributes can be modified within this method + which will result in the new value being inserted. However + *no* changes to the overall flush plan can be made, and + manipulation of the ``Session`` will not have the desired effect. + To manipulate the ``Session`` within an extension, use + ``SessionExtension``. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + + def after_insert(self, mapper, connection, instance): + """Receive an object instance after that instance is inserted. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + + def before_update(self, mapper, connection, instance): + """Receive an object instance before that instance is updated. + + Note that this method is called for all instances that are marked as + "dirty", even those which have no net changes to their column-based + attributes. An object is marked as dirty when any of its column-based + attributes have a "set attribute" operation called or when any of its + collections are modified. If, at update time, no column-based attributes + have any net changes, no UPDATE statement will be issued. This means + that an instance being sent to before_update is *not* a guarantee that + an UPDATE statement will be issued (although you can affect the outcome + here). + + To detect if the column-based attributes on the object have net changes, + and will therefore generate an UPDATE statement, use + ``object_session(instance).is_modified(instance, include_collections=False)``. + + Column-based attributes can be modified within this method + which will result in the new value being updated. However + *no* changes to the overall flush plan can be made, and + manipulation of the ``Session`` will not have the desired effect. + To manipulate the ``Session`` within an extension, use + ``SessionExtension``. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + + def after_update(self, mapper, connection, instance): + """Receive an object instance after that instance is updated. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + + def before_delete(self, mapper, connection, instance): + """Receive an object instance before that instance is deleted. + + Note that *no* changes to the overall flush plan can be made + here; and manipulation of the ``Session`` will not have the + desired effect. To manipulate the ``Session`` within an + extension, use ``SessionExtension``. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + + def after_delete(self, mapper, connection, instance): + """Receive an object instance after that instance is deleted. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + +class SessionExtension(object): + """An extension hook object for Sessions. Subclasses may be installed into a Session + (or sessionmaker) using the ``extension`` keyword argument. + """ + + def before_commit(self, session): + """Execute right before commit is called. + + Note that this may not be per-flush if a longer running transaction is ongoing.""" + + def after_commit(self, session): + """Execute after a commit has occured. + + Note that this may not be per-flush if a longer running transaction is ongoing.""" + + def after_rollback(self, session): + """Execute after a rollback has occured. + + Note that this may not be per-flush if a longer running transaction is ongoing.""" + + def before_flush(self, session, flush_context, instances): + """Execute before flush process has started. + + `instances` is an optional list of objects which were passed to the ``flush()`` + method. + """ + + def after_flush(self, session, flush_context): + """Execute after flush has completed, but before commit has been called. + + Note that the session's state is still in pre-flush, i.e. 'new', 'dirty', + and 'deleted' lists still show pre-flush state as well as the history + settings on instance attributes.""" + + def after_flush_postexec(self, session, flush_context): + """Execute after flush has completed, and after the post-exec state occurs. + + This will be when the 'new', 'dirty', and 'deleted' lists are in their final + state. An actual commit() may or may not have occured, depending on whether or not + the flush started its own transaction or participated in a larger transaction. + """ + + def after_begin(self, session, transaction, connection): + """Execute after a transaction is begun on a connection + + `transaction` is the SessionTransaction. This method is called after an + engine level transaction is begun on a connection. + """ + + def after_attach(self, session, instance): + """Execute after an instance is attached to a session. + + This is called after an add, delete or merge. + """ + + def after_bulk_update(self, session, query, query_context, result): + """Execute after a bulk update operation to the session. + + This is called after a session.query(...).update() + + `query` is the query object that this update operation was called on. + `query_context` was the query context object. + `result` is the result object returned from the bulk operation. + """ + + def after_bulk_delete(self, session, query, query_context, result): + """Execute after a bulk delete operation to the session. + + This is called after a session.query(...).delete() + + `query` is the query object that this delete operation was called on. + `query_context` was the query context object. + `result` is the result object returned from the bulk operation. + """ + +class MapperProperty(object): + """Manage the relationship of a ``Mapper`` to a single class + attribute, as well as that attribute as it appears on individual + instances of the class, including attribute instrumentation, + attribute access, loading behavior, and dependency calculations. + """ + + def setup(self, context, entity, path, adapter, **kwargs): + """Called by Query for the purposes of constructing a SQL statement. + + Each MapperProperty associated with the target mapper processes the + statement referenced by the query context, adding columns and/or + criterion as appropriate. + """ + + pass + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + """Return a 2-tuple consiting of two row processing functions and + an instance post-processing function. + + Input arguments are the query.SelectionContext and the *first* + applicable row of a result set obtained within + query.Query.instances(), called only the first time a particular + mapper's populate_instance() method is invoked for the overall result. + + The settings contained within the SelectionContext as well as the + columns present in the row (which will be the same columns present in + all rows) are used to determine the presence and behavior of the + returned callables. The callables will then be used to process all + rows and instances. + + Callables are of the following form:: + + def new_execute(state, dict_, row, isnew): + # process incoming instance state and given row. the instance is + # "new" and was just created upon receipt of this row. + "isnew" indicates if the instance was newly created as a + result of reading this row + + def existing_execute(state, dict_, row): + # process incoming instance state and given row. the instance is + # "existing" and was created based on a previous row. + + return (new_execute, existing_execute) + + Either of the three tuples can be ``None`` in which case no function + is called. + """ + + raise NotImplementedError() + + def cascade_iterator(self, type_, state, visited_instances=None, halt_on=None): + """Iterate through instances related to the given instance for + a particular 'cascade', starting with this MapperProperty. + + See PropertyLoader for the related instance implementation. + """ + + return iter(()) + + def set_parent(self, parent): + self.parent = parent + + def instrument_class(self, mapper): + raise NotImplementedError() + + _compile_started = False + _compile_finished = False + + def init(self): + """Called after all mappers are created to assemble + relationships between mappers and perform other post-mapper-creation + initialization steps. + + """ + self._compile_started = True + self.do_init() + self._compile_finished = True + + @property + def class_attribute(self): + """Return the class-bound descriptor corresponding to this MapperProperty.""" + + return getattr(self.parent.class_, self.key) + + def do_init(self): + """Perform subclass-specific initialization post-mapper-creation steps. + + This is a *template* method called by the + ``MapperProperty`` object's init() method. + + """ + pass + + def post_instrument_class(self, mapper): + """Perform instrumentation adjustments that need to occur + after init() has completed. + + """ + pass + + def register_dependencies(self, *args, **kwargs): + """Called by the ``Mapper`` in response to the UnitOfWork + calling the ``Mapper``'s register_dependencies operation. + Establishes a topological dependency between two mappers + which will affect the order in which mappers persist data. + + """ + + pass + + def register_processors(self, *args, **kwargs): + """Called by the ``Mapper`` in response to the UnitOfWork + calling the ``Mapper``'s register_processors operation. + Establishes a processor object between two mappers which + will link data and state between parent/child objects. + + """ + + pass + + def is_primary(self): + """Return True if this ``MapperProperty``'s mapper is the + primary mapper for its class. + + This flag is used to indicate that the ``MapperProperty`` can + define attribute instrumentation for the class at the class + level (as opposed to the individual instance level). + """ + + return not self.parent.non_primary + + def merge(self, session, source, dest, load, _recursive): + """Merge the attribute represented by this ``MapperProperty`` + from source to destination object""" + + raise NotImplementedError() + + def compare(self, operator, value): + """Return a compare operation for the columns represented by + this ``MapperProperty`` to the given value, which may be a + column value or an instance. 'operator' is an operator from + the operators module, or from sql.Comparator. + + By default uses the PropComparator attached to this MapperProperty + under the attribute name "comparator". + """ + + return operator(self.comparator, value) + +class PropComparator(expression.ColumnOperators): + """defines comparison operations for MapperProperty objects. + + PropComparator instances should also define an accessor 'property' + which returns the MapperProperty associated with this + PropComparator. + """ + + def __init__(self, prop, mapper, adapter=None): + self.prop = self.property = prop + self.mapper = mapper + self.adapter = adapter + + def __clause_element__(self): + raise NotImplementedError("%r" % self) + + def adapted(self, adapter): + """Return a copy of this PropComparator which will use the given adaption function + on the local side of generated expressions. + + """ + return self.__class__(self.prop, self.mapper, adapter) + + @staticmethod + def any_op(a, b, **kwargs): + return a.any(b, **kwargs) + + @staticmethod + def has_op(a, b, **kwargs): + return a.has(b, **kwargs) + + @staticmethod + def of_type_op(a, class_): + return a.of_type(class_) + + def of_type(self, class_): + """Redefine this object in terms of a polymorphic subclass. + + Returns a new PropComparator from which further criterion can be evaluated. + + e.g.:: + + query.join(Company.employees.of_type(Engineer)).\\ + filter(Engineer.name=='foo') + + \class_ + a class or mapper indicating that criterion will be against + this specific subclass. + + + """ + + return self.operate(PropComparator.of_type_op, class_) + + def any(self, criterion=None, **kwargs): + """Return true if this collection contains any member that meets the given criterion. + + criterion + an optional ClauseElement formulated against the member class' table + or attributes. + + \**kwargs + key/value pairs corresponding to member class attribute names which + will be compared via equality to the corresponding values. + """ + + return self.operate(PropComparator.any_op, criterion, **kwargs) + + def has(self, criterion=None, **kwargs): + """Return true if this element references a member which meets the given criterion. + + criterion + an optional ClauseElement formulated against the member class' table + or attributes. + + \**kwargs + key/value pairs corresponding to member class attribute names which + will be compared via equality to the corresponding values. + """ + + return self.operate(PropComparator.has_op, criterion, **kwargs) + + +class StrategizedProperty(MapperProperty): + """A MapperProperty which uses selectable strategies to affect + loading behavior. + + There is a single strategy selected by default. Alternate + strategies can be selected at Query time through the usage of + ``StrategizedOption`` objects via the Query.options() method. + + """ + + def _get_context_strategy(self, context, path): + cls = context.attributes.get(("loaderstrategy", _reduce_path(path)), None) + if cls: + try: + return self.__all_strategies[cls] + except KeyError: + return self.__init_strategy(cls) + else: + return self.strategy + + def _get_strategy(self, cls): + try: + return self.__all_strategies[cls] + except KeyError: + return self.__init_strategy(cls) + + def __init_strategy(self, cls): + self.__all_strategies[cls] = strategy = cls(self) + strategy.init() + return strategy + + def setup(self, context, entity, path, adapter, **kwargs): + self._get_context_strategy(context, path + (self.key,)).\ + setup_query(context, entity, path, adapter, **kwargs) + + def create_row_processor(self, context, path, mapper, row, adapter): + return self._get_context_strategy(context, path + (self.key,)).\ + create_row_processor(context, path, mapper, row, adapter) + + def do_init(self): + self.__all_strategies = {} + self.strategy = self.__init_strategy(self.strategy_class) + + def post_instrument_class(self, mapper): + if self.is_primary(): + self.strategy.init_class_attribute(mapper) + +def build_path(entity, key, prev=None): + if prev: + return prev + (entity, key) + else: + return (entity, key) + +def serialize_path(path): + if path is None: + return None + + return zip( + [m.class_ for m in [path[i] for i in range(0, len(path), 2)]], + [path[i] for i in range(1, len(path), 2)] + [None] + ) + +def deserialize_path(path): + if path is None: + return None + + global class_mapper + if class_mapper is None: + from sqlalchemy.orm import class_mapper + + p = tuple(chain(*[(class_mapper(cls), key) for cls, key in path])) + if p and p[-1] is None: + p = p[0:-1] + return p + +class MapperOption(object): + """Describe a modification to a Query.""" + + propagate_to_loaders = False + """if True, indicate this option should be carried along + Query object generated by scalar or object lazy loaders. + """ + + def process_query(self, query): + pass + + def process_query_conditionally(self, query): + """same as process_query(), except that this option may not apply + to the given query. + + Used when secondary loaders resend existing options to a new + Query.""" + self.process_query(query) + +class ExtensionOption(MapperOption): + """a MapperOption that applies a MapperExtension to a query operation.""" + + def __init__(self, ext): + self.ext = ext + + def process_query(self, query): + entity = query._generate_mapper_zero() + entity.extension = entity.extension.copy() + entity.extension.push(self.ext) + +class PropertyOption(MapperOption): + """A MapperOption that is applied to a property off the mapper or + one of its child mappers, identified by a dot-separated key. + """ + + def __init__(self, key, mapper=None): + self.key = key + self.mapper = mapper + + def process_query(self, query): + self._process(query, True) + + def process_query_conditionally(self, query): + self._process(query, False) + + def _process(self, query, raiseerr): + paths, mappers = self._get_paths(query, raiseerr) + if paths: + self.process_query_property(query, paths, mappers) + + def process_query_property(self, query, paths, mappers): + pass + + def __getstate__(self): + d = self.__dict__.copy() + d['key'] = ret = [] + for token in util.to_list(self.key): + if isinstance(token, PropComparator): + ret.append((token.mapper.class_, token.key)) + else: + ret.append(token) + return d + + def __setstate__(self, state): + ret = [] + for key in state['key']: + if isinstance(key, tuple): + cls, propkey = key + ret.append(getattr(cls, propkey)) + else: + ret.append(key) + state['key'] = tuple(ret) + self.__dict__ = state + + def _find_entity(self, query, mapper, raiseerr): + from sqlalchemy.orm.util import _class_to_mapper, _is_aliased_class + + if _is_aliased_class(mapper): + searchfor = mapper + isa = False + else: + searchfor = _class_to_mapper(mapper) + isa = True + + for ent in query._mapper_entities: + if searchfor is ent.path_entity or ( + isa and + searchfor.common_parent(ent.path_entity)): + return ent + else: + if raiseerr: + raise sa_exc.ArgumentError( + "Can't find entity %s in Query. Current list: %r" + % (searchfor, [ + str(m.path_entity) for m in query._entities + ])) + else: + return None + + def _get_paths(self, query, raiseerr): + path = None + entity = None + l = [] + mappers = [] + + # _current_path implies we're in a secondary load + # with an existing path + current_path = list(query._current_path) + + tokens = [] + for key in util.to_list(self.key): + if isinstance(key, basestring): + tokens += key.split('.') + else: + tokens += [key] + + for token in tokens: + if isinstance(token, basestring): + if not entity: + if current_path: + if current_path[1] == token: + current_path = current_path[2:] + continue + + entity = query._entity_zero() + path_element = entity.path_entity + mapper = entity.mapper + mappers.append(mapper) + prop = mapper.get_property( + token, + resolve_synonyms=True, + raiseerr=raiseerr) + key = token + elif isinstance(token, PropComparator): + prop = token.property + if not entity: + if current_path: + if current_path[0:2] == [token.parententity, prop.key]: + current_path = current_path[2:] + continue + + entity = self._find_entity( + query, + token.parententity, + raiseerr) + if not entity: + return [], [] + path_element = entity.path_entity + mapper = entity.mapper + mappers.append(prop.parent) + key = prop.key + else: + raise sa_exc.ArgumentError("mapper option expects string key " + "or list of attributes") + + if prop is None: + return [], [] + + path = build_path(path_element, prop.key, path) + l.append(path) + if getattr(token, '_of_type', None): + path_element = mapper = token._of_type + else: + path_element = mapper = getattr(prop, 'mapper', None) + + if path_element: + path_element = path_element + + + # if current_path tokens remain, then + # we didn't have an exact path match. + if current_path: + return [], [] + + return l, mappers + +class AttributeExtension(object): + """An event handler for individual attribute change events. + + AttributeExtension is assembled within the descriptors associated + with a mapped class. + + """ + + active_history = True + """indicates that the set() method would like to receive the 'old' value, + even if it means firing lazy callables. + """ + + def append(self, state, value, initiator): + """Receive a collection append event. + + The returned value will be used as the actual value to be + appended. + + """ + return value + + def remove(self, state, value, initiator): + """Receive a remove event. + + No return value is defined. + + """ + pass + + def set(self, state, value, oldvalue, initiator): + """Receive a set event. + + The returned value will be used as the actual value to be + set. + + """ + return value + + +class StrategizedOption(PropertyOption): + """A MapperOption that affects which LoaderStrategy will be used + for an operation by a StrategizedProperty. + """ + + is_chained = False + + def process_query_property(self, query, paths, mappers): + # _get_context_strategy may receive the path in terms of + # a base mapper - e.g. options(eagerload_all(Company.employees, Engineer.machines)) + # in the polymorphic tests leads to "(Person, 'machines')" in + # the path due to the mechanics of how the eager strategy builds + # up the path + if self.is_chained: + for path in paths: + query._attributes[("loaderstrategy", _reduce_path(path))] = \ + self.get_strategy_class() + else: + query._attributes[("loaderstrategy", _reduce_path(paths[-1]))] = \ + self.get_strategy_class() + + def get_strategy_class(self): + raise NotImplementedError() + +def _reduce_path(path): + """Convert a (mapper, path) path to use base mappers. + + This is used to allow more open ended selection of loader strategies, i.e. + Mapper -> prop1 -> Subclass -> prop2, where Subclass is a sub-mapper + of the mapper referened by Mapper.prop1. + + """ + return tuple([i % 2 != 0 and + path[i] or + getattr(path[i], 'base_mapper', path[i]) + for i in xrange(len(path))]) + +class LoaderStrategy(object): + """Describe the loading behavior of a StrategizedProperty object. + + The ``LoaderStrategy`` interacts with the querying process in three + ways: + + * it controls the configuration of the ``InstrumentedAttribute`` + placed on a class to handle the behavior of the attribute. this + may involve setting up class-level callable functions to fire + off a select operation when the attribute is first accessed + (i.e. a lazy load) + + * it processes the ``QueryContext`` at statement construction time, + where it can modify the SQL statement that is being produced. + simple column attributes may add their represented column to the + list of selected columns, *eager loading* properties may add + ``LEFT OUTER JOIN`` clauses to the statement. + + * it processes the ``SelectionContext`` at row-processing time. This + includes straight population of attributes corresponding to rows, + setting instance-level lazyloader callables on newly + constructed instances, and appending child items to scalar/collection + attributes in response to eagerly-loaded relations. + """ + + def __init__(self, parent): + self.parent_property = parent + self.is_class_level = False + self.parent = self.parent_property.parent + self.key = self.parent_property.key + + def init(self): + raise NotImplementedError("LoaderStrategy") + + def init_class_attribute(self, mapper): + pass + + def setup_query(self, context, entity, path, adapter, **kwargs): + pass + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + """Return row processing functions which fulfill the contract specified + by MapperProperty.create_row_processor. + + StrategizedProperty delegates its create_row_processor method directly + to this method. + """ + + raise NotImplementedError() + + def __str__(self): + return str(self.parent_property) + + def debug_callable(self, fn, logger, announcement, logfn): + if announcement: + logger.debug(announcement) + if logfn: + def call(*args, **kwargs): + logger.debug(logfn(*args, **kwargs)) + return fn(*args, **kwargs) + return call + else: + return fn + +class InstrumentationManager(object): + """User-defined class instrumentation extension. + + The API for this class should be considered as semi-stable, + and may change slightly with new releases. + + """ + + # r4361 added a mandatory (cls) constructor to this interface. + # given that, perhaps class_ should be dropped from all of these + # signatures. + + def __init__(self, class_): + pass + + def manage(self, class_, manager): + setattr(class_, '_default_class_manager', manager) + + def dispose(self, class_, manager): + delattr(class_, '_default_class_manager') + + def manager_getter(self, class_): + def get(cls): + return cls._default_class_manager + return get + + def instrument_attribute(self, class_, key, inst): + pass + + def post_configure_attribute(self, class_, key, inst): + pass + + def install_descriptor(self, class_, key, inst): + setattr(class_, key, inst) + + def uninstall_descriptor(self, class_, key): + delattr(class_, key) + + def install_member(self, class_, key, implementation): + setattr(class_, key, implementation) + + def uninstall_member(self, class_, key): + delattr(class_, key) + + def instrument_collection_class(self, class_, key, collection_class): + global collections + if collections is None: + from sqlalchemy.orm import collections + return collections.prepare_instrumentation(collection_class) + + def get_instance_dict(self, class_, instance): + return instance.__dict__ + + def initialize_instance_dict(self, class_, instance): + pass + + def install_state(self, class_, instance, state): + setattr(instance, '_default_state', state) + + def remove_state(self, class_, instance): + delattr(instance, '_default_state', state) + + def state_getter(self, class_): + return lambda instance: getattr(instance, '_default_state') + + def dict_getter(self, class_): + return lambda inst: self.get_instance_dict(class_, inst) + \ No newline at end of file diff --git a/sqlalchemy/orm/mapper.py b/sqlalchemy/orm/mapper.py new file mode 100644 index 0000000..8f0f212 --- /dev/null +++ b/sqlalchemy/orm/mapper.py @@ -0,0 +1,1958 @@ +# mapper.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Logic to map Python classes to and from selectables. + +Defines the :class:`~sqlalchemy.orm.mapper.Mapper` class, the central configurational +unit which associates a class with a database table. + +This is a semi-private module; the main configurational API of the ORM is +available in :class:`~sqlalchemy.orm.`. + +""" + +import types +import weakref +import operator +from itertools import chain +deque = __import__('collections').deque + +from sqlalchemy import sql, util, log, exc as sa_exc +from sqlalchemy.sql import expression, visitors, operators, util as sqlutil +from sqlalchemy.orm import attributes, sync, exc as orm_exc +from sqlalchemy.orm.interfaces import ( + MapperProperty, EXT_CONTINUE, PropComparator + ) +from sqlalchemy.orm.util import ( + ExtensionCarrier, _INSTRUMENTOR, _class_to_mapper, _state_has_identity, + _state_mapper, class_mapper, instance_str, state_str, + ) + +__all__ = ( + 'Mapper', + '_mapper_registry', + 'class_mapper', + 'object_mapper', + ) + +_mapper_registry = weakref.WeakKeyDictionary() +_new_mappers = False +_already_compiling = False +_none_set = frozenset([None]) + +# a list of MapperExtensions that will be installed in all mappers by default +global_extensions = [] + +# a constant returned by _get_attr_by_column to indicate +# this mapper is not handling an attribute for a particular +# column +NO_ATTRIBUTE = util.symbol('NO_ATTRIBUTE') + +# lock used to synchronize the "mapper compile" step +_COMPILE_MUTEX = util.threading.RLock() + +# initialize these lazily +ColumnProperty = None +SynonymProperty = None +ComparableProperty = None +RelationshipProperty = None +ConcreteInheritedProperty = None +_expire_state = None +_state_session = None + +class Mapper(object): + """Define the correlation of class attributes to database table + columns. + + Instances of this class should be constructed via the + :func:`~sqlalchemy.orm.mapper` function. + + """ + def __init__(self, + class_, + local_table, + properties = None, + primary_key = None, + non_primary = False, + inherits = None, + inherit_condition = None, + inherit_foreign_keys = None, + extension = None, + order_by = False, + always_refresh = False, + version_id_col = None, + version_id_generator = None, + polymorphic_on=None, + _polymorphic_map=None, + polymorphic_identity=None, + concrete=False, + with_polymorphic=None, + allow_null_pks=None, + allow_partial_pks=True, + batch=True, + column_prefix=None, + include_properties=None, + exclude_properties=None, + passive_updates=True, + eager_defaults=False): + """Construct a new mapper. + + Mappers are normally constructed via the :func:`~sqlalchemy.orm.mapper` + function. See for details. + + """ + + self.class_ = util.assert_arg_type(class_, type, 'class_') + + self.class_manager = None + + self.primary_key_argument = primary_key + self.non_primary = non_primary + + if order_by is not False: + self.order_by = util.to_list(order_by) + else: + self.order_by = order_by + + self.always_refresh = always_refresh + self.version_id_col = version_id_col + self.version_id_generator = version_id_generator or (lambda x:(x or 0) + 1) + self.concrete = concrete + self.single = False + self.inherits = inherits + self.local_table = local_table + self.inherit_condition = inherit_condition + self.inherit_foreign_keys = inherit_foreign_keys + self.extension = extension + self._init_properties = properties or {} + self.delete_orphans = [] + self.batch = batch + self.eager_defaults = eager_defaults + self.column_prefix = column_prefix + self.polymorphic_on = polymorphic_on + self._dependency_processors = [] + self._validators = {} + self.passive_updates = passive_updates + self._clause_adapter = None + self._requires_row_aliasing = False + self._inherits_equated_pairs = None + + if allow_null_pks: + util.warn_deprecated('the allow_null_pks option to Mapper() is ' + 'deprecated. It is now allow_partial_pks=False|True, ' + 'defaults to True.') + allow_partial_pks = allow_null_pks + + self.allow_partial_pks = allow_partial_pks + + if with_polymorphic == '*': + self.with_polymorphic = ('*', None) + elif isinstance(with_polymorphic, (tuple, list)): + if isinstance(with_polymorphic[0], (basestring, tuple, list)): + self.with_polymorphic = with_polymorphic + else: + self.with_polymorphic = (with_polymorphic, None) + elif with_polymorphic is not None: + raise sa_exc.ArgumentError("Invalid setting for with_polymorphic") + else: + self.with_polymorphic = None + + if isinstance(self.local_table, expression._SelectBaseMixin): + raise sa_exc.InvalidRequestError( + "When mapping against a select() construct, map against " + "an alias() of the construct instead." + "This because several databases don't allow a " + "SELECT from a subquery that does not have an alias." + ) + + if self.with_polymorphic and \ + isinstance(self.with_polymorphic[1], expression._SelectBaseMixin): + self.with_polymorphic = (self.with_polymorphic[0], self.with_polymorphic[1].alias()) + + # our 'polymorphic identity', a string name that when located in a result set row + # indicates this Mapper should be used to construct the object instance for that row. + self.polymorphic_identity = polymorphic_identity + + # a dictionary of 'polymorphic identity' names, associating those names with + # Mappers that will be used to construct object instances upon a select operation. + if _polymorphic_map is None: + self.polymorphic_map = {} + else: + self.polymorphic_map = _polymorphic_map + + self.include_properties = include_properties + self.exclude_properties = exclude_properties + + self.compiled = False + + # prevent this mapper from being constructed + # while a compile() is occuring (and defer a compile() + # until construction succeeds) + _COMPILE_MUTEX.acquire() + try: + self._configure_inheritance() + self._configure_extensions() + self._configure_class_instrumentation() + self._configure_properties() + self._configure_pks() + global _new_mappers + _new_mappers = True + self._log("constructed") + finally: + _COMPILE_MUTEX.release() + + def _configure_inheritance(self): + """Configure settings related to inherting and/or inherited mappers being present.""" + + # a set of all mappers which inherit from this one. + self._inheriting_mappers = set() + + if self.inherits: + if isinstance(self.inherits, type): + self.inherits = class_mapper(self.inherits, compile=False) + if not issubclass(self.class_, self.inherits.class_): + raise sa_exc.ArgumentError( + "Class '%s' does not inherit from '%s'" % + (self.class_.__name__, self.inherits.class_.__name__)) + if self.non_primary != self.inherits.non_primary: + np = not self.non_primary and "primary" or "non-primary" + raise sa_exc.ArgumentError("Inheritance of %s mapper for class '%s' is " + "only allowed from a %s mapper" % (np, self.class_.__name__, np)) + # inherit_condition is optional. + if self.local_table is None: + self.local_table = self.inherits.local_table + self.mapped_table = self.inherits.mapped_table + self.single = True + elif not self.local_table is self.inherits.local_table: + if self.concrete: + self.mapped_table = self.local_table + for mapper in self.iterate_to_root(): + if mapper.polymorphic_on is not None: + mapper._requires_row_aliasing = True + else: + if self.inherit_condition is None: + # figure out inherit condition from our table to the immediate table + # of the inherited mapper, not its full table which could pull in other + # stuff we dont want (allows test/inheritance.InheritTest4 to pass) + self.inherit_condition = sqlutil.join_condition(self.inherits.local_table, self.local_table) + self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition) + + fks = util.to_set(self.inherit_foreign_keys) + self._inherits_equated_pairs = \ + sqlutil.criterion_as_pairs(self.mapped_table.onclause, + consider_as_foreign_keys=fks) + else: + self.mapped_table = self.local_table + + if self.polymorphic_identity is not None and not self.concrete: + self._identity_class = self.inherits._identity_class + else: + self._identity_class = self.class_ + + if self.version_id_col is None: + self.version_id_col = self.inherits.version_id_col + self.version_id_generator = self.inherits.version_id_generator + + for mapper in self.iterate_to_root(): + util.reset_memoized(mapper, '_equivalent_columns') + util.reset_memoized(mapper, '_sorted_tables') + + if self.order_by is False and not self.concrete and self.inherits.order_by is not False: + self.order_by = self.inherits.order_by + + self.polymorphic_map = self.inherits.polymorphic_map + self.batch = self.inherits.batch + self.inherits._inheriting_mappers.add(self) + self.base_mapper = self.inherits.base_mapper + self.passive_updates = self.inherits.passive_updates + self._all_tables = self.inherits._all_tables + + if self.polymorphic_identity is not None: + self.polymorphic_map[self.polymorphic_identity] = self + if self.polymorphic_on is None: + for mapper in self.iterate_to_root(): + # try to set up polymorphic on using correesponding_column(); else leave + # as None + if mapper.polymorphic_on is not None: + self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on) + break + else: + self._all_tables = set() + self.base_mapper = self + self.mapped_table = self.local_table + if self.polymorphic_identity is not None: + self.polymorphic_map[self.polymorphic_identity] = self + self._identity_class = self.class_ + + if self.mapped_table is None: + raise sa_exc.ArgumentError("Mapper '%s' does not have a mapped_table specified." % self) + + def _configure_extensions(self): + """Go through the global_extensions list as well as the list + of ``MapperExtensions`` specified for this ``Mapper`` and + creates a linked list of those extensions. + + """ + extlist = util.OrderedSet() + + extension = self.extension + if extension: + for ext_obj in util.to_list(extension): + # local MapperExtensions have already instrumented the class + extlist.add(ext_obj) + + if self.inherits: + for ext in self.inherits.extension: + if ext not in extlist: + extlist.add(ext) + else: + for ext in global_extensions: + if isinstance(ext, type): + ext = ext() + if ext not in extlist: + extlist.add(ext) + + self.extension = ExtensionCarrier() + for ext in extlist: + self.extension.append(ext) + + def _configure_class_instrumentation(self): + """If this mapper is to be a primary mapper (i.e. the + non_primary flag is not set), associate this Mapper with the + given class_ and entity name. + + Subsequent calls to ``class_mapper()`` for the class_/entity + name combination will return this mapper. Also decorate the + `__init__` method on the mapped class to include optional + auto-session attachment logic. + + """ + manager = attributes.manager_of_class(self.class_) + + if self.non_primary: + if not manager or manager.mapper is None: + raise sa_exc.InvalidRequestError( + "Class %s has no primary mapper configured. Configure " + "a primary mapper first before setting up a non primary " + "Mapper.") + self.class_manager = manager + _mapper_registry[self] = True + return + + if manager is not None: + assert manager.class_ is self.class_ + if manager.mapper: + raise sa_exc.ArgumentError( + "Class '%s' already has a primary mapper defined. " + "Use non_primary=True to " + "create a non primary Mapper. clear_mappers() will " + "remove *all* current mappers from all classes." % + self.class_) + #else: + # a ClassManager may already exist as + # ClassManager.instrument_attribute() creates + # new managers for each subclass if they don't yet exist. + + _mapper_registry[self] = True + + self.extension.instrument_class(self, self.class_) + + if manager is None: + manager = attributes.register_class(self.class_, + deferred_scalar_loader = _load_scalar_attributes + ) + + self.class_manager = manager + + manager.mapper = self + + # The remaining members can be added by any mapper, e_name None or not. + if manager.info.get(_INSTRUMENTOR, False): + return + + event_registry = manager.events + event_registry.add_listener('on_init', _event_on_init) + event_registry.add_listener('on_init_failure', _event_on_init_failure) + event_registry.add_listener('on_resurrect', _event_on_resurrect) + + for key, method in util.iterate_attributes(self.class_): + if isinstance(method, types.FunctionType): + if hasattr(method, '__sa_reconstructor__'): + event_registry.add_listener('on_load', method) + elif hasattr(method, '__sa_validators__'): + for name in method.__sa_validators__: + self._validators[name] = method + + if 'reconstruct_instance' in self.extension: + def reconstruct(instance): + self.extension.reconstruct_instance(self, instance) + event_registry.add_listener('on_load', reconstruct) + + manager.info[_INSTRUMENTOR] = self + + def dispose(self): + # Disable any attribute-based compilation. + self.compiled = True + + if hasattr(self, '_compile_failed'): + del self._compile_failed + + if not self.non_primary and self.class_manager.mapper is self: + attributes.unregister_class(self.class_) + + def _configure_pks(self): + + self.tables = sqlutil.find_tables(self.mapped_table) + + if not self.tables: + raise sa_exc.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table)) + + self._pks_by_table = {} + self._cols_by_table = {} + + all_cols = util.column_set(chain(*[col.proxy_set for col in self._columntoproperty])) + pk_cols = util.column_set(c for c in all_cols if c.primary_key) + + # identify primary key columns which are also mapped by this mapper. + tables = set(self.tables + [self.mapped_table]) + self._all_tables.update(tables) + for t in tables: + if t.primary_key and pk_cols.issuperset(t.primary_key): + # ordering is important since it determines the ordering of mapper.primary_key (and therefore query.get()) + self._pks_by_table[t] = util.ordered_column_set(t.primary_key).intersection(pk_cols) + self._cols_by_table[t] = util.ordered_column_set(t.c).intersection(all_cols) + + # determine cols that aren't expressed within our tables; mark these + # as "read only" properties which are refreshed upon INSERT/UPDATE + self._readonly_props = set( + self._columntoproperty[col] + for col in self._columntoproperty + if not hasattr(col, 'table') or col.table not in self._cols_by_table) + + # if explicit PK argument sent, add those columns to the primary key mappings + if self.primary_key_argument: + for k in self.primary_key_argument: + if k.table not in self._pks_by_table: + self._pks_by_table[k.table] = util.OrderedSet() + self._pks_by_table[k.table].add(k) + + if self.mapped_table not in self._pks_by_table or len(self._pks_by_table[self.mapped_table]) == 0: + raise sa_exc.ArgumentError("Mapper %s could not assemble any primary " + "key columns for mapped table '%s'" % (self, self.mapped_table.description)) + + if self.inherits and not self.concrete and not self.primary_key_argument: + # if inheriting, the "primary key" for this mapper is that of the inheriting (unless concrete or explicit) + self.primary_key = self.inherits.primary_key + else: + # determine primary key from argument or mapped_table pks - reduce to the minimal set of columns + if self.primary_key_argument: + primary_key = sqlutil.reduce_columns( + [self.mapped_table.corresponding_column(c) for c in self.primary_key_argument], + ignore_nonexistent_tables=True) + else: + primary_key = sqlutil.reduce_columns( + self._pks_by_table[self.mapped_table], ignore_nonexistent_tables=True) + + if len(primary_key) == 0: + raise sa_exc.ArgumentError("Mapper %s could not assemble any primary " + "key columns for mapped table '%s'" % (self, self.mapped_table.description)) + + self.primary_key = primary_key + self._log("Identified primary key columns: %s", primary_key) + + def _configure_properties(self): + + # Column and other ClauseElement objects which are mapped + self.columns = self.c = util.OrderedProperties() + + # object attribute names mapped to MapperProperty objects + self._props = util.OrderedDict() + + # table columns mapped to lists of MapperProperty objects + # using a list allows a single column to be defined as + # populating multiple object attributes + self._columntoproperty = util.column_dict() + + # load custom properties + if self._init_properties: + for key, prop in self._init_properties.iteritems(): + self._configure_property(key, prop, False) + + # pull properties from the inherited mapper if any. + if self.inherits: + for key, prop in self.inherits._props.iteritems(): + if key not in self._props and not self._should_exclude(key, key, local=False): + self._adapt_inherited_property(key, prop, False) + + # create properties for each column in the mapped table, + # for those columns which don't already map to a property + for column in self.mapped_table.columns: + if column in self._columntoproperty: + continue + + column_key = (self.column_prefix or '') + column.key + + if self._should_exclude(column.key, column_key, local=self.local_table.c.contains_column(column)): + continue + + # adjust the "key" used for this column to that + # of the inheriting mapper + for mapper in self.iterate_to_root(): + if column in mapper._columntoproperty: + column_key = mapper._columntoproperty[column].key + + self._configure_property(column_key, column, init=False, setparent=True) + + # do a special check for the "discriminiator" column, as it may only be present + # in the 'with_polymorphic' selectable but we need it for the base mapper + if self.polymorphic_on is not None and self.polymorphic_on not in self._columntoproperty: + col = self.mapped_table.corresponding_column(self.polymorphic_on) + if col is None: + instrument = False + col = self.polymorphic_on + else: + instrument = True + if self._should_exclude(col.key, col.key, local=False): + raise sa_exc.InvalidRequestError("Cannot exclude or override the discriminator column %r" % col.key) + self._configure_property(col.key, ColumnProperty(col, _instrument=instrument), init=False, setparent=True) + + def _adapt_inherited_property(self, key, prop, init): + if not self.concrete: + self._configure_property(key, prop, init=False, setparent=False) + elif key not in self._props: + self._configure_property(key, ConcreteInheritedProperty(), init=init, setparent=True) + + def _configure_property(self, key, prop, init=True, setparent=True): + self._log("_configure_property(%s, %s)", key, prop.__class__.__name__) + + if not isinstance(prop, MapperProperty): + # we were passed a Column or a list of Columns; generate a ColumnProperty + columns = util.to_list(prop) + column = columns[0] + if not expression.is_column(column): + raise sa_exc.ArgumentError("%s=%r is not an instance of MapperProperty or Column" % (key, prop)) + + prop = self._props.get(key, None) + + if isinstance(prop, ColumnProperty): + # TODO: the "property already exists" case is still not well defined here. + # assuming single-column, etc. + + if prop.parent is not self: + # existing ColumnProperty from an inheriting mapper. + # make a copy and append our column to it + prop = prop.copy() + prop.columns.append(column) + self._log("appending to existing ColumnProperty %s" % (key)) + elif prop is None or isinstance(prop, ConcreteInheritedProperty): + mapped_column = [] + for c in columns: + mc = self.mapped_table.corresponding_column(c) + if mc is None: + mc = self.local_table.corresponding_column(c) + if mc is not None: + # if the column is in the local table but not the mapped table, + # this corresponds to adding a column after the fact to the local table. + # [ticket:1523] + self.mapped_table._reset_exported() + mc = self.mapped_table.corresponding_column(c) + if mc is None: + raise sa_exc.ArgumentError("Column '%s' is not represented in mapper's table. " + "Use the `column_property()` function to force this column " + "to be mapped as a read-only attribute." % c) + mapped_column.append(mc) + prop = ColumnProperty(*mapped_column) + else: + raise sa_exc.ArgumentError("WARNING: column '%s' conflicts with property '%r'. " + "To resolve this, map the column to the class under a different " + "name in the 'properties' dictionary. Or, to remove all awareness " + "of the column entirely (including its availability as a foreign key), " + "use the 'include_properties' or 'exclude_properties' mapper arguments " + "to control specifically which table columns get mapped." % (column.key, prop)) + + if isinstance(prop, ColumnProperty): + col = self.mapped_table.corresponding_column(prop.columns[0]) + + # if the column is not present in the mapped table, + # test if a column has been added after the fact to the parent table + # (or their parent, etc.) + # [ticket:1570] + if col is None and self.inherits: + path = [self] + for m in self.inherits.iterate_to_root(): + col = m.local_table.corresponding_column(prop.columns[0]) + if col is not None: + for m2 in path: + m2.mapped_table._reset_exported() + col = self.mapped_table.corresponding_column(prop.columns[0]) + break + path.append(m) + + # otherwise, col might not be present! the selectable given + # to the mapper need not include "deferred" + # columns (included in zblog tests) + if col is None: + col = prop.columns[0] + + # column is coming in after _readonly_props was initialized; check + # for 'readonly' + if hasattr(self, '_readonly_props') and \ + (not hasattr(col, 'table') or col.table not in self._cols_by_table): + self._readonly_props.add(prop) + + else: + # if column is coming in after _cols_by_table was initialized, ensure the col is in the + # right set + if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]: + self._cols_by_table[col.table].add(col) + + # if this ColumnProperty represents the "polymorphic discriminator" + # column, mark it. We'll need this when rendering columns + # in SELECT statements. + if not hasattr(prop, '_is_polymorphic_discriminator'): + prop._is_polymorphic_discriminator = (col is self.polymorphic_on or prop.columns[0] is self.polymorphic_on) + + self.columns[key] = col + for col in prop.columns: + for col in col.proxy_set: + self._columntoproperty[col] = prop + + elif isinstance(prop, (ComparableProperty, SynonymProperty)) and setparent: + if prop.descriptor is None: + desc = getattr(self.class_, key, None) + if self._is_userland_descriptor(desc): + prop.descriptor = desc + if getattr(prop, 'map_column', False): + if key not in self.mapped_table.c: + raise sa_exc.ArgumentError( + "Can't compile synonym '%s': no column on table '%s' named '%s'" + % (prop.name, self.mapped_table.description, key)) + elif self.mapped_table.c[key] in self._columntoproperty and \ + self._columntoproperty[self.mapped_table.c[key]].key == prop.name: + raise sa_exc.ArgumentError( + "Can't call map_column=True for synonym %r=%r, " + "a ColumnProperty already exists keyed to the name %r " + "for column %r" % + (key, prop.name, prop.name, key) + ) + p = ColumnProperty(self.mapped_table.c[key]) + self._configure_property(prop.name, p, init=init, setparent=setparent) + p._mapped_by_synonym = key + + if key in self._props and getattr(self._props[key], '_mapped_by_synonym', False): + syn = self._props[key]._mapped_by_synonym + raise sa_exc.ArgumentError( + "Can't call map_column=True for synonym %r=%r, " + "a ColumnProperty already exists keyed to the name " + "%r for column %r" % (syn, key, key, syn) + ) + + self._props[key] = prop + prop.key = key + + if setparent: + prop.set_parent(self) + + if not self.non_primary: + prop.instrument_class(self) + + for mapper in self._inheriting_mappers: + mapper._adapt_inherited_property(key, prop, init) + + if init: + prop.init() + prop.post_instrument_class(self) + + + def compile(self): + """Compile this mapper and all other non-compiled mappers. + + This method checks the local compiled status as well as for + any new mappers that have been defined, and is safe to call + repeatedly. + + """ + global _new_mappers + if self.compiled and not _new_mappers: + return self + + _COMPILE_MUTEX.acquire() + try: + try: + global _already_compiling + if _already_compiling: + return + _already_compiling = True + try: + + # double-check inside mutex + if self.compiled and not _new_mappers: + return self + + # initialize properties on all mappers + # note that _mapper_registry is unordered, which + # may randomly conceal/reveal issues related to + # the order of mapper compilation + for mapper in list(_mapper_registry): + if getattr(mapper, '_compile_failed', False): + raise sa_exc.InvalidRequestError( + "One or more mappers failed to compile. " + "Exception was probably " + "suppressed within a hasattr() call. " + "Message was: %s" % mapper._compile_failed) + if not mapper.compiled: + mapper._post_configure_properties() + + _new_mappers = False + return self + finally: + _already_compiling = False + except: + import sys + exc = sys.exc_info()[1] + self._compile_failed = exc + raise + finally: + _COMPILE_MUTEX.release() + + def _post_configure_properties(self): + """Call the ``init()`` method on all ``MapperProperties`` + attached to this mapper. + + This is a deferred configuration step which is intended + to execute once all mappers have been constructed. + + """ + + self._log("_post_configure_properties() started") + l = [(key, prop) for key, prop in self._props.iteritems()] + for key, prop in l: + self._log("initialize prop %s", key) + + if prop.parent is self and not prop._compile_started: + prop.init() + + if prop._compile_finished: + prop.post_instrument_class(self) + + self._log("_post_configure_properties() complete") + self.compiled = True + + def add_properties(self, dict_of_properties): + """Add the given dictionary of properties to this mapper, + using `add_property`. + + """ + for key, value in dict_of_properties.iteritems(): + self.add_property(key, value) + + def add_property(self, key, prop): + """Add an individual MapperProperty to this mapper. + + If the mapper has not been compiled yet, just adds the + property to the initial properties dictionary sent to the + constructor. If this Mapper has already been compiled, then + the given MapperProperty is compiled immediately. + + """ + self._init_properties[key] = prop + self._configure_property(key, prop, init=self.compiled) + + + def _log(self, msg, *args): + self.logger.info( + "(" + self.class_.__name__ + + "|" + + (self.local_table is not None and + self.local_table.description or + str(self.local_table)) + + (self.non_primary and "|non-primary" or "") + ") " + + msg, *args) + + def _log_debug(self, msg, *args): + self.logger.debug( + "(" + self.class_.__name__ + + "|" + + (self.local_table is not None and + self.local_table.description + or str(self.local_table)) + + (self.non_primary and "|non-primary" or "") + ") " + + msg, *args) + + def __repr__(self): + return '' % ( + id(self), self.class_.__name__) + + def __str__(self): + return "Mapper|%s|%s%s" % ( + self.class_.__name__, + self.local_table is not None and self.local_table.description or None, + self.non_primary and "|non-primary" or "" + ) + + def _is_orphan(self, state): + o = False + for mapper in self.iterate_to_root(): + for (key, cls) in mapper.delete_orphans: + if attributes.manager_of_class(cls).has_parent( + state, key, optimistic=_state_has_identity(state)): + return False + o = o or bool(mapper.delete_orphans) + return o + + def has_property(self, key): + return key in self._props + + def get_property(self, key, resolve_synonyms=False, raiseerr=True): + """return a MapperProperty associated with the given key.""" + + if not self.compiled: + self.compile() + return self._get_property(key, resolve_synonyms=resolve_synonyms, raiseerr=raiseerr) + + def _get_property(self, key, resolve_synonyms=False, raiseerr=True): + prop = self._props.get(key, None) + if resolve_synonyms: + while isinstance(prop, SynonymProperty): + prop = self._props.get(prop.name, None) + if prop is None and raiseerr: + raise sa_exc.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key)) + return prop + + @property + def iterate_properties(self): + """return an iterator of all MapperProperty objects.""" + if not self.compiled: + self.compile() + return self._props.itervalues() + + def _mappers_from_spec(self, spec, selectable): + """given a with_polymorphic() argument, return the set of mappers it represents. + + Trims the list of mappers to just those represented within the given selectable, if present. + This helps some more legacy-ish mappings. + + """ + if spec == '*': + mappers = list(self.polymorphic_iterator()) + elif spec: + mappers = [_class_to_mapper(m) for m in util.to_list(spec)] + for m in mappers: + if not m.isa(self): + raise sa_exc.InvalidRequestError("%r does not inherit from %r" % (m, self)) + else: + mappers = [] + + if selectable is not None: + tables = set(sqlutil.find_tables(selectable, include_aliases=True)) + mappers = [m for m in mappers if m.local_table in tables] + + return mappers + + def _selectable_from_mappers(self, mappers): + """given a list of mappers (assumed to be within this mapper's inheritance hierarchy), + construct an outerjoin amongst those mapper's mapped tables. + + """ + + from_obj = self.mapped_table + for m in mappers: + if m is self: + continue + if m.concrete: + raise sa_exc.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.") + elif not m.single: + from_obj = from_obj.outerjoin(m.local_table, m.inherit_condition) + + return from_obj + + @property + def _single_table_criterion(self): + if self.single and \ + self.inherits and \ + self.polymorphic_on is not None and \ + self.polymorphic_identity is not None: + return self.polymorphic_on.in_( + m.polymorphic_identity + for m in self.polymorphic_iterator()) + else: + return None + + + @util.memoized_property + def _with_polymorphic_mappers(self): + if not self.with_polymorphic: + return [self] + return self._mappers_from_spec(*self.with_polymorphic) + + @util.memoized_property + def _with_polymorphic_selectable(self): + if not self.with_polymorphic: + return self.mapped_table + + spec, selectable = self.with_polymorphic + if selectable is not None: + return selectable + else: + return self._selectable_from_mappers(self._mappers_from_spec(spec, selectable)) + + def _with_polymorphic_args(self, spec=None, selectable=False): + if self.with_polymorphic: + if not spec: + spec = self.with_polymorphic[0] + if selectable is False: + selectable = self.with_polymorphic[1] + + mappers = self._mappers_from_spec(spec, selectable) + if selectable is not None: + return mappers, selectable + else: + return mappers, self._selectable_from_mappers(mappers) + + def _iterate_polymorphic_properties(self, mappers=None): + """Return an iterator of MapperProperty objects which will render into a SELECT.""" + + if mappers is None: + mappers = self._with_polymorphic_mappers + + if not mappers: + for c in self.iterate_properties: + yield c + else: + # in the polymorphic case, filter out discriminator columns + # from other mappers, as these are sometimes dependent on that + # mapper's polymorphic selectable (which we don't want rendered) + for c in util.unique_list( + chain(*[list(mapper.iterate_properties) for mapper in [self] + mappers]) + ): + if getattr(c, '_is_polymorphic_discriminator', False) and \ + (self.polymorphic_on is None or c.columns[0] is not self.polymorphic_on): + continue + yield c + + @property + def properties(self): + raise NotImplementedError("Public collection of MapperProperty objects is " + "provided by the get_property() and iterate_properties accessors.") + + @util.memoized_property + def _get_clause(self): + """create a "get clause" based on the primary key. this is used + by query.get() and many-to-one lazyloads to load this item + by primary key. + + """ + params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key] + return sql.and_(*[k==v for (k, v) in params]), util.column_dict(params) + + @util.memoized_property + def _equivalent_columns(self): + """Create a map of all *equivalent* columns, based on + the determination of column pairs that are equated to + one another based on inherit condition. This is designed + to work with the queries that util.polymorphic_union + comes up with, which often don't include the columns from + the base table directly (including the subclass table columns + only). + + The resulting structure is a dictionary of columns mapped + to lists of equivalent columns, i.e. + + { + tablea.col1: + set([tableb.col1, tablec.col1]), + tablea.col2: + set([tabled.col2]) + } + + """ + result = util.column_dict() + def visit_binary(binary): + if binary.operator == operators.eq: + if binary.left in result: + result[binary.left].add(binary.right) + else: + result[binary.left] = util.column_set((binary.right,)) + if binary.right in result: + result[binary.right].add(binary.left) + else: + result[binary.right] = util.column_set((binary.left,)) + for mapper in self.base_mapper.polymorphic_iterator(): + if mapper.inherit_condition is not None: + visitors.traverse(mapper.inherit_condition, {}, {'binary':visit_binary}) + + return result + + def _is_userland_descriptor(self, obj): + return not isinstance(obj, (MapperProperty, attributes.InstrumentedAttribute)) and hasattr(obj, '__get__') + + def _should_exclude(self, name, assigned_name, local): + """determine whether a particular property should be implicitly present on the class. + + This occurs when properties are propagated from an inherited class, or are + applied from the columns present in the mapped table. + + """ + + # check for descriptors, either local or from + # an inherited class + if local: + if self.class_.__dict__.get(assigned_name, None) is not None\ + and self._is_userland_descriptor(self.class_.__dict__[assigned_name]): + return True + else: + if getattr(self.class_, assigned_name, None) is not None\ + and self._is_userland_descriptor(getattr(self.class_, assigned_name)): + return True + + if (self.include_properties is not None and + name not in self.include_properties): + self._log("not including property %s" % (name)) + return True + + if (self.exclude_properties is not None and + name in self.exclude_properties): + self._log("excluding property %s" % (name)) + return True + + return False + + def common_parent(self, other): + """Return true if the given mapper shares a common inherited parent as this mapper.""" + + return self.base_mapper is other.base_mapper + + def _canload(self, state, allow_subtypes): + s = self.primary_mapper() + if self.polymorphic_on is not None or allow_subtypes: + return _state_mapper(state).isa(s) + else: + return _state_mapper(state) is s + + def isa(self, other): + """Return True if the this mapper inherits from the given mapper.""" + + m = self + while m and m is not other: + m = m.inherits + return bool(m) + + def iterate_to_root(self): + m = self + while m: + yield m + m = m.inherits + + def polymorphic_iterator(self): + """Iterate through the collection including this mapper and + all descendant mappers. + + This includes not just the immediately inheriting mappers but + all their inheriting mappers as well. + + To iterate through an entire hierarchy, use + ``mapper.base_mapper.polymorphic_iterator()``. + + """ + stack = deque([self]) + while stack: + item = stack.popleft() + yield item + stack.extend(item._inheriting_mappers) + + def primary_mapper(self): + """Return the primary mapper corresponding to this mapper's class key (class).""" + + return self.class_manager.mapper + + def identity_key_from_row(self, row, adapter=None): + """Return an identity-map key for use in storing/retrieving an + item from the identity map. + + row + A ``sqlalchemy.engine.base.RowProxy`` instance or a + dictionary corresponding result-set ``ColumnElement`` + instances to their values within a row. + + """ + pk_cols = self.primary_key + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + + return (self._identity_class, tuple(row[column] for column in pk_cols)) + + def identity_key_from_primary_key(self, primary_key): + """Return an identity-map key for use in storing/retrieving an + item from an identity map. + + primary_key + A list of values indicating the identifier. + + """ + return (self._identity_class, tuple(util.to_list(primary_key))) + + def identity_key_from_instance(self, instance): + """Return the identity key for the given instance, based on + its primary key attributes. + + This value is typically also found on the instance state under the + attribute name `key`. + + """ + return self.identity_key_from_primary_key(self.primary_key_from_instance(instance)) + + def _identity_key_from_state(self, state): + return self.identity_key_from_primary_key(self._primary_key_from_state(state)) + + def primary_key_from_instance(self, instance): + """Return the list of primary key values for the given + instance. + + """ + state = attributes.instance_state(instance) + return self._primary_key_from_state(state) + + def _primary_key_from_state(self, state): + return [self._get_state_attr_by_column(state, column) for column in self.primary_key] + + def _get_col_to_prop(self, column): + try: + return self._columntoproperty[column] + except KeyError: + prop = self._props.get(column.key, None) + if prop: + raise orm_exc.UnmappedColumnError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) + else: + raise orm_exc.UnmappedColumnError("No column %s is configured on mapper %s..." % (column, self)) + + # TODO: improve names? + def _get_state_attr_by_column(self, state, column): + return self._get_col_to_prop(column).getattr(state, column) + + def _set_state_attr_by_column(self, state, column, value): + return self._get_col_to_prop(column).setattr(state, value, column) + + def _get_committed_attr_by_column(self, obj, column): + state = attributes.instance_state(obj) + return self._get_committed_state_attr_by_column(state, column) + + def _get_committed_state_attr_by_column(self, state, column, passive=False): + return self._get_col_to_prop(column).getcommitted(state, column, passive=passive) + + def _optimized_get_statement(self, state, attribute_names): + """assemble a WHERE clause which retrieves a given state by primary key, using a minimized set of tables. + + Applies to a joined-table inheritance mapper where the + requested attribute names are only present on joined tables, + not the base table. The WHERE clause attempts to include + only those tables to minimize joins. + + """ + props = self._props + + tables = set(chain(* + (sqlutil.find_tables(props[key].columns[0], check_columns=True) + for key in attribute_names) + )) + + if self.base_mapper.local_table in tables: + return None + + class ColumnsNotAvailable(Exception): + pass + + def visit_binary(binary): + leftcol = binary.left + rightcol = binary.right + if leftcol is None or rightcol is None: + return + + if leftcol.table not in tables: + leftval = self._get_committed_state_attr_by_column(state, leftcol, passive=True) + if leftval is attributes.PASSIVE_NO_RESULT: + raise ColumnsNotAvailable() + binary.left = sql.bindparam(None, leftval, type_=binary.right.type) + elif rightcol.table not in tables: + rightval = self._get_committed_state_attr_by_column(state, rightcol, passive=True) + if rightval is attributes.PASSIVE_NO_RESULT: + raise ColumnsNotAvailable() + binary.right = sql.bindparam(None, rightval, type_=binary.right.type) + + allconds = [] + + try: + start = False + for mapper in reversed(list(self.iterate_to_root())): + if mapper.local_table in tables: + start = True + if start and not mapper.single: + allconds.append(visitors.cloned_traverse(mapper.inherit_condition, {}, {'binary':visit_binary})) + except ColumnsNotAvailable: + return None + + cond = sql.and_(*allconds) + + cols = [] + for key in attribute_names: + cols.extend(props[key].columns) + return sql.select(cols, cond, use_labels=True) + + def cascade_iterator(self, type_, state, halt_on=None): + """Iterate each element and its mapper in an object graph, + for all relationships that meet the given cascade rule. + + ``type\_``: + The name of the cascade rule (i.e. save-update, delete, + etc.) + + ``state``: + The lead InstanceState. child items will be processed per + the relationships defined for this object's mapper. + + the return value are object instances; this provides a strong + reference so that they don't fall out of scope immediately. + + """ + visited_instances = util.IdentitySet() + visitables = [(self._props.itervalues(), 'property', state)] + + while visitables: + iterator, item_type, parent_state = visitables[-1] + try: + if item_type == 'property': + prop = iterator.next() + visitables.append((prop.cascade_iterator(type_, parent_state, visited_instances, halt_on), 'mapper', None)) + elif item_type == 'mapper': + instance, instance_mapper, corresponding_state = iterator.next() + yield (instance, instance_mapper) + visitables.append((instance_mapper._props.itervalues(), 'property', corresponding_state)) + except StopIteration: + visitables.pop() + + @util.memoized_property + def _sorted_tables(self): + table_to_mapper = {} + for mapper in self.base_mapper.polymorphic_iterator(): + for t in mapper.tables: + table_to_mapper[t] = mapper + + sorted_ = sqlutil.sort_tables(table_to_mapper.iterkeys()) + ret = util.OrderedDict() + for t in sorted_: + ret[t] = table_to_mapper[t] + return ret + + def _save_obj(self, states, uowtransaction, postupdate=False, + post_update_cols=None, single=False): + """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. + + This is called within the context of a UOWTransaction during a + flush operation. + + `_save_obj` issues SQL statements not just for instances mapped + directly by this mapper, but for instances mapped by all + inheriting mappers as well. This is to maintain proper insert + ordering among a polymorphic chain of instances. Therefore + _save_obj is typically called only on a *base mapper*, or a + mapper which does not inherit from any other mapper. + + """ + # if batch=false, call _save_obj separately for each object + if not single and not self.batch: + for state in _sort_states(states): + self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) + return + + # if session has a connection callable, + # organize individual states with the connection to use for insert/update + tups = [] + if 'connection_callable' in uowtransaction.mapper_flush_opts: + connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] + for state in _sort_states(states): + m = _state_mapper(state) + tups.append( + ( + state, + m, + connection_callable(self, state.obj()), + _state_has_identity(state), + state.key or m._identity_key_from_state(state) + ) + ) + else: + connection = uowtransaction.transaction.connection(self) + for state in _sort_states(states): + m = _state_mapper(state) + tups.append( + ( + state, + m, + connection, + _state_has_identity(state), + state.key or m._identity_key_from_state(state) + ) + ) + + if not postupdate: + # call before_XXX extensions + for state, mapper, connection, has_identity, instance_key in tups: + if not has_identity: + if 'before_insert' in mapper.extension: + mapper.extension.before_insert(mapper, connection, state.obj()) + else: + if 'before_update' in mapper.extension: + mapper.extension.before_update(mapper, connection, state.obj()) + + row_switches = {} + if not postupdate: + for state, mapper, connection, has_identity, instance_key in tups: + # detect if we have a "pending" instance (i.e. has no instance_key attached to it), + # and another instance with the same identity key already exists as persistent. convert to an + # UPDATE if so. + if not has_identity and instance_key in uowtransaction.session.identity_map: + instance = uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + if not uowtransaction.is_deleted(existing): + raise orm_exc.FlushError( + "New instance %s with identity key %s conflicts " + "with persistent instance %s" % + (state_str(state), instance_key, state_str(existing))) + + self._log_debug( + "detected row switch for identity %s. will update %s, remove %s from " + "transaction", instance_key, state_str(state), state_str(existing)) + + # remove the "delete" flag from the existing element + uowtransaction.set_row_switch(existing) + row_switches[state] = existing + + table_to_mapper = self._sorted_tables + + for table in table_to_mapper.iterkeys(): + insert = [] + update = [] + + for state, mapper, connection, has_identity, instance_key in tups: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + isinsert = not has_identity and not postupdate and state not in row_switches + + params = {} + value_params = {} + hasdata = False + + if isinsert: + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col.key] = mapper.version_id_generator(None) + elif mapper.polymorphic_on is not None and \ + mapper.polymorphic_on.shares_lineage(col): + value = mapper.polymorphic_identity + if ((col.default is None and + col.server_default is None) or + value is not None): + params[col.key] = value + elif col in pks: + value = mapper._get_state_attr_by_column(state, col) + if value is not None: + params[col.key] = value + else: + value = mapper._get_state_attr_by_column(state, col) + if ((col.default is None and + col.server_default is None) or + value is not None): + if isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value + insert.append((state, params, mapper, connection, value_params)) + else: + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col._label] = mapper._get_state_attr_by_column(row_switches.get(state, state), col) + params[col.key] = mapper.version_id_generator(params[col._label]) + for prop in mapper._columntoproperty.itervalues(): + history = attributes.get_state_history(state, prop.key, passive=True) + if history.added: + hasdata = True + elif mapper.polymorphic_on is not None and \ + mapper.polymorphic_on.shares_lineage(col) and col not in pks: + pass + else: + if post_update_cols is not None and col not in post_update_cols: + if col in pks: + params[col._label] = mapper._get_state_attr_by_column(state, col) + continue + + prop = mapper._columntoproperty[col] + history = attributes.get_state_history(state, prop.key, passive=True) + if history.added: + if isinstance(history.added[0], sql.ClauseElement): + value_params[col] = history.added[0] + else: + params[col.key] = prop.get_col_value(col, history.added[0]) + + if col in pks: + if history.deleted: + # if passive_updates and sync detected this was a + # pk->pk sync, use the new value to locate the row, + # since the DB would already have set this + if ("pk_cascaded", state, col) in \ + uowtransaction.attributes: + params[col._label] = \ + prop.get_col_value(col, history.added[0]) + else: + # use the old value to locate the row + params[col._label] = \ + prop.get_col_value(col, history.deleted[0]) + hasdata = True + else: + # row switch logic can reach us here + # remove the pk from the update params so the update doesn't + # attempt to include the pk in the update statement + del params[col.key] + params[col._label] = \ + prop.get_col_value(col, history.added[0]) + else: + hasdata = True + elif col in pks: + params[col._label] = mapper._get_state_attr_by_column(state, col) + if hasdata: + update.append((state, params, mapper, connection, value_params)) + + if update: + mapper = table_to_mapper[table] + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, type_=col.type)) + + if mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col): + + clause.clauses.append(mapper.version_id_col ==\ + sql.bindparam(mapper.version_id_col._label, type_=col.type)) + + statement = table.update(clause) + + rows = 0 + for state, params, mapper, connection, value_params in update: + c = connection.execute(statement.values(value_params), params) + mapper._postfetch(uowtransaction, connection, table, + state, c, c.last_updated_params(), value_params) + + rows += c.rowcount + + if connection.dialect.supports_sane_rowcount: + if rows != len(update): + raise orm_exc.ConcurrentModificationError( + "Updated rowcount %d does not match number of objects updated %d" % + (rows, len(update))) + + elif mapper.version_id_col is not None: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % c.dialect.dialect_description, + stacklevel=12) + + if insert: + statement = table.insert() + for state, params, mapper, connection, value_params in insert: + c = connection.execute(statement.values(value_params), params) + primary_key = c.inserted_primary_key + + if primary_key is not None: + # set primary key attributes + for i, col in enumerate(mapper._pks_by_table[table]): + if mapper._get_state_attr_by_column(state, col) is None and \ + len(primary_key) > i: + mapper._set_state_attr_by_column(state, col, primary_key[i]) + + mapper._postfetch(uowtransaction, connection, table, + state, c, c.last_inserted_params(), value_params) + + + if not postupdate: + for state, mapper, connection, has_identity, instance_key in tups: + + # expire readonly attributes + readonly = state.unmodified.intersection( + p.key for p in mapper._readonly_props + ) + + if readonly: + _expire_state(state, state.dict, readonly) + + # if specified, eagerly refresh whatever has + # been expired. + if self.eager_defaults and state.unloaded: + state.key = self._identity_key_from_state(state) + uowtransaction.session.query(self)._get( + state.key, refresh_state=state, + only_load_props=state.unloaded) + + # call after_XXX extensions + if not has_identity: + if 'after_insert' in mapper.extension: + mapper.extension.after_insert(mapper, connection, state.obj()) + else: + if 'after_update' in mapper.extension: + mapper.extension.after_update(mapper, connection, state.obj()) + + def _postfetch(self, uowtransaction, connection, table, + state, resultproxy, params, value_params): + """Expire attributes in need of newly persisted database state.""" + + postfetch_cols = resultproxy.postfetch_cols() + generated_cols = list(resultproxy.prefetch_cols()) + + if self.polymorphic_on is not None: + po = table.corresponding_column(self.polymorphic_on) + if po is not None: + generated_cols.append(po) + + if self.version_id_col is not None: + generated_cols.append(self.version_id_col) + + for c in generated_cols: + if c.key in params and c in self._columntoproperty: + self._set_state_attr_by_column(state, c, params[c.key]) + + deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]] + + if deferred_props: + _expire_state(state, state.dict, deferred_props) + + # synchronize newly inserted ids from one table to the next + # TODO: this still goes a little too often. would be nice to + # have definitive list of "columns that changed" here + cols = set(table.c) + for m in self.iterate_to_root(): + if m._inherits_equated_pairs and \ + cols.intersection([l for l, r in m._inherits_equated_pairs]): + sync.populate(state, m, state, m, + m._inherits_equated_pairs, + uowtransaction, + self.passive_updates) + + def _delete_obj(self, states, uowtransaction): + """Issue ``DELETE`` statements for a list of objects. + + This is called within the context of a UOWTransaction during a + flush operation. + + """ + if 'connection_callable' in uowtransaction.mapper_flush_opts: + connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] + tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in _sort_states(states)] + else: + connection = uowtransaction.transaction.connection(self) + tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)] + + for state, mapper, connection in tups: + if 'before_delete' in mapper.extension: + mapper.extension.before_delete(mapper, connection, state.obj()) + + table_to_mapper = self._sorted_tables + + for table in reversed(table_to_mapper.keys()): + delete = {} + for state, mapper, connection in tups: + if table not in mapper._pks_by_table: + continue + + params = {} + if not _state_has_identity(state): + continue + else: + delete.setdefault(connection, []).append(params) + for col in mapper._pks_by_table[table]: + params[col.key] = mapper._get_state_attr_by_column(state, col) + if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): + params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col) + + for connection, del_objects in delete.iteritems(): + mapper = table_to_mapper[table] + clause = sql.and_() + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col.key, type_=col.type)) + if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): + clause.clauses.append( + mapper.version_id_col == + sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type)) + statement = table.delete(clause) + c = connection.execute(statement, del_objects) + if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects): + raise orm_exc.ConcurrentModificationError("Deleted rowcount %d does not match " + "number of objects deleted %d" % (c.rowcount, len(del_objects))) + + for state, mapper, connection in tups: + if 'after_delete' in mapper.extension: + mapper.extension.after_delete(mapper, connection, state.obj()) + + def _register_dependencies(self, uowcommit): + """Register ``DependencyProcessor`` instances with a + ``unitofwork.UOWTransaction``. + + This call `register_dependencies` on all attached + ``MapperProperty`` instances. + + """ + for dep in self._props.values() + self._dependency_processors: + dep.register_dependencies(uowcommit) + + def _register_processors(self, uowcommit): + for dep in self._props.values() + self._dependency_processors: + dep.register_processors(uowcommit) + + def _instance_processor(self, context, path, adapter, + polymorphic_from=None, extension=None, + only_load_props=None, refresh_state=None, + polymorphic_discriminator=None): + + """Produce a mapper level row processor callable + which processes rows into mapped instances.""" + + pk_cols = self.primary_key + + if polymorphic_from or refresh_state: + polymorphic_on = None + else: + if polymorphic_discriminator is not None: + polymorphic_on = polymorphic_discriminator + else: + polymorphic_on = self.polymorphic_on + polymorphic_instances = util.PopulateDict( + self._configure_subclass_mapper(context, path, adapter) + ) + + version_id_col = self.version_id_col + + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + if polymorphic_on is not None: + polymorphic_on = adapter.columns[polymorphic_on] + if version_id_col is not None: + version_id_col = adapter.columns[version_id_col] + + identity_class = self._identity_class + def identity_key(row): + return (identity_class, tuple([row[column] for column in pk_cols])) + + new_populators = [] + existing_populators = [] + load_path = context.query._current_path + path + + def populate_state(state, dict_, row, isnew, only_load_props): + if isnew: + if context.propagate_options: + state.load_options = context.propagate_options + if state.load_options: + state.load_path = load_path + + if not new_populators: + new_populators[:], existing_populators[:] = \ + self._populators(context, path, row, adapter) + + if isnew: + populators = new_populators + else: + populators = existing_populators + + if only_load_props: + populators = [p for p in populators if p[0] in only_load_props] + + for key, populator in populators: + populator(state, dict_, row) + + session_identity_map = context.session.identity_map + + if not extension: + extension = self.extension + + translate_row = extension.get('translate_row', None) + create_instance = extension.get('create_instance', None) + populate_instance = extension.get('populate_instance', None) + append_result = extension.get('append_result', None) + populate_existing = context.populate_existing or self.always_refresh + if self.allow_partial_pks: + is_not_primary_key = _none_set.issuperset + else: + is_not_primary_key = _none_set.issubset + + def _instance(row, result): + if translate_row: + ret = translate_row(self, context, row) + if ret is not EXT_CONTINUE: + row = ret + + if polymorphic_on is not None: + discriminator = row[polymorphic_on] + if discriminator is not None: + _instance = polymorphic_instances[discriminator] + if _instance: + return _instance(row, result) + + # determine identity key + if refresh_state: + identitykey = refresh_state.key + if identitykey is None: + # super-rare condition; a refresh is being called + # on a non-instance-key instance; this is meant to only + # occur within a flush() + identitykey = self._identity_key_from_state(refresh_state) + else: + identitykey = identity_key(row) + + instance = session_identity_map.get(identitykey) + if instance is not None: + state = attributes.instance_state(instance) + dict_ = attributes.instance_dict(instance) + + isnew = state.runid != context.runid + currentload = not isnew + loaded_instance = False + + if not currentload and \ + version_id_col is not None and \ + context.version_check and \ + self._get_state_attr_by_column( + state, + self.version_id_col) != row[version_id_col]: + + raise orm_exc.ConcurrentModificationError( + "Instance '%s' version of %s does not match %s" + % (state_str(state), + self._get_state_attr_by_column(state, self.version_id_col), + row[version_id_col])) + elif refresh_state: + # out of band refresh_state detected (i.e. its not in the session.identity_map) + # honor it anyway. this can happen if a _get() occurs within save_obj(), such as + # when eager_defaults is True. + state = refresh_state + instance = state.obj() + dict_ = attributes.instance_dict(instance) + isnew = state.runid != context.runid + currentload = True + loaded_instance = False + else: + # check for non-NULL values in the primary key columns, + # else no entity is returned for the row + if is_not_primary_key(identitykey[1]): + return None + + isnew = True + currentload = True + loaded_instance = True + + if create_instance: + instance = create_instance(self, context, row, self.class_) + if instance is EXT_CONTINUE: + instance = self.class_manager.new_instance() + else: + manager = attributes.manager_of_class(instance.__class__) + # TODO: if manager is None, raise a friendly error about + # returning instances of unmapped types + manager.setup_instance(instance) + else: + instance = self.class_manager.new_instance() + + dict_ = attributes.instance_dict(instance) + state = attributes.instance_state(instance) + state.key = identitykey + + # manually adding instance to session. for a complete add, + # session._finalize_loaded() must be called. + state.session_id = context.session.hash_key + session_identity_map.add(state) + + if currentload or populate_existing: + if isnew: + state.runid = context.runid + context.progress[state] = dict_ + + if not populate_instance or \ + populate_instance(self, context, row, instance, + only_load_props=only_load_props, + instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: + populate_state(state, dict_, row, isnew, only_load_props) + + else: + # populate attributes on non-loading instances which have been expired + # TODO: apply eager loads to un-lazy loaded collections ? + if state in context.partials or state.unloaded: + + if state in context.partials: + isnew = False + (d_, attrs) = context.partials[state] + else: + isnew = True + attrs = state.unloaded + # allow query.instances to commit the subset of attrs + context.partials[state] = (dict_, attrs) + + if not populate_instance or \ + populate_instance(self, context, row, instance, + only_load_props=attrs, + instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: + populate_state(state, dict_, row, isnew, attrs) + + if loaded_instance: + state._run_on_load(instance) + + if result is not None and \ + (not append_result or + append_result(self, context, row, instance, + result, instancekey=identitykey, isnew=isnew) + is EXT_CONTINUE): + result.append(instance) + + return instance + return _instance + + def _populators(self, context, path, row, adapter): + """Produce a collection of attribute level row processor callables.""" + + new_populators, existing_populators = [], [] + for prop in self._props.itervalues(): + newpop, existingpop = prop.create_row_processor(context, path, self, row, adapter) + if newpop: + new_populators.append((prop.key, newpop)) + if existingpop: + existing_populators.append((prop.key, existingpop)) + return new_populators, existing_populators + + def _configure_subclass_mapper(self, context, path, adapter): + """Produce a mapper level row processor callable factory for mappers inheriting this one.""" + + def configure_subclass_mapper(discriminator): + try: + mapper = self.polymorphic_map[discriminator] + except KeyError: + raise AssertionError("No such polymorphic_identity %r is defined" % discriminator) + if mapper is self: + return None + + # replace the tip of the path info with the subclass mapper being used. + # that way accurate "load_path" info is available for options + # invoked during deferred loads. + # we lose AliasedClass path elements this way, but currently, + # those are not needed at this stage. + + # this asserts to true + #assert mapper.isa(_class_to_mapper(path[-1])) + + return mapper._instance_processor(context, path[0:-1] + (mapper,), + adapter, polymorphic_from=self) + return configure_subclass_mapper + +log.class_logger(Mapper) + + +def reconstructor(fn): + """Decorate a method as the 'reconstructor' hook. + + Designates a method as the "reconstructor", an ``__init__``-like + method that will be called by the ORM after the instance has been + loaded from the database or otherwise reconstituted. + + The reconstructor will be invoked with no arguments. Scalar + (non-collection) database-mapped attributes of the instance will + be available for use within the function. Eagerly-loaded + collections are generally not yet available and will usually only + contain the first element. ORM state changes made to objects at + this stage will not be recorded for the next flush() operation, so + the activity within a reconstructor should be conservative. + + """ + fn.__sa_reconstructor__ = True + return fn + +def validates(*names): + """Decorate a method as a 'validator' for one or more named properties. + + Designates a method as a validator, a method which receives the + name of the attribute as well as a value to be assigned, or in the + case of a collection to be added to the collection. The function + can then raise validation exceptions to halt the process from continuing, + or can modify or replace the value before proceeding. The function + should otherwise return the given value. + + """ + def wrap(fn): + fn.__sa_validators__ = names + return fn + return wrap + +def _event_on_init(state, instance, args, kwargs): + """Trigger mapper compilation and run init_instance hooks.""" + + instrumenting_mapper = state.manager.info[_INSTRUMENTOR] + # compile() always compiles all mappers + instrumenting_mapper.compile() + if 'init_instance' in instrumenting_mapper.extension: + instrumenting_mapper.extension.init_instance( + instrumenting_mapper, instrumenting_mapper.class_, + state.manager.events.original_init, + instance, args, kwargs) + +def _event_on_init_failure(state, instance, args, kwargs): + """Run init_failed hooks.""" + + instrumenting_mapper = state.manager.info[_INSTRUMENTOR] + if 'init_failed' in instrumenting_mapper.extension: + util.warn_exception( + instrumenting_mapper.extension.init_failed, + instrumenting_mapper, instrumenting_mapper.class_, + state.manager.events.original_init, instance, args, kwargs) + +def _event_on_resurrect(state, instance): + # re-populate the primary key elements + # of the dict based on the mapping. + instrumenting_mapper = state.manager.info[_INSTRUMENTOR] + for col, val in zip(instrumenting_mapper.primary_key, state.key[1]): + instrumenting_mapper._set_state_attr_by_column(state, col, val) + + +def _sort_states(states): + return sorted(states, key=operator.attrgetter('sort_key')) + +def _load_scalar_attributes(state, attribute_names): + """initiate a column-based attribute refresh operation.""" + + mapper = _state_mapper(state) + session = _state_session(state) + if not session: + raise orm_exc.DetachedInstanceError("Instance %s is not bound to a Session; " + "attribute refresh operation cannot proceed" % (state_str(state))) + + has_key = _state_has_identity(state) + + result = False + if mapper.inherits and not mapper.concrete: + statement = mapper._optimized_get_statement(state, attribute_names) + if statement is not None: + result = session.query(mapper).from_statement(statement).\ + _get(None, + only_load_props=attribute_names, + refresh_state=state) + + if result is False: + if has_key: + identity_key = state.key + else: + identity_key = mapper._identity_key_from_state(state) + result = session.query(mapper)._get( + identity_key, + refresh_state=state, + only_load_props=attribute_names) + + # if instance is pending, a refresh operation + # may not complete (even if PK attributes are assigned) + if has_key and result is None: + raise orm_exc.ObjectDeletedError("Instance '%s' has been deleted." % state_str(state)) diff --git a/sqlalchemy/orm/properties.py b/sqlalchemy/orm/properties.py new file mode 100644 index 0000000..80d101b --- /dev/null +++ b/sqlalchemy/orm/properties.py @@ -0,0 +1,1205 @@ +# properties.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""MapperProperty implementations. + +This is a private module which defines the behavior of invidual ORM-mapped +attributes. + +""" + +from sqlalchemy import sql, util, log +import sqlalchemy.exceptions as sa_exc +from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, join_condition +from sqlalchemy.sql import operators, expression +from sqlalchemy.orm import ( + attributes, dependency, mapper, object_mapper, strategies, + ) +from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, _orm_annotate, _orm_deannotate +from sqlalchemy.orm.interfaces import ( + MANYTOMANY, MANYTOONE, MapperProperty, ONETOMANY, PropComparator, + StrategizedProperty, + ) +NoneType = type(None) + +__all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty', + 'ComparableProperty', 'RelationshipProperty', 'RelationProperty', 'BackRef') + + +class ColumnProperty(StrategizedProperty): + """Describes an object attribute that corresponds to a table column.""" + + def __init__(self, *columns, **kwargs): + """Construct a ColumnProperty. + + :param \*columns: The list of `columns` describes a single + object property. If there are multiple tables joined + together for the mapper, this list represents the equivalent + column as it appears across each table. + + :param group: + + :param deferred: + + :param comparator_factory: + + :param descriptor: + + :param extension: + + """ + self.columns = [expression._labeled(c) for c in columns] + self.group = kwargs.pop('group', None) + self.deferred = kwargs.pop('deferred', False) + self.instrument = kwargs.pop('_instrument', True) + self.comparator_factory = kwargs.pop('comparator_factory', self.__class__.Comparator) + self.descriptor = kwargs.pop('descriptor', None) + self.extension = kwargs.pop('extension', None) + if kwargs: + raise TypeError( + "%s received unexpected keyword argument(s): %s" % ( + self.__class__.__name__, ', '.join(sorted(kwargs.keys())))) + + util.set_creation_order(self) + if not self.instrument: + self.strategy_class = strategies.UninstrumentedColumnLoader + elif self.deferred: + self.strategy_class = strategies.DeferredColumnLoader + else: + self.strategy_class = strategies.ColumnLoader + + def instrument_class(self, mapper): + if not self.instrument: + return + + attributes.register_descriptor( + mapper.class_, + self.key, + comparator=self.comparator_factory(self, mapper), + parententity=mapper, + property_=self + ) + + def do_init(self): + super(ColumnProperty, self).do_init() + if len(self.columns) > 1 and self.parent.primary_key.issuperset(self.columns): + util.warn( + ("On mapper %s, primary key column '%s' is being combined " + "with distinct primary key column '%s' in attribute '%s'. " + "Use explicit properties to give each column its own mapped " + "attribute name.") % (self.parent, self.columns[1], + self.columns[0], self.key)) + + def copy(self): + return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) + + def getattr(self, state, column): + return state.get_impl(self.key).get(state, state.dict) + + def getcommitted(self, state, column, passive=False): + return state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive) + + def setattr(self, state, value, column): + state.get_impl(self.key).set(state, state.dict, value, None) + + def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): + if self.key in source_dict: + value = source_dict[self.key] + + if not load: + dest_dict[self.key] = value + else: + impl = dest_state.get_impl(self.key) + impl.set(dest_state, dest_dict, value, None) + else: + if self.key not in dest_dict: + dest_state.expire_attributes(dest_dict, [self.key]) + + def get_col_value(self, column, value): + return value + + class Comparator(PropComparator): + @util.memoized_instancemethod + def __clause_element__(self): + if self.adapter: + return self.adapter(self.prop.columns[0]) + else: + return self.prop.columns[0]._annotate({"parententity": self.mapper, "parentmapper":self.mapper}) + + def operate(self, op, *other, **kwargs): + return op(self.__clause_element__(), *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + col = self.__clause_element__() + return op(col._bind_param(op, other), col, **kwargs) + + # TODO: legacy..do we need this ? (0.5) + ColumnComparator = Comparator + + def __str__(self): + return str(self.parent.class_.__name__) + "." + self.key + +log.class_logger(ColumnProperty) + +class CompositeProperty(ColumnProperty): + """subclasses ColumnProperty to provide composite type support.""" + + def __init__(self, class_, *columns, **kwargs): + super(CompositeProperty, self).__init__(*columns, **kwargs) + self._col_position_map = util.column_dict((c, i) for i, c in enumerate(columns)) + self.composite_class = class_ + self.strategy_class = strategies.CompositeColumnLoader + + def copy(self): + return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns) + + def do_init(self): + # skip over ColumnProperty's do_init(), + # which issues assertions that do not apply to CompositeColumnProperty + super(ColumnProperty, self).do_init() + + def getattr(self, state, column): + obj = state.get_impl(self.key).get(state, state.dict) + return self.get_col_value(column, obj) + + def getcommitted(self, state, column, passive=False): + # TODO: no coverage here + obj = state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive) + return self.get_col_value(column, obj) + + def setattr(self, state, value, column): + + obj = state.get_impl(self.key).get(state, state.dict) + if obj is None: + obj = self.composite_class(*[None for c in self.columns]) + state.get_impl(self.key).set(state, state.dict, obj, None) + + if hasattr(obj, '__set_composite_values__'): + values = list(obj.__composite_values__()) + values[self._col_position_map[column]] = value + obj.__set_composite_values__(*values) + else: + setattr(obj, column.key, value) + + def get_col_value(self, column, value): + if value is None: + return None + for a, b in zip(self.columns, value.__composite_values__()): + if a is column: + return b + + class Comparator(PropComparator): + def __clause_element__(self): + if self.adapter: + # TODO: test coverage for adapted composite comparison + return expression.ClauseList(*[self.adapter(x) for x in self.prop.columns]) + else: + return expression.ClauseList(*self.prop.columns) + + __hash__ = None + + def __eq__(self, other): + if other is None: + values = [None] * len(self.prop.columns) + else: + values = other.__composite_values__() + return sql.and_(*[a==b for a, b in zip(self.prop.columns, values)]) + + def __ne__(self, other): + return sql.not_(self.__eq__(other)) + + def __str__(self): + return str(self.parent.class_.__name__) + "." + self.key + +class ConcreteInheritedProperty(MapperProperty): + extension = None + + def setup(self, context, entity, path, adapter, **kwargs): + pass + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + return (None, None) + + def instrument_class(self, mapper): + def warn(): + raise AttributeError("Concrete %s does not implement attribute %r at " + "the instance level. Add this property explicitly to %s." % + (self.parent, self.key, self.parent)) + + class NoninheritedConcreteProp(object): + def __set__(s, obj, value): + warn() + def __delete__(s, obj): + warn() + def __get__(s, obj, owner): + warn() + + comparator_callable = None + # TODO: put this process into a deferred callable? + for m in self.parent.iterate_to_root(): + p = m._get_property(self.key) + if not isinstance(p, ConcreteInheritedProperty): + comparator_callable = p.comparator_factory + break + + attributes.register_descriptor( + mapper.class_, + self.key, + comparator=comparator_callable(self, mapper), + parententity=mapper, + property_=self, + proxy_property=NoninheritedConcreteProp() + ) + + +class SynonymProperty(MapperProperty): + + extension = None + + def __init__(self, name, map_column=None, descriptor=None, comparator_factory=None): + self.name = name + self.map_column = map_column + self.descriptor = descriptor + self.comparator_factory = comparator_factory + util.set_creation_order(self) + + def setup(self, context, entity, path, adapter, **kwargs): + pass + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + return (None, None) + + def instrument_class(self, mapper): + class_ = self.parent.class_ + + if self.descriptor is None: + class SynonymProp(object): + def __set__(s, obj, value): + setattr(obj, self.name, value) + def __delete__(s, obj): + delattr(obj, self.name) + def __get__(s, obj, owner): + if obj is None: + return s + return getattr(obj, self.name) + + self.descriptor = SynonymProp() + + def comparator_callable(prop, mapper): + def comparator(): + prop = self.parent._get_property(self.key, resolve_synonyms=True) + if self.comparator_factory: + return self.comparator_factory(prop, mapper) + else: + return prop.comparator_factory(prop, mapper) + return comparator + + attributes.register_descriptor( + mapper.class_, + self.key, + comparator=comparator_callable(self, mapper), + parententity=mapper, + property_=self, + proxy_property=self.descriptor + ) + + def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): + pass + +log.class_logger(SynonymProperty) + +class ComparableProperty(MapperProperty): + """Instruments a Python property for use in query expressions.""" + + extension = None + + def __init__(self, comparator_factory, descriptor=None): + self.descriptor = descriptor + self.comparator_factory = comparator_factory + util.set_creation_order(self) + + def instrument_class(self, mapper): + """Set up a proxy to the unmanaged descriptor.""" + + attributes.register_descriptor( + mapper.class_, + self.key, + comparator=self.comparator_factory(self, mapper), + parententity=mapper, + property_=self, + proxy_property=self.descriptor + ) + + def setup(self, context, entity, path, adapter, **kwargs): + pass + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + return (None, None) + + def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): + pass + + +class RelationshipProperty(StrategizedProperty): + """Describes an object property that holds a single item or list + of items that correspond to a related database table. + """ + + def __init__(self, argument, + secondary=None, primaryjoin=None, + secondaryjoin=None, + foreign_keys=None, + uselist=None, + order_by=False, + backref=None, + back_populates=None, + post_update=False, + cascade=False, extension=None, + viewonly=False, lazy=True, + collection_class=None, passive_deletes=False, + passive_updates=True, remote_side=None, + enable_typechecks=True, join_depth=None, + comparator_factory=None, + single_parent=False, innerjoin=False, + strategy_class=None, _local_remote_pairs=None, query_class=None): + + self.uselist = uselist + self.argument = argument + self.secondary = secondary + self.primaryjoin = primaryjoin + self.secondaryjoin = secondaryjoin + self.post_update = post_update + self.direction = None + self.viewonly = viewonly + self.lazy = lazy + self.single_parent = single_parent + self._foreign_keys = foreign_keys + self.collection_class = collection_class + self.passive_deletes = passive_deletes + self.passive_updates = passive_updates + self.remote_side = remote_side + self.enable_typechecks = enable_typechecks + self.query_class = query_class + self.innerjoin = innerjoin + + self.join_depth = join_depth + self.local_remote_pairs = _local_remote_pairs + self.extension = extension + self.comparator_factory = comparator_factory or RelationshipProperty.Comparator + self.comparator = self.comparator_factory(self, None) + util.set_creation_order(self) + + if strategy_class: + self.strategy_class = strategy_class + elif self.lazy== 'dynamic': + from sqlalchemy.orm import dynamic + self.strategy_class = dynamic.DynaLoader + else: + self.strategy_class = strategies.factory(self.lazy) + + self._reverse_property = set() + + if cascade is not False: + self.cascade = CascadeOptions(cascade) + else: + self.cascade = CascadeOptions("save-update, merge") + + if self.passive_deletes == 'all' and ("delete" in self.cascade or "delete-orphan" in self.cascade): + raise sa_exc.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade") + + self.order_by = order_by + + self.back_populates = back_populates + + if self.back_populates: + if backref: + raise sa_exc.ArgumentError("backref and back_populates keyword arguments are mutually exclusive") + self.backref = None + else: + self.backref = backref + + def instrument_class(self, mapper): + attributes.register_descriptor( + mapper.class_, + self.key, + comparator=self.comparator_factory(self, mapper), + parententity=mapper, + property_=self + ) + + class Comparator(PropComparator): + def __init__(self, prop, mapper, of_type=None, adapter=None): + self.prop = prop + self.mapper = mapper + self.adapter = adapter + if of_type: + self._of_type = _class_to_mapper(of_type) + + def adapted(self, adapter): + """Return a copy of this PropComparator which will use the given adaption function + on the local side of generated expressions. + + """ + return self.__class__(self.property, self.mapper, getattr(self, '_of_type', None), adapter) + + @property + def parententity(self): + return self.property.parent + + def __clause_element__(self): + elem = self.property.parent._with_polymorphic_selectable + if self.adapter: + return self.adapter(elem) + else: + return elem + + def operate(self, op, *other, **kwargs): + return op(self, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + return op(self, *other, **kwargs) + + def of_type(self, cls): + return RelationshipProperty.Comparator(self.property, self.mapper, cls, adapter=self.adapter) + + def in_(self, other): + raise NotImplementedError("in_() not yet supported for relationships. For a " + "simple many-to-one, use in_() against the set of foreign key values.") + + __hash__ = None + + def __eq__(self, other): + if isinstance(other, (NoneType, expression._Null)): + if self.property.direction in [ONETOMANY, MANYTOMANY]: + return ~self._criterion_exists() + else: + return _orm_annotate(self.property._optimized_compare(None, adapt_source=self.adapter)) + elif self.property.uselist: + raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.") + else: + return _orm_annotate(self.property._optimized_compare(other, adapt_source=self.adapter)) + + def _criterion_exists(self, criterion=None, **kwargs): + if getattr(self, '_of_type', None): + target_mapper = self._of_type + to_selectable = target_mapper._with_polymorphic_selectable + if self.property._is_self_referential(): + to_selectable = to_selectable.alias() + + single_crit = target_mapper._single_table_criterion + if single_crit is not None: + if criterion is not None: + criterion = single_crit & criterion + else: + criterion = single_crit + else: + to_selectable = None + + if self.adapter: + source_selectable = self.__clause_element__() + else: + source_selectable = None + + pj, sj, source, dest, secondary, target_adapter = \ + self.property._create_joins(dest_polymorphic=True, dest_selectable=to_selectable, source_selectable=source_selectable) + + for k in kwargs: + crit = self.property.mapper.class_manager[k] == kwargs[k] + if criterion is None: + criterion = crit + else: + criterion = criterion & crit + + # annotate the *local* side of the join condition, in the case of pj + sj this + # is the full primaryjoin, in the case of just pj its the local side of + # the primaryjoin. + if sj is not None: + j = _orm_annotate(pj) & sj + else: + j = _orm_annotate(pj, exclude=self.property.remote_side) + + if criterion is not None and target_adapter: + # limit this adapter to annotated only? + criterion = target_adapter.traverse(criterion) + + # only have the "joined left side" of what we return be subject to Query adaption. The right + # side of it is used for an exists() subquery and should not correlate or otherwise reach out + # to anything in the enclosing query. + if criterion is not None: + criterion = criterion._annotate({'_halt_adapt': True}) + + crit = j & criterion + + return sql.exists([1], crit, from_obj=dest).correlate(source) + + def any(self, criterion=None, **kwargs): + if not self.property.uselist: + raise sa_exc.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().") + + return self._criterion_exists(criterion, **kwargs) + + def has(self, criterion=None, **kwargs): + if self.property.uselist: + raise sa_exc.InvalidRequestError("'has()' not implemented for collections. Use any().") + return self._criterion_exists(criterion, **kwargs) + + def contains(self, other, **kwargs): + if not self.property.uselist: + raise sa_exc.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==") + clause = self.property._optimized_compare(other, adapt_source=self.adapter) + + if self.property.secondaryjoin is not None: + clause.negation_clause = self.__negated_contains_or_equals(other) + + return clause + + def __negated_contains_or_equals(self, other): + if self.property.direction == MANYTOONE: + state = attributes.instance_state(other) + strategy = self.property._get_strategy(strategies.LazyLoader) + + def state_bindparam(state, col): + o = state.obj() # strong ref + return lambda: self.property.mapper._get_committed_attr_by_column(o, col) + + def adapt(col): + if self.adapter: + return self.adapter(col) + else: + return col + + if strategy.use_get: + return sql.and_(*[ + sql.or_( + adapt(x) != state_bindparam(state, y), + adapt(x) == None) + for (x, y) in self.property.local_remote_pairs]) + + criterion = sql.and_(*[x==y for (x, y) in zip(self.property.mapper.primary_key, self.property.mapper.primary_key_from_instance(other))]) + return ~self._criterion_exists(criterion) + + def __ne__(self, other): + if isinstance(other, (NoneType, expression._Null)): + if self.property.direction == MANYTOONE: + return sql.or_(*[x!=None for x in self.property._foreign_keys]) + else: + return self._criterion_exists() + elif self.property.uselist: + raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.") + else: + return self.__negated_contains_or_equals(other) + + @util.memoized_property + def property(self): + self.prop.parent.compile() + return self.prop + + def compare(self, op, value, value_is_parent=False, alias_secondary=True): + if op == operators.eq: + if value is None: + if self.uselist: + return ~sql.exists([1], self.primaryjoin) + else: + return self._optimized_compare(None, + value_is_parent=value_is_parent, + alias_secondary=alias_secondary) + else: + return self._optimized_compare(value, + value_is_parent=value_is_parent, + alias_secondary=alias_secondary) + else: + return op(self.comparator, value) + + def _optimized_compare(self, value, value_is_parent=False, + adapt_source=None, alias_secondary=True): + if value is not None: + value = attributes.instance_state(value) + return self._get_strategy(strategies.LazyLoader).\ + lazy_clause(value, + reverse_direction=not value_is_parent, + alias_secondary=alias_secondary, adapt_source=adapt_source) + + def __str__(self): + return str(self.parent.class_.__name__) + "." + self.key + + def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): + if load: + # TODO: no test coverage for recursive check + for r in self._reverse_property: + if (source_state, r) in _recursive: + return + + if not "merge" in self.cascade: + return + + if self.key not in source_dict: + return + + if self.uselist: + instances = source_state.get_impl(self.key).\ + get(source_state, source_dict) + if hasattr(instances, '_sa_adapter'): + # convert collections to adapters to get a true iterator + instances = instances._sa_adapter + + if load: + # for a full merge, pre-load the destination collection, + # so that individual _merge of each item pulls from identity + # map for those already present. + # also assumes CollectionAttrbiuteImpl behavior of loading + # "old" list in any case + dest_state.get_impl(self.key).get(dest_state, dest_dict) + + dest_list = [] + for current in instances: + current_state = attributes.instance_state(current) + current_dict = attributes.instance_dict(current) + _recursive[(current_state, self)] = True + obj = session._merge(current_state, current_dict, load=load, _recursive=_recursive) + if obj is not None: + dest_list.append(obj) + + if not load: + coll = attributes.init_state_collection(dest_state, dest_dict, self.key) + for c in dest_list: + coll.append_without_event(c) + else: + dest_state.get_impl(self.key)._set_iterable(dest_state, dest_dict, dest_list) + else: + current = source_dict[self.key] + if current is not None: + current_state = attributes.instance_state(current) + current_dict = attributes.instance_dict(current) + _recursive[(current_state, self)] = True + obj = session._merge(current_state, current_dict, load=load, _recursive=_recursive) + else: + obj = None + + if not load: + dest_dict[self.key] = obj + else: + dest_state.get_impl(self.key).set(dest_state, dest_dict, obj, None) + + def cascade_iterator(self, type_, state, visited_instances, halt_on=None): + if not type_ in self.cascade: + return + + # only actively lazy load on the 'delete' cascade + if type_ != 'delete' or self.passive_deletes: + passive = attributes.PASSIVE_NO_INITIALIZE + else: + passive = attributes.PASSIVE_OFF + + if type_ == 'save-update': + instances = attributes.get_state_history(state, self.key, passive=passive).sum() + else: + instances = state.value_as_iterable(self.key, passive=passive) + + if instances: + for c in instances: + if c is not None and \ + c is not attributes.PASSIVE_NO_RESULT and \ + c not in visited_instances and \ + (halt_on is None or not halt_on(c)): + + if not isinstance(c, self.mapper.class_): + raise AssertionError("Attribute '%s' on class '%s' " + "doesn't handle objects " + "of type '%s'" % ( + self.key, + str(self.parent.class_), + str(c.__class__) + )) + visited_instances.add(c) + + # cascade using the mapper local to this + # object, so that its individual properties are located + instance_mapper = object_mapper(c) + yield (c, instance_mapper, attributes.instance_state(c)) + + def _add_reverse_property(self, key): + other = self.mapper._get_property(key) + self._reverse_property.add(other) + other._reverse_property.add(self) + + if not other._get_target().common_parent(self.parent): + raise sa_exc.ArgumentError("reverse_property %r on relationship %s references " + "relationship %s, which does not reference mapper %s" % (key, self, other, self.parent)) + + if self.direction in (ONETOMANY, MANYTOONE) and self.direction == other.direction: + raise sa_exc.ArgumentError("%s and back-reference %s are both of the same direction %r." + " Did you mean to set remote_side on the many-to-one side ?" % (other, self, self.direction)) + + def do_init(self): + self._get_target() + self._assert_is_primary() + self._process_dependent_arguments() + self._determine_joins() + self._determine_synchronize_pairs() + self._determine_direction() + self._determine_local_remote_pairs() + self._post_init() + self._generate_backref() + super(RelationshipProperty, self).do_init() + + def _get_target(self): + if not hasattr(self, 'mapper'): + if isinstance(self.argument, type): + self.mapper = mapper.class_mapper(self.argument, compile=False) + elif isinstance(self.argument, mapper.Mapper): + self.mapper = self.argument + elif util.callable(self.argument): + # accept a callable to suit various deferred-configurational schemes + self.mapper = mapper.class_mapper(self.argument(), compile=False) + else: + raise sa_exc.ArgumentError("relationship '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument))) + assert isinstance(self.mapper, mapper.Mapper), self.mapper + return self.mapper + + def _process_dependent_arguments(self): + + # accept callables for other attributes which may require deferred initialization + for attr in ('order_by', 'primaryjoin', 'secondaryjoin', 'secondary', '_foreign_keys', 'remote_side'): + if util.callable(getattr(self, attr)): + setattr(self, attr, getattr(self, attr)()) + + # in the case that InstrumentedAttributes were used to construct + # primaryjoin or secondaryjoin, remove the "_orm_adapt" annotation so these + # interact with Query in the same way as the original Table-bound Column objects + for attr in ('primaryjoin', 'secondaryjoin'): + val = getattr(self, attr) + if val is not None: + util.assert_arg_type(val, sql.ColumnElement, attr) + setattr(self, attr, _orm_deannotate(val)) + + if self.order_by is not False and self.order_by is not None: + self.order_by = [expression._literal_as_column(x) for x in util.to_list(self.order_by)] + + self._foreign_keys = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self._foreign_keys)) + self.remote_side = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self.remote_side)) + + if not self.parent.concrete: + for inheriting in self.parent.iterate_to_root(): + if inheriting is not self.parent and inheriting._get_property(self.key, raiseerr=False): + util.warn( + ("Warning: relationship '%s' on mapper '%s' supercedes " + "the same relationship on inherited mapper '%s'; this " + "can cause dependency issues during flush") % + (self.key, self.parent, inheriting)) + + # TODO: remove 'self.table' + self.target = self.table = self.mapper.mapped_table + + if self.cascade.delete_orphan: + if self.parent.class_ is self.mapper.class_: + raise sa_exc.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade " + "rule on a self-referential relationship. " + "You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self))) + self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_)) + + def _determine_joins(self): + if self.secondaryjoin is not None and self.secondary is None: + raise sa_exc.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument") + # if join conditions were not specified, figure them out based on foreign keys + + def _search_for_join(mapper, table): + # find a join between the given mapper's mapped table and the given table. + # will try the mapper's local table first for more specificity, then if not + # found will try the more general mapped table, which in the case of inheritance + # is a join. + try: + return join_condition(mapper.local_table, table) + except sa_exc.ArgumentError, e: + return join_condition(mapper.mapped_table, table) + + try: + if self.secondary is not None: + if self.secondaryjoin is None: + self.secondaryjoin = _search_for_join(self.mapper, self.secondary) + if self.primaryjoin is None: + self.primaryjoin = _search_for_join(self.parent, self.secondary) + else: + if self.primaryjoin is None: + self.primaryjoin = _search_for_join(self.parent, self.target) + except sa_exc.ArgumentError, e: + raise sa_exc.ArgumentError("Could not determine join condition between " + "parent/child tables on relationship %s. " + "Specify a 'primaryjoin' expression. If this is a " + "many-to-many relationship, 'secondaryjoin' is needed as well." % (self)) + + def _col_is_part_of_mappings(self, column): + if self.secondary is None: + return self.parent.mapped_table.c.contains_column(column) or \ + self.target.c.contains_column(column) + else: + return self.parent.mapped_table.c.contains_column(column) or \ + self.target.c.contains_column(column) or \ + self.secondary.c.contains_column(column) is not None + + def _determine_synchronize_pairs(self): + + if self.local_remote_pairs: + if not self._foreign_keys: + raise sa_exc.ArgumentError("foreign_keys argument is required with _local_remote_pairs argument") + + self.synchronize_pairs = [] + + for l, r in self.local_remote_pairs: + if r in self._foreign_keys: + self.synchronize_pairs.append((l, r)) + elif l in self._foreign_keys: + self.synchronize_pairs.append((r, l)) + else: + eq_pairs = criterion_as_pairs( + self.primaryjoin, + consider_as_foreign_keys=self._foreign_keys, + any_operator=self.viewonly + ) + eq_pairs = [ + (l, r) for l, r in eq_pairs if + (self._col_is_part_of_mappings(l) and + self._col_is_part_of_mappings(r)) + or self.viewonly and r in self._foreign_keys + ] + + if not eq_pairs: + if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True): + raise sa_exc.ArgumentError("Could not locate any equated, locally " + "mapped column pairs for primaryjoin condition '%s' on relationship %s. " + "For more relaxed rules on join conditions, the relationship may be " + "marked as viewonly=True." % (self.primaryjoin, self) + ) + else: + if self._foreign_keys: + raise sa_exc.ArgumentError("Could not determine relationship direction for " + "primaryjoin condition '%s', on relationship %s. " + "Do the columns in 'foreign_keys' represent only the 'foreign' columns " + "in this join condition ?" % (self.primaryjoin, self)) + else: + raise sa_exc.ArgumentError("Could not determine relationship direction for " + "primaryjoin condition '%s', on relationship %s. " + "Specify the 'foreign_keys' argument to indicate which columns " + "on the relationship are foreign." % (self.primaryjoin, self)) + + self.synchronize_pairs = eq_pairs + + if self.secondaryjoin is not None: + sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=self.viewonly) + sq_pairs = [(l, r) for l, r in sq_pairs if (self._col_is_part_of_mappings(l) and self._col_is_part_of_mappings(r)) or r in self._foreign_keys] + + if not sq_pairs: + if not self.viewonly and criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True): + raise sa_exc.ArgumentError("Could not locate any equated, locally mapped " + "column pairs for secondaryjoin condition '%s' on relationship %s. " + "For more relaxed rules on join conditions, the " + "relationship may be marked as viewonly=True." % (self.secondaryjoin, self) + ) + else: + raise sa_exc.ArgumentError("Could not determine relationship direction " + "for secondaryjoin condition '%s', on relationship %s. " + "Specify the foreign_keys argument to indicate which " + "columns on the relationship are foreign." % (self.secondaryjoin, self)) + + self.secondary_synchronize_pairs = sq_pairs + else: + self.secondary_synchronize_pairs = None + + self._foreign_keys = util.column_set(r for l, r in self.synchronize_pairs) + if self.secondary_synchronize_pairs: + self._foreign_keys.update(r for l, r in self.secondary_synchronize_pairs) + + def _determine_direction(self): + if self.secondaryjoin is not None: + self.direction = MANYTOMANY + + elif self._refers_to_parent_table(): + # self referential defaults to ONETOMANY unless the "remote" side is present + # and does not reference any foreign key columns + + if self.local_remote_pairs: + remote = [r for l, r in self.local_remote_pairs] + elif self.remote_side: + remote = self.remote_side + else: + remote = None + + if not remote or self._foreign_keys.\ + difference(l for l, r in self.synchronize_pairs).\ + intersection(remote): + self.direction = ONETOMANY + else: + self.direction = MANYTOONE + + else: + foreign_keys = [f for c, f in self.synchronize_pairs] + + parentcols = util.column_set(self.parent.mapped_table.c) + targetcols = util.column_set(self.mapper.mapped_table.c) + + # fk collection which suggests ONETOMANY. + onetomany_fk = targetcols.intersection(foreign_keys) + + # fk collection which suggests MANYTOONE. + manytoone_fk = parentcols.intersection(foreign_keys) + + if not onetomany_fk and not manytoone_fk: + raise sa_exc.ArgumentError( + "Can't determine relationship direction for relationship '%s' " + "- foreign key columns are present in neither the " + "parent nor the child's mapped tables" % self ) + + elif onetomany_fk and manytoone_fk: + # fks on both sides. do the same + # test only based on the local side. + referents = [c for c, f in self.synchronize_pairs] + onetomany_local = parentcols.intersection(referents) + manytoone_local = targetcols.intersection(referents) + + if onetomany_local and not manytoone_local: + self.direction = ONETOMANY + elif manytoone_local and not onetomany_local: + self.direction = MANYTOONE + elif onetomany_fk: + self.direction = ONETOMANY + elif manytoone_fk: + self.direction = MANYTOONE + + if not self.direction: + raise sa_exc.ArgumentError( + "Can't determine relationship direction for relationship '%s' " + "- foreign key columns are present in both the parent and " + "the child's mapped tables. Specify 'foreign_keys' " + "argument." % self) + + if self.cascade.delete_orphan and not self.single_parent and \ + (self.direction is MANYTOMANY or self.direction is MANYTOONE): + util.warn("On %s, delete-orphan cascade is not supported on a " + "many-to-many or many-to-one relationship when single_parent is not set. " + " Set single_parent=True on the relationship()." % self) + + def _determine_local_remote_pairs(self): + if not self.local_remote_pairs: + if self.remote_side: + if self.direction is MANYTOONE: + self.local_remote_pairs = [ + (r, l) for l, r in + criterion_as_pairs(self.primaryjoin, consider_as_referenced_keys=self.remote_side, any_operator=True) + ] + else: + self.local_remote_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.remote_side, any_operator=True) + + if not self.local_remote_pairs: + raise sa_exc.ArgumentError("Relationship %s could not determine any local/remote column pairs from remote side argument %r" % (self, self.remote_side)) + + else: + if self.viewonly: + eq_pairs = self.synchronize_pairs + if self.secondaryjoin is not None: + eq_pairs += self.secondary_synchronize_pairs + else: + eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True) + if self.secondaryjoin is not None: + eq_pairs += criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True) + eq_pairs = [(l, r) for l, r in eq_pairs if self._col_is_part_of_mappings(l) and self._col_is_part_of_mappings(r)] + + if self.direction is MANYTOONE: + self.local_remote_pairs = [(r, l) for l, r in eq_pairs] + else: + self.local_remote_pairs = eq_pairs + elif self.remote_side: + raise sa_exc.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.") + + for l, r in self.local_remote_pairs: + + if self.direction is ONETOMANY and not self._col_is_part_of_mappings(l): + raise sa_exc.ArgumentError("Local column '%s' is not part of mapping %s. " + "Specify remote_side argument to indicate which column " + "lazy join condition should compare against." % (l, self.parent)) + + elif self.direction is MANYTOONE and not self._col_is_part_of_mappings(r): + raise sa_exc.ArgumentError("Remote column '%s' is not part of mapping %s. " + "Specify remote_side argument to indicate which column lazy " + "join condition should bind." % (r, self.mapper)) + + self.local_side, self.remote_side = [util.ordered_column_set(x) for x in zip(*list(self.local_remote_pairs))] + + def _assert_is_primary(self): + if not self.is_primary() and \ + not mapper.class_mapper(self.parent.class_, compile=False)._get_property(self.key, raiseerr=False): + + raise sa_exc.ArgumentError("Attempting to assign a new relationship '%s' to " + "a non-primary mapper on class '%s'. New relationships can only be " + "added to the primary mapper, i.e. the very first " + "mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__)) + + def _generate_backref(self): + if not self.is_primary(): + return + + if self.backref is not None and not self.back_populates: + if isinstance(self.backref, basestring): + backref_key, kwargs = self.backref, {} + else: + backref_key, kwargs = self.backref + + mapper = self.mapper.primary_mapper() + if mapper._get_property(backref_key, raiseerr=False) is not None: + raise sa_exc.ArgumentError("Error creating backref '%s' on relationship '%s': " + "property of that name exists on mapper '%s'" % (backref_key, self, mapper)) + + if self.secondary is not None: + pj = kwargs.pop('primaryjoin', self.secondaryjoin) + sj = kwargs.pop('secondaryjoin', self.primaryjoin) + else: + pj = kwargs.pop('primaryjoin', self.primaryjoin) + sj = kwargs.pop('secondaryjoin', None) + if sj: + raise sa_exc.InvalidRequestError( + "Can't assign 'secondaryjoin' on a backref against " + "a non-secondary relationship.") + + foreign_keys = kwargs.pop('foreign_keys', self._foreign_keys) + + parent = self.parent.primary_mapper() + kwargs.setdefault('viewonly', self.viewonly) + kwargs.setdefault('post_update', self.post_update) + + self.back_populates = backref_key + relationship = RelationshipProperty( + parent, + self.secondary, + pj, + sj, + foreign_keys=foreign_keys, + back_populates=self.key, + **kwargs) + + mapper._configure_property(backref_key, relationship) + + + if self.back_populates: + self.extension = list(util.to_list(self.extension, default=[])) + self.extension.append(attributes.GenericBackrefExtension(self.back_populates)) + self._add_reverse_property(self.back_populates) + + + def _post_init(self): + self.logger.info("%s setup primary join %s", self, self.primaryjoin) + self.logger.info("%s setup secondary join %s", self, self.secondaryjoin) + self.logger.info("%s synchronize pairs [%s]", self, ",".join("(%s => %s)" % (l, r) for l, r in self.synchronize_pairs)) + self.logger.info("%s secondary synchronize pairs [%s]", self, ",".join(("(%s => %s)" % (l, r) for l, r in self.secondary_synchronize_pairs or []))) + self.logger.info("%s local/remote pairs [%s]", self, ",".join("(%s / %s)" % (l, r) for l, r in self.local_remote_pairs)) + self.logger.info("%s relationship direction %s", self, self.direction) + + if self.uselist is None: + self.uselist = self.direction is not MANYTOONE + + if not self.viewonly: + self._dependency_processor = dependency.create_dependency_processor(self) + + def _refers_to_parent_table(self): + for c, f in self.synchronize_pairs: + if c.table is f.table: + return True + else: + return False + + def _is_self_referential(self): + return self.mapper.common_parent(self.parent) + + def _create_joins(self, source_polymorphic=False, source_selectable=None, dest_polymorphic=False, dest_selectable=None, of_type=None): + if source_selectable is None: + if source_polymorphic and self.parent.with_polymorphic: + source_selectable = self.parent._with_polymorphic_selectable + + aliased = False + if dest_selectable is None: + if dest_polymorphic and self.mapper.with_polymorphic: + dest_selectable = self.mapper._with_polymorphic_selectable + aliased = True + else: + dest_selectable = self.mapper.mapped_table + + if self._is_self_referential() and source_selectable is None: + dest_selectable = dest_selectable.alias() + aliased = True + else: + aliased = True + + aliased = aliased or (source_selectable is not None) + + primaryjoin, secondaryjoin, secondary = self.primaryjoin, self.secondaryjoin, self.secondary + + # adjust the join condition for single table inheritance, + # in the case that the join is to a subclass + # this is analgous to the "_adjust_for_single_table_inheritance()" + # method in Query. + + dest_mapper = of_type or self.mapper + + single_crit = dest_mapper._single_table_criterion + if single_crit is not None: + if secondaryjoin is not None: + secondaryjoin = secondaryjoin & single_crit + else: + primaryjoin = primaryjoin & single_crit + + + if aliased: + if secondary is not None: + secondary = secondary.alias() + primary_aliasizer = ClauseAdapter(secondary) + if dest_selectable is not None: + secondary_aliasizer = ClauseAdapter(dest_selectable, equivalents=self.mapper._equivalent_columns).chain(primary_aliasizer) + else: + secondary_aliasizer = primary_aliasizer + + if source_selectable is not None: + primary_aliasizer = ClauseAdapter(secondary).chain(ClauseAdapter(source_selectable, equivalents=self.parent._equivalent_columns)) + + secondaryjoin = secondary_aliasizer.traverse(secondaryjoin) + else: + if dest_selectable is not None: + primary_aliasizer = ClauseAdapter(dest_selectable, exclude=self.local_side, equivalents=self.mapper._equivalent_columns) + if source_selectable is not None: + primary_aliasizer.chain(ClauseAdapter(source_selectable, exclude=self.remote_side, equivalents=self.parent._equivalent_columns)) + elif source_selectable is not None: + primary_aliasizer = ClauseAdapter(source_selectable, exclude=self.remote_side, equivalents=self.parent._equivalent_columns) + + secondary_aliasizer = None + + primaryjoin = primary_aliasizer.traverse(primaryjoin) + target_adapter = secondary_aliasizer or primary_aliasizer + target_adapter.include = target_adapter.exclude = None + else: + target_adapter = None + + if source_selectable is None: + source_selectable = self.parent.local_table + + if dest_selectable is None: + dest_selectable = self.mapper.local_table + + return (primaryjoin, secondaryjoin, + source_selectable, + dest_selectable, secondary, target_adapter) + + def register_dependencies(self, uowcommit): + if not self.viewonly: + self._dependency_processor.register_dependencies(uowcommit) + + def register_processors(self, uowcommit): + if not self.viewonly: + self._dependency_processor.register_processors(uowcommit) + +PropertyLoader = RelationProperty = RelationshipProperty +log.class_logger(RelationshipProperty) + +mapper.ColumnProperty = ColumnProperty +mapper.SynonymProperty = SynonymProperty +mapper.ComparableProperty = ComparableProperty +mapper.RelationshipProperty = RelationshipProperty +mapper.ConcreteInheritedProperty = ConcreteInheritedProperty diff --git a/sqlalchemy/orm/query.py b/sqlalchemy/orm/query.py new file mode 100644 index 0000000..e98ad89 --- /dev/null +++ b/sqlalchemy/orm/query.py @@ -0,0 +1,2469 @@ +# orm/query.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""The Query class and support. + +Defines the :class:`~sqlalchemy.orm.query.Query` class, the central construct used by +the ORM to construct database queries. + +The ``Query`` class should not be confused with the +:class:`~sqlalchemy.sql.expression.Select` class, which defines database SELECT +operations at the SQL (non-ORM) level. ``Query`` differs from ``Select`` in +that it returns ORM-mapped objects and interacts with an ORM session, whereas +the ``Select`` construct interacts directly with the database to return +iterable result sets. + +""" + +from itertools import chain +from operator import itemgetter + +from sqlalchemy import sql, util, log, schema +from sqlalchemy import exc as sa_exc +from sqlalchemy.orm import exc as orm_exc +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import expression, visitors, operators +from sqlalchemy.orm import ( + attributes, interfaces, mapper, object_mapper, evaluator, + ) +from sqlalchemy.orm.util import ( + AliasedClass, ORMAdapter, _entity_descriptor, _entity_info, + _is_aliased_class, _is_mapped_class, _orm_columns, _orm_selectable, + join as orm_join, + ) + + +__all__ = ['Query', 'QueryContext', 'aliased'] + + +aliased = AliasedClass + +def _generative(*assertions): + """Mark a method as generative.""" + + @util.decorator + def generate(fn, *args, **kw): + self = args[0]._clone() + for assertion in assertions: + assertion(self, fn.func_name) + fn(self, *args[1:], **kw) + return self + return generate + +class Query(object): + """ORM-level SQL construction object.""" + + _enable_eagerloads = True + _enable_assertions = True + _with_labels = False + _criterion = None + _yield_per = None + _lockmode = None + _order_by = False + _group_by = False + _having = None + _distinct = False + _offset = None + _limit = None + _statement = None + _correlate = frozenset() + _populate_existing = False + _version_check = False + _autoflush = True + _current_path = () + _only_load_props = None + _refresh_state = None + _from_obj = () + _filter_aliases = None + _from_obj_alias = None + _joinpath = _joinpoint = util.frozendict() + _execution_options = util.frozendict() + _params = util.frozendict() + _attributes = util.frozendict() + _with_options = () + _with_hints = () + + def __init__(self, entities, session=None): + self.session = session + self._polymorphic_adapters = {} + self._set_entities(entities) + + def _set_entities(self, entities, entity_wrapper=None): + if entity_wrapper is None: + entity_wrapper = _QueryEntity + self._entities = [] + for ent in util.to_list(entities): + entity_wrapper(self, ent) + + self._setup_aliasizers(self._entities) + + def _setup_aliasizers(self, entities): + if hasattr(self, '_mapper_adapter_map'): + # usually safe to share a single map, but copying to prevent + # subtle leaks if end-user is reusing base query with arbitrary + # number of aliased() objects + self._mapper_adapter_map = d = self._mapper_adapter_map.copy() + else: + self._mapper_adapter_map = d = {} + + for ent in entities: + for entity in ent.entities: + if entity not in d: + mapper, selectable, is_aliased_class = _entity_info(entity) + if not is_aliased_class and mapper.with_polymorphic: + with_polymorphic = mapper._with_polymorphic_mappers + if mapper.mapped_table not in self._polymorphic_adapters: + self.__mapper_loads_polymorphically_with(mapper, + sql_util.ColumnAdapter(selectable, mapper._equivalent_columns)) + adapter = None + elif is_aliased_class: + adapter = sql_util.ColumnAdapter(selectable, mapper._equivalent_columns) + with_polymorphic = None + else: + with_polymorphic = adapter = None + + d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic) + ent.setup_entity(entity, *d[entity]) + + def __mapper_loads_polymorphically_with(self, mapper, adapter): + for m2 in mapper._with_polymorphic_mappers: + self._polymorphic_adapters[m2] = adapter + for m in m2.iterate_to_root(): + self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter + + def _set_select_from(self, *obj): + + fa = [] + for from_obj in obj: + if isinstance(from_obj, expression._SelectBaseMixin): + from_obj = from_obj.alias() + fa.append(from_obj) + + self._from_obj = tuple(fa) + + if len(self._from_obj) == 1 and \ + isinstance(self._from_obj[0], expression.Alias): + equivs = self.__all_equivs() + self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj[0], equivs) + + def _get_polymorphic_adapter(self, entity, selectable): + self.__mapper_loads_polymorphically_with(entity.mapper, + sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns)) + + def _reset_polymorphic_adapter(self, mapper): + for m2 in mapper._with_polymorphic_mappers: + self._polymorphic_adapters.pop(m2, None) + for m in m2.iterate_to_root(): + self._polymorphic_adapters.pop(m.mapped_table, None) + self._polymorphic_adapters.pop(m.local_table, None) + + def __adapt_polymorphic_element(self, element): + if isinstance(element, expression.FromClause): + search = element + elif hasattr(element, 'table'): + search = element.table + else: + search = None + + if search is not None: + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) + + def __replace_element(self, adapters): + def replace(elem): + if '_halt_adapt' in elem._annotations: + return elem + + for adapter in adapters: + e = adapter(elem) + if e is not None: + return e + return replace + + def __replace_orm_element(self, adapters): + def replace(elem): + if '_halt_adapt' in elem._annotations: + return elem + + if "_orm_adapt" in elem._annotations or "parententity" in elem._annotations: + for adapter in adapters: + e = adapter(elem) + if e is not None: + return e + return replace + + @_generative() + def _adapt_all_clauses(self): + self._disable_orm_filtering = True + + def _adapt_col_list(self, cols): + return [ + self._adapt_clause(expression._literal_as_text(o), True, True) + for o in cols + ] + + def _adapt_clause(self, clause, as_filter, orm_only): + adapters = [] + if as_filter and self._filter_aliases: + for fa in self._filter_aliases._visitor_iterator: + adapters.append(fa.replace) + + if self._from_obj_alias: + adapters.append(self._from_obj_alias.replace) + + if self._polymorphic_adapters: + adapters.append(self.__adapt_polymorphic_element) + + if not adapters: + return clause + + if getattr(self, '_disable_orm_filtering', not orm_only): + return visitors.replacement_traverse( + clause, + {'column_collections':False}, + self.__replace_element(adapters) + ) + else: + return visitors.replacement_traverse( + clause, + {'column_collections':False}, + self.__replace_orm_element(adapters) + ) + + def _entity_zero(self): + return self._entities[0] + + def _mapper_zero(self): + return self._entity_zero().entity_zero + + def _extension_zero(self): + ent = self._entity_zero() + return getattr(ent, 'extension', ent.mapper.extension) + + @property + def _mapper_entities(self): + # TODO: this is wrong, its hardcoded to "priamry entity" when + # for the case of __all_equivs() it should not be + # the name of this accessor is wrong too + for ent in self._entities: + if hasattr(ent, 'primary_entity'): + yield ent + + def _joinpoint_zero(self): + return self._joinpoint.get('_joinpoint_entity', self._entity_zero().entity_zero) + + def _mapper_zero_or_none(self): + if not getattr(self._entities[0], 'primary_entity', False): + return None + return self._entities[0].mapper + + def _only_mapper_zero(self, rationale=None): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError( + rationale or "This operation requires a Query against a single mapper." + ) + return self._mapper_zero() + + def _only_entity_zero(self, rationale=None): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError( + rationale or "This operation requires a Query against a single mapper." + ) + return self._entity_zero() + + def _generate_mapper_zero(self): + if not getattr(self._entities[0], 'primary_entity', False): + raise sa_exc.InvalidRequestError("No primary mapper set up for this Query.") + entity = self._entities[0]._clone() + self._entities = [entity] + self._entities[1:] + return entity + + def __all_equivs(self): + equivs = {} + for ent in self._mapper_entities: + equivs.update(ent.mapper._equivalent_columns) + return equivs + + def _get_condition(self): + self._order_by = self._distinct = False + return self._no_criterion_condition("get") + + def _no_criterion_condition(self, meth): + if not self._enable_assertions: + return + if self._criterion is not None or self._statement is not None or self._from_obj or \ + self._limit is not None or self._offset is not None or \ + self._group_by or self._order_by or self._distinct: + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a " + "Query with existing criterion. " % meth) + + self._from_obj = () + self._statement = self._criterion = None + self._order_by = self._group_by = self._distinct = False + + def _no_clauseelement_condition(self, meth): + if not self._enable_assertions: + return + if self._order_by: + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a " + "Query with existing criterion. " % meth) + self._no_criterion_condition(meth) + + def _no_statement_condition(self, meth): + if not self._enable_assertions: + return + if self._statement: + raise sa_exc.InvalidRequestError( + ("Query.%s() being called on a Query with an existing full " + "statement - can't apply criterion.") % meth) + + def _no_limit_offset(self, meth): + if not self._enable_assertions: + return + if self._limit is not None or self._offset is not None: + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a Query which already has LIMIT or OFFSET applied. " + "To modify the row-limited results of a Query, call from_self() first. " + "Otherwise, call %s() before limit() or offset() are applied." % (meth, meth) + ) + + def _no_select_modifiers(self, meth): + if not self._enable_assertions: + return + for attr, methname, notset in ( + ('_limit', 'limit()', None), + ('_offset', 'offset()', None), + ('_order_by', 'order_by()', False), + ('_group_by', 'group_by()', False), + ('_distinct', 'distinct()', False), + ): + if getattr(self, attr) is not notset: + raise sa_exc.InvalidRequestError( + "Can't call Query.%s() when %s has been called" % (meth, methname) + ) + + def _get_options(self, populate_existing=None, + version_check=None, + only_load_props=None, + refresh_state=None): + if populate_existing: + self._populate_existing = populate_existing + if version_check: + self._version_check = version_check + if refresh_state: + self._refresh_state = refresh_state + if only_load_props: + self._only_load_props = set(only_load_props) + return self + + def _clone(self): + cls = self.__class__ + q = cls.__new__(cls) + q.__dict__ = self.__dict__.copy() + return q + + @property + def statement(self): + """The full SELECT statement represented by this Query. + + The statement by default will not have disambiguating labels + applied to the construct unless with_labels(True) is called + first. + + """ + + return self._compile_context(labels=self._with_labels).\ + statement._annotate({'_halt_adapt': True}) + + def subquery(self): + """return the full SELECT statement represented by this Query, + embedded within an Alias. + + Eager JOIN generation within the query is disabled. + + The statement by default will not have disambiguating labels + applied to the construct unless with_labels(True) is called + first. + + """ + return self.enable_eagerloads(False).statement.alias() + + def __clause_element__(self): + return self.enable_eagerloads(False).with_labels().statement + + @_generative() + def enable_eagerloads(self, value): + """Control whether or not eager joins and subqueries are + rendered. + + When set to False, the returned Query will not render + eager joins regardless of :func:`~sqlalchemy.orm.joinedload`, + :func:`~sqlalchemy.orm.subqueryload` options + or mapper-level ``lazy='joined'``/``lazy='subquery'`` + configurations. + + This is used primarily when nesting the Query's + statement into a subquery or other + selectable. + + """ + self._enable_eagerloads = value + + @_generative() + def with_labels(self): + """Apply column labels to the return value of Query.statement. + + Indicates that this Query's `statement` accessor should return + a SELECT statement that applies labels to all columns in the + form _; this is commonly used to + disambiguate columns from multiple tables which have the same + name. + + When the `Query` actually issues SQL to load rows, it always + uses column labeling. + + """ + self._with_labels = True + + @_generative() + def enable_assertions(self, value): + """Control whether assertions are generated. + + When set to False, the returned Query will + not assert its state before certain operations, + including that LIMIT/OFFSET has not been applied + when filter() is called, no criterion exists + when get() is called, and no "from_statement()" + exists when filter()/order_by()/group_by() etc. + is called. This more permissive mode is used by + custom Query subclasses to specify criterion or + other modifiers outside of the usual usage patterns. + + Care should be taken to ensure that the usage + pattern is even possible. A statement applied + by from_statement() will override any criterion + set by filter() or order_by(), for example. + + """ + self._enable_assertions = value + + @property + def whereclause(self): + """The WHERE criterion for this Query.""" + return self._criterion + + @_generative() + def _with_current_path(self, path): + """indicate that this query applies to objects loaded within a certain path. + + Used by deferred loaders (see strategies.py) which transfer query + options from an originating query to a newly generated query intended + for the deferred load. + + """ + self._current_path = path + + @_generative(_no_clauseelement_condition) + def with_polymorphic(self, cls_or_mappers, selectable=None, discriminator=None): + """Load columns for descendant mappers of this Query's mapper. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + :param cls_or_mappers: a single class or mapper, or list of class/mappers, + which inherit from this Query's mapper. Alternatively, it + may also be the string ``'*'``, in which case all descending + mappers will be added to the FROM clause. + + :param selectable: a table or select() statement that will + be used in place of the generated FROM clause. This argument + is required if any of the desired mappers use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` + argument must represent the full set of tables and columns mapped + by every desired mapper. Otherwise, the unaccounted mapped columns + will result in their table being appended directly to the FROM + clause which will usually lead to incorrect results. + + :param discriminator: a column to be used as the "discriminator" + column for the given selectable. If not given, the polymorphic_on + attribute of the mapper will be used, if any. This is useful + for mappers that don't have polymorphic loading behavior by default, + such as concrete table mappers. + + """ + entity = self._generate_mapper_zero() + entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable, discriminator=discriminator) + + @_generative() + def yield_per(self, count): + """Yield only ``count`` rows at a time. + + WARNING: use this method with caution; if the same instance is present + in more than one batch of rows, end-user changes to attributes will be + overwritten. + + In particular, it's usually impossible to use this setting with + eagerly loaded collections (i.e. any lazy='joined' or 'subquery') + since those collections will be cleared for a new load when + encountered in a subsequent result batch. In the case of 'subquery' + loading, the full result for all rows is fetched which generally + defeats the purpose of :meth:`~sqlalchemy.orm.query.Query.yield_per`. + + Also note that many DBAPIs do not "stream" results, pre-buffering + all rows before making them available, including mysql-python and + psycopg2. :meth:`~sqlalchemy.orm.query.Query.yield_per` will also + set the ``stream_results`` execution + option to ``True``, which currently is only understood by psycopg2 + and causes server side cursors to be used. + + """ + self._yield_per = count + self._execution_options = self._execution_options.copy() + self._execution_options['stream_results'] = True + + def get(self, ident): + """Return an instance of the object based on the given identifier, or None if not found. + + The `ident` argument is a scalar or tuple of primary key column values + in the order of the table def's primary key columns. + + """ + + # convert composite types to individual args + if hasattr(ident, '__composite_values__'): + ident = ident.__composite_values__() + + key = self._only_mapper_zero( + "get() can only be used against a single mapped class." + ).identity_key_from_primary_key(ident) + return self._get(key, ident) + + @_generative() + def correlate(self, *args): + self._correlate = self._correlate.union(_orm_selectable(s) for s in args) + + @_generative() + def autoflush(self, setting): + """Return a Query with a specific 'autoflush' setting. + + Note that a Session with autoflush=False will + not autoflush, even if this flag is set to True at the + Query level. Therefore this flag is usually used only + to disable autoflush for a specific Query. + + """ + self._autoflush = setting + + @_generative() + def populate_existing(self): + """Return a Query that will refresh all instances loaded. + + This includes all entities accessed from the database, including + secondary entities, eagerly-loaded collection items. + + All changes present on entities which are already present in the + session will be reset and the entities will all be marked "clean". + + An alternative to populate_existing() is to expire the Session + fully using session.expire_all(). + + """ + self._populate_existing = True + + def with_parent(self, instance, property=None): + """Add a join criterion corresponding to a relationship to the given + parent instance. + + instance + a persistent or detached instance which is related to class + represented by this query. + + property + string name of the property which relates this query's class to the + instance. if None, the method will attempt to find a suitable + property. + + Currently, this method only works with immediate parent relationships, + but in the future may be enhanced to work across a chain of parent + mappers. + + """ + from sqlalchemy.orm import properties + mapper = object_mapper(instance) + if property is None: + for prop in mapper.iterate_properties: + if isinstance(prop, properties.PropertyLoader) and prop.mapper is self._mapper_zero(): + break + else: + raise sa_exc.InvalidRequestError( + "Could not locate a property which relates instances " + "of class '%s' to instances of class '%s'" % + (self._mapper_zero().class_.__name__, instance.__class__.__name__) + ) + else: + prop = mapper.get_property(property, resolve_synonyms=True) + return self.filter(prop.compare(operators.eq, instance, value_is_parent=True)) + + @_generative() + def add_entity(self, entity, alias=None): + """add a mapped entity to the list of result columns to be returned.""" + + if alias is not None: + entity = aliased(entity, alias) + + self._entities = list(self._entities) + m = _MapperEntity(self, entity) + self._setup_aliasizers([m]) + + def from_self(self, *entities): + """return a Query that selects from this Query's SELECT statement. + + \*entities - optional list of entities which will replace + those being selected. + + """ + fromclause = self.with_labels().enable_eagerloads(False).\ + statement.correlate(None) + q = self._from_selectable(fromclause) + if entities: + q._set_entities(entities) + return q + + @_generative() + def _from_selectable(self, fromclause): + for attr in ('_statement', '_criterion', '_order_by', '_group_by', + '_limit', '_offset', '_joinpath', '_joinpoint', + '_distinct' + ): + self.__dict__.pop(attr, None) + self._set_select_from(fromclause) + old_entities = self._entities + self._entities = [] + for e in old_entities: + e.adapt_to_selectable(self, self._from_obj[0]) + + def values(self, *columns): + """Return an iterator yielding result tuples corresponding to the given list of columns""" + + if not columns: + return iter(()) + q = self._clone() + q._set_entities(columns, entity_wrapper=_ColumnEntity) + if not q._yield_per: + q._yield_per = 10 + return iter(q) + _values = values + + def value(self, column): + """Return a scalar result corresponding to the given column expression.""" + try: + # Py3K + #return self.values(column).__next__()[0] + # Py2K + return self.values(column).next()[0] + # end Py2K + except StopIteration: + return None + + @_generative() + def add_columns(self, *column): + """Add one or more column expressions to the list + of result columns to be returned.""" + + self._entities = list(self._entities) + l = len(self._entities) + for c in column: + _ColumnEntity(self, c) + # _ColumnEntity may add many entities if the + # given arg is a FROM clause + self._setup_aliasizers(self._entities[l:]) + + @util.pending_deprecation("add_column() superceded by add_columns()") + def add_column(self, column): + """Add a column expression to the list of result columns + to be returned.""" + + return self.add_columns(column) + + def options(self, *args): + """Return a new Query object, applying the given list of + MapperOptions. + + """ + return self._options(False, *args) + + def _conditional_options(self, *args): + return self._options(True, *args) + + @_generative() + def _options(self, conditional, *args): + # most MapperOptions write to the '_attributes' dictionary, + # so copy that as well + self._attributes = self._attributes.copy() + opts = tuple(util.flatten_iterator(args)) + self._with_options = self._with_options + opts + if conditional: + for opt in opts: + opt.process_query_conditionally(self) + else: + for opt in opts: + opt.process_query(self) + + @_generative() + def with_hint(self, selectable, text, dialect_name=None): + """Add an indexing hint for the given entity or selectable to + this :class:`Query`. + + Functionality is passed straight through to + :meth:`~sqlalchemy.sql.expression.Select.with_hint`, + with the addition that ``selectable`` can be a + :class:`Table`, :class:`Alias`, or ORM entity / mapped class + /etc. + """ + mapper, selectable, is_aliased_class = _entity_info(selectable) + + self._with_hints += ((selectable, text, dialect_name),) + + @_generative() + def execution_options(self, **kwargs): + """ Set non-SQL options which take effect during execution. + + The options are the same as those accepted by + :meth:`sqlalchemy.sql.expression.Executable.execution_options`. + + Note that the ``stream_results`` execution option is enabled + automatically if the :meth:`~sqlalchemy.orm.query.Query.yield_per()` + method is used. + + """ + _execution_options = self._execution_options.copy() + for key, value in kwargs.items(): + _execution_options[key] = value + self._execution_options = _execution_options + + @_generative() + def with_lockmode(self, mode): + """Return a new Query object with the specified locking mode.""" + + self._lockmode = mode + + @_generative() + def params(self, *args, **kwargs): + """add values for bind parameters which may have been specified in filter(). + + parameters may be specified using \**kwargs, or optionally a single dictionary + as the first positional argument. The reason for both is that \**kwargs is + convenient, however some parameter dictionaries contain unicode keys in which case + \**kwargs cannot be used. + + """ + if len(args) == 1: + kwargs.update(args[0]) + elif len(args) > 0: + raise sa_exc.ArgumentError("params() takes zero or one positional argument, which is a dictionary.") + self._params = self._params.copy() + self._params.update(kwargs) + + @_generative(_no_statement_condition, _no_limit_offset) + def filter(self, criterion): + """apply the given filtering criterion to the query and return the newly resulting ``Query`` + + the criterion is any sql.ClauseElement applicable to the WHERE clause of a select. + + """ + if isinstance(criterion, basestring): + criterion = sql.text(criterion) + + if criterion is not None and not isinstance(criterion, sql.ClauseElement): + raise sa_exc.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string") + + criterion = self._adapt_clause(criterion, True, True) + + if self._criterion is not None: + self._criterion = self._criterion & criterion + else: + self._criterion = criterion + + def filter_by(self, **kwargs): + """apply the given filtering criterion to the query and return the newly resulting ``Query``.""" + + clauses = [_entity_descriptor(self._joinpoint_zero(), key)[0] == value + for key, value in kwargs.iteritems()] + + return self.filter(sql.and_(*clauses)) + + @_generative(_no_statement_condition, _no_limit_offset) + @util.accepts_a_list_as_starargs(list_deprecation='deprecated') + def order_by(self, *criterion): + """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``""" + + if len(criterion) == 1 and criterion[0] is None: + self._order_by = None + else: + criterion = self._adapt_col_list(criterion) + + if self._order_by is False or self._order_by is None: + self._order_by = criterion + else: + self._order_by = self._order_by + criterion + + @_generative(_no_statement_condition, _no_limit_offset) + @util.accepts_a_list_as_starargs(list_deprecation='deprecated') + def group_by(self, *criterion): + """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``""" + + criterion = list(chain(*[_orm_columns(c) for c in criterion])) + + criterion = self._adapt_col_list(criterion) + + if self._group_by is False: + self._group_by = criterion + else: + self._group_by = self._group_by + criterion + + @_generative(_no_statement_condition, _no_limit_offset) + def having(self, criterion): + """apply a HAVING criterion to the query and return the newly resulting ``Query``.""" + + if isinstance(criterion, basestring): + criterion = sql.text(criterion) + + if criterion is not None and not isinstance(criterion, sql.ClauseElement): + raise sa_exc.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string") + + criterion = self._adapt_clause(criterion, True, True) + + if self._having is not None: + self._having = self._having & criterion + else: + self._having = criterion + + def union(self, *q): + """Produce a UNION of this Query against one or more queries. + + e.g.:: + + q1 = sess.query(SomeClass).filter(SomeClass.foo=='bar') + q2 = sess.query(SomeClass).filter(SomeClass.bar=='foo') + + q3 = q1.union(q2) + + The method accepts multiple Query objects so as to control + the level of nesting. A series of ``union()`` calls such as:: + + x.union(y).union(z).all() + + will nest on each ``union()``, and produces:: + + SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y) UNION SELECT * FROM Z) + + Whereas:: + + x.union(y, z).all() + + produces:: + + SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION SELECT * FROM Z) + + """ + + + return self._from_selectable( + expression.union(*([self]+ list(q)))) + + def union_all(self, *q): + """Produce a UNION ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See that + method for usage examples. + + """ + return self._from_selectable( + expression.union_all(*([self]+ list(q))) + ) + + def intersect(self, *q): + """Produce an INTERSECT of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See that + method for usage examples. + + """ + return self._from_selectable( + expression.intersect(*([self]+ list(q))) + ) + + def intersect_all(self, *q): + """Produce an INTERSECT ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See that + method for usage examples. + + """ + return self._from_selectable( + expression.intersect_all(*([self]+ list(q))) + ) + + def except_(self, *q): + """Produce an EXCEPT of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See that + method for usage examples. + + """ + return self._from_selectable( + expression.except_(*([self]+ list(q))) + ) + + def except_all(self, *q): + """Produce an EXCEPT ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See that + method for usage examples. + + """ + return self._from_selectable( + expression.except_all(*([self]+ list(q))) + ) + + @util.accepts_a_list_as_starargs(list_deprecation='deprecated') + def join(self, *props, **kwargs): + """Create a join against this ``Query`` object's criterion + and apply generatively, returning the newly resulting ``Query``. + + Each element in \*props may be: + + * a string property name, i.e. "rooms". This will join along the + relationship of the same name from this Query's "primary" mapper, if + one is present. + + * a class-mapped attribute, i.e. Houses.rooms. This will create a + join from "Houses" table to that of the "rooms" relationship. + + * a 2-tuple containing a target class or selectable, and an "ON" + clause. The ON clause can be the property name/ attribute like + above, or a SQL expression. + + e.g.:: + + # join along string attribute names + session.query(Company).join('employees') + session.query(Company).join('employees', 'tasks') + + # join the Person entity to an alias of itself, + # along the "friends" relationship + PAlias = aliased(Person) + session.query(Person).join((Palias, Person.friends)) + + # join from Houses to the "rooms" attribute on the + # "Colonials" subclass of Houses, then join to the + # "closets" relationship on Room + session.query(Houses).join(Colonials.rooms, Room.closets) + + # join from Company entities to the "employees" collection, + # using "people JOIN engineers" as the target. Then join + # to the "computers" collection on the Engineer entity. + session.query(Company).join((people.join(engineers), 'employees'), Engineer.computers) + + # join from Articles to Keywords, using the "keywords" attribute. + # assume this is a many-to-many relationship. + session.query(Article).join(Article.keywords) + + # same thing, but spelled out entirely explicitly + # including the association table. + session.query(Article).join( + (article_keywords, Articles.id==article_keywords.c.article_id), + (Keyword, Keyword.id==article_keywords.c.keyword_id) + ) + + \**kwargs include: + + aliased - when joining, create anonymous aliases of each table. This is + used for self-referential joins or multiple joins to the same table. + Consider usage of the aliased(SomeClass) construct as a more explicit + approach to this. + + from_joinpoint - when joins are specified using string property names, + locate the property from the mapper found in the most recent previous + join() call, instead of from the root entity. + + """ + aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) + if kwargs: + raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys())) + return self._join(props, + outerjoin=False, create_aliases=aliased, + from_joinpoint=from_joinpoint) + + @util.accepts_a_list_as_starargs(list_deprecation='deprecated') + def outerjoin(self, *props, **kwargs): + """Create a left outer join against this ``Query`` object's criterion + and apply generatively, retunring the newly resulting ``Query``. + + Usage is the same as the ``join()`` method. + + """ + aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) + if kwargs: + raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys())) + return self._join(props, + outerjoin=True, create_aliases=aliased, + from_joinpoint=from_joinpoint) + + @_generative(_no_statement_condition, _no_limit_offset) + def _join(self, keys, outerjoin, create_aliases, from_joinpoint): + """consumes arguments from join() or outerjoin(), places them into a consistent + format with which to form the actual JOIN constructs. + + """ + self._polymorphic_adapters = self._polymorphic_adapters.copy() + + if not from_joinpoint: + self._reset_joinpoint() + + for arg1 in util.to_list(keys): + if isinstance(arg1, tuple): + arg1, arg2 = arg1 + else: + arg2 = None + + # determine onclause/right_entity. there + # is a little bit of legacy behavior still at work here + # which means they might be in either order. may possibly + # lock this down to (right_entity, onclause) in 0.6. + if isinstance(arg1, (interfaces.PropComparator, basestring)): + right_entity, onclause = arg2, arg1 + else: + right_entity, onclause = arg1, arg2 + + left_entity = prop = None + + if isinstance(onclause, basestring): + left_entity = self._joinpoint_zero() + + descriptor, prop = _entity_descriptor(left_entity, onclause) + onclause = descriptor + + # check for q.join(Class.propname, from_joinpoint=True) + # and Class is that of the current joinpoint + elif from_joinpoint and isinstance(onclause, interfaces.PropComparator): + left_entity = onclause.parententity + + left_mapper, left_selectable, left_is_aliased = \ + _entity_info(self._joinpoint_zero()) + if left_mapper is left_entity: + left_entity = self._joinpoint_zero() + descriptor, prop = _entity_descriptor(left_entity, onclause.key) + onclause = descriptor + + if isinstance(onclause, interfaces.PropComparator): + if right_entity is None: + right_entity = onclause.property.mapper + of_type = getattr(onclause, '_of_type', None) + if of_type: + right_entity = of_type + else: + right_entity = onclause.property.mapper + + left_entity = onclause.parententity + + prop = onclause.property + if not isinstance(onclause, attributes.QueryableAttribute): + onclause = prop + + if not create_aliases: + # check for this path already present. + # don't render in that case. + if (left_entity, right_entity, prop.key) in self._joinpoint: + self._joinpoint = self._joinpoint[(left_entity, right_entity, prop.key)] + continue + + elif onclause is not None and right_entity is None: + # TODO: no coverage here + raise NotImplementedError("query.join(a==b) not supported.") + + self._join_left_to_right( + left_entity, + right_entity, onclause, + outerjoin, create_aliases, prop) + + def _join_left_to_right(self, left, right, onclause, outerjoin, create_aliases, prop): + """append a JOIN to the query's from clause.""" + + if left is None: + left = self._joinpoint_zero() + + if left is right and \ + not create_aliases: + raise sa_exc.InvalidRequestError( + "Can't construct a join from %s to %s, they are the same entity" % + (left, right)) + + left_mapper, left_selectable, left_is_aliased = _entity_info(left) + right_mapper, right_selectable, is_aliased_class = _entity_info(right) + + if right_mapper and prop and not right_mapper.common_parent(prop.mapper): + raise sa_exc.InvalidRequestError( + "Join target %s does not correspond to " + "the right side of join condition %s" % (right, onclause) + ) + + if not right_mapper and prop: + right_mapper = prop.mapper + + need_adapter = False + + if right_mapper and right is right_selectable: + if not right_selectable.is_derived_from(right_mapper.mapped_table): + raise sa_exc.InvalidRequestError( + "Selectable '%s' is not derived from '%s'" % + (right_selectable.description, right_mapper.mapped_table.description)) + + if not isinstance(right_selectable, expression.Alias): + right_selectable = right_selectable.alias() + + right = aliased(right_mapper, right_selectable) + need_adapter = True + + aliased_entity = right_mapper and \ + not is_aliased_class and \ + ( + right_mapper.with_polymorphic or + isinstance(right_mapper.mapped_table, expression.Join) + ) + + if not need_adapter and (create_aliases or aliased_entity): + right = aliased(right) + need_adapter = True + + # if joining on a MapperProperty path, + # track the path to prevent redundant joins + if not create_aliases and prop: + + self._joinpoint = jp = { + '_joinpoint_entity':right, + 'prev':((left, right, prop.key), self._joinpoint) + } + + # copy backwards to the root of the _joinpath + # dict, so that no existing dict in the path is mutated + while 'prev' in jp: + f, prev = jp['prev'] + prev = prev.copy() + prev[f] = jp + jp['prev'] = (f, prev) + jp = prev + + self._joinpath = jp + + else: + self._joinpoint = { + '_joinpoint_entity':right + } + + # if an alias() of the right side was generated here, + # apply an adapter to all subsequent filter() calls + # until reset_joinpoint() is called. + if need_adapter: + self._filter_aliases = ORMAdapter(right, + equivalents=right_mapper._equivalent_columns, chain_to=self._filter_aliases) + + # if the onclause is a ClauseElement, adapt it with any + # adapters that are in place right now + if isinstance(onclause, expression.ClauseElement): + onclause = self._adapt_clause(onclause, True, True) + + # if an alias() on the right side was generated, + # which is intended to wrap a the right side in a subquery, + # ensure that columns retrieved from this target in the result + # set are also adapted. + if aliased_entity: + self.__mapper_loads_polymorphically_with( + right_mapper, + ORMAdapter( + right, + equivalents=right_mapper._equivalent_columns + ) + ) + + join_to_left = not is_aliased_class and not left_is_aliased + + if self._from_obj: + replace_clause_index, clause = sql_util.find_join_source( + self._from_obj, + left_selectable) + if clause is not None: + # the entire query's FROM clause is an alias of itself (i.e. from_self(), similar). + # if the left clause is that one, ensure it aliases to the left side. + if self._from_obj_alias and clause is self._from_obj[0]: + join_to_left = True + + clause = orm_join(clause, + right, + onclause, isouter=outerjoin, + join_to_left=join_to_left) + + self._from_obj = \ + self._from_obj[:replace_clause_index] + \ + (clause, ) + \ + self._from_obj[replace_clause_index + 1:] + return + + if left_mapper: + for ent in self._entities: + if ent.corresponds_to(left): + clause = ent.selectable + break + else: + clause = left + else: + clause = None + + if clause is None: + raise sa_exc.InvalidRequestError("Could not find a FROM clause to join from") + + clause = orm_join(clause, right, onclause, isouter=outerjoin, join_to_left=join_to_left) + self._from_obj = self._from_obj + (clause,) + + def _reset_joinpoint(self): + self._joinpoint = self._joinpath + self._filter_aliases = None + + @_generative(_no_statement_condition) + def reset_joinpoint(self): + """return a new Query reset the 'joinpoint' of this Query reset + back to the starting mapper. Subsequent generative calls will + be constructed from the new joinpoint. + + Note that each call to join() or outerjoin() also starts from + the root. + + """ + self._reset_joinpoint() + + @_generative(_no_clauseelement_condition) + def select_from(self, *from_obj): + """Set the `from_obj` parameter of the query and return the newly + resulting ``Query``. This replaces the table which this Query selects + from with the given table. + + ``select_from()`` also accepts class arguments. Though usually not necessary, + can ensure that the full selectable of the given mapper is applied, e.g. + for joined-table mappers. + + """ + + obj = [] + for fo in from_obj: + if _is_mapped_class(fo): + mapper, selectable, is_aliased_class = _entity_info(fo) + obj.append(selectable) + elif not isinstance(fo, expression.FromClause): + raise sa_exc.ArgumentError("select_from() accepts FromClause objects only.") + else: + obj.append(fo) + + self._set_select_from(*obj) + + def __getitem__(self, item): + if isinstance(item, slice): + start, stop, step = util.decode_slice(item) + + if isinstance(stop, int) and isinstance(start, int) and stop - start <= 0: + return [] + + # perhaps we should execute a count() here so that we + # can still use LIMIT/OFFSET ? + elif (isinstance(start, int) and start < 0) \ + or (isinstance(stop, int) and stop < 0): + return list(self)[item] + + res = self.slice(start, stop) + if step is not None: + return list(res)[None:None:item.step] + else: + return list(res) + else: + return list(self[item:item+1])[0] + + @_generative(_no_statement_condition) + def slice(self, start, stop): + """apply LIMIT/OFFSET to the ``Query`` based on a " + "range and return the newly resulting ``Query``.""" + + if start is not None and stop is not None: + self._offset = (self._offset or 0) + start + self._limit = stop - start + elif start is None and stop is not None: + self._limit = stop + elif start is not None and stop is None: + self._offset = (self._offset or 0) + start + + @_generative(_no_statement_condition) + def limit(self, limit): + """Apply a ``LIMIT`` to the query and return the newly resulting + + ``Query``. + + """ + self._limit = limit + + @_generative(_no_statement_condition) + def offset(self, offset): + """Apply an ``OFFSET`` to the query and return the newly resulting + ``Query``. + + """ + self._offset = offset + + @_generative(_no_statement_condition) + def distinct(self): + """Apply a ``DISTINCT`` to the query and return the newly resulting + ``Query``. + + """ + self._distinct = True + + def all(self): + """Return the results represented by this ``Query`` as a list. + + This results in an execution of the underlying query. + + """ + return list(self) + + @_generative(_no_clauseelement_condition) + def from_statement(self, statement): + """Execute the given SELECT statement and return results. + + This method bypasses all internal statement compilation, and the + statement is executed without modification. + + The statement argument is either a string, a ``select()`` construct, + or a ``text()`` construct, and should return the set of columns + appropriate to the entity class represented by this ``Query``. + + Also see the ``instances()`` method. + + """ + if isinstance(statement, basestring): + statement = sql.text(statement) + + if not isinstance(statement, (expression._TextClause, expression._SelectBaseMixin)): + raise sa_exc.ArgumentError("from_statement accepts text(), select(), and union() objects only.") + + self._statement = statement + + def first(self): + """Return the first result of this ``Query`` or + None if the result doesn't contain any row. + + first() applies a limit of one within the generated SQL, so that + only one primary entity row is generated on the server side + (note this may consist of multiple result rows if join-loaded + collections are present). + + Calling ``first()`` results in an execution of the underlying query. + + """ + if self._statement is not None: + ret = list(self)[0:1] + else: + ret = list(self[0:1]) + if len(ret) > 0: + return ret[0] + else: + return None + + def one(self): + """Return exactly one result or raise an exception. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` + if multiple object identities are returned, or if multiple + rows are returned for a query that does not return object + identities. + + Note that an entity query, that is, one which selects one or + more mapped classes as opposed to individual column attributes, + may ultimately represent many rows but only one row of + unique entity or entities - this is a successful result for one(). + + Calling ``one()`` results in an execution of the underlying query. + As of 0.6, ``one()`` fully fetches all results instead of applying + any kind of limit, so that the "unique"-ing of entities does not + conceal multiple object identities. + + """ + ret = list(self) + + l = len(ret) + if l == 1: + return ret[0] + elif l == 0: + raise orm_exc.NoResultFound("No row was found for one()") + else: + raise orm_exc.MultipleResultsFound( + "Multiple rows were found for one()") + + def scalar(self): + """Return the first element of the first result or None + if no rows present. If multiple rows are returned, + raises MultipleResultsFound. + + >>> session.query(Item).scalar() + + >>> session.query(Item.id).scalar() + 1 + >>> session.query(Item.id).filter(Item.id < 0).scalar() + None + >>> session.query(Item.id, Item.name).scalar() + 1 + >>> session.query(func.count(Parent.id)).scalar() + 20 + + This results in an execution of the underlying query. + + """ + try: + ret = self.one() + if not isinstance(ret, tuple): + return ret + return ret[0] + except orm_exc.NoResultFound: + return None + + def __iter__(self): + context = self._compile_context() + context.statement.use_labels = True + if self._autoflush and not self._populate_existing: + self.session._autoflush() + return self._execute_and_instances(context) + + def _execute_and_instances(self, querycontext): + result = self.session.execute( + querycontext.statement, params=self._params, + mapper=self._mapper_zero_or_none()) + return self.instances(result, querycontext) + + def instances(self, cursor, __context=None): + """Given a ResultProxy cursor as returned by connection.execute(), + return an ORM result as an iterator. + + e.g.:: + + result = engine.execute("select * from users") + for u in session.query(User).instances(result): + print u + """ + session = self.session + + context = __context + if context is None: + context = QueryContext(self) + + context.runid = _new_runid() + + filtered = bool(list(self._mapper_entities)) + single_entity = filtered and len(self._entities) == 1 + + if filtered: + if single_entity: + filter = lambda x: util.unique_list(x, util.IdentitySet) + else: + filter = util.unique_list + else: + filter = None + + custom_rows = single_entity and \ + 'append_result' in self._entities[0].extension + + (process, labels) = \ + zip(*[ + query_entity.row_processor(self, context, custom_rows) + for query_entity in self._entities + ]) + + if not single_entity: + labels = [l for l in labels if l] + + while True: + context.progress = {} + context.partials = {} + + if self._yield_per: + fetch = cursor.fetchmany(self._yield_per) + if not fetch: + break + else: + fetch = cursor.fetchall() + + if custom_rows: + rows = [] + for row in fetch: + process[0](row, rows) + elif single_entity: + rows = [process[0](row, None) for row in fetch] + else: + rows = [util.NamedTuple([proc(row, None) for proc in process], labels) + for row in fetch] + + if filter: + rows = filter(rows) + + if context.refresh_state and self._only_load_props \ + and context.refresh_state in context.progress: + context.refresh_state.commit( + context.refresh_state.dict, self._only_load_props) + context.progress.pop(context.refresh_state) + + session._finalize_loaded(context.progress) + + for ii, (dict_, attrs) in context.partials.iteritems(): + ii.commit(dict_, attrs) + + for row in rows: + yield row + + if not self._yield_per: + break + + def merge_result(self, iterator, load=True): + """Merge a result into this Query's Session. + + Given an iterator returned by a Query of the same structure as this one, + return an identical iterator of results, with all mapped instances + merged into the session using Session.merge(). This is an optimized + method which will merge all mapped instances, preserving the structure + of the result rows and unmapped columns with less method overhead than + that of calling Session.merge() explicitly for each value. + + The structure of the results is determined based on the column list + of this Query - if these do not correspond, unchecked errors will occur. + + The 'load' argument is the same as that of Session.merge(). + + """ + + session = self.session + if load: + # flush current contents if we expect to load data + session._autoflush() + + autoflush = session.autoflush + try: + session.autoflush = False + single_entity = len(self._entities) == 1 + if single_entity: + if isinstance(self._entities[0], _MapperEntity): + result = [session._merge( + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, _recursive={}) + for instance in iterator] + else: + result = list(iterator) + else: + mapped_entities = [i for i, e in enumerate(self._entities) + if isinstance(e, _MapperEntity)] + result = [] + for row in iterator: + newrow = list(row) + for i in mapped_entities: + newrow[i] = session._merge( + attributes.instance_state(newrow[i]), + attributes.instance_dict(newrow[i]), + load=load, _recursive={}) + result.append(util.NamedTuple(newrow, row._labels)) + + return iter(result) + finally: + session.autoflush = autoflush + + + def _get(self, key=None, ident=None, refresh_state=None, lockmode=None, + only_load_props=None, passive=None): + lockmode = lockmode or self._lockmode + + mapper = self._mapper_zero() + if not self._populate_existing and \ + not refresh_state and \ + not mapper.always_refresh and \ + lockmode is None: + instance = self.session.identity_map.get(key) + if instance: + # item present in identity map with a different class + if not issubclass(instance.__class__, mapper.class_): + return None + + state = attributes.instance_state(instance) + + # expired - ensure it still exists + if state.expired: + if passive is attributes.PASSIVE_NO_FETCH: + return attributes.PASSIVE_NO_RESULT + try: + state() + except orm_exc.ObjectDeletedError: + self.session._remove_newly_deleted(state) + return None + return instance + elif passive is attributes.PASSIVE_NO_FETCH: + return attributes.PASSIVE_NO_RESULT + + if ident is None: + if key is not None: + ident = key[1] + else: + ident = util.to_list(ident) + + if refresh_state is None: + q = self._clone() + q._get_condition() + else: + q = self._clone() + + if ident is not None: + (_get_clause, _get_params) = mapper._get_clause + + # None present in ident - turn those comparisons + # into "IS NULL" + if None in ident: + nones = set([ + _get_params[col].key for col, value in + zip(mapper.primary_key, ident) if value is None + ]) + _get_clause = sql_util.adapt_criterion_to_null( + _get_clause, nones) + + _get_clause = q._adapt_clause(_get_clause, True, False) + q._criterion = _get_clause + + params = dict([ + (_get_params[primary_key].key, id_val) + for id_val, primary_key in zip(ident, mapper.primary_key) + ]) + + if len(params) != len(mapper.primary_key): + raise sa_exc.InvalidRequestError( + "Incorrect number of values in identifier to formulate primary " + "key for query.get(); primary key columns are %s" % + ','.join("'%s'" % c for c in mapper.primary_key)) + + q._params = params + + if lockmode is not None: + q._lockmode = lockmode + q._get_options( + populate_existing=bool(refresh_state), + version_check=(lockmode is not None), + only_load_props=only_load_props, + refresh_state=refresh_state) + q._order_by = None + + try: + return q.one() + except orm_exc.NoResultFound: + return None + + @property + def _select_args(self): + return { + 'limit':self._limit, + 'offset':self._offset, + 'distinct':self._distinct, + 'group_by':self._group_by or None, + 'having':self._having + } + + @property + def _should_nest_selectable(self): + kwargs = self._select_args + return (kwargs.get('limit') is not None or + kwargs.get('offset') is not None or + kwargs.get('distinct', False)) + + def count(self): + """Return a count of rows this Query would return. + + For simple entity queries, count() issues + a SELECT COUNT, and will specifically count the primary + key column of the first entity only. If the query uses + LIMIT, OFFSET, or DISTINCT, count() will wrap the statement + generated by this Query in a subquery, from which a SELECT COUNT + is issued, so that the contract of "how many rows + would be returned?" is honored. + + For queries that request specific columns or expressions, + count() again makes no assumptions about those expressions + and will wrap everything in a subquery. Therefore, + ``Query.count()`` is usually not what you want in this case. + To count specific columns, often in conjunction with + GROUP BY, use ``func.count()`` as an individual column expression + instead of ``Query.count()``. See the ORM tutorial + for an example. + + """ + should_nest = [self._should_nest_selectable] + def ent_cols(ent): + if isinstance(ent, _MapperEntity): + return ent.mapper.primary_key + else: + should_nest[0] = True + return [ent.column] + + return self._col_aggregate(sql.literal_column('1'), sql.func.count, + nested_cols=chain(*[ent_cols(ent) for ent in self._entities]), + should_nest = should_nest[0] + ) + + def _col_aggregate(self, col, func, nested_cols=None, should_nest=False): + context = QueryContext(self) + + for entity in self._entities: + entity.setup_context(self, context) + + if context.from_clause: + from_obj = list(context.from_clause) + else: + from_obj = context.froms + + self._adjust_for_single_inheritance(context) + + whereclause = context.whereclause + + if should_nest: + if not nested_cols: + nested_cols = [col] + else: + nested_cols = list(nested_cols) + s = sql.select(nested_cols, whereclause, + from_obj=from_obj, use_labels=True, + **self._select_args) + s = s.alias() + s = sql.select( + [func(s.corresponding_column(col) or col)]).select_from(s) + else: + s = sql.select([func(col)], whereclause, from_obj=from_obj, + **self._select_args) + + if self._autoflush and not self._populate_existing: + self.session._autoflush() + return self.session.scalar(s, params=self._params, + mapper=self._mapper_zero()) + + def delete(self, synchronize_session='evaluate'): + """Perform a bulk delete query. + + Deletes rows matched by this query from the database. + + :param synchronize_session: chooses the strategy for the removal of + matched objects from the session. Valid values are: + + False - don't synchronize the session. This option is the most + efficient and is reliable once the session is expired, which + typically occurs after a commit(), or explicitly using + expire_all(). Before the expiration, objects may still remain in + the session which were in fact deleted which can lead to confusing + results if they are accessed via get() or already loaded + collections. + + 'fetch' - performs a select query before the delete to find + objects that are matched by the delete query and need to be + removed from the session. Matched objects are removed from the + session. + + 'evaluate' - Evaluate the query's criteria in Python straight on + the objects in the session. If evaluation of the criteria isn't + implemented, an error is raised. In that case you probably + want to use the 'fetch' strategy as a fallback. + + The expression evaluator currently doesn't account for differing + string collations between the database and Python. + + Returns the number of rows deleted, excluding any cascades. + + The method does *not* offer in-Python cascading of relationships - it is + assumed that ON DELETE CASCADE is configured for any foreign key + references which require it. The Session needs to be expired (occurs + automatically after commit(), or call expire_all()) in order for the + state of dependent objects subject to delete or delete-orphan cascade + to be correctly represented. + + Also, the ``before_delete()`` and ``after_delete()`` + :class:`~sqlalchemy.orm.interfaces.MapperExtension` methods are not + called from this method. For a delete hook here, use the + ``after_bulk_delete()`` + :class:`~sqlalchemy.orm.interfaces.MapperExtension` method. + + """ + #TODO: lots of duplication and ifs - probably needs to be refactored to strategies + #TODO: cascades need handling. + + if synchronize_session not in [False, 'evaluate', 'fetch']: + raise sa_exc.ArgumentError("Valid strategies for session " + "synchronization are False, 'evaluate' and 'fetch'") + self._no_select_modifiers("delete") + + self = self.enable_eagerloads(False) + + context = self._compile_context() + if len(context.statement.froms) != 1 or \ + not isinstance(context.statement.froms[0], schema.Table): + raise sa_exc.ArgumentError("Only deletion via a single table " + "query is currently supported") + primary_table = context.statement.froms[0] + + session = self.session + + if synchronize_session == 'evaluate': + try: + evaluator_compiler = evaluator.EvaluatorCompiler() + if self.whereclause is not None: + eval_condition = evaluator_compiler.process(self.whereclause) + else: + def eval_condition(obj): + return True + + except evaluator.UnevaluatableError: + raise sa_exc.InvalidRequestError("Could not evaluate current criteria in Python. " + "Specify 'fetch' or False for the synchronize_session parameter.") + + delete_stmt = sql.delete(primary_table, context.whereclause) + + if synchronize_session == 'fetch': + #TODO: use RETURNING when available + select_stmt = context.statement.with_only_columns(primary_table.primary_key) + matched_rows = session.execute(select_stmt, params=self._params).fetchall() + + if self._autoflush: + session._autoflush() + result = session.execute(delete_stmt, params=self._params) + + if synchronize_session == 'evaluate': + target_cls = self._mapper_zero().class_ + + #TODO: detect when the where clause is a trivial primary key match + objs_to_expunge = [obj for (cls, pk),obj in session.identity_map.iteritems() + if issubclass(cls, target_cls) and eval_condition(obj)] + for obj in objs_to_expunge: + session._remove_newly_deleted(attributes.instance_state(obj)) + elif synchronize_session == 'fetch': + target_mapper = self._mapper_zero() + for primary_key in matched_rows: + identity_key = target_mapper.identity_key_from_primary_key(list(primary_key)) + if identity_key in session.identity_map: + session._remove_newly_deleted(attributes.instance_state(session.identity_map[identity_key])) + + for ext in session.extensions: + ext.after_bulk_delete(session, self, context, result) + + return result.rowcount + + def update(self, values, synchronize_session='evaluate'): + """Perform a bulk update query. + + Updates rows matched by this query in the database. + + :param values: a dictionary with attributes names as keys and literal + values or sql expressions as values. + + :param synchronize_session: chooses the strategy to update the + attributes on objects in the session. Valid values are: + + False - don't synchronize the session. This option is the most + efficient and is reliable once the session is expired, which + typically occurs after a commit(), or explicitly using + expire_all(). Before the expiration, updated objects may still + remain in the session with stale values on their attributes, which + can lead to confusing results. + + 'fetch' - performs a select query before the update to find + objects that are matched by the update query. The updated + attributes are expired on matched objects. + + 'evaluate' - Evaluate the Query's criteria in Python straight on + the objects in the session. If evaluation of the criteria isn't + implemented, an exception is raised. + + The expression evaluator currently doesn't account for differing + string collations between the database and Python. + + Returns the number of rows matched by the update. + + The method does *not* offer in-Python cascading of relationships - it is assumed that + ON UPDATE CASCADE is configured for any foreign key references which require it. + + The Session needs to be expired (occurs automatically after commit(), or call expire_all()) + in order for the state of dependent objects subject foreign key cascade to be + correctly represented. + + Also, the ``before_update()`` and ``after_update()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension` + methods are not called from this method. For an update hook here, use the + ``after_bulk_update()`` :class:`~sqlalchemy.orm.interfaces.SessionExtension` method. + + """ + + #TODO: value keys need to be mapped to corresponding sql cols and instr.attr.s to string keys + #TODO: updates of manytoone relationships need to be converted to fk assignments + #TODO: cascades need handling. + + if synchronize_session == 'expire': + util.warn_deprecated("The 'expire' value as applied to " + "the synchronize_session argument of " + "query.update() is now called 'fetch'") + synchronize_session = 'fetch' + + if synchronize_session not in [False, 'evaluate', 'fetch']: + raise sa_exc.ArgumentError("Valid strategies for session synchronization are False, 'evaluate' and 'fetch'") + self._no_select_modifiers("update") + + self = self.enable_eagerloads(False) + + context = self._compile_context() + if len(context.statement.froms) != 1 or not isinstance(context.statement.froms[0], schema.Table): + raise sa_exc.ArgumentError("Only update via a single table query is currently supported") + primary_table = context.statement.froms[0] + + session = self.session + + if synchronize_session == 'evaluate': + try: + evaluator_compiler = evaluator.EvaluatorCompiler() + if self.whereclause is not None: + eval_condition = evaluator_compiler.process(self.whereclause) + else: + def eval_condition(obj): + return True + + value_evaluators = {} + for key,value in values.iteritems(): + key = expression._column_as_key(key) + value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value)) + except evaluator.UnevaluatableError: + raise sa_exc.InvalidRequestError("Could not evaluate current criteria in Python. " + "Specify 'fetch' or False for the synchronize_session parameter.") + + update_stmt = sql.update(primary_table, context.whereclause, values) + + if synchronize_session == 'fetch': + select_stmt = context.statement.with_only_columns(primary_table.primary_key) + matched_rows = session.execute(select_stmt, params=self._params).fetchall() + + if self._autoflush: + session._autoflush() + result = session.execute(update_stmt, params=self._params) + + if synchronize_session == 'evaluate': + target_cls = self._mapper_zero().class_ + + for (cls, pk),obj in session.identity_map.iteritems(): + evaluated_keys = value_evaluators.keys() + + if issubclass(cls, target_cls) and eval_condition(obj): + state, dict_ = attributes.instance_state(obj), attributes.instance_dict(obj) + + # only evaluate unmodified attributes + to_evaluate = state.unmodified.intersection(evaluated_keys) + for key in to_evaluate: + dict_[key] = value_evaluators[key](obj) + + state.commit(dict_, list(to_evaluate)) + + # expire attributes with pending changes + # (there was no autoflush, so they are overwritten) + state.expire_attributes(dict_, set(evaluated_keys).difference(to_evaluate)) + + elif synchronize_session == 'fetch': + target_mapper = self._mapper_zero() + + for primary_key in matched_rows: + identity_key = target_mapper.identity_key_from_primary_key(list(primary_key)) + if identity_key in session.identity_map: + session.expire( + session.identity_map[identity_key], + [expression._column_as_key(k) for k in values] + ) + + for ext in session.extensions: + ext.after_bulk_update(session, self, context, result) + + return result.rowcount + + def _compile_context(self, labels=True): + context = QueryContext(self) + + if context.statement is not None: + return context + + if self._lockmode: + try: + for_update = {'read': 'read', + 'update': True, + 'update_nowait': 'nowait', + None: False}[self._lockmode] + except KeyError: + raise sa_exc.ArgumentError("Unknown lockmode %r" % self._lockmode) + else: + for_update = False + + for entity in self._entities: + entity.setup_context(self, context) + + for rec in context.create_eager_joins: + strategy = rec[0] + strategy(*rec[1:]) + + eager_joins = context.eager_joins.values() + + if context.from_clause: + froms = list(context.from_clause) # "load from explicit FROMs" mode, + # i.e. when select_from() or join() is used + else: + froms = context.froms # "load from discrete FROMs" mode, + # i.e. when each _MappedEntity has its own FROM + + self._adjust_for_single_inheritance(context) + + if not context.primary_columns: + if self._only_load_props: + raise sa_exc.InvalidRequestError( + "No column-based properties specified for refresh operation." + " Use session.expire() to reload collections and related items.") + else: + raise sa_exc.InvalidRequestError( + "Query contains no columns with which to SELECT from.") + + if context.multi_row_eager_loaders and self._should_nest_selectable: + # for eager joins present and LIMIT/OFFSET/DISTINCT, + # wrap the query inside a select, + # then append eager joins onto that + + if context.order_by: + order_by_col_expr = list( + chain(*[ + sql_util.find_columns(o) + for o in context.order_by + ]) + ) + else: + context.order_by = None + order_by_col_expr = [] + + inner = sql.select( + context.primary_columns + order_by_col_expr, + context.whereclause, + from_obj=froms, + use_labels=labels, + correlate=False, + order_by=context.order_by, + **self._select_args + ) + + for hint in self._with_hints: + inner = inner.with_hint(*hint) + + if self._correlate: + inner = inner.correlate(*self._correlate) + + inner = inner.alias() + + equivs = self.__all_equivs() + + context.adapter = sql_util.ColumnAdapter(inner, equivs) + + statement = sql.select( + [inner] + context.secondary_columns, + for_update=for_update, + use_labels=labels) + + if self._execution_options: + statement = statement.execution_options(**self._execution_options) + + from_clause = inner + for eager_join in eager_joins: + # EagerLoader places a 'stop_on' attribute on the join, + # giving us a marker as to where the "splice point" of the join should be + from_clause = sql_util.splice_joins(from_clause, eager_join, eager_join.stop_on) + + statement.append_from(from_clause) + + if context.order_by: + statement.append_order_by(*context.adapter.copy_and_process(context.order_by)) + + statement.append_order_by(*context.eager_order_by) + else: + if not context.order_by: + context.order_by = None + + if self._distinct and context.order_by: + order_by_col_expr = list( + chain(*[ + sql_util.find_columns(o) + for o in context.order_by + ]) + ) + context.primary_columns += order_by_col_expr + + froms += tuple(context.eager_joins.values()) + + statement = sql.select( + context.primary_columns + context.secondary_columns, + context.whereclause, + from_obj=froms, + use_labels=labels, + for_update=for_update, + correlate=False, + order_by=context.order_by, + **self._select_args + ) + + for hint in self._with_hints: + statement = statement.with_hint(*hint) + + if self._execution_options: + statement = statement.execution_options(**self._execution_options) + + if self._correlate: + statement = statement.correlate(*self._correlate) + + if context.eager_order_by: + statement.append_order_by(*context.eager_order_by) + + context.statement = statement + + return context + + def _adjust_for_single_inheritance(self, context): + """Apply single-table-inheritance filtering. + + For all distinct single-table-inheritance mappers represented in the + columns clause of this query, add criterion to the WHERE clause of the + given QueryContext such that only the appropriate subtypes are + selected from the total results. + + """ + for entity, (mapper, adapter, s, i, w) in self._mapper_adapter_map.iteritems(): + single_crit = mapper._single_table_criterion + if single_crit is not None: + if adapter: + single_crit = adapter.traverse(single_crit) + single_crit = self._adapt_clause(single_crit, False, False) + context.whereclause = sql.and_(context.whereclause, single_crit) + + def __str__(self): + return str(self._compile_context().statement) + + +class _QueryEntity(object): + """represent an entity column returned within a Query result.""" + + def __new__(cls, *args, **kwargs): + if cls is _QueryEntity: + entity = args[1] + if not isinstance(entity, basestring) and _is_mapped_class(entity): + cls = _MapperEntity + else: + cls = _ColumnEntity + return object.__new__(cls) + + def _clone(self): + q = self.__class__.__new__(self.__class__) + q.__dict__ = self.__dict__.copy() + return q + +class _MapperEntity(_QueryEntity): + """mapper/class/AliasedClass entity""" + + def __init__(self, query, entity): + self.primary_entity = not query._entities + query._entities.append(self) + + self.entities = [entity] + self.entity_zero = entity + + def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic): + self.mapper = mapper + self.extension = self.mapper.extension + self.adapter = adapter + self.selectable = from_obj + self._with_polymorphic = with_polymorphic + self._polymorphic_discriminator = None + self.is_aliased_class = is_aliased_class + if is_aliased_class: + self.path_entity = self.entity = self.entity_zero = entity + else: + self.path_entity = mapper + self.entity = self.entity_zero = mapper + + def set_with_polymorphic(self, query, cls_or_mappers, selectable, discriminator): + if cls_or_mappers is None: + query._reset_polymorphic_adapter(self.mapper) + return + + mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable) + self._with_polymorphic = mappers + self._polymorphic_discriminator = discriminator + + # TODO: do the wrapped thing here too so that with_polymorphic() can be + # applied to aliases + if not self.is_aliased_class: + self.selectable = from_obj + self.adapter = query._get_polymorphic_adapter(self, from_obj) + + def corresponds_to(self, entity): + if _is_aliased_class(entity) or self.is_aliased_class: + return entity is self.path_entity + else: + return entity.common_parent(self.path_entity) + + def adapt_to_selectable(self, query, sel): + query._entities.append(self) + + def _get_entity_clauses(self, query, context): + + adapter = None + if not self.is_aliased_class and query._polymorphic_adapters: + adapter = query._polymorphic_adapters.get(self.mapper, None) + + if not adapter and self.adapter: + adapter = self.adapter + + if adapter: + if query._from_obj_alias: + ret = adapter.wrap(query._from_obj_alias) + else: + ret = adapter + else: + ret = query._from_obj_alias + + return ret + + def row_processor(self, query, context, custom_rows): + adapter = self._get_entity_clauses(query, context) + + if context.adapter and adapter: + adapter = adapter.wrap(context.adapter) + elif not adapter: + adapter = context.adapter + + # polymorphic mappers which have concrete tables in their hierarchy usually + # require row aliasing unconditionally. + if not adapter and self.mapper._requires_row_aliasing: + adapter = sql_util.ColumnAdapter(self.selectable, self.mapper._equivalent_columns) + + if self.primary_entity: + _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter, + extension=self.extension, only_load_props=query._only_load_props, refresh_state=context.refresh_state, + polymorphic_discriminator=self._polymorphic_discriminator + ) + else: + _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter, + polymorphic_discriminator=self._polymorphic_discriminator) + + if self.is_aliased_class: + entname = self.entity._sa_label_name + else: + entname = self.mapper.class_.__name__ + + return _instance, entname + + def setup_context(self, query, context): + adapter = self._get_entity_clauses(query, context) + + context.froms += (self.selectable,) + + if context.order_by is False and self.mapper.order_by: + context.order_by = self.mapper.order_by + + # apply adaptation to the mapper's order_by if needed. + if adapter: + context.order_by = adapter.adapt_list(util.to_list(context.order_by)) + + for value in self.mapper._iterate_polymorphic_properties(self._with_polymorphic): + if query._only_load_props and value.key not in query._only_load_props: + continue + value.setup( + context, + self, + (self.path_entity,), + adapter, + only_load_props=query._only_load_props, + column_collection=context.primary_columns + ) + + if self._polymorphic_discriminator is not None: + if adapter: + pd = adapter.columns[self._polymorphic_discriminator] + else: + pd = self._polymorphic_discriminator + context.primary_columns.append(pd) + + def __str__(self): + return str(self.mapper) + +class _ColumnEntity(_QueryEntity): + """Column/expression based entity.""" + + def __init__(self, query, column): + if isinstance(column, basestring): + column = sql.literal_column(column) + self._result_label = column.name + elif isinstance(column, attributes.QueryableAttribute): + self._result_label = column.key + column = column.__clause_element__() + else: + self._result_label = getattr(column, 'key', None) + + if not isinstance(column, expression.ColumnElement) and hasattr(column, '_select_iterable'): + for c in column._select_iterable: + if c is column: + break + _ColumnEntity(query, c) + + if c is not column: + return + + if not isinstance(column, sql.ColumnElement): + raise sa_exc.InvalidRequestError( + "SQL expression, column, or mapped entity expected - got '%r'" % column + ) + + # if the Column is unnamed, give it a + # label() so that mutable column expressions + # can be located in the result even + # if the expression's identity has been changed + # due to adaption + if not column._label: + column = column.label(None) + + query._entities.append(self) + + self.column = column + self.froms = set() + + # look for ORM entities represented within the + # given expression. Try to count only entities + # for columns whos FROM object is in the actual list + # of FROMs for the overall expression - this helps + # subqueries which were built from ORM constructs from + # leaking out their entities into the main select construct + actual_froms = set(column._from_objects) + + self.entities = util.OrderedSet( + elem._annotations['parententity'] + for elem in visitors.iterate(column, {}) + if 'parententity' in elem._annotations + and actual_froms.intersection(elem._from_objects) + ) + + if self.entities: + self.entity_zero = list(self.entities)[0] + else: + self.entity_zero = None + + def adapt_to_selectable(self, query, sel): + _ColumnEntity(query, sel.corresponding_column(self.column)) + + def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic): + self.selectable = from_obj + self.froms.add(from_obj) + + def corresponds_to(self, entity): + if self.entity_zero is None: + return False + elif _is_aliased_class(entity): + return entity is self.entity_zero + else: + return not _is_aliased_class(self.entity_zero) and \ + entity.common_parent(self.entity_zero) + + def _resolve_expr_against_query_aliases(self, query, expr, context): + return query._adapt_clause(expr, False, True) + + def row_processor(self, query, context, custom_rows): + column = self._resolve_expr_against_query_aliases(query, self.column, context) + + if context.adapter: + column = context.adapter.columns[column] + + def proc(row, result): + return row[column] + + return (proc, self._result_label) + + def setup_context(self, query, context): + column = self._resolve_expr_against_query_aliases(query, self.column, context) + context.froms += tuple(self.froms) + context.primary_columns.append(column) + + def __str__(self): + return str(self.column) + +log.class_logger(Query) + +class QueryContext(object): + multi_row_eager_loaders = False + adapter = None + froms = () + + def __init__(self, query): + + if query._statement is not None: + if isinstance(query._statement, expression._SelectBaseMixin) and not query._statement.use_labels: + self.statement = query._statement.apply_labels() + else: + self.statement = query._statement + else: + self.statement = None + self.from_clause = query._from_obj + self.whereclause = query._criterion + self.order_by = query._order_by + + self.query = query + self.session = query.session + self.populate_existing = query._populate_existing + self.version_check = query._version_check + self.refresh_state = query._refresh_state + self.primary_columns = [] + self.secondary_columns = [] + self.eager_order_by = [] + self.eager_joins = {} + self.create_eager_joins = [] + self.propagate_options = set(o for o in query._with_options if o.propagate_to_loaders) + self.attributes = query._attributes.copy() + +class AliasOption(interfaces.MapperOption): + + def __init__(self, alias): + self.alias = alias + + def process_query(self, query): + if isinstance(self.alias, basestring): + alias = query._mapper_zero().mapped_table.alias(self.alias) + else: + alias = self.alias + query._from_obj_alias = sql_util.ColumnAdapter(alias) + + +_runid = 1L +_id_lock = util.threading.Lock() + +def _new_runid(): + global _runid + _id_lock.acquire() + try: + _runid += 1 + return _runid + finally: + _id_lock.release() diff --git a/sqlalchemy/orm/scoping.py b/sqlalchemy/orm/scoping.py new file mode 100644 index 0000000..40bbb32 --- /dev/null +++ b/sqlalchemy/orm/scoping.py @@ -0,0 +1,205 @@ +# scoping.py +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import sqlalchemy.exceptions as sa_exc +from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \ + to_list, get_cls_kwargs, deprecated +from sqlalchemy.orm import ( + EXT_CONTINUE, MapperExtension, class_mapper, object_session + ) +from sqlalchemy.orm import exc as orm_exc +from sqlalchemy.orm.session import Session + + +__all__ = ['ScopedSession'] + + +class ScopedSession(object): + """Provides thread-local management of Sessions. + + Usage:: + + Session = scoped_session(sessionmaker(autoflush=True)) + + ... use session normally. + + """ + + def __init__(self, session_factory, scopefunc=None): + self.session_factory = session_factory + if scopefunc: + self.registry = ScopedRegistry(session_factory, scopefunc) + else: + self.registry = ThreadLocalRegistry(session_factory) + self.extension = _ScopedExt(self) + + def __call__(self, **kwargs): + if kwargs: + scope = kwargs.pop('scope', False) + if scope is not None: + if self.registry.has(): + raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.") + else: + sess = self.session_factory(**kwargs) + self.registry.set(sess) + return sess + else: + return self.session_factory(**kwargs) + else: + return self.registry() + + def remove(self): + """Dispose of the current contextual session.""" + + if self.registry.has(): + self.registry().close() + self.registry.clear() + + @deprecated("Session.mapper is deprecated. " + "Please see http://www.sqlalchemy.org/trac/wiki/UsageRecipes/SessionAwareMapper " + "for information on how to replicate its behavior.") + def mapper(self, *args, **kwargs): + """return a mapper() function which associates this ScopedSession with the Mapper. + + DEPRECATED. + + """ + + from sqlalchemy.orm import mapper + + extension_args = dict((arg, kwargs.pop(arg)) + for arg in get_cls_kwargs(_ScopedExt) + if arg in kwargs) + + kwargs['extension'] = extension = to_list(kwargs.get('extension', [])) + if extension_args: + extension.append(self.extension.configure(**extension_args)) + else: + extension.append(self.extension) + return mapper(*args, **kwargs) + + def configure(self, **kwargs): + """reconfigure the sessionmaker used by this ScopedSession.""" + + self.session_factory.configure(**kwargs) + + def query_property(self, query_cls=None): + """return a class property which produces a `Query` object against the + class when called. + + e.g.:: + Session = scoped_session(sessionmaker()) + + class MyClass(object): + query = Session.query_property() + + # after mappers are defined + result = MyClass.query.filter(MyClass.name=='foo').all() + + Produces instances of the session's configured query class by + default. To override and use a custom implementation, provide + a ``query_cls`` callable. The callable will be invoked with + the class's mapper as a positional argument and a session + keyword argument. + + There is no limit to the number of query properties placed on + a class. + + """ + class query(object): + def __get__(s, instance, owner): + try: + mapper = class_mapper(owner) + if mapper: + if query_cls: + # custom query class + return query_cls(mapper, session=self.registry()) + else: + # session's configured query class + return self.registry().query(mapper) + except orm_exc.UnmappedClassError: + return None + return query() + +def instrument(name): + def do(self, *args, **kwargs): + return getattr(self.registry(), name)(*args, **kwargs) + return do +for meth in Session.public_methods: + setattr(ScopedSession, meth, instrument(meth)) + +def makeprop(name): + def set(self, attr): + setattr(self.registry(), name, attr) + def get(self): + return getattr(self.registry(), name) + return property(get, set) +for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active', 'autoflush'): + setattr(ScopedSession, prop, makeprop(prop)) + +def clslevel(name): + def do(cls, *args, **kwargs): + return getattr(Session, name)(*args, **kwargs) + return classmethod(do) +for prop in ('close_all', 'object_session', 'identity_key'): + setattr(ScopedSession, prop, clslevel(prop)) + +class _ScopedExt(MapperExtension): + def __init__(self, context, validate=False, save_on_init=True): + self.context = context + self.validate = validate + self.save_on_init = save_on_init + self.set_kwargs_on_init = True + + def validating(self): + return _ScopedExt(self.context, validate=True) + + def configure(self, **kwargs): + return _ScopedExt(self.context, **kwargs) + + def instrument_class(self, mapper, class_): + class query(object): + def __getattr__(s, key): + return getattr(self.context.registry().query(class_), key) + def __call__(s): + return self.context.registry().query(class_) + def __get__(self, instance, cls): + return self + + if not 'query' in class_.__dict__: + class_.query = query() + + if self.set_kwargs_on_init and class_.__init__ is object.__init__: + class_.__init__ = self._default__init__(mapper) + + def _default__init__(ext, mapper): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + if ext.validate: + if not mapper.get_property(key, resolve_synonyms=False, + raiseerr=False): + raise sa_exc.ArgumentError( + "Invalid __init__ argument: '%s'" % key) + setattr(self, key, value) + return __init__ + + def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): + if self.save_on_init: + session = kwargs.pop('_sa_session', None) + if session is None: + session = self.context.registry() + session._save_without_cascade(instance) + return EXT_CONTINUE + + def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): + sess = object_session(instance) + if sess: + sess.expunge(instance) + return EXT_CONTINUE + + def dispose_class(self, mapper, class_): + if hasattr(class_, 'query'): + delattr(class_, 'query') diff --git a/sqlalchemy/orm/session.py b/sqlalchemy/orm/session.py new file mode 100644 index 0000000..0a3fbe7 --- /dev/null +++ b/sqlalchemy/orm/session.py @@ -0,0 +1,1604 @@ +# session.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Provides the Session class and related utilities.""" + +import weakref +from itertools import chain +import sqlalchemy.exceptions as sa_exc +from sqlalchemy import util, sql, engine, log +from sqlalchemy.sql import util as sql_util, expression +from sqlalchemy.orm import ( + SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, state + ) +from sqlalchemy.orm.util import object_mapper as _object_mapper +from sqlalchemy.orm.util import class_mapper as _class_mapper +from sqlalchemy.orm.util import ( + _class_to_mapper, _state_has_identity, _state_mapper, + ) +from sqlalchemy.orm.mapper import Mapper, _none_set +from sqlalchemy.orm.unitofwork import UOWTransaction +from sqlalchemy.orm import identity + +__all__ = ['Session', 'SessionTransaction', 'SessionExtension'] + + +def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, + expire_on_commit=True, **kwargs): + """Generate a custom-configured :class:`~sqlalchemy.orm.session.Session` class. + + The returned object is a subclass of ``Session``, which, when instantiated + with no arguments, uses the keyword arguments configured here as its + constructor arguments. + + It is intended that the `sessionmaker()` function be called within the + global scope of an application, and the returned class be made available + to the rest of the application as the single class used to instantiate + sessions. + + e.g.:: + + # global scope + Session = sessionmaker(autoflush=False) + + # later, in a local scope, create and use a session: + sess = Session() + + Any keyword arguments sent to the constructor itself will override the + "configured" keywords:: + + Session = sessionmaker() + + # bind an individual session to a connection + sess = Session(bind=connection) + + The class also includes a special classmethod ``configure()``, which + allows additional configurational options to take place after the custom + ``Session`` class has been generated. This is useful particularly for + defining the specific ``Engine`` (or engines) to which new instances of + ``Session`` should be bound:: + + Session = sessionmaker() + Session.configure(bind=create_engine('sqlite:///foo.db')) + + sess = Session() + + Options: + + autocommit + Defaults to ``False``. When ``True``, the ``Session`` does not keep a + persistent transaction running, and will acquire connections from the + engine on an as-needed basis, returning them immediately after their + use. Flushes will begin and commit (or possibly rollback) their own + transaction if no transaction is present. When using this mode, the + `session.begin()` method may be used to begin a transaction explicitly. + + Leaving it on its default value of ``False`` means that the ``Session`` + will acquire a connection and begin a transaction the first time it is + used, which it will maintain persistently until ``rollback()``, + ``commit()``, or ``close()`` is called. When the transaction is released + by any of these methods, the ``Session`` is ready for the next usage, + which will again acquire and maintain a new connection/transaction. + + autoflush + When ``True``, all query operations will issue a ``flush()`` call to + this ``Session`` before proceeding. This is a convenience feature so + that ``flush()`` need not be called repeatedly in order for database + queries to retrieve results. It's typical that ``autoflush`` is used in + conjunction with ``autocommit=False``. In this scenario, explicit calls + to ``flush()`` are rarely needed; you usually only need to call + ``commit()`` (which flushes) to finalize changes. + + bind + An optional ``Engine`` or ``Connection`` to which this ``Session`` + should be bound. When specified, all SQL operations performed by this + session will execute via this connectable. + + binds + An optional dictionary, which contains more granular "bind" information + than the ``bind`` parameter provides. This dictionary can map individual + ``Table`` instances as well as ``Mapper`` instances to individual + ``Engine`` or ``Connection`` objects. Operations which proceed relative + to a particular ``Mapper`` will consult this dictionary for the direct + ``Mapper`` instance as well as the mapper's ``mapped_table`` attribute + in order to locate an connectable to use. The full resolution is + described in the ``get_bind()`` method of ``Session``. Usage looks + like:: + + sess = Session(binds={ + SomeMappedClass: create_engine('postgresql://engine1'), + somemapper: create_engine('postgresql://engine2'), + some_table: create_engine('postgresql://engine3'), + }) + + Also see the ``bind_mapper()`` and ``bind_table()`` methods. + + \class_ + Specify an alternate class other than ``sqlalchemy.orm.session.Session`` + which should be used by the returned class. This is the only argument + that is local to the ``sessionmaker()`` function, and is not sent + directly to the constructor for ``Session``. + + _enable_transaction_accounting + Defaults to ``True``. A legacy-only flag which when ``False`` + disables *all* 0.5-style object accounting on transaction boundaries, + including auto-expiry of instances on rollback and commit, maintenance of + the "new" and "deleted" lists upon rollback, and autoflush + of pending changes upon begin(), all of which are interdependent. + + expire_on_commit + Defaults to ``True``. When ``True``, all instances will be fully expired after + each ``commit()``, so that all attribute/object access subsequent to a completed + transaction will load from the most recent database state. + + extension + An optional :class:`~sqlalchemy.orm.session.SessionExtension` instance, or + a list of such instances, which + will receive pre- and post- commit and flush events, as well as a + post-rollback event. User- defined code may be placed within these + hooks using a user-defined subclass of ``SessionExtension``. + + query_cls + Class which should be used to create new Query objects, as returned + by the ``query()`` method. Defaults to :class:`~sqlalchemy.orm.query.Query`. + + twophase + When ``True``, all transactions will be started using + :mod:~sqlalchemy.engine_TwoPhaseTransaction. During a ``commit()``, after + ``flush()`` has been issued for all attached databases, the + ``prepare()`` method on each database's ``TwoPhaseTransaction`` will be + called. This allows each database to roll back the entire transaction, + before each transaction is committed. + + weak_identity_map + When set to the default value of ``True``, a weak-referencing map is + used; instances which are not externally referenced will be garbage + collected immediately. For dereferenced instances which have pending + changes present, the attribute management system will create a temporary + strong-reference to the object which lasts until the changes are flushed + to the database, at which point it's again dereferenced. Alternatively, + when using the value ``False``, the identity map uses a regular Python + dictionary to store instances. The session will maintain all instances + present until they are removed using expunge(), clear(), or purge(). + + """ + kwargs['bind'] = bind + kwargs['autoflush'] = autoflush + kwargs['autocommit'] = autocommit + kwargs['expire_on_commit'] = expire_on_commit + + if class_ is None: + class_ = Session + + class Sess(object): + def __init__(self, **local_kwargs): + for k in kwargs: + local_kwargs.setdefault(k, kwargs[k]) + super(Sess, self).__init__(**local_kwargs) + + def configure(self, **new_kwargs): + """(Re)configure the arguments for this sessionmaker. + + e.g.:: + + Session = sessionmaker() + + Session.configure(bind=create_engine('sqlite://')) + """ + kwargs.update(new_kwargs) + configure = classmethod(configure) + s = type.__new__(type, "Session", (Sess, class_), {}) + return s + + +class SessionTransaction(object): + """A Session-level transaction. + + This corresponds to one or more :class:`~sqlalchemy.engine.Transaction` + instances behind the scenes, with one ``Transaction`` per ``Engine`` in + use. + + Direct usage of ``SessionTransaction`` is not necessary as of SQLAlchemy + 0.4; use the ``begin()`` and ``commit()`` methods on ``Session`` itself. + + The ``SessionTransaction`` object is **not** thread-safe. + + .. index:: + single: thread safety; SessionTransaction + + """ + + def __init__(self, session, parent=None, nested=False): + self.session = session + self._connections = {} + self._parent = parent + self.nested = nested + self._active = True + self._prepared = False + if not parent and nested: + raise sa_exc.InvalidRequestError( + "Can't start a SAVEPOINT transaction when no existing " + "transaction is in progress") + + if self.session._enable_transaction_accounting: + self._take_snapshot() + + @property + def is_active(self): + return self.session is not None and self._active + + def _assert_is_active(self): + self._assert_is_open() + if not self._active: + raise sa_exc.InvalidRequestError( + "The transaction is inactive due to a rollback in a " + "subtransaction. Issue rollback() to cancel the transaction.") + + def _assert_is_open(self, error_msg="The transaction is closed"): + if self.session is None: + raise sa_exc.InvalidRequestError(error_msg) + + @property + def _is_transaction_boundary(self): + return self.nested or not self._parent + + def connection(self, bindkey, **kwargs): + self._assert_is_active() + engine = self.session.get_bind(bindkey, **kwargs) + return self._connection_for_bind(engine) + + def _begin(self, nested=False): + self._assert_is_active() + return SessionTransaction( + self.session, self, nested=nested) + + def _iterate_parents(self, upto=None): + if self._parent is upto: + return (self,) + else: + if self._parent is None: + raise sa_exc.InvalidRequestError( + "Transaction %s is not on the active transaction list" % ( + upto)) + return (self,) + self._parent._iterate_parents(upto) + + def _take_snapshot(self): + if not self._is_transaction_boundary: + self._new = self._parent._new + self._deleted = self._parent._deleted + return + + if not self.session._flushing: + self.session.flush() + + self._new = weakref.WeakKeyDictionary() + self._deleted = weakref.WeakKeyDictionary() + + def _restore_snapshot(self): + assert self._is_transaction_boundary + + for s in set(self._new).union(self.session._new): + self.session._expunge_state(s) + + for s in set(self._deleted).union(self.session._deleted): + self.session._update_impl(s) + + assert not self.session._deleted + + for s in self.session.identity_map.all_states(): + _expire_state(s, s.dict, None, instance_dict=self.session.identity_map) + + def _remove_snapshot(self): + assert self._is_transaction_boundary + + if not self.nested and self.session.expire_on_commit: + for s in self.session.identity_map.all_states(): + _expire_state(s, s.dict, None, instance_dict=self.session.identity_map) + + def _connection_for_bind(self, bind): + self._assert_is_active() + + if bind in self._connections: + return self._connections[bind][0] + + if self._parent: + conn = self._parent._connection_for_bind(bind) + if not self.nested: + return conn + else: + if isinstance(bind, engine.Connection): + conn = bind + if conn.engine in self._connections: + raise sa_exc.InvalidRequestError( + "Session already has a Connection associated for the " + "given Connection's Engine") + else: + conn = bind.contextual_connect() + + if self.session.twophase and self._parent is None: + transaction = conn.begin_twophase() + elif self.nested: + transaction = conn.begin_nested() + else: + transaction = conn.begin() + + self._connections[conn] = self._connections[conn.engine] = \ + (conn, transaction, conn is not bind) + for ext in self.session.extensions: + ext.after_begin(self.session, self, conn) + return conn + + def prepare(self): + if self._parent is not None or not self.session.twophase: + raise sa_exc.InvalidRequestError( + "Only root two phase transactions of can be prepared") + self._prepare_impl() + + def _prepare_impl(self): + self._assert_is_active() + if self._parent is None or self.nested: + for ext in self.session.extensions: + ext.before_commit(self.session) + + stx = self.session.transaction + if stx is not self: + for subtransaction in stx._iterate_parents(upto=self): + subtransaction.commit() + + if not self.session._flushing: + self.session.flush() + + if self._parent is None and self.session.twophase: + try: + for t in set(self._connections.values()): + t[1].prepare() + except: + self.rollback() + raise + + self._deactivate() + self._prepared = True + + def commit(self): + self._assert_is_open() + if not self._prepared: + self._prepare_impl() + + if self._parent is None or self.nested: + for t in set(self._connections.values()): + t[1].commit() + + for ext in self.session.extensions: + ext.after_commit(self.session) + + if self.session._enable_transaction_accounting: + self._remove_snapshot() + + self.close() + return self._parent + + def rollback(self): + self._assert_is_open() + + stx = self.session.transaction + if stx is not self: + for subtransaction in stx._iterate_parents(upto=self): + subtransaction.close() + + if self.is_active or self._prepared: + for transaction in self._iterate_parents(): + if transaction._parent is None or transaction.nested: + transaction._rollback_impl() + transaction._deactivate() + break + else: + transaction._deactivate() + + self.close() + return self._parent + + def _rollback_impl(self): + for t in set(self._connections.values()): + t[1].rollback() + + if self.session._enable_transaction_accounting: + self._restore_snapshot() + + for ext in self.session.extensions: + ext.after_rollback(self.session) + + def _deactivate(self): + self._active = False + + def close(self): + self.session.transaction = self._parent + if self._parent is None: + for connection, transaction, autoclose in set(self._connections.values()): + if autoclose: + connection.close() + else: + transaction.close() + if not self.session.autocommit: + self.session.begin() + self._deactivate() + self.session = None + self._connections = None + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self._assert_is_open("Cannot end transaction context. The transaction was closed from within the context") + if self.session.transaction is None: + return + if type is None: + try: + self.commit() + except: + self.rollback() + raise + else: + self.rollback() + +class Session(object): + """Manages persistence operations for ORM-mapped objects. + + The Session is the front end to SQLAlchemy's **Unit of Work** + implementation. The concept behind Unit of Work is to track modifications + to a field of objects, and then be able to flush those changes to the + database in a single operation. + + SQLAlchemy's unit of work includes these functions: + + * The ability to track in-memory changes on scalar- and collection-based + object attributes, such that database persistence operations can be + assembled based on those changes. + + * The ability to organize individual SQL queries and population of newly + generated primary and foreign key-holding attributes during a persist + operation such that referential integrity is maintained at all times. + + * The ability to maintain insert ordering against the order in which new + instances were added to the session. + + * An Identity Map, which is a dictionary keying instances to their unique + primary key identity. This ensures that only one copy of a particular + entity is ever present within the session, even if repeated load + operations for the same entity occur. This allows many parts of an + application to get a handle to a particular object without any chance of + modifications going to two different places. + + When dealing with instances of mapped classes, an instance may be + *attached* to a particular Session, else it is *unattached* . An instance + also may or may not correspond to an actual row in the database. These + conditions break up into four distinct states: + + * *Transient* - an instance that's not in a session, and is not saved to + the database; i.e. it has no database identity. The only relationship + such an object has to the ORM is that its class has a ``mapper()`` + associated with it. + + * *Pending* - when you ``add()`` a transient instance, it becomes + pending. It still wasn't actually flushed to the database yet, but it + will be when the next flush occurs. + + * *Persistent* - An instance which is present in the session and has a + record in the database. You get persistent instances by either flushing + so that the pending instances become persistent, or by querying the + database for existing instances (or moving persistent instances from + other sessions into your local session). + + * *Detached* - an instance which has a record in the database, but is not + in any session. Theres nothing wrong with this, and you can use objects + normally when they're detached, **except** they will not be able to + issue any SQL in order to load collections or attributes which are not + yet loaded, or were marked as "expired". + + The session methods which control instance state include ``add()``, + ``delete()``, ``merge()``, and ``expunge()``. + + The Session object is generally **not** threadsafe. A session which is + set to ``autocommit`` and is only read from may be used by concurrent + threads if it's acceptable that some object instances may be loaded twice. + + The typical pattern to managing Sessions in a multi-threaded environment + is either to use mutexes to limit concurrent access to one thread at a + time, or more commonly to establish a unique session for every thread, + using a threadlocal variable. SQLAlchemy provides a thread-managed + Session adapter, provided by the :func:`~sqlalchemy.orm.scoped_session` + function. + + """ + + public_methods = ( + '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested', + 'close', 'commit', 'connection', 'delete', 'execute', 'expire', + 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind', 'is_modified', + 'merge', 'query', 'refresh', 'rollback', + 'scalar') + + def __init__(self, bind=None, autoflush=True, expire_on_commit=True, + _enable_transaction_accounting=True, + autocommit=False, twophase=False, + weak_identity_map=True, binds=None, extension=None, query_cls=query.Query): + """Construct a new Session. + + Arguments to ``Session`` are described using the + :func:`~sqlalchemy.orm.sessionmaker` function. + + """ + + if weak_identity_map: + self._identity_cls = identity.WeakInstanceDict + else: + self._identity_cls = identity.StrongInstanceDict + self.identity_map = self._identity_cls() + + self._new = {} # InstanceState->object, strong refs object + self._deleted = {} # same + self.bind = bind + self.__binds = {} + self._flushing = False + self.transaction = None + self.hash_key = id(self) + self.autoflush = autoflush + self.autocommit = autocommit + self.expire_on_commit = expire_on_commit + self._enable_transaction_accounting = _enable_transaction_accounting + self.twophase = twophase + self.extensions = util.to_list(extension) or [] + self._query_cls = query_cls + self._mapper_flush_opts = {} + + if binds is not None: + for mapperortable, bind in binds.iteritems(): + if isinstance(mapperortable, (type, Mapper)): + self.bind_mapper(mapperortable, bind) + else: + self.bind_table(mapperortable, bind) + + if not self.autocommit: + self.begin() + _sessions[self.hash_key] = self + + def begin(self, subtransactions=False, nested=False): + """Begin a transaction on this Session. + + If this Session is already within a transaction, either a plain + transaction or nested transaction, an error is raised, unless + ``subtransactions=True`` or ``nested=True`` is specified. + + The ``subtransactions=True`` flag indicates that this ``begin()`` can + create a subtransaction if a transaction is already in progress. A + subtransaction is a non-transactional, delimiting construct that + allows matching begin()/commit() pairs to be nested together, with + only the outermost begin/commit pair actually affecting transactional + state. When a rollback is issued, the subtransaction will directly + roll back the innermost real transaction, however each subtransaction + still must be explicitly rolled back to maintain proper stacking of + subtransactions. + + If no transaction is in progress, then a real transaction is begun. + + The ``nested`` flag begins a SAVEPOINT transaction and is equivalent + to calling ``begin_nested()``. + + """ + if self.transaction is not None: + if subtransactions or nested: + self.transaction = self.transaction._begin( + nested=nested) + else: + raise sa_exc.InvalidRequestError( + "A transaction is already begun. Use subtransactions=True " + "to allow subtransactions.") + else: + self.transaction = SessionTransaction( + self, nested=nested) + return self.transaction # needed for __enter__/__exit__ hook + + def begin_nested(self): + """Begin a `nested` transaction on this Session. + + The target database(s) must support SQL SAVEPOINTs or a + SQLAlchemy-supported vendor implementation of the idea. + + The nested transaction is a real transation, unlike a "subtransaction" + which corresponds to multiple ``begin()`` calls. The next + ``rollback()`` or ``commit()`` call will operate upon this nested + transaction. + + """ + return self.begin(nested=True) + + def rollback(self): + """Rollback the current transaction in progress. + + If no transaction is in progress, this method is a pass-through. + + This method rolls back the current transaction or nested transaction + regardless of subtransactions being in effect. All subtransactions up + to the first real transaction are closed. Subtransactions occur when + begin() is called multiple times. + + """ + if self.transaction is None: + pass + else: + self.transaction.rollback() + + def commit(self): + """Flush pending changes and commit the current transaction. + + If no transaction is in progress, this method raises an + InvalidRequestError. + + If a subtransaction is in effect (which occurs when begin() is called + multiple times), the subtransaction will be closed, and the next call + to ``commit()`` will operate on the enclosing transaction. + + For a session configured with autocommit=False, a new transaction will + be begun immediately after the commit, but note that the newly begun + transaction does *not* use any connection resources until the first + SQL is actually emitted. + + """ + if self.transaction is None: + if not self.autocommit: + self.begin() + else: + raise sa_exc.InvalidRequestError("No transaction is begun.") + + self.transaction.commit() + + def prepare(self): + """Prepare the current transaction in progress for two phase commit. + + If no transaction is in progress, this method raises an + InvalidRequestError. + + Only root transactions of two phase sessions can be prepared. If the + current transaction is not such, an InvalidRequestError is raised. + + """ + if self.transaction is None: + if not self.autocommit: + self.begin() + else: + raise sa_exc.InvalidRequestError("No transaction is begun.") + + self.transaction.prepare() + + def connection(self, mapper=None, clause=None): + """Return the active Connection. + + Retrieves the ``Connection`` managing the current transaction. Any + operations executed on the Connection will take place in the same + transactional context as ``Session`` operations. + + For ``autocommit`` Sessions with no active manual transaction, + ``connection()`` is a passthrough to ``contextual_connect()`` on the + underlying engine. + + Ambiguity in multi-bind or unbound Sessions can be resolved through + any of the optional keyword arguments. See ``get_bind()`` for more + information. + + mapper + Optional, a ``mapper`` or mapped class + + clause + Optional, any ``ClauseElement`` + + """ + return self._connection_for_bind(self.get_bind(mapper, clause)) + + def _connection_for_bind(self, engine, **kwargs): + if self.transaction is not None: + return self.transaction._connection_for_bind(engine) + else: + return engine.contextual_connect(**kwargs) + + def execute(self, clause, params=None, mapper=None, **kw): + """Execute a clause within the current transaction. + + Returns a ``ResultProxy`` of execution results. `autocommit` Sessions + will create a transaction on the fly. + + Connection ambiguity in multi-bind or unbound Sessions will be + resolved by inspecting the clause for binds. The 'mapper' and + 'instance' keyword arguments may be used if this is insufficient, See + ``get_bind()`` for more information. + + clause + A ClauseElement (i.e. select(), text(), etc.) or + string SQL statement to be executed + + params + Optional, a dictionary of bind parameters. + + mapper + Optional, a ``mapper`` or mapped class + + \**kw + Additional keyword arguments are sent to :meth:`get_bind()` + which locates a connectable to use for the execution. + Subclasses of :class:`Session` may override this. + + """ + clause = expression._literal_as_text(clause) + + engine = self.get_bind(mapper, clause=clause, **kw) + + return self._connection_for_bind(engine, close_with_result=True).execute( + clause, params or {}) + + def scalar(self, clause, params=None, mapper=None, **kw): + """Like execute() but return a scalar result.""" + + return self.execute(clause, params=params, mapper=mapper, **kw).scalar() + + def close(self): + """Close this Session. + + This clears all items and ends any transaction in progress. + + If this session were created with ``autocommit=False``, a new + transaction is immediately begun. Note that this new transaction does + not use any connection resources until they are first needed. + + """ + self.expunge_all() + if self.transaction is not None: + for transaction in self.transaction._iterate_parents(): + transaction.close() + + @classmethod + def close_all(cls): + """Close *all* sessions in memory.""" + + for sess in _sessions.values(): + sess.close() + + def expunge_all(self): + """Remove all object instances from this ``Session``. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + """ + for state in self.identity_map.all_states() + list(self._new): + state.detach() + + self.identity_map = self._identity_cls() + self._new = {} + self._deleted = {} + + # TODO: need much more test coverage for bind_mapper() and similar ! + # TODO: + crystalize + document resolution order vis. bind_mapper/bind_table + + def bind_mapper(self, mapper, bind): + """Bind operations for a mapper to a Connectable. + + mapper + A mapper instance or mapped class + + bind + Any Connectable: a ``Engine`` or ``Connection``. + + All subsequent operations involving this mapper will use the given + `bind`. + + """ + if isinstance(mapper, type): + mapper = _class_mapper(mapper) + + self.__binds[mapper.base_mapper] = bind + for t in mapper._all_tables: + self.__binds[t] = bind + + def bind_table(self, table, bind): + """Bind operations on a Table to a Connectable. + + table + A ``Table`` instance + + bind + Any Connectable: a ``Engine`` or ``Connection``. + + All subsequent operations involving this ``Table`` will use the + given `bind`. + + """ + self.__binds[table] = bind + + def get_bind(self, mapper, clause=None): + """Return an engine corresponding to the given arguments. + + All arguments are optional. + + mapper + Optional, a ``Mapper`` or mapped class + + clause + Optional, A ClauseElement (i.e. select(), text(), etc.) + + """ + if mapper is clause is None: + if self.bind: + return self.bind + else: + raise sa_exc.UnboundExecutionError( + "This session is not bound to a single Engine or " + "Connection, and no context was provided to locate " + "a binding.") + + c_mapper = mapper is not None and _class_to_mapper(mapper) or None + + # manually bound? + if self.__binds: + if c_mapper: + if c_mapper.base_mapper in self.__binds: + return self.__binds[c_mapper.base_mapper] + elif c_mapper.mapped_table in self.__binds: + return self.__binds[c_mapper.mapped_table] + if clause is not None: + for t in sql_util.find_tables(clause, include_crud=True): + if t in self.__binds: + return self.__binds[t] + + if self.bind: + return self.bind + + if isinstance(clause, sql.expression.ClauseElement) and clause.bind: + return clause.bind + + if c_mapper and c_mapper.mapped_table.bind: + return c_mapper.mapped_table.bind + + context = [] + if mapper is not None: + context.append('mapper %s' % c_mapper) + if clause is not None: + context.append('SQL expression') + + raise sa_exc.UnboundExecutionError( + "Could not locate a bind configured on %s or this Session" % ( + ', '.join(context))) + + def query(self, *entities, **kwargs): + """Return a new ``Query`` object corresponding to this ``Session``.""" + + return self._query_cls(entities, self, **kwargs) + + def _autoflush(self): + if self.autoflush and not self._flushing: + self.flush() + + def _finalize_loaded(self, states): + for state, dict_ in states.items(): + state.commit_all(dict_, self.identity_map) + + def refresh(self, instance, attribute_names=None, lockmode=None): + """Refresh the attributes on the given instance. + + A query will be issued to the database and all attributes will be + refreshed with their current database value. + + Lazy-loaded relational attributes will remain lazily loaded, so that + the instance-wide refresh operation will be followed immediately by + the lazy load of that attribute. + + Eagerly-loaded relational attributes will eagerly load within the + single refresh operation. + + :param attribute_names: optional. An iterable collection of + string attribute names indicating a subset of attributes to + be refreshed. + + :param lockmode: Passed to the :class:`~sqlalchemy.orm.query.Query` + as used by :meth:`~sqlalchemy.orm.query.Query.with_lockmode`. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + self._validate_persistent(state) + if self.query(_object_mapper(instance))._get( + state.key, refresh_state=state, + lockmode=lockmode, + only_load_props=attribute_names) is None: + raise sa_exc.InvalidRequestError( + "Could not refresh instance '%s'" % + mapperutil.instance_str(instance)) + + def expire_all(self): + """Expires all persistent instances within this Session.""" + + for state in self.identity_map.all_states(): + _expire_state(state, state.dict, None, instance_dict=self.identity_map) + + def expire(self, instance, attribute_names=None): + """Expire the attributes on an instance. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, query will be issued to the database and + the attributes will be refreshed with their current database value. + ``expire()`` is a lazy variant of ``refresh()``. + + The ``attribute_names`` argument is an iterable collection + of attribute names indicating a subset of attributes to be + expired. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + self._validate_persistent(state) + if attribute_names: + _expire_state(state, state.dict, + attribute_names=attribute_names, instance_dict=self.identity_map) + else: + # pre-fetch the full cascade since the expire is going to + # remove associations + cascaded = list(_cascade_state_iterator('refresh-expire', state)) + _expire_state(state, state.dict, None, instance_dict=self.identity_map) + for (state, m, o) in cascaded: + _expire_state(state, state.dict, None, instance_dict=self.identity_map) + + def prune(self): + """Remove unreferenced instances cached in the identity map. + + Note that this method is only meaningful if "weak_identity_map" is set + to False. The default weak identity map is self-pruning. + + Removes any object in this Session's identity map that is not + referenced in user code, modified, new or scheduled for deletion. + Returns the number of objects pruned. + + """ + return self.identity_map.prune() + + def expunge(self, instance): + """Remove the `instance` from this ``Session``. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + if state.session_id is not self.hash_key: + raise sa_exc.InvalidRequestError( + "Instance %s is not present in this Session" % + mapperutil.state_str(state)) + for s, m, o in [(state, None, None)] + list(_cascade_state_iterator('expunge', state)): + self._expunge_state(s) + + def _expunge_state(self, state): + if state in self._new: + self._new.pop(state) + state.detach() + elif self.identity_map.contains_state(state): + self.identity_map.discard(state) + self._deleted.pop(state, None) + state.detach() + + def _register_newly_persistent(self, state): + mapper = _state_mapper(state) + + # prevent against last minute dereferences of the object + obj = state.obj() + if obj is not None: + + instance_key = mapper._identity_key_from_state(state) + + if state.key is None: + state.key = instance_key + elif state.key != instance_key: + # primary key switch. + # use discard() in case another state has already replaced this + # one in the identity map (see test/orm/test_naturalpks.py ReversePKsTest) + self.identity_map.discard(state) + state.key = instance_key + + self.identity_map.replace(state) + state.commit_all(state.dict, self.identity_map) + + # remove from new last, might be the last strong ref + if state in self._new: + if self._enable_transaction_accounting and self.transaction: + self.transaction._new[state] = True + self._new.pop(state) + + def _remove_newly_deleted(self, state): + if self._enable_transaction_accounting and self.transaction: + self.transaction._deleted[state] = True + + self.identity_map.discard(state) + self._deleted.pop(state, None) + + def _save_without_cascade(self, instance): + """Used by scoping.py to save on init without cascade.""" + + state = _state_for_unsaved_instance(instance, create=True) + self._save_impl(state) + + def add(self, instance): + """Place an object in the ``Session``. + + Its state will be persisted to the database on the next flush + operation. + + Repeated calls to ``add()`` will be ignored. The opposite of ``add()`` + is ``expunge()``. + + """ + state = _state_for_unknown_persistence_instance(instance) + self._save_or_update_state(state) + + def add_all(self, instances): + """Add the given collection of instances to this ``Session``.""" + + for instance in instances: + self.add(instance) + + def _save_or_update_state(self, state): + self._save_or_update_impl(state) + self._cascade_save_or_update(state) + + def _cascade_save_or_update(self, state): + for state, mapper in _cascade_unknown_state_iterator( + 'save-update', state, halt_on=self.__contains__): + self._save_or_update_impl(state) + + def delete(self, instance): + """Mark an instance as deleted. + + The database delete operation occurs upon ``flush()``. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + + if state.key is None: + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persisted" % + mapperutil.state_str(state)) + + if state in self._deleted: + return + + # ensure object is attached to allow the + # cascade operation to load deferred attributes + # and collections + self._attach(state) + + # grab the cascades before adding the item to the deleted list + # so that autoflush does not delete the item + cascade_states = list(_cascade_state_iterator('delete', state)) + + self._deleted[state] = state.obj() + self.identity_map.add(state) + + for state, m, o in cascade_states: + self._delete_impl(state) + + def merge(self, instance, load=True, **kw): + """Copy the state an instance onto the persistent instance with the same identifier. + + If there is no persistent instance currently associated with the + session, it will be loaded. Return the persistent instance. If the + given instance is unsaved, save a copy of and return it as a newly + persistent instance. The given instance does not become associated + with the session. + + This operation cascades to associated instances if the association is + mapped with ``cascade="merge"``. + + """ + if 'dont_load' in kw: + load = not kw['dont_load'] + util.warn_deprecated("dont_load=True has been renamed to load=False.") + + _recursive = {} + + if load: + # flush current contents if we expect to load data + self._autoflush() + + _object_mapper(instance) # verify mapped + autoflush = self.autoflush + try: + self.autoflush = False + return self._merge( + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, _recursive=_recursive) + finally: + self.autoflush = autoflush + + def _merge(self, state, state_dict, load=True, _recursive=None): + mapper = _state_mapper(state) + if state in _recursive: + return _recursive[state] + + new_instance = False + key = state.key + + if key is None: + if not load: + raise sa_exc.InvalidRequestError( + "merge() with load=False option does not support " + "objects transient (i.e. unpersisted) objects. flush() " + "all changes on mapped instances before merging with " + "load=False.") + key = mapper._identity_key_from_state(state) + + if key in self.identity_map: + merged = self.identity_map[key] + + elif not load: + if state.modified: + raise sa_exc.InvalidRequestError( + "merge() with load=False option does not support " + "objects marked as 'dirty'. flush() all changes on " + "mapped instances before merging with load=False.") + merged = mapper.class_manager.new_instance() + merged_state = attributes.instance_state(merged) + merged_state.key = key + self._update_impl(merged_state) + new_instance = True + + elif not _none_set.issubset(key[1]) or \ + (mapper.allow_partial_pks and + not _none_set.issuperset(key[1])): + merged = self.query(mapper.class_).get(key[1]) + else: + merged = None + + if merged is None: + merged = mapper.class_manager.new_instance() + merged_state = attributes.instance_state(merged) + merged_dict = attributes.instance_dict(merged) + new_instance = True + self._save_or_update_state(merged_state) + else: + merged_state = attributes.instance_state(merged) + merged_dict = attributes.instance_dict(merged) + + _recursive[state] = merged + + # check that we didn't just pull the exact same + # state out. + if state is not merged_state: + merged_state.load_path = state.load_path + merged_state.load_options = state.load_options + + for prop in mapper.iterate_properties: + prop.merge(self, state, state_dict, merged_state, merged_dict, load, _recursive) + + if not load: + # remove any history + merged_state.commit_all(merged_dict, self.identity_map) + + if new_instance: + merged_state._run_on_load(merged) + return merged + + @classmethod + def identity_key(cls, *args, **kwargs): + return mapperutil.identity_key(*args, **kwargs) + + @classmethod + def object_session(cls, instance): + """Return the ``Session`` to which an object belongs.""" + + return object_session(instance) + + def _validate_persistent(self, state): + if not self.identity_map.contains_state(state): + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persistent within this Session" % + mapperutil.state_str(state)) + + def _save_impl(self, state): + if state.key is not None: + raise sa_exc.InvalidRequestError( + "Object '%s' already has an identity - it can't be registered " + "as pending" % mapperutil.state_str(state)) + + self._attach(state) + if state not in self._new: + self._new[state] = state.obj() + state.insert_order = len(self._new) + + def _update_impl(self, state): + if (self.identity_map.contains_state(state) and + state not in self._deleted): + return + + if state.key is None: + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persisted" % + mapperutil.state_str(state)) + + self._attach(state) + self._deleted.pop(state, None) + self.identity_map.add(state) + + def _save_or_update_impl(self, state): + if state.key is None: + self._save_impl(state) + else: + self._update_impl(state) + + def _delete_impl(self, state): + if state in self._deleted: + return + + if state.key is None: + return + + self._attach(state) + self._deleted[state] = state.obj() + self.identity_map.add(state) + + def _attach(self, state): + if state.key and \ + state.key in self.identity_map and \ + not self.identity_map.contains_state(state): + raise sa_exc.InvalidRequestError( + "Can't attach instance %s; another instance with key %s is already present in this session." % + (mapperutil.state_str(state), state.key) + ) + + if state.session_id and state.session_id is not self.hash_key: + raise sa_exc.InvalidRequestError( + "Object '%s' is already attached to session '%s' " + "(this is '%s')" % (mapperutil.state_str(state), + state.session_id, self.hash_key)) + + if state.session_id != self.hash_key: + state.session_id = self.hash_key + for ext in self.extensions: + ext.after_attach(self, state.obj()) + + def __contains__(self, instance): + """Return True if the instance is associated with this session. + + The instance may be pending or persistent within the Session for a + result of True. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + return self._contains_state(state) + + def __iter__(self): + """Iterate over all pending or persistent instances within this Session.""" + + return iter(list(self._new.values()) + self.identity_map.values()) + + def _contains_state(self, state): + return state in self._new or self.identity_map.contains_state(state) + + def flush(self, objects=None): + """Flush all the object changes to the database. + + Writes out all pending object creations, deletions and modifications + to the database as INSERTs, DELETEs, UPDATEs, etc. Operations are + automatically ordered by the Session's unit of work dependency + solver.. + + Database operations will be issued in the current transactional + context and do not affect the state of the transaction. You may + flush() as often as you like within a transaction to move changes from + Python to the database's transaction buffer. + + For ``autocommit`` Sessions with no active manual transaction, flush() + will create a transaction on the fly that surrounds the entire set of + operations int the flush. + + objects + Optional; a list or tuple collection. Restricts the flush operation + to only these objects, rather than all pending changes. + Deprecated - this flag prevents the session from properly maintaining + accounting among inter-object relations and can cause invalid results. + + """ + + if objects: + util.warn_deprecated( + "The 'objects' argument to session.flush() is deprecated; " + "Please do not add objects to the session which should not yet be persisted.") + + if self._flushing: + raise sa_exc.InvalidRequestError("Session is already flushing") + + try: + self._flushing = True + self._flush(objects) + finally: + self._flushing = False + + def _flush(self, objects=None): + if (not self.identity_map.check_modified() and + not self._deleted and not self._new): + return + + dirty = self._dirty_states + if not dirty and not self._deleted and not self._new: + self.identity_map._modified.clear() + return + + flush_context = UOWTransaction(self) + + if self.extensions: + for ext in self.extensions: + ext.before_flush(self, flush_context, objects) + dirty = self._dirty_states + + deleted = set(self._deleted) + new = set(self._new) + + dirty = set(dirty).difference(deleted) + + # create the set of all objects we want to operate upon + if objects: + # specific list passed in + objset = set() + for o in objects: + try: + state = attributes.instance_state(o) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(o) + objset.add(state) + else: + objset = None + + # store objects whose fate has been decided + processed = set() + + # put all saves/updates into the flush context. detect top-level + # orphans and throw them into deleted. + if objset: + proc = new.union(dirty).intersection(objset).difference(deleted) + else: + proc = new.union(dirty).difference(deleted) + + for state in proc: + is_orphan = _state_mapper(state)._is_orphan(state) + if is_orphan and not _state_has_identity(state): + path = ", nor ".join( + ["any parent '%s' instance " + "via that classes' '%s' attribute" % + (cls.__name__, key) + for (key, cls) in chain(*(m.delete_orphans for m in _state_mapper(state).iterate_to_root()))]) + raise exc.FlushError( + "Instance %s is an unsaved, pending instance and is an " + "orphan (is not attached to %s)" % ( + mapperutil.state_str(state), path)) + flush_context.register_object(state, isdelete=is_orphan) + processed.add(state) + + # put all remaining deletes into the flush context. + if objset: + proc = deleted.intersection(objset).difference(processed) + else: + proc = deleted.difference(processed) + for state in proc: + flush_context.register_object(state, isdelete=True) + + if len(flush_context.tasks) == 0: + return + + flush_context.transaction = transaction = self.begin( + subtransactions=True) + try: + flush_context.execute() + + for ext in self.extensions: + ext.after_flush(self, flush_context) + transaction.commit() + except: + transaction.rollback() + raise + + flush_context.finalize_flush_changes() + + # useful assertions: + #if not objects: + # assert not self.identity_map._modified + #else: + # assert self.identity_map._modified == self.identity_map._modified.difference(objects) + #self.identity_map._modified.clear() + + for ext in self.extensions: + ext.after_flush_postexec(self, flush_context) + + def is_modified(self, instance, include_collections=True, passive=False): + """Return True if instance has modified attributes. + + This method retrieves a history instance for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value. Note that instances present + in the 'dirty' collection may result in a value of ``False`` when + tested with this method. + + `include_collections` indicates if multivalued collections should be + included in the operation. Setting this to False is a way to detect + only local-column based properties (i.e. scalar columns or many-to-one + foreign keys) that would result in an UPDATE for this instance upon + flush. + + The `passive` flag indicates if unloaded attributes and collections + should not be loaded in the course of performing this test. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + dict_ = state.dict + for attr in state.manager.attributes: + if \ + ( + not include_collections and + hasattr(attr.impl, 'get_collection') + ) or not hasattr(attr.impl, 'get_history'): + continue + + (added, unchanged, deleted) = \ + attr.impl.get_history(state, dict_, passive=passive) + + if added or deleted: + return True + return False + + @property + def is_active(self): + """True if this Session has an active transaction.""" + + return self.transaction and self.transaction.is_active + + @property + def _dirty_states(self): + """The set of all persistent states considered dirty. + + This method returns all states that were modified including + those that were possibly deleted. + + """ + return self.identity_map._dirty_states() + + @property + def dirty(self): + """The set of all persistent instances considered dirty. + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the is_modified() method. + + """ + return util.IdentitySet( + [state.obj() + for state in self._dirty_states + if state not in self._deleted]) + + @property + def deleted(self): + "The set of all instances marked as 'deleted' within this ``Session``" + + return util.IdentitySet(self._deleted.values()) + + @property + def new(self): + "The set of all instances marked as 'new' within this ``Session``." + + return util.IdentitySet(self._new.values()) + +_expire_state = state.InstanceState.expire_attributes + +UOWEventHandler = unitofwork.UOWEventHandler + +_sessions = weakref.WeakValueDictionary() + +def _cascade_state_iterator(cascade, state, **kwargs): + mapper = _state_mapper(state) + # yield the state, object, mapper. yielding the object + # allows the iterator's results to be held in a list without + # states being garbage collected + for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs): + yield attributes.instance_state(o), o, m + +def _cascade_unknown_state_iterator(cascade, state, **kwargs): + mapper = _state_mapper(state) + for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs): + yield _state_for_unknown_persistence_instance(o), m + +def _state_for_unsaved_instance(instance, create=False): + try: + state = attributes.instance_state(instance) + except AttributeError: + raise exc.UnmappedInstanceError(instance) + if state: + if state.key is not None: + raise sa_exc.InvalidRequestError( + "Instance '%s' is already persistent" % + mapperutil.state_str(state)) + elif create: + manager = attributes.manager_of_class(instance.__class__) + if manager is None: + raise exc.UnmappedInstanceError(instance) + state = manager.setup_instance(instance) + else: + raise exc.UnmappedInstanceError(instance) + + return state + +def _state_for_unknown_persistence_instance(instance): + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + + return state + +def make_transient(instance): + """Make the given instance 'transient'. + + This will remove its association with any + session and additionally will remove its "identity key", + such that it's as though the object were newly constructed, + except retaining its values. + + """ + state = attributes.instance_state(instance) + s = _state_session(state) + if s: + s._expunge_state(state) + del state.key + + +def object_session(instance): + """Return the ``Session`` to which instance belongs, or None.""" + + return _state_session(attributes.instance_state(instance)) + +def _state_session(state): + if state.session_id: + try: + return _sessions[state.session_id] + except KeyError: + pass + return None + +# Lazy initialization to avoid circular imports +unitofwork.object_session = object_session +unitofwork._state_session = _state_session +from sqlalchemy.orm import mapper +mapper._expire_state = _expire_state +mapper._state_session = _state_session diff --git a/sqlalchemy/orm/shard.py b/sqlalchemy/orm/shard.py new file mode 100644 index 0000000..9cb26db --- /dev/null +++ b/sqlalchemy/orm/shard.py @@ -0,0 +1,15 @@ +# shard.py +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from sqlalchemy import util + +util.warn_deprecated( + "Horizontal sharding is now importable via " + "'import sqlalchemy.ext.horizontal_shard" +) + +from sqlalchemy.ext.horizontal_shard import * + diff --git a/sqlalchemy/orm/state.py b/sqlalchemy/orm/state.py new file mode 100644 index 0000000..25466b3 --- /dev/null +++ b/sqlalchemy/orm/state.py @@ -0,0 +1,527 @@ +from sqlalchemy.util import EMPTY_SET +import weakref +from sqlalchemy import util +from sqlalchemy.orm.attributes import PASSIVE_NO_RESULT, PASSIVE_OFF, \ + NEVER_SET, NO_VALUE, manager_of_class, \ + ATTR_WAS_SET +from sqlalchemy.orm import attributes, exc as orm_exc, interfaces + +import sys +attributes.state = sys.modules['sqlalchemy.orm.state'] + +class InstanceState(object): + """tracks state information at the instance level.""" + + session_id = None + key = None + runid = None + load_options = EMPTY_SET + load_path = () + insert_order = None + mutable_dict = None + _strong_obj = None + modified = False + expired = False + + def __init__(self, obj, manager): + self.class_ = obj.__class__ + self.manager = manager + self.obj = weakref.ref(obj, self._cleanup) + + @util.memoized_property + def committed_state(self): + return {} + + @util.memoized_property + def parents(self): + return {} + + @util.memoized_property + def pending(self): + return {} + + @util.memoized_property + def callables(self): + return {} + + def detach(self): + if self.session_id: + try: + del self.session_id + except AttributeError: + pass + + def dispose(self): + self.detach() + del self.obj + + def _cleanup(self, ref): + instance_dict = self._instance_dict() + if instance_dict: + try: + instance_dict.remove(self) + except AssertionError: + pass + # remove possible cycles + self.__dict__.pop('callables', None) + self.dispose() + + def obj(self): + return None + + @property + def dict(self): + o = self.obj() + if o is not None: + return attributes.instance_dict(o) + else: + return {} + + @property + def sort_key(self): + return self.key and self.key[1] or (self.insert_order, ) + + def initialize_instance(*mixed, **kwargs): + self, instance, args = mixed[0], mixed[1], mixed[2:] + manager = self.manager + + for fn in manager.events.on_init: + fn(self, instance, args, kwargs) + + # LESSTHANIDEAL: + # adjust for the case where the InstanceState was created before + # mapper compilation, and this actually needs to be a MutableAttrInstanceState + if manager.mutable_attributes and self.__class__ is not MutableAttrInstanceState: + self.__class__ = MutableAttrInstanceState + self.obj = weakref.ref(self.obj(), self._cleanup) + self.mutable_dict = {} + + try: + return manager.events.original_init(*mixed[1:], **kwargs) + except: + for fn in manager.events.on_init_failure: + fn(self, instance, args, kwargs) + raise + + def get_history(self, key, **kwargs): + return self.manager.get_impl(key).get_history(self, self.dict, **kwargs) + + def get_impl(self, key): + return self.manager.get_impl(key) + + def get_pending(self, key): + if key not in self.pending: + self.pending[key] = PendingCollection() + return self.pending[key] + + def value_as_iterable(self, key, passive=PASSIVE_OFF): + """return an InstanceState attribute as a list, + regardless of it being a scalar or collection-based + attribute. + + returns None if passive is not PASSIVE_OFF and the getter returns + PASSIVE_NO_RESULT. + """ + + impl = self.get_impl(key) + dict_ = self.dict + x = impl.get(self, dict_, passive=passive) + if x is PASSIVE_NO_RESULT: + return None + elif hasattr(impl, 'get_collection'): + return impl.get_collection(self, dict_, x, passive=passive) + else: + return [x] + + def _run_on_load(self, instance): + self.manager.events.run('on_load', instance) + + def __getstate__(self): + d = {'instance':self.obj()} + + d.update( + (k, self.__dict__[k]) for k in ( + 'committed_state', 'pending', 'parents', 'modified', 'expired', + 'callables', 'key', 'load_options', 'mutable_dict' + ) if k in self.__dict__ + ) + if self.load_path: + d['load_path'] = interfaces.serialize_path(self.load_path) + return d + + def __setstate__(self, state): + self.obj = weakref.ref(state['instance'], self._cleanup) + self.class_ = state['instance'].__class__ + self.manager = manager = manager_of_class(self.class_) + if manager is None: + raise orm_exc.UnmappedInstanceError( + state['instance'], + "Cannot deserialize object of type %r - no mapper() has" + " been configured for this class within the current Python process!" % + self.class_) + elif manager.mapper and not manager.mapper.compiled: + manager.mapper.compile() + + self.committed_state = state.get('committed_state', {}) + self.pending = state.get('pending', {}) + self.parents = state.get('parents', {}) + self.modified = state.get('modified', False) + self.expired = state.get('expired', False) + self.callables = state.get('callables', {}) + + if self.modified: + self._strong_obj = state['instance'] + + self.__dict__.update([ + (k, state[k]) for k in ( + 'key', 'load_options', 'mutable_dict' + ) if k in state + ]) + + if 'load_path' in state: + self.load_path = interfaces.deserialize_path(state['load_path']) + + def initialize(self, key): + """Set this attribute to an empty value or collection, + based on the AttributeImpl in use.""" + + self.manager.get_impl(key).initialize(self, self.dict) + + def reset(self, dict_, key): + """Remove the given attribute and any + callables associated with it.""" + + dict_.pop(key, None) + self.callables.pop(key, None) + + def expire_attribute_pre_commit(self, dict_, key): + """a fast expire that can be called by column loaders during a load. + + The additional bookkeeping is finished up in commit_all(). + + This method is actually called a lot with joined-table + loading, when the second table isn't present in the result. + + """ + dict_.pop(key, None) + self.callables[key] = self + + def set_callable(self, dict_, key, callable_): + """Remove the given attribute and set the given callable + as a loader.""" + + dict_.pop(key, None) + self.callables[key] = callable_ + + def expire_attributes(self, dict_, attribute_names, instance_dict=None): + """Expire all or a group of attributes. + + If all attributes are expired, the "expired" flag is set to True. + + """ + if attribute_names is None: + attribute_names = self.manager.keys() + self.expired = True + if self.modified: + if not instance_dict: + instance_dict = self._instance_dict() + if instance_dict: + instance_dict._modified.discard(self) + else: + instance_dict._modified.discard(self) + + self.modified = False + filter_deferred = True + else: + filter_deferred = False + + to_clear = ( + self.__dict__.get('pending', None), + self.__dict__.get('committed_state', None), + self.mutable_dict + ) + + for key in attribute_names: + impl = self.manager[key].impl + if impl.accepts_scalar_loader and \ + (not filter_deferred or impl.expire_missing or key in dict_): + self.callables[key] = self + dict_.pop(key, None) + + for d in to_clear: + if d is not None: + d.pop(key, None) + + def __call__(self, **kw): + """__call__ allows the InstanceState to act as a deferred + callable for loading expired attributes, which is also + serializable (picklable). + + """ + + if kw.get('passive') is attributes.PASSIVE_NO_FETCH: + return attributes.PASSIVE_NO_RESULT + + toload = self.expired_attributes.\ + intersection(self.unmodified) + + self.manager.deferred_scalar_loader(self, toload) + + # if the loader failed, or this + # instance state didn't have an identity, + # the attributes still might be in the callables + # dict. ensure they are removed. + for k in toload.intersection(self.callables): + del self.callables[k] + + return ATTR_WAS_SET + + @property + def unmodified(self): + """Return the set of keys which have no uncommitted changes""" + + return set(self.manager).difference(self.committed_state) + + @property + def unloaded(self): + """Return the set of keys which do not have a loaded value. + + This includes expired attributes and any other attribute that + was never populated or modified. + + """ + return set(self.manager).\ + difference(self.committed_state).\ + difference(self.dict) + + @property + def expired_attributes(self): + """Return the set of keys which are 'expired' to be loaded by + the manager's deferred scalar loader, assuming no pending + changes. + + see also the ``unmodified`` collection which is intersected + against this set when a refresh operation occurs. + + """ + return set([k for k, v in self.callables.items() if v is self]) + + def _instance_dict(self): + return None + + def _is_really_none(self): + return self.obj() + + def modified_event(self, dict_, attr, should_copy, previous, passive=PASSIVE_OFF): + needs_committed = attr.key not in self.committed_state + + if needs_committed: + if previous is NEVER_SET: + if passive: + if attr.key in dict_: + previous = dict_[attr.key] + else: + previous = attr.get(self, dict_) + + if should_copy and previous not in (None, NO_VALUE, NEVER_SET): + previous = attr.copy(previous) + + if needs_committed: + self.committed_state[attr.key] = previous + + if not self.modified: + instance_dict = self._instance_dict() + if instance_dict: + instance_dict._modified.add(self) + + self.modified = True + if self._strong_obj is None: + self._strong_obj = self.obj() + + def commit(self, dict_, keys): + """Commit attributes. + + This is used by a partial-attribute load operation to mark committed + those attributes which were refreshed from the database. + + Attributes marked as "expired" can potentially remain "expired" after + this step if a value was not populated in state.dict. + + """ + class_manager = self.manager + for key in keys: + if key in dict_ and key in class_manager.mutable_attributes: + self.committed_state[key] = self.manager[key].impl.copy(dict_[key]) + else: + self.committed_state.pop(key, None) + + self.expired = False + + for key in set(self.callables).\ + intersection(keys).\ + intersection(dict_): + del self.callables[key] + + def commit_all(self, dict_, instance_dict=None): + """commit all attributes unconditionally. + + This is used after a flush() or a full load/refresh + to remove all pending state from the instance. + + - all attributes are marked as "committed" + - the "strong dirty reference" is removed + - the "modified" flag is set to False + - any "expired" markers/callables for attributes loaded are removed. + + Attributes marked as "expired" can potentially remain "expired" after this step + if a value was not populated in state.dict. + + """ + + self.__dict__.pop('committed_state', None) + self.__dict__.pop('pending', None) + + if 'callables' in self.__dict__: + callables = self.callables + for key in list(callables): + if key in dict_ and callables[key] is self: + del callables[key] + + for key in self.manager.mutable_attributes: + if key in dict_: + self.committed_state[key] = self.manager[key].impl.copy(dict_[key]) + + if instance_dict and self.modified: + instance_dict._modified.discard(self) + + self.modified = self.expired = False + self._strong_obj = None + +class MutableAttrInstanceState(InstanceState): + """InstanceState implementation for objects that reference 'mutable' + attributes. + + Has a more involved "cleanup" handler that checks mutable attributes + for changes upon dereference, resurrecting if needed. + + """ + + @util.memoized_property + def mutable_dict(self): + return {} + + def _get_modified(self, dict_=None): + if self.__dict__.get('modified', False): + return True + else: + if dict_ is None: + dict_ = self.dict + for key in self.manager.mutable_attributes: + if self.manager[key].impl.check_mutable_modified(self, dict_): + return True + else: + return False + + def _set_modified(self, value): + self.__dict__['modified'] = value + + modified = property(_get_modified, _set_modified) + + @property + def unmodified(self): + """a set of keys which have no uncommitted changes""" + + dict_ = self.dict + + return set([ + key for key in self.manager + if (key not in self.committed_state or + (key in self.manager.mutable_attributes and + not self.manager[key].impl.check_mutable_modified(self, dict_)))]) + + def _is_really_none(self): + """do a check modified/resurrect. + + This would be called in the extremely rare + race condition that the weakref returned None but + the cleanup handler had not yet established the + __resurrect callable as its replacement. + + """ + if self.modified: + self.obj = self.__resurrect + return self.obj() + else: + return None + + def reset(self, dict_, key): + self.mutable_dict.pop(key, None) + InstanceState.reset(self, dict_, key) + + def _cleanup(self, ref): + """weakref callback. + + This method may be called by an asynchronous + gc. + + If the state shows pending changes, the weakref + is replaced by the __resurrect callable which will + re-establish an object reference on next access, + else removes this InstanceState from the owning + identity map, if any. + + """ + if self._get_modified(self.mutable_dict): + self.obj = self.__resurrect + else: + instance_dict = self._instance_dict() + if instance_dict: + try: + instance_dict.remove(self) + except AssertionError: + pass + self.dispose() + + def __resurrect(self): + """A substitute for the obj() weakref function which resurrects.""" + + # store strong ref'ed version of the object; will revert + # to weakref when changes are persisted + + obj = self.manager.new_instance(state=self) + self.obj = weakref.ref(obj, self._cleanup) + self._strong_obj = obj + obj.__dict__.update(self.mutable_dict) + + # re-establishes identity attributes from the key + self.manager.events.run('on_resurrect', self, obj) + + # TODO: don't really think we should run this here. + # resurrect is only meant to preserve the minimal state needed to + # do an UPDATE, not to produce a fully usable object + self._run_on_load(obj) + + return obj + +class PendingCollection(object): + """A writable placeholder for an unloaded collection. + + Stores items appended to and removed from a collection that has not yet + been loaded. When the collection is loaded, the changes stored in + PendingCollection are applied to it to produce the final result. + + """ + def __init__(self): + self.deleted_items = util.IdentitySet() + self.added_items = util.OrderedIdentitySet() + + def append(self, value): + if value in self.deleted_items: + self.deleted_items.remove(value) + self.added_items.add(value) + + def remove(self, value): + if value in self.added_items: + self.added_items.remove(value) + self.deleted_items.add(value) + diff --git a/sqlalchemy/orm/strategies.py b/sqlalchemy/orm/strategies.py new file mode 100644 index 0000000..25c2f83 --- /dev/null +++ b/sqlalchemy/orm/strategies.py @@ -0,0 +1,1229 @@ +# strategies.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""sqlalchemy.orm.interfaces.LoaderStrategy + implementations, and related MapperOptions.""" + +from sqlalchemy import exc as sa_exc +from sqlalchemy import sql, util, log +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import visitors, expression, operators +from sqlalchemy.orm import mapper, attributes, interfaces, exc as orm_exc +from sqlalchemy.orm.interfaces import ( + LoaderStrategy, StrategizedOption, MapperOption, PropertyOption, + serialize_path, deserialize_path, StrategizedProperty + ) +from sqlalchemy.orm import session as sessionlib +from sqlalchemy.orm import util as mapperutil +import itertools + +def _register_attribute(strategy, mapper, useobject, + compare_function=None, + typecallable=None, + copy_function=None, + mutable_scalars=False, + uselist=False, + callable_=None, + proxy_property=None, + active_history=False, + impl_class=None, + **kw +): + + prop = strategy.parent_property + attribute_ext = list(util.to_list(prop.extension, default=[])) + + if useobject and prop.single_parent: + attribute_ext.insert(0, _SingleParentValidator(prop)) + + if prop.key in prop.parent._validators: + attribute_ext.insert(0, + mapperutil.Validator(prop.key, prop.parent._validators[prop.key]) + ) + + if useobject: + attribute_ext.append(sessionlib.UOWEventHandler(prop.key)) + + + for m in mapper.polymorphic_iterator(): + if prop is m._props.get(prop.key): + + attributes.register_attribute_impl( + m.class_, + prop.key, + parent_token=prop, + mutable_scalars=mutable_scalars, + uselist=uselist, + copy_function=copy_function, + compare_function=compare_function, + useobject=useobject, + extension=attribute_ext, + trackparent=useobject, + typecallable=typecallable, + callable_=callable_, + active_history=active_history, + impl_class=impl_class, + **kw + ) + +class UninstrumentedColumnLoader(LoaderStrategy): + """Represent the a non-instrumented MapperProperty. + + The polymorphic_on argument of mapper() often results in this, + if the argument is against the with_polymorphic selectable. + + """ + def init(self): + self.columns = self.parent_property.columns + + def setup_query(self, context, entity, path, adapter, + column_collection=None, **kwargs): + for c in self.columns: + if adapter: + c = adapter.columns[c] + column_collection.append(c) + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + return None, None + +class ColumnLoader(LoaderStrategy): + """Strategize the loading of a plain column-based MapperProperty.""" + + def init(self): + self.columns = self.parent_property.columns + self.is_composite = hasattr(self.parent_property, 'composite_class') + + def setup_query(self, context, entity, path, adapter, + column_collection=None, **kwargs): + for c in self.columns: + if adapter: + c = adapter.columns[c] + column_collection.append(c) + + def init_class_attribute(self, mapper): + self.is_class_level = True + coltype = self.columns[0].type + # TODO: check all columns ? check for foreign key as well? + active_history = self.columns[0].primary_key + + _register_attribute(self, mapper, useobject=False, + compare_function=coltype.compare_values, + copy_function=coltype.copy_value, + mutable_scalars=self.columns[0].type.is_mutable(), + active_history = active_history + ) + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + key, col = self.key, self.columns[0] + if adapter: + col = adapter.columns[col] + + if col is not None and col in row: + def new_execute(state, dict_, row): + dict_[key] = row[col] + else: + def new_execute(state, dict_, row): + state.expire_attribute_pre_commit(dict_, key) + return new_execute, None + +log.class_logger(ColumnLoader) + +class CompositeColumnLoader(ColumnLoader): + """Strategize the loading of a composite column-based MapperProperty.""" + + def init_class_attribute(self, mapper): + self.is_class_level = True + self.logger.info("%s register managed composite attribute", self) + + def copy(obj): + if obj is None: + return None + return self.parent_property.\ + composite_class(*obj.__composite_values__()) + + def compare(a, b): + if a is None or b is None: + return a is b + + for col, aprop, bprop in zip(self.columns, + a.__composite_values__(), + b.__composite_values__()): + if not col.type.compare_values(aprop, bprop): + return False + else: + return True + + _register_attribute(self, mapper, useobject=False, + compare_function=compare, + copy_function=copy, + mutable_scalars=True + #active_history ? + ) + + def create_row_processor(self, selectcontext, path, mapper, + row, adapter): + key = self.key + columns = self.columns + composite_class = self.parent_property.composite_class + if adapter: + columns = [adapter.columns[c] for c in columns] + + for c in columns: + if c not in row: + def new_execute(state, dict_, row): + state.expire_attribute_pre_commit(dict_, key) + break + else: + def new_execute(state, dict_, row): + dict_[key] = composite_class(*[row[c] for c in columns]) + + return new_execute, None + +log.class_logger(CompositeColumnLoader) + +class DeferredColumnLoader(LoaderStrategy): + """Strategize the loading of a deferred column-based MapperProperty.""" + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + col = self.columns[0] + if adapter: + col = adapter.columns[col] + + key = self.key + if col in row: + return self.parent_property._get_strategy(ColumnLoader).\ + create_row_processor( + selectcontext, path, mapper, row, adapter) + + elif not self.is_class_level: + def new_execute(state, dict_, row): + state.set_callable(dict_, key, LoadDeferredColumns(state, key)) + else: + def new_execute(state, dict_, row): + # reset state on the key so that deferred callables + # fire off on next access. + state.reset(dict_, key) + + return new_execute, None + + def init(self): + if hasattr(self.parent_property, 'composite_class'): + raise NotImplementedError("Deferred loading for composite " + "types not implemented yet") + self.columns = self.parent_property.columns + self.group = self.parent_property.group + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _register_attribute(self, mapper, useobject=False, + compare_function=self.columns[0].type.compare_values, + copy_function=self.columns[0].type.copy_value, + mutable_scalars=self.columns[0].type.is_mutable(), + callable_=self._class_level_loader, + expire_missing=False + ) + + def setup_query(self, context, entity, path, adapter, + only_load_props=None, **kwargs): + if ( + self.group is not None and + context.attributes.get(('undefer', self.group), False) + ) or (only_load_props and self.key in only_load_props): + self.parent_property._get_strategy(ColumnLoader).\ + setup_query(context, entity, + path, adapter, **kwargs) + + def _class_level_loader(self, state): + if not mapperutil._state_has_identity(state): + return None + + return LoadDeferredColumns(state, self.key) + + +log.class_logger(DeferredColumnLoader) + +class LoadDeferredColumns(object): + """serializable loader object used by DeferredColumnLoader""" + + def __init__(self, state, key): + self.state, self.key = state, key + + def __call__(self, **kw): + if kw.get('passive') is attributes.PASSIVE_NO_FETCH: + return attributes.PASSIVE_NO_RESULT + + state = self.state + + localparent = mapper._state_mapper(state) + + prop = localparent.get_property(self.key) + strategy = prop._get_strategy(DeferredColumnLoader) + + if strategy.group: + toload = [ + p.key for p in + localparent.iterate_properties + if isinstance(p, StrategizedProperty) and + isinstance(p.strategy, DeferredColumnLoader) and + p.group==strategy.group + ] + else: + toload = [self.key] + + # narrow the keys down to just those which have no history + group = [k for k in toload if k in state.unmodified] + + if strategy._should_log_debug(): + strategy.logger.debug( + "deferred load %s group %s", + (mapperutil.state_attribute_str(state, self.key), + group and ','.join(group) or 'None') + ) + + session = sessionlib._state_session(state) + if session is None: + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session; " + "deferred load operation of attribute '%s' cannot proceed" % + (mapperutil.state_str(state), self.key) + ) + + query = session.query(localparent) + ident = state.key[1] + query._get(None, ident=ident, + only_load_props=group, refresh_state=state) + return attributes.ATTR_WAS_SET + +class DeferredOption(StrategizedOption): + propagate_to_loaders = True + + def __init__(self, key, defer=False): + super(DeferredOption, self).__init__(key) + self.defer = defer + + def get_strategy_class(self): + if self.defer: + return DeferredColumnLoader + else: + return ColumnLoader + +class UndeferGroupOption(MapperOption): + propagate_to_loaders = True + + def __init__(self, group): + self.group = group + + def process_query(self, query): + query._attributes[('undefer', self.group)] = True + +class AbstractRelationshipLoader(LoaderStrategy): + """LoaderStratgies which deal with related objects.""" + + def init(self): + self.mapper = self.parent_property.mapper + self.target = self.parent_property.target + self.table = self.parent_property.table + self.uselist = self.parent_property.uselist + +class NoLoader(AbstractRelationshipLoader): + """Strategize a relationship() that doesn't load data automatically.""" + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _register_attribute(self, mapper, + useobject=True, + uselist=self.parent_property.uselist, + typecallable = self.parent_property.collection_class, + ) + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + def new_execute(state, dict_, row): + state.initialize(self.key) + return new_execute, None + +log.class_logger(NoLoader) + +class LazyLoader(AbstractRelationshipLoader): + """Strategize a relationship() that loads when first accessed.""" + + def init(self): + super(LazyLoader, self).init() + self.__lazywhere, \ + self.__bind_to_col, \ + self._equated_columns = self._create_lazy_clause(self.parent_property) + + self.logger.info("%s lazy loading clause %s", self, self.__lazywhere) + + # determine if our "lazywhere" clause is the same as the mapper's + # get() clause. then we can just use mapper.get() + #from sqlalchemy.orm import query + self.use_get = not self.uselist and \ + self.mapper._get_clause[0].compare( + self.__lazywhere, + use_proxies=True, + equivalents=self.mapper._equivalent_columns + ) + + if self.use_get: + for col in self._equated_columns.keys(): + if col in self.mapper._equivalent_columns: + for c in self.mapper._equivalent_columns[col]: + self._equated_columns[c] = self._equated_columns[col] + + self.logger.info("%s will use query.get() to " + "optimize instance loads" % self) + + def init_class_attribute(self, mapper): + self.is_class_level = True + + # MANYTOONE currently only needs the + # "old" value for delete-orphan + # cascades. the required _SingleParentValidator + # will enable active_history + # in that case. otherwise we don't need the + # "old" value during backref operations. + _register_attribute(self, + mapper, + useobject=True, + callable_=self._class_level_loader, + uselist = self.parent_property.uselist, + typecallable = self.parent_property.collection_class, + active_history = \ + self.parent_property.direction is not \ + interfaces.MANYTOONE or \ + not self.use_get, + ) + + def lazy_clause(self, state, reverse_direction=False, + alias_secondary=False, adapt_source=None): + if state is None: + return self._lazy_none_clause( + reverse_direction, + adapt_source=adapt_source) + + if not reverse_direction: + criterion, bind_to_col, rev = \ + self.__lazywhere, \ + self.__bind_to_col, \ + self._equated_columns + else: + criterion, bind_to_col, rev = \ + LazyLoader._create_lazy_clause( + self.parent_property, + reverse_direction=reverse_direction) + + if reverse_direction: + mapper = self.parent_property.mapper + else: + mapper = self.parent_property.parent + + def visit_bindparam(bindparam): + if bindparam.key in bind_to_col: + # use the "committed" (database) version to get + # query column values + # also its a deferred value; so that when used + # by Query, the committed value is used + # after an autoflush occurs + o = state.obj() # strong ref + bindparam.value = \ + lambda: mapper._get_committed_attr_by_column( + o, bind_to_col[bindparam.key]) + + if self.parent_property.secondary is not None and alias_secondary: + criterion = sql_util.ClauseAdapter( + self.parent_property.secondary.alias()).\ + traverse(criterion) + + criterion = visitors.cloned_traverse( + criterion, {}, {'bindparam':visit_bindparam}) + if adapt_source: + criterion = adapt_source(criterion) + return criterion + + def _lazy_none_clause(self, reverse_direction=False, adapt_source=None): + if not reverse_direction: + criterion, bind_to_col, rev = \ + self.__lazywhere, \ + self.__bind_to_col,\ + self._equated_columns + else: + criterion, bind_to_col, rev = \ + LazyLoader._create_lazy_clause( + self.parent_property, + reverse_direction=reverse_direction) + + criterion = sql_util.adapt_criterion_to_null(criterion, bind_to_col) + + if adapt_source: + criterion = adapt_source(criterion) + return criterion + + def _class_level_loader(self, state): + if not mapperutil._state_has_identity(state): + return None + + return LoadLazyAttribute(state, self.key) + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + key = self.key + if not self.is_class_level: + def new_execute(state, dict_, row): + # we are not the primary manager for this attribute + # on this class - set up a + # per-instance lazyloader, which will override the + # class-level behavior. + # this currently only happens when using a + # "lazyload" option on a "no load" + # attribute - "eager" attributes always have a + # class-level lazyloader installed. + state.set_callable(dict_, key, LoadLazyAttribute(state, key)) + else: + def new_execute(state, dict_, row): + # we are the primary manager for this attribute on + # this class - reset its + # per-instance attribute state, so that the class-level + # lazy loader is + # executed when next referenced on this instance. + # this is needed in + # populate_existing() types of scenarios to reset + # any existing state. + state.reset(dict_, key) + + return new_execute, None + + @classmethod + def _create_lazy_clause(cls, prop, reverse_direction=False): + binds = util.column_dict() + lookup = util.column_dict() + equated_columns = util.column_dict() + + if reverse_direction and prop.secondaryjoin is None: + for l, r in prop.local_remote_pairs: + _list = lookup.setdefault(r, []) + _list.append((r, l)) + equated_columns[l] = r + else: + for l, r in prop.local_remote_pairs: + _list = lookup.setdefault(l, []) + _list.append((l, r)) + equated_columns[r] = l + + def col_to_bind(col): + if col in lookup: + for tobind, equated in lookup[col]: + if equated in binds: + return None + if col not in binds: + binds[col] = sql.bindparam(None, None, type_=col.type) + return binds[col] + return None + + lazywhere = prop.primaryjoin + + if prop.secondaryjoin is None or not reverse_direction: + lazywhere = visitors.replacement_traverse( + lazywhere, {}, col_to_bind) + + if prop.secondaryjoin is not None: + secondaryjoin = prop.secondaryjoin + if reverse_direction: + secondaryjoin = visitors.replacement_traverse( + secondaryjoin, {}, col_to_bind) + lazywhere = sql.and_(lazywhere, secondaryjoin) + + bind_to_col = dict((binds[col].key, col) for col in binds) + + return lazywhere, bind_to_col, equated_columns + +log.class_logger(LazyLoader) + +class LoadLazyAttribute(object): + """serializable loader object used by LazyLoader""" + + def __init__(self, state, key): + self.state, self.key = state, key + + def __getstate__(self): + return (self.state, self.key) + + def __setstate__(self, state): + self.state, self.key = state + + def __call__(self, **kw): + state = self.state + instance_mapper = mapper._state_mapper(state) + prop = instance_mapper.get_property(self.key) + strategy = prop._get_strategy(LazyLoader) + + if kw.get('passive') is attributes.PASSIVE_NO_FETCH and \ + not strategy.use_get: + return attributes.PASSIVE_NO_RESULT + + if strategy._should_log_debug(): + strategy.logger.debug("loading %s", + mapperutil.state_attribute_str( + state, self.key)) + + session = sessionlib._state_session(state) + if session is None: + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session; " + "lazy load operation of attribute '%s' cannot proceed" % + (mapperutil.state_str(state), self.key) + ) + + q = session.query(prop.mapper)._adapt_all_clauses() + + if state.load_path: + q = q._with_current_path(state.load_path + (self.key,)) + + # if we have a simple primary key load, use mapper.get() + # to possibly save a DB round trip + if strategy.use_get: + ident = [] + allnulls = True + for primary_key in prop.mapper.primary_key: + val = instance_mapper.\ + _get_committed_state_attr_by_column( + state, + strategy._equated_columns[primary_key], + **kw) + if val is attributes.PASSIVE_NO_RESULT: + return val + allnulls = allnulls and val is None + ident.append(val) + + if allnulls: + return None + + if state.load_options: + q = q._conditional_options(*state.load_options) + + key = prop.mapper.identity_key_from_primary_key(ident) + return q._get(key, ident, **kw) + + + if prop.order_by: + q = q.order_by(*util.to_list(prop.order_by)) + + if state.load_options: + q = q._conditional_options(*state.load_options) + q = q.filter(strategy.lazy_clause(state)) + + result = q.all() + if strategy.uselist: + return result + else: + l = len(result) + if l: + if l > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for lazily-loaded attribute '%s' " + % prop) + + return result[0] + else: + return None + +class SubqueryLoader(AbstractRelationshipLoader): + def init(self): + super(SubqueryLoader, self).init() + self.join_depth = self.parent_property.join_depth + + def init_class_attribute(self, mapper): + self.parent_property.\ + _get_strategy(LazyLoader).\ + init_class_attribute(mapper) + + def setup_query(self, context, entity, + path, adapter, column_collection=None, + parentmapper=None, **kwargs): + + if not context.query._enable_eagerloads: + return + + path = path + (self.key, ) + + # build up a path indicating the path from the leftmost + # entity to the thing we're subquery loading. + subq_path = context.attributes.get(('subquery_path', None), ()) + + subq_path = subq_path + path + + reduced_path = interfaces._reduce_path(path) + + # join-depth / recursion check + if ("loaderstrategy", reduced_path) not in context.attributes: + if self.join_depth: + if len(path) / 2 > self.join_depth: + return + else: + if self.mapper.base_mapper in interfaces._reduce_path(subq_path): + return + + orig_query = context.attributes.get( + ("orig_query", SubqueryLoader), + context.query) + + # determine attributes of the leftmost mapper + if self.parent.isa(subq_path[0]) and self.key==subq_path[1]: + leftmost_mapper, leftmost_prop = \ + self.parent, self.parent_property + else: + leftmost_mapper, leftmost_prop = \ + subq_path[0], \ + subq_path[0].get_property(subq_path[1]) + leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop) + + leftmost_attr = [ + leftmost_mapper._get_col_to_prop(c).class_attribute + for c in leftmost_cols + ] + + # reformat the original query + # to look only for significant columns + q = orig_query._clone() + # TODO: why does polymporphic etc. require hardcoding + # into _adapt_col_list ? Does query.add_columns(...) work + # with polymorphic loading ? + q._set_entities(q._adapt_col_list(leftmost_attr)) + + # don't need ORDER BY if no limit/offset + if q._limit is None and q._offset is None: + q._order_by = None + + # the original query now becomes a subquery + # which we'll join onto. + embed_q = q.with_labels().subquery() + left_alias = mapperutil.AliasedClass(leftmost_mapper, embed_q) + + # q becomes a new query. basically doing a longhand + # "from_self()". (from_self() itself not quite industrial + # strength enough for all contingencies...but very close) + + q = q.session.query(self.mapper) + q._attributes = { + ("orig_query", SubqueryLoader): orig_query, + ('subquery_path', None) : subq_path + } + + # figure out what's being joined. a.k.a. the fun part + to_join = [ + (subq_path[i], subq_path[i+1]) + for i in xrange(0, len(subq_path), 2) + ] + + if len(to_join) < 2: + parent_alias = left_alias + else: + parent_alias = mapperutil.AliasedClass(self.parent) + + local_cols, remote_cols = \ + self._local_remote_columns(self.parent_property) + + local_attr = [ + getattr(parent_alias, self.parent._get_col_to_prop(c).key) + for c in local_cols + ] + q = q.order_by(*local_attr) + q = q.add_columns(*local_attr) + + for i, (mapper, key) in enumerate(to_join): + + # we need to use query.join() as opposed to + # orm.join() here because of the + # rich behavior it brings when dealing with + # "with_polymorphic" mappers. "aliased" + # and "from_joinpoint" take care of most of + # the chaining and aliasing for us. + + first = i == 0 + middle = i < len(to_join) - 1 + second_to_last = i == len(to_join) - 2 + + if first: + attr = getattr(left_alias, key) + else: + attr = key + + if second_to_last: + q = q.join((parent_alias, attr), from_joinpoint=True) + else: + q = q.join(attr, aliased=middle, from_joinpoint=True) + + # propagate loader options etc. to the new query. + # these will fire relative to subq_path. + q = q._with_current_path(subq_path) + q = q._conditional_options(*orig_query._with_options) + + if self.parent_property.order_by: + # if there's an ORDER BY, alias it the same + # way joinedloader does, but we have to pull out + # the "eagerjoin" from the query. + # this really only picks up the "secondary" table + # right now. + eagerjoin = q._from_obj[0] + eager_order_by = \ + eagerjoin._target_adapter.\ + copy_and_process( + util.to_list( + self.parent_property.order_by + ) + ) + q = q.order_by(*eager_order_by) + + # add new query to attributes to be picked up + # by create_row_processor + context.attributes[('subquery', reduced_path)] = q + + def _local_remote_columns(self, prop): + if prop.secondary is None: + return zip(*prop.local_remote_pairs) + else: + return \ + [p[0] for p in prop.synchronize_pairs],\ + [ + p[0] for p in prop. + secondary_synchronize_pairs + ] + + def create_row_processor(self, context, path, mapper, row, adapter): + path = path + (self.key,) + + path = interfaces._reduce_path(path) + + if ('subquery', path) not in context.attributes: + return None, None + + local_cols, remote_cols = self._local_remote_columns(self.parent_property) + + remote_attr = [ + self.mapper._get_col_to_prop(c).key + for c in remote_cols] + + q = context.attributes[('subquery', path)] + + collections = dict( + (k, [v[0] for v in v]) + for k, v in itertools.groupby( + q, + lambda x:x[1:] + )) + + if adapter: + local_cols = [adapter.columns[c] for c in local_cols] + + if self.uselist: + def execute(state, dict_, row): + collection = collections.get( + tuple([row[col] for col in local_cols]), + () + ) + state.get_impl(self.key).\ + set_committed_value(state, dict_, collection) + else: + def execute(state, dict_, row): + collection = collections.get( + tuple([row[col] for col in local_cols]), + (None,) + ) + if len(collection) > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded attribute '%s' " + % self) + + scalar = collection[0] + state.get_impl(self.key).\ + set_committed_value(state, dict_, scalar) + + return execute, None + +log.class_logger(SubqueryLoader) + +class EagerLoader(AbstractRelationshipLoader): + """Strategize a relationship() that loads within the process + of the parent object being selected.""" + + def init(self): + super(EagerLoader, self).init() + self.join_depth = self.parent_property.join_depth + + def init_class_attribute(self, mapper): + self.parent_property.\ + _get_strategy(LazyLoader).init_class_attribute(mapper) + + def setup_query(self, context, entity, path, adapter, \ + column_collection=None, parentmapper=None, + **kwargs): + """Add a left outer join to the statement thats being constructed.""" + + if not context.query._enable_eagerloads: + return + + path = path + (self.key,) + + reduced_path = interfaces._reduce_path(path) + + # check for user-defined eager alias + if ("user_defined_eager_row_processor", reduced_path) in\ + context.attributes: + clauses = context.attributes[ + ("user_defined_eager_row_processor", + reduced_path)] + + adapter = entity._get_entity_clauses(context.query, context) + if adapter and clauses: + context.attributes[ + ("user_defined_eager_row_processor", + reduced_path)] = clauses = clauses.wrap(adapter) + elif adapter: + context.attributes[ + ("user_defined_eager_row_processor", + reduced_path)] = clauses = adapter + + add_to_collection = context.primary_columns + + else: + # check for join_depth or basic recursion, + # if the current path was not explicitly stated as + # a desired "loaderstrategy" (i.e. via query.options()) + if ("loaderstrategy", reduced_path) not in context.attributes: + if self.join_depth: + if len(path) / 2 > self.join_depth: + return + else: + if self.mapper.base_mapper in reduced_path: + return + + clauses = mapperutil.ORMAdapter( + mapperutil.AliasedClass(self.mapper), + equivalents=self.mapper._equivalent_columns, + adapt_required=True) + + if self.parent_property.direction != interfaces.MANYTOONE: + context.multi_row_eager_loaders = True + + context.create_eager_joins.append( + (self._create_eager_join, context, + entity, path, adapter, + parentmapper, clauses) + ) + + add_to_collection = context.secondary_columns + context.attributes[ + ("eager_row_processor", reduced_path) + ] = clauses + + for value in self.mapper._iterate_polymorphic_properties(): + value.setup( + context, + entity, + path + (self.mapper,), + clauses, + parentmapper=self.mapper, + column_collection=add_to_collection) + + def _create_eager_join(self, context, entity, + path, adapter, parentmapper, clauses): + + if parentmapper is None: + localparent = entity.mapper + else: + localparent = parentmapper + + # whether or not the Query will wrap the selectable in a subquery, + # and then attach eager load joins to that (i.e., in the case of + # LIMIT/OFFSET etc.) + should_nest_selectable = context.multi_row_eager_loaders and \ + context.query._should_nest_selectable + + entity_key = None + if entity not in context.eager_joins and \ + not should_nest_selectable and \ + context.from_clause: + index, clause = \ + sql_util.find_join_source( + context.from_clause, entity.selectable) + if clause is not None: + # join to an existing FROM clause on the query. + # key it to its list index in the eager_joins dict. + # Query._compile_context will adapt as needed and + # append to the FROM clause of the select(). + entity_key, default_towrap = index, clause + + if entity_key is None: + entity_key, default_towrap = entity, entity.selectable + + towrap = context.eager_joins.setdefault(entity_key, default_towrap) + + join_to_left = False + if adapter: + if getattr(adapter, 'aliased_class', None): + onclause = getattr( + adapter.aliased_class, self.key, + self.parent_property) + else: + onclause = getattr( + mapperutil.AliasedClass( + self.parent, + adapter.selectable + ), + self.key, self.parent_property + ) + + if onclause is self.parent_property: + # TODO: this is a temporary hack to + # account for polymorphic eager loads where + # the eagerload is referencing via of_type(). + join_to_left = True + else: + onclause = self.parent_property + + innerjoin = context.attributes.get( + ("eager_join_type", path), + self.parent_property.innerjoin) + + context.eager_joins[entity_key] = eagerjoin = \ + mapperutil.join( + towrap, + clauses.aliased_class, + onclause, + join_to_left=join_to_left, + isouter=not innerjoin + ) + + # send a hint to the Query as to where it may "splice" this join + eagerjoin.stop_on = entity.selectable + + if self.parent_property.secondary is None and \ + not parentmapper: + # for parentclause that is the non-eager end of the join, + # ensure all the parent cols in the primaryjoin are actually + # in the + # columns clause (i.e. are not deferred), so that aliasing applied + # by the Query propagates those columns outward. + # This has the effect + # of "undefering" those columns. + for col in sql_util.find_columns( + self.parent_property.primaryjoin): + if localparent.mapped_table.c.contains_column(col): + if adapter: + col = adapter.columns[col] + context.primary_columns.append(col) + + if self.parent_property.order_by: + context.eager_order_by += \ + eagerjoin._target_adapter.\ + copy_and_process( + util.to_list( + self.parent_property.order_by + ) + ) + + + def _create_eager_adapter(self, context, row, adapter, path): + reduced_path = interfaces._reduce_path(path) + if ("user_defined_eager_row_processor", reduced_path) in \ + context.attributes: + decorator = context.attributes[ + ("user_defined_eager_row_processor", + reduced_path)] + # user defined eagerloads are part of the "primary" + # portion of the load. + # the adapters applied to the Query should be honored. + if context.adapter and decorator: + decorator = decorator.wrap(context.adapter) + elif context.adapter: + decorator = context.adapter + elif ("eager_row_processor", reduced_path) in context.attributes: + decorator = context.attributes[ + ("eager_row_processor", reduced_path)] + else: + return False + + try: + identity_key = self.mapper.identity_key_from_row(row, decorator) + return decorator + except KeyError, k: + # no identity key - dont return a row + # processor, will cause a degrade to lazy + return False + + def create_row_processor(self, context, path, mapper, row, adapter): + path = path + (self.key,) + + eager_adapter = self._create_eager_adapter( + context, + row, + adapter, path) + + if eager_adapter is not False: + key = self.key + _instance = self.mapper._instance_processor( + context, + path + (self.mapper,), + eager_adapter) + + if not self.uselist: + def new_execute(state, dict_, row): + # set a scalar object instance directly on the parent + # object, bypassing InstrumentedAttribute event handlers. + dict_[key] = _instance(row, None) + + def existing_execute(state, dict_, row): + # call _instance on the row, even though the object has + # been created, so that we further descend into properties + existing = _instance(row, None) + if existing is not None \ + and key in dict_ \ + and existing is not dict_[key]: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded attribute '%s' " + % self) + return new_execute, existing_execute + else: + def new_execute(state, dict_, row): + collection = attributes.init_state_collection( + state, dict_, key) + result_list = util.UniqueAppender(collection, + 'append_without_event') + context.attributes[(state, key)] = result_list + _instance(row, result_list) + + def existing_execute(state, dict_, row): + if (state, key) in context.attributes: + result_list = context.attributes[(state, key)] + else: + # appender_key can be absent from context.attributes + # with isnew=False when self-referential eager loading + # is used; the same instance may be present in two + # distinct sets of result columns + collection = attributes.init_state_collection(state, + dict_, key) + result_list = util.UniqueAppender( + collection, + 'append_without_event') + context.attributes[(state, key)] = result_list + _instance(row, result_list) + return new_execute, existing_execute + else: + return self.parent_property.\ + _get_strategy(LazyLoader).\ + create_row_processor( + context, path, + mapper, row, adapter) + +log.class_logger(EagerLoader) + +class EagerLazyOption(StrategizedOption): + def __init__(self, key, lazy=True, chained=False, + propagate_to_loaders=True + ): + super(EagerLazyOption, self).__init__(key) + self.lazy = lazy + self.chained = chained + self.propagate_to_loaders = propagate_to_loaders + self.strategy_cls = factory(lazy) + + @property + def is_eager(self): + return self.lazy in (False, 'joined', 'subquery') + + @property + def is_chained(self): + return self.is_eager and self.chained + + def get_strategy_class(self): + return self.strategy_cls + +def factory(identifier): + if identifier is False or identifier == 'joined': + return EagerLoader + elif identifier is None or identifier == 'noload': + return NoLoader + elif identifier is False or identifier == 'select': + return LazyLoader + elif identifier == 'subquery': + return SubqueryLoader + else: + return LazyLoader + + + +class EagerJoinOption(PropertyOption): + + def __init__(self, key, innerjoin, chained=False): + super(EagerJoinOption, self).__init__(key) + self.innerjoin = innerjoin + self.chained = chained + + def is_chained(self): + return self.chained + + def process_query_property(self, query, paths, mappers): + if self.is_chained(): + for path in paths: + query._attributes[("eager_join_type", path)] = self.innerjoin + else: + query._attributes[("eager_join_type", paths[-1])] = self.innerjoin + +class LoadEagerFromAliasOption(PropertyOption): + + def __init__(self, key, alias=None): + super(LoadEagerFromAliasOption, self).__init__(key) + if alias is not None: + if not isinstance(alias, basestring): + m, alias, is_aliased_class = mapperutil._entity_info(alias) + self.alias = alias + + def process_query_property(self, query, paths, mappers): + if self.alias is not None: + if isinstance(self.alias, basestring): + mapper = mappers[-1] + (root_mapper, propname) = paths[-1][-2:] + prop = mapper.get_property(propname, resolve_synonyms=True) + self.alias = prop.target.alias(self.alias) + query._attributes[ + ("user_defined_eager_row_processor", + interfaces._reduce_path(paths[-1])) + ] = sql_util.ColumnAdapter(self.alias) + else: + (root_mapper, propname) = paths[-1][-2:] + mapper = mappers[-1] + prop = mapper.get_property(propname, resolve_synonyms=True) + adapter = query._polymorphic_adapters.get(prop.mapper, None) + query._attributes[ + ("user_defined_eager_row_processor", + interfaces._reduce_path(paths[-1]))] = adapter + +class _SingleParentValidator(interfaces.AttributeExtension): + def __init__(self, prop): + self.prop = prop + + def _do_check(self, state, value, oldvalue, initiator): + if value is not None: + hasparent = initiator.hasparent(attributes.instance_state(value)) + if hasparent and oldvalue is not value: + raise sa_exc.InvalidRequestError( + "Instance %s is already associated with an instance " + "of %s via its %s attribute, and is only allowed a " + "single parent." % + (mapperutil.instance_str(value), state.class_, self.prop) + ) + return value + + def append(self, state, value, initiator): + return self._do_check(state, value, None, initiator) + + def set(self, state, value, oldvalue, initiator): + return self._do_check(state, value, oldvalue, initiator) + + diff --git a/sqlalchemy/orm/sync.py b/sqlalchemy/orm/sync.py new file mode 100644 index 0000000..30daacb --- /dev/null +++ b/sqlalchemy/orm/sync.py @@ -0,0 +1,98 @@ +# mapper/sync.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""private module containing functions used for copying data +between instances based on join conditions. +""" + +from sqlalchemy.orm import exc, util as mapperutil + +def populate(source, source_mapper, dest, dest_mapper, + synchronize_pairs, uowcommit, passive_updates): + for l, r in synchronize_pairs: + try: + value = source_mapper._get_state_attr_by_column(source, l) + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, dest_mapper, r) + + try: + dest_mapper._set_state_attr_by_column(dest, r, value) + except exc.UnmappedColumnError: + _raise_col_to_prop(True, source_mapper, l, dest_mapper, r) + + # techically the "r.primary_key" check isn't + # needed here, but we check for this condition to limit + # how often this logic is invoked for memory/performance + # reasons, since we only need this info for a primary key + # destination. + if l.primary_key and r.primary_key and \ + r.references(l) and passive_updates: + uowcommit.attributes[("pk_cascaded", dest, r)] = True + +def clear(dest, dest_mapper, synchronize_pairs): + for l, r in synchronize_pairs: + if r.primary_key: + raise AssertionError( + "Dependency rule tried to blank-out primary key " + "column '%s' on instance '%s'" % + (r, mapperutil.state_str(dest)) + ) + try: + dest_mapper._set_state_attr_by_column(dest, r, None) + except exc.UnmappedColumnError: + _raise_col_to_prop(True, None, l, dest_mapper, r) + +def update(source, source_mapper, dest, old_prefix, synchronize_pairs): + for l, r in synchronize_pairs: + try: + oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l) + value = source_mapper._get_state_attr_by_column(source, l) + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, None, r) + dest[r.key] = value + dest[old_prefix + r.key] = oldvalue + +def populate_dict(source, source_mapper, dict_, synchronize_pairs): + for l, r in synchronize_pairs: + try: + value = source_mapper._get_state_attr_by_column(source, l) + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, None, r) + + dict_[r.key] = value + +def source_modified(uowcommit, source, source_mapper, synchronize_pairs): + """return true if the source object has changes from an old to a + new value on the given synchronize pairs + + """ + for l, r in synchronize_pairs: + try: + prop = source_mapper._get_col_to_prop(l) + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, None, r) + history = uowcommit.get_attribute_history(source, prop.key, passive=True) + if len(history.deleted): + return True + else: + return False + +def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column): + if isdest: + raise exc.UnmappedColumnError( + "Can't execute sync rule for destination column '%s'; " + "mapper '%s' does not map this column. Try using an explicit" + " `foreign_keys` collection which does not include this column " + "(or use a viewonly=True relation)." % (dest_column, source_mapper) + ) + else: + raise exc.UnmappedColumnError( + "Can't execute sync rule for source column '%s'; mapper '%s' " + "does not map this column. Try using an explicit `foreign_keys`" + " collection which does not include destination column '%s' (or " + "use a viewonly=True relation)." % + (source_column, source_mapper, dest_column) + ) diff --git a/sqlalchemy/orm/unitofwork.py b/sqlalchemy/orm/unitofwork.py new file mode 100644 index 0000000..30b0b61 --- /dev/null +++ b/sqlalchemy/orm/unitofwork.py @@ -0,0 +1,781 @@ +# orm/unitofwork.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""The internals for the Unit Of Work system. + +Includes hooks into the attributes package enabling the routing of +change events to Unit Of Work objects, as well as the flush() +mechanism which creates a dependency structure that executes change +operations. + +A Unit of Work is essentially a system of maintaining a graph of +in-memory objects and their modified state. Objects are maintained as +unique against their primary key identity using an *identity map* +pattern. The Unit of Work then maintains lists of objects that are +new, dirty, or deleted and provides the capability to flush all those +changes at once. + +""" + +from sqlalchemy import util, log, topological +from sqlalchemy.orm import attributes, interfaces +from sqlalchemy.orm import util as mapperutil +from sqlalchemy.orm.mapper import _state_mapper + +# Load lazily +object_session = None +_state_session = None + +class UOWEventHandler(interfaces.AttributeExtension): + """An event handler added to all relationship attributes which handles + session cascade operations. + """ + + active_history = False + + def __init__(self, key): + self.key = key + + def append(self, state, item, initiator): + # process "save_update" cascade rules for when an instance is appended to the list of another instance + sess = _state_session(state) + if sess: + prop = _state_mapper(state).get_property(self.key) + if prop.cascade.save_update and item not in sess: + sess.add(item) + return item + + def remove(self, state, item, initiator): + sess = _state_session(state) + if sess: + prop = _state_mapper(state).get_property(self.key) + # expunge pending orphans + if prop.cascade.delete_orphan and \ + item in sess.new and \ + prop.mapper._is_orphan(attributes.instance_state(item)): + sess.expunge(item) + + def set(self, state, newvalue, oldvalue, initiator): + # process "save_update" cascade rules for when an instance is attached to another instance + if oldvalue is newvalue: + return newvalue + sess = _state_session(state) + if sess: + prop = _state_mapper(state).get_property(self.key) + if newvalue is not None and prop.cascade.save_update and newvalue not in sess: + sess.add(newvalue) + if prop.cascade.delete_orphan and oldvalue in sess.new and \ + prop.mapper._is_orphan(attributes.instance_state(oldvalue)): + sess.expunge(oldvalue) + return newvalue + + +class UOWTransaction(object): + """Handles the details of organizing and executing transaction + tasks during a UnitOfWork object's flush() operation. + + The central operation is to form a graph of nodes represented by the + ``UOWTask`` class, which is then traversed by a ``UOWExecutor`` object + that issues SQL and instance-synchronizing operations via the related + packages. + """ + + def __init__(self, session): + self.session = session + self.mapper_flush_opts = session._mapper_flush_opts + + # stores tuples of mapper/dependent mapper pairs, + # representing a partial ordering fed into topological sort + self.dependencies = set() + + # dictionary of mappers to UOWTasks + self.tasks = {} + + # dictionary used by external actors to store arbitrary state + # information. + self.attributes = {} + + self.processors = set() + + def get_attribute_history(self, state, key, passive=True): + hashkey = ("history", state, key) + + # cache the objects, not the states; the strong reference here + # prevents newly loaded objects from being dereferenced during the + # flush process + if hashkey in self.attributes: + (history, cached_passive) = self.attributes[hashkey] + # if the cached lookup was "passive" and now we want non-passive, do a non-passive + # lookup and re-cache + if cached_passive and not passive: + history = attributes.get_state_history(state, key, passive=False) + self.attributes[hashkey] = (history, passive) + else: + history = attributes.get_state_history(state, key, passive=passive) + self.attributes[hashkey] = (history, passive) + + if not history or not state.get_impl(key).uses_objects: + return history + else: + return history.as_state() + + def register_object(self, state, isdelete=False, + listonly=False, postupdate=False, post_update_cols=None): + + # if object is not in the overall session, do nothing + if not self.session._contains_state(state): + return + + mapper = _state_mapper(state) + + task = self.get_task_by_mapper(mapper) + if postupdate: + task.append_postupdate(state, post_update_cols) + else: + task.append(state, listonly=listonly, isdelete=isdelete) + + # ensure the mapper for this object has had its + # DependencyProcessors added. + if mapper not in self.processors: + mapper._register_processors(self) + self.processors.add(mapper) + + if mapper.base_mapper not in self.processors: + mapper.base_mapper._register_processors(self) + self.processors.add(mapper.base_mapper) + + def set_row_switch(self, state): + """mark a deleted object as a 'row switch'. + + this indicates that an INSERT statement elsewhere corresponds to this DELETE; + the INSERT is converted to an UPDATE and the DELETE does not occur. + + """ + mapper = _state_mapper(state) + task = self.get_task_by_mapper(mapper) + taskelement = task._objects[state] + taskelement.isdelete = "rowswitch" + + def is_deleted(self, state): + """return true if the given state is marked as deleted within this UOWTransaction.""" + + mapper = _state_mapper(state) + task = self.get_task_by_mapper(mapper) + return task.is_deleted(state) + + def get_task_by_mapper(self, mapper, dontcreate=False): + """return UOWTask element corresponding to the given mapper. + + Will create a new UOWTask, including a UOWTask corresponding to the + "base" inherited mapper, if needed, unless the dontcreate flag is True. + + """ + try: + return self.tasks[mapper] + except KeyError: + if dontcreate: + return None + + base_mapper = mapper.base_mapper + if base_mapper in self.tasks: + base_task = self.tasks[base_mapper] + else: + self.tasks[base_mapper] = base_task = UOWTask(self, base_mapper) + base_mapper._register_dependencies(self) + + if mapper not in self.tasks: + self.tasks[mapper] = task = UOWTask(self, mapper, base_task=base_task) + mapper._register_dependencies(self) + else: + task = self.tasks[mapper] + + return task + + def register_dependency(self, mapper, dependency): + """register a dependency between two mappers. + + Called by ``mapper.PropertyLoader`` to register the objects + handled by one mapper being dependent on the objects handled + by another. + + """ + # correct for primary mapper + # also convert to the "base mapper", the parentmost task at the top of an inheritance chain + # dependency sorting is done via non-inheriting mappers only, dependencies between mappers + # in the same inheritance chain is done at the per-object level + mapper = mapper.primary_mapper().base_mapper + dependency = dependency.primary_mapper().base_mapper + + self.dependencies.add((mapper, dependency)) + + def register_processor(self, mapper, processor, mapperfrom): + """register a dependency processor, corresponding to + operations which occur between two mappers. + + """ + # correct for primary mapper + mapper = mapper.primary_mapper() + mapperfrom = mapperfrom.primary_mapper() + + task = self.get_task_by_mapper(mapper) + targettask = self.get_task_by_mapper(mapperfrom) + up = UOWDependencyProcessor(processor, targettask) + task.dependencies.add(up) + + def execute(self): + """Execute this UOWTransaction. + + This will organize all collected UOWTasks into a dependency-sorted + list which is then traversed using the traversal scheme + encoded in the UOWExecutor class. Operations to mappers and dependency + processors are fired off in order to issue SQL to the database and + synchronize instance attributes with database values and related + foreign key values.""" + + # pre-execute dependency processors. this process may + # result in new tasks, objects and/or dependency processors being added, + # particularly with 'delete-orphan' cascade rules. + # keep running through the full list of tasks until all + # objects have been processed. + while True: + ret = False + for task in self.tasks.values(): + for up in list(task.dependencies): + if up.preexecute(self): + ret = True + if not ret: + break + + tasks = self._sort_dependencies() + if self._should_log_info(): + self.logger.info("Task dump:\n%s", self._dump(tasks)) + UOWExecutor().execute(self, tasks) + self.logger.info("Execute Complete") + + def _dump(self, tasks): + from uowdumper import UOWDumper + return UOWDumper.dump(tasks) + + @property + def elements(self): + """Iterate UOWTaskElements.""" + + for task in self.tasks.itervalues(): + for elem in task.elements: + yield elem + + def finalize_flush_changes(self): + """mark processed objects as clean / deleted after a successful flush(). + + this method is called within the flush() method after the + execute() method has succeeded and the transaction has been committed. + """ + + for elem in self.elements: + if elem.isdelete: + self.session._remove_newly_deleted(elem.state) + elif not elem.listonly: + self.session._register_newly_persistent(elem.state) + + def _sort_dependencies(self): + nodes = topological.sort_with_cycles(self.dependencies, + [t.mapper for t in self.tasks.itervalues() if t.base_task is t] + ) + + ret = [] + for item, cycles in nodes: + task = self.get_task_by_mapper(item) + if cycles: + for t in task._sort_circular_dependencies( + self, + [self.get_task_by_mapper(i) for i in cycles] + ): + ret.append(t) + else: + ret.append(task) + + return ret + +log.class_logger(UOWTransaction) + +class UOWTask(object): + """A collection of mapped states corresponding to a particular mapper.""" + + def __init__(self, uowtransaction, mapper, base_task=None): + self.uowtransaction = uowtransaction + + # base_task is the UOWTask which represents the "base mapper" + # in our mapper's inheritance chain. if the mapper does not + # inherit from any other mapper, the base_task is self. + # the _inheriting_tasks dictionary is a dictionary present only + # on the "base_task"-holding UOWTask, which maps all mappers within + # an inheritance hierarchy to their corresponding UOWTask instances. + if base_task is None: + self.base_task = self + self._inheriting_tasks = {mapper:self} + else: + self.base_task = base_task + base_task._inheriting_tasks[mapper] = self + + # the Mapper which this UOWTask corresponds to + self.mapper = mapper + + # mapping of InstanceState -> UOWTaskElement + self._objects = {} + + self.dependent_tasks = [] + self.dependencies = set() + self.cyclical_dependencies = set() + + @util.memoized_property + def inheriting_mappers(self): + return list(self.mapper.polymorphic_iterator()) + + @property + def polymorphic_tasks(self): + """Return an iterator of UOWTask objects corresponding to the + inheritance sequence of this UOWTask's mapper. + + e.g. if mapper B and mapper C inherit from mapper A, and + mapper D inherits from B: + + mapperA -> mapperB -> mapperD + -> mapperC + + the inheritance sequence starting at mapper A is a depth-first + traversal: + + [mapperA, mapperB, mapperD, mapperC] + + this method will therefore return + + [UOWTask(mapperA), UOWTask(mapperB), UOWTask(mapperD), + UOWTask(mapperC)] + + The concept of "polymporphic iteration" is adapted into + several property-based iterators which return object + instances, UOWTaskElements and UOWDependencyProcessors in an + order corresponding to this sequence of parent UOWTasks. This + is used to issue operations related to inheritance-chains of + mappers in the proper order based on dependencies between + those mappers. + + """ + for mapper in self.inheriting_mappers: + t = self.base_task._inheriting_tasks.get(mapper, None) + if t is not None: + yield t + + def is_empty(self): + """return True if this UOWTask is 'empty', meaning it has no child items. + + used only for debugging output. + """ + + return not self._objects and not self.dependencies + + def append(self, state, listonly=False, isdelete=False): + if state not in self._objects: + self._objects[state] = rec = UOWTaskElement(state) + else: + rec = self._objects[state] + + rec.update(listonly, isdelete) + + def append_postupdate(self, state, post_update_cols): + """issue a 'post update' UPDATE statement via this object's mapper immediately. + + this operation is used only with relationships that specify the `post_update=True` + flag. + """ + + # postupdates are UPDATED immeditely (for now) + # convert post_update_cols list to a Set so that __hash__() is used to compare columns + # instead of __eq__() + self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=set(post_update_cols)) + + def __contains__(self, state): + """return True if the given object is contained within this UOWTask or inheriting tasks.""" + + for task in self.polymorphic_tasks: + if state in task._objects: + return True + else: + return False + + def is_deleted(self, state): + """return True if the given object is marked as to be deleted within this UOWTask.""" + + try: + return self._objects[state].isdelete + except KeyError: + return False + + def _polymorphic_collection(fn): + """return a property that will adapt the collection returned by the + given callable into a polymorphic traversal.""" + + @property + def collection(self): + for task in self.polymorphic_tasks: + for rec in fn(task): + yield rec + return collection + + def _polymorphic_collection_filtered(fn): + + def collection(self, mappers): + for task in self.polymorphic_tasks: + if task.mapper in mappers: + for rec in fn(task): + yield rec + return collection + + @property + def elements(self): + return self._objects.values() + + @_polymorphic_collection + def polymorphic_elements(self): + return self.elements + + @_polymorphic_collection_filtered + def filter_polymorphic_elements(self): + return self.elements + + @property + def polymorphic_tosave_elements(self): + return [rec for rec in self.polymorphic_elements if not rec.isdelete] + + @property + def polymorphic_todelete_elements(self): + return [rec for rec in self.polymorphic_elements if rec.isdelete] + + @property + def polymorphic_tosave_objects(self): + return [ + rec.state for rec in self.polymorphic_elements + if rec.state is not None and not rec.listonly and rec.isdelete is False + ] + + @property + def polymorphic_todelete_objects(self): + return [ + rec.state for rec in self.polymorphic_elements + if rec.state is not None and not rec.listonly and rec.isdelete is True + ] + + @_polymorphic_collection + def polymorphic_dependencies(self): + return self.dependencies + + @_polymorphic_collection + def polymorphic_cyclical_dependencies(self): + return self.cyclical_dependencies + + def _sort_circular_dependencies(self, trans, cycles): + """Topologically sort individual entities with row-level dependencies. + + Builds a modified UOWTask structure, and is invoked when the + per-mapper topological structure is found to have cycles. + + """ + + dependencies = {} + def set_processor_for_state(state, depprocessor, target_state, isdelete): + if state not in dependencies: + dependencies[state] = {} + tasks = dependencies[state] + if depprocessor not in tasks: + tasks[depprocessor] = UOWDependencyProcessor( + depprocessor.processor, + UOWTask(self.uowtransaction, depprocessor.targettask.mapper) + ) + tasks[depprocessor].targettask.append(target_state, isdelete=isdelete) + + cycles = set(cycles) + def dependency_in_cycles(dep): + proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper, True) + targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper, True) + return targettask in cycles and (proctask is not None and proctask in cycles) + + deps_by_targettask = {} + extradeplist = [] + for task in cycles: + for dep in task.polymorphic_dependencies: + if not dependency_in_cycles(dep): + extradeplist.append(dep) + for t in dep.targettask.polymorphic_tasks: + l = deps_by_targettask.setdefault(t, []) + l.append(dep) + + object_to_original_task = {} + tuples = [] + + for task in cycles: + for subtask in task.polymorphic_tasks: + for taskelement in subtask.elements: + state = taskelement.state + object_to_original_task[state] = subtask + if subtask not in deps_by_targettask: + continue + for dep in deps_by_targettask[subtask]: + if not dep.processor.has_dependencies or not dependency_in_cycles(dep): + continue + (processor, targettask) = (dep.processor, dep.targettask) + isdelete = taskelement.isdelete + + # list of dependent objects from this object + (added, unchanged, deleted) = dep.get_object_dependencies(state, trans, passive=True) + if not added and not unchanged and not deleted: + continue + + # the task corresponding to saving/deleting of those dependent objects + childtask = trans.get_task_by_mapper(processor.mapper) + + childlist = added + unchanged + deleted + + for o in childlist: + if o is None: + continue + + if o not in childtask: + childtask.append(o, listonly=True) + object_to_original_task[o] = childtask + + whosdep = dep.whose_dependent_on_who(state, o) + if whosdep is not None: + tuples.append(whosdep) + + if whosdep[0] is state: + set_processor_for_state(whosdep[0], dep, whosdep[0], isdelete=isdelete) + else: + set_processor_for_state(whosdep[0], dep, whosdep[1], isdelete=isdelete) + else: + # TODO: no test coverage here + set_processor_for_state(state, dep, state, isdelete=isdelete) + + t = UOWTask(self.uowtransaction, self.mapper) + t.dependencies.update(extradeplist) + + used_tasks = set() + + # rationale for "tree" sort as opposed to a straight + # dependency - keep non-dependent objects + # grouped together, so that insert ordering as determined + # by session.add() is maintained. + # An alternative might be to represent the "insert order" + # as part of the topological sort itself, which would + # eliminate the need for this step (but may make the original + # topological sort more expensive) + head = topological.sort_as_tree(tuples, object_to_original_task.iterkeys()) + if head is not None: + original_to_tasks = {} + stack = [(head, t)] + while stack: + ((state, cycles, children), parenttask) = stack.pop() + + originating_task = object_to_original_task[state] + used_tasks.add(originating_task) + + if (parenttask, originating_task) not in original_to_tasks: + task = UOWTask(self.uowtransaction, originating_task.mapper) + original_to_tasks[(parenttask, originating_task)] = task + parenttask.dependent_tasks.append(task) + else: + task = original_to_tasks[(parenttask, originating_task)] + + task.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete) + + if state in dependencies: + task.cyclical_dependencies.update(dependencies[state].itervalues()) + + stack += [(n, task) for n in children] + + ret = [t] + + # add tasks that were in the cycle, but didnt get assembled + # into the cyclical tree, to the start of the list + for t2 in cycles: + if t2 not in used_tasks and t2 is not self: + localtask = UOWTask(self.uowtransaction, t2.mapper) + for state in t2.elements: + localtask.append(state, t2.listonly, isdelete=t2._objects[state].isdelete) + for dep in t2.dependencies: + localtask.dependencies.add(dep) + ret.insert(0, localtask) + + return ret + + def __repr__(self): + return ("UOWTask(%s) Mapper: '%r'" % (hex(id(self)), self.mapper)) + +class UOWTaskElement(object): + """Corresponds to a single InstanceState to be saved, deleted, + or otherwise marked as having dependencies. A collection of + UOWTaskElements are held by a UOWTask. + + """ + def __init__(self, state): + self.state = state + self.listonly = True + self.isdelete = False + self.preprocessed = set() + + def update(self, listonly, isdelete): + if not listonly and self.listonly: + self.listonly = False + self.preprocessed.clear() + if isdelete and not self.isdelete: + self.isdelete = True + self.preprocessed.clear() + + def __repr__(self): + return "UOWTaskElement/%d: %s/%d %s" % ( + id(self), + self.state.class_.__name__, + id(self.state.obj()), + (self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save')) + ) + +class UOWDependencyProcessor(object): + """In between the saving and deleting of objects, process + dependent data, such as filling in a foreign key on a child item + from a new primary key, or deleting association rows before a + delete. This object acts as a proxy to a DependencyProcessor. + + """ + def __init__(self, processor, targettask): + self.processor = processor + self.targettask = targettask + prop = processor.prop + + # define a set of mappers which + # will filter the lists of entities + # this UOWDP processes. this allows + # MapperProperties to be overridden + # at least for concrete mappers. + self._mappers = set([ + m + for m in self.processor.parent.polymorphic_iterator() + if m._props[prop.key] is prop + ]).union(self.processor.mapper.polymorphic_iterator()) + + def __repr__(self): + return "UOWDependencyProcessor(%s, %s)" % (str(self.processor), str(self.targettask)) + + def __eq__(self, other): + return other.processor is self.processor and other.targettask is self.targettask + + def __hash__(self): + return hash((self.processor, self.targettask)) + + def preexecute(self, trans): + """preprocess all objects contained within this ``UOWDependencyProcessor``s target task. + + This may locate additional objects which should be part of the + transaction, such as those affected deletes, orphans to be + deleted, etc. + + Once an object is preprocessed, its ``UOWTaskElement`` is marked as processed. If subsequent + changes occur to the ``UOWTaskElement``, its processed flag is reset, and will require processing + again. + + Return True if any objects were preprocessed, or False if no + objects were preprocessed. If True is returned, the parent ``UOWTransaction`` will + ultimately call ``preexecute()`` again on all processors until no new objects are processed. + """ + + def getobj(elem): + elem.preprocessed.add(self) + return elem.state + + ret = False + elements = [getobj(elem) for elem in + self.targettask.filter_polymorphic_elements(self._mappers) + if self not in elem.preprocessed and not elem.isdelete] + if elements: + ret = True + self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False) + + elements = [getobj(elem) for elem in + self.targettask.filter_polymorphic_elements(self._mappers) + if self not in elem.preprocessed and elem.isdelete] + if elements: + ret = True + self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True) + return ret + + def execute(self, trans, delete): + """process all objects contained within this ``UOWDependencyProcessor``s target task.""" + + + elements = [e for e in + self.targettask.filter_polymorphic_elements(self._mappers) + if bool(e.isdelete)==delete] + + self.processor.process_dependencies( + self.targettask, + [elem.state for elem in elements], + trans, + delete=delete) + + def get_object_dependencies(self, state, trans, passive): + return trans.get_attribute_history(state, self.processor.key, passive=passive) + + def whose_dependent_on_who(self, state1, state2): + """establish which object is operationally dependent amongst a parent/child + using the semantics stated by the dependency processor. + + This method is used to establish a partial ordering (set of dependency tuples) + when toplogically sorting on a per-instance basis. + + """ + return self.processor.whose_dependent_on_who(state1, state2) + +class UOWExecutor(object): + """Encapsulates the execution traversal of a UOWTransaction structure.""" + + def execute(self, trans, tasks, isdelete=None): + if isdelete is not True: + for task in tasks: + self.execute_save_steps(trans, task) + if isdelete is not False: + for task in reversed(tasks): + self.execute_delete_steps(trans, task) + + def save_objects(self, trans, task): + task.mapper._save_obj(task.polymorphic_tosave_objects, trans) + + def delete_objects(self, trans, task): + task.mapper._delete_obj(task.polymorphic_todelete_objects, trans) + + def execute_dependency(self, trans, dep, isdelete): + dep.execute(trans, isdelete) + + def execute_save_steps(self, trans, task): + self.save_objects(trans, task) + for dep in task.polymorphic_cyclical_dependencies: + self.execute_dependency(trans, dep, False) + for dep in task.polymorphic_cyclical_dependencies: + self.execute_dependency(trans, dep, True) + self.execute_cyclical_dependencies(trans, task, False) + self.execute_dependencies(trans, task) + + def execute_delete_steps(self, trans, task): + self.execute_cyclical_dependencies(trans, task, True) + self.delete_objects(trans, task) + + def execute_dependencies(self, trans, task): + polymorphic_dependencies = list(task.polymorphic_dependencies) + for dep in polymorphic_dependencies: + self.execute_dependency(trans, dep, False) + for dep in reversed(polymorphic_dependencies): + self.execute_dependency(trans, dep, True) + + def execute_cyclical_dependencies(self, trans, task, isdelete): + for t in task.dependent_tasks: + self.execute(trans, [t], isdelete) diff --git a/sqlalchemy/orm/uowdumper.py b/sqlalchemy/orm/uowdumper.py new file mode 100644 index 0000000..dd96b6b --- /dev/null +++ b/sqlalchemy/orm/uowdumper.py @@ -0,0 +1,101 @@ +# orm/uowdumper.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Dumps out a string representation of a UOWTask structure""" + +from sqlalchemy.orm import unitofwork +from sqlalchemy.orm import util as mapperutil +import StringIO + +class UOWDumper(unitofwork.UOWExecutor): + def __init__(self, tasks, buf): + self.indent = 0 + self.tasks = tasks + self.buf = buf + self.execute(None, tasks) + + @classmethod + def dump(cls, tasks): + buf = StringIO.StringIO() + UOWDumper(tasks, buf) + return buf.getvalue() + + def execute(self, trans, tasks, isdelete=None): + if isdelete is not True: + for task in tasks: + self._execute(trans, task, False) + if isdelete is not False: + for task in reversed(tasks): + self._execute(trans, task, True) + + def _execute(self, trans, task, isdelete): + try: + i = self._indent() + if i: + i = i[:-1] + "+-" + self.buf.write(i + " " + self._repr_task(task)) + self.buf.write(" (" + (isdelete and "delete " or "save/update ") + "phase) \n") + self.indent += 1 + super(UOWDumper, self).execute(trans, [task], isdelete) + finally: + self.indent -= 1 + + + def save_objects(self, trans, task): + for rec in sorted(task.polymorphic_tosave_elements, key=lambda a: a.state.sort_key): + if rec.listonly: + continue + self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec) + "\n") + + def delete_objects(self, trans, task): + for rec in task.polymorphic_todelete_elements: + if rec.listonly: + continue + self.buf.write(self._indent() + "- " + self._repr_task_element(rec) + "\n") + + def execute_dependency(self, transaction, dep, isdelete): + self._dump_processor(dep, isdelete) + + def _dump_processor(self, proc, deletes): + if deletes: + val = proc.targettask.polymorphic_todelete_elements + else: + val = proc.targettask.polymorphic_tosave_elements + + for v in val: + self.buf.write(self._indent() + " +- " + self._repr_task_element(v, proc.processor.key, process=True) + "\n") + + def _repr_task_element(self, te, attribute=None, process=False): + if getattr(te, 'state', None) is None: + objid = "(placeholder)" + else: + if attribute is not None: + objid = "%s.%s" % (mapperutil.state_str(te.state), attribute) + else: + objid = mapperutil.state_str(te.state) + if process: + return "Process %s" % (objid) + else: + return "%s %s" % ((te.isdelete and "Delete" or "Save"), objid) + + def _repr_task(self, task): + if task.mapper is not None: + if task.mapper.__class__.__name__ == 'Mapper': + name = task.mapper.class_.__name__ + "/" + task.mapper.local_table.description + else: + name = repr(task.mapper) + else: + name = '(none)' + return ("UOWTask(%s, %s)" % (hex(id(task)), name)) + + def _repr_task_class(self, task): + if task.mapper is not None and task.mapper.__class__.__name__ == 'Mapper': + return task.mapper.class_.__name__ + else: + return '(none)' + + def _indent(self): + return " |" * self.indent diff --git a/sqlalchemy/orm/util.py b/sqlalchemy/orm/util.py new file mode 100644 index 0000000..63b9d56 --- /dev/null +++ b/sqlalchemy/orm/util.py @@ -0,0 +1,668 @@ +# mapper/util.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import sqlalchemy.exceptions as sa_exc +from sqlalchemy import sql, util +from sqlalchemy.sql import expression, util as sql_util, operators +from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, \ + MapperProperty, AttributeExtension +from sqlalchemy.orm import attributes, exc + +mapperlib = None + +all_cascades = frozenset(("delete", "delete-orphan", "all", "merge", + "expunge", "save-update", "refresh-expire", + "none")) + +_INSTRUMENTOR = ('mapper', 'instrumentor') + +class CascadeOptions(object): + """Keeps track of the options sent to relationship().cascade""" + + def __init__(self, arg=""): + if not arg: + values = set() + else: + values = set(c.strip() for c in arg.split(',')) + self.delete_orphan = "delete-orphan" in values + self.delete = "delete" in values or "all" in values + self.save_update = "save-update" in values or "all" in values + self.merge = "merge" in values or "all" in values + self.expunge = "expunge" in values or "all" in values + self.refresh_expire = "refresh-expire" in values or "all" in values + + if self.delete_orphan and not self.delete: + util.warn("The 'delete-orphan' cascade option requires " + "'delete'. This will raise an error in 0.6.") + + for x in values: + if x not in all_cascades: + raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x) + + def __contains__(self, item): + return getattr(self, item.replace("-", "_"), False) + + def __repr__(self): + return "CascadeOptions(%s)" % repr(",".join( + [x for x in ['delete', 'save_update', 'merge', 'expunge', + 'delete_orphan', 'refresh-expire'] + if getattr(self, x, False) is True])) + + +class Validator(AttributeExtension): + """Runs a validation method on an attribute value to be set or appended. + + The Validator class is used by the :func:`~sqlalchemy.orm.validates` + decorator, and direct access is usually not needed. + + """ + + def __init__(self, key, validator): + """Construct a new Validator. + + key - name of the attribute to be validated; + will be passed as the second argument to + the validation method (the first is the object instance itself). + + validator - an function or instance method which accepts + three arguments; an instance (usually just 'self' for a method), + the key name of the attribute, and the value. The function should + return the same value given, unless it wishes to modify it. + + """ + self.key = key + self.validator = validator + + def append(self, state, value, initiator): + return self.validator(state.obj(), self.key, value) + + def set(self, state, value, oldvalue, initiator): + return self.validator(state.obj(), self.key, value) + +def polymorphic_union(table_map, typecolname, aliasname='p_union'): + """Create a ``UNION`` statement used by a polymorphic mapper. + + See :ref:`concrete_inheritance` for an example of how + this is used. + """ + + colnames = set() + colnamemaps = {} + types = {} + for key in table_map.keys(): + table = table_map[key] + + # mysql doesnt like selecting from a select; make it an alias of the select + if isinstance(table, sql.Select): + table = table.alias() + table_map[key] = table + + m = {} + for c in table.c: + colnames.add(c.key) + m[c.key] = c + types[c.key] = c.type + colnamemaps[table] = m + + def col(name, table): + try: + return colnamemaps[table][name] + except KeyError: + return sql.cast(sql.null(), types[name]).label(name) + + result = [] + for type, table in table_map.iteritems(): + if typecolname is not None: + result.append(sql.select([col(name, table) for name in colnames] + + [sql.literal_column("'%s'" % type).label(typecolname)], + from_obj=[table])) + else: + result.append(sql.select([col(name, table) for name in colnames], + from_obj=[table])) + return sql.union_all(*result).alias(aliasname) + +def identity_key(*args, **kwargs): + """Get an identity key. + + Valid call signatures: + + * ``identity_key(class, ident)`` + + class + mapped class (must be a positional argument) + + ident + primary key, if the key is composite this is a tuple + + + * ``identity_key(instance=instance)`` + + instance + object instance (must be given as a keyword arg) + + * ``identity_key(class, row=row)`` + + class + mapped class (must be a positional argument) + + row + result proxy row (must be given as a keyword arg) + + """ + if args: + if len(args) == 1: + class_ = args[0] + try: + row = kwargs.pop("row") + except KeyError: + ident = kwargs.pop("ident") + elif len(args) == 2: + class_, ident = args + elif len(args) == 3: + class_, ident = args + else: + raise sa_exc.ArgumentError("expected up to three " + "positional arguments, got %s" % len(args)) + if kwargs: + raise sa_exc.ArgumentError("unknown keyword arguments: %s" + % ", ".join(kwargs.keys())) + mapper = class_mapper(class_) + if "ident" in locals(): + return mapper.identity_key_from_primary_key(ident) + return mapper.identity_key_from_row(row) + instance = kwargs.pop("instance") + if kwargs: + raise sa_exc.ArgumentError("unknown keyword arguments: %s" + % ", ".join(kwargs.keys())) + mapper = object_mapper(instance) + return mapper.identity_key_from_instance(instance) + +class ExtensionCarrier(dict): + """Fronts an ordered collection of MapperExtension objects. + + Bundles multiple MapperExtensions into a unified callable unit, + encapsulating ordering, looping and EXT_CONTINUE logic. The + ExtensionCarrier implements the MapperExtension interface, e.g.:: + + carrier.after_insert(...args...) + + The dictionary interface provides containment for implemented + method names mapped to a callable which executes that method + for participating extensions. + + """ + + interface = set(method for method in dir(MapperExtension) + if not method.startswith('_')) + + def __init__(self, extensions=None): + self._extensions = [] + for ext in extensions or (): + self.append(ext) + + def copy(self): + return ExtensionCarrier(self._extensions) + + def push(self, extension): + """Insert a MapperExtension at the beginning of the collection.""" + self._register(extension) + self._extensions.insert(0, extension) + + def append(self, extension): + """Append a MapperExtension at the end of the collection.""" + self._register(extension) + self._extensions.append(extension) + + def __iter__(self): + """Iterate over MapperExtensions in the collection.""" + return iter(self._extensions) + + def _register(self, extension): + """Register callable fronts for overridden interface methods.""" + + for method in self.interface.difference(self): + impl = getattr(extension, method, None) + if impl and impl is not getattr(MapperExtension, method): + self[method] = self._create_do(method) + + def _create_do(self, method): + """Return a closure that loops over impls of the named method.""" + + def _do(*args, **kwargs): + for ext in self._extensions: + ret = getattr(ext, method)(*args, **kwargs) + if ret is not EXT_CONTINUE: + return ret + else: + return EXT_CONTINUE + _do.__name__ = method + return _do + + @staticmethod + def _pass(*args, **kwargs): + return EXT_CONTINUE + + def __getattr__(self, key): + """Delegate MapperExtension methods to bundled fronts.""" + + if key not in self.interface: + raise AttributeError(key) + return self.get(key, self._pass) + +class ORMAdapter(sql_util.ColumnAdapter): + """Extends ColumnAdapter to accept ORM entities. + + The selectable is extracted from the given entity, + and the AliasedClass if any is referenced. + + """ + def __init__(self, entity, equivalents=None, chain_to=None, adapt_required=False): + self.mapper, selectable, is_aliased_class = _entity_info(entity) + if is_aliased_class: + self.aliased_class = entity + else: + self.aliased_class = None + sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to, adapt_required=adapt_required) + + def replace(self, elem): + entity = elem._annotations.get('parentmapper', None) + if not entity or entity.isa(self.mapper): + return sql_util.ColumnAdapter.replace(self, elem) + else: + return None + +class AliasedClass(object): + """Represents an "aliased" form of a mapped class for usage with Query. + + The ORM equivalent of a :func:`sqlalchemy.sql.expression.alias` + construct, this object mimics the mapped class using a + __getattr__ scheme and maintains a reference to a + real :class:`~sqlalchemy.sql.expression.Alias` object. + + Usage is via the :class:`~sqlalchemy.orm.aliased()` synonym:: + + # find all pairs of users with the same name + user_alias = aliased(User) + session.query(User, user_alias).\\ + join((user_alias, User.id > user_alias.id)).\\ + filter(User.name==user_alias.name) + + """ + def __init__(self, cls, alias=None, name=None): + self.__mapper = _class_to_mapper(cls) + self.__target = self.__mapper.class_ + if alias is None: + alias = self.__mapper._with_polymorphic_selectable.alias() + self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns) + self.__alias = alias + # used to assign a name to the RowTuple object + # returned by Query. + self._sa_label_name = name + self.__name__ = 'AliasedClass_' + str(self.__target) + + def __getstate__(self): + return {'mapper':self.__mapper, 'alias':self.__alias, 'name':self._sa_label_name} + + def __setstate__(self, state): + self.__mapper = state['mapper'] + self.__target = self.__mapper.class_ + alias = state['alias'] + self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns) + self.__alias = alias + name = state['name'] + self._sa_label_name = name + self.__name__ = 'AliasedClass_' + str(self.__target) + + def __adapt_element(self, elem): + return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper}) + + def __adapt_prop(self, prop): + existing = getattr(self.__target, prop.key) + comparator = existing.comparator.adapted(self.__adapt_element) + + queryattr = attributes.QueryableAttribute(prop.key, + impl=existing.impl, parententity=self, comparator=comparator) + setattr(self, prop.key, queryattr) + return queryattr + + def __getattr__(self, key): + prop = self.__mapper._get_property(key, raiseerr=False) + if prop: + return self.__adapt_prop(prop) + + for base in self.__target.__mro__: + try: + attr = object.__getattribute__(base, key) + except AttributeError: + continue + else: + break + else: + raise AttributeError(key) + + if hasattr(attr, 'func_code'): + is_method = getattr(self.__target, key, None) + if is_method and is_method.im_self is not None: + return util.types.MethodType(attr.im_func, self, self) + else: + return None + elif hasattr(attr, '__get__'): + return attr.__get__(None, self) + else: + return attr + + def __repr__(self): + return '' % ( + id(self), self.__target.__name__) + +def _orm_annotate(element, exclude=None): + """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag. + + Elements within the exclude collection will be cloned but not annotated. + + """ + return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude) + +_orm_deannotate = sql_util._deep_deannotate + +class _ORMJoin(expression.Join): + """Extend Join to support ORM constructs as input.""" + + __visit_name__ = expression.Join.__visit_name__ + + def __init__(self, left, right, onclause=None, isouter=False, join_to_left=True): + adapt_from = None + + if hasattr(left, '_orm_mappers'): + left_mapper = left._orm_mappers[1] + if join_to_left: + adapt_from = left.right + else: + left_mapper, left, left_is_aliased = _entity_info(left) + if join_to_left and (left_is_aliased or not left_mapper): + adapt_from = left + + right_mapper, right, right_is_aliased = _entity_info(right) + if right_is_aliased: + adapt_to = right + else: + adapt_to = None + + if left_mapper or right_mapper: + self._orm_mappers = (left_mapper, right_mapper) + + if isinstance(onclause, basestring): + prop = left_mapper.get_property(onclause) + elif isinstance(onclause, attributes.QueryableAttribute): + if adapt_from is None: + adapt_from = onclause.__clause_element__() + prop = onclause.property + elif isinstance(onclause, MapperProperty): + prop = onclause + else: + prop = None + + if prop: + pj, sj, source, dest, secondary, target_adapter = prop._create_joins( + source_selectable=adapt_from, + dest_selectable=adapt_to, + source_polymorphic=True, + dest_polymorphic=True, + of_type=right_mapper) + + if sj is not None: + left = sql.join(left, secondary, pj, isouter) + onclause = sj + else: + onclause = pj + self._target_adapter = target_adapter + + expression.Join.__init__(self, left, right, onclause, isouter) + + def join(self, right, onclause=None, isouter=False, join_to_left=True): + return _ORMJoin(self, right, onclause, isouter, join_to_left) + + def outerjoin(self, right, onclause=None, join_to_left=True): + return _ORMJoin(self, right, onclause, True, join_to_left) + +def join(left, right, onclause=None, isouter=False, join_to_left=True): + """Produce an inner join between left and right clauses. + + In addition to the interface provided by + :func:`~sqlalchemy.sql.expression.join()`, left and right may be mapped + classes or AliasedClass instances. The onclause may be a + string name of a relationship(), or a class-bound descriptor + representing a relationship. + + join_to_left indicates to attempt aliasing the ON clause, + in whatever form it is passed, to the selectable + passed as the left side. If False, the onclause + is used as is. + + """ + return _ORMJoin(left, right, onclause, isouter, join_to_left) + +def outerjoin(left, right, onclause=None, join_to_left=True): + """Produce a left outer join between left and right clauses. + + In addition to the interface provided by + :func:`~sqlalchemy.sql.expression.outerjoin()`, left and right may be mapped + classes or AliasedClass instances. The onclause may be a + string name of a relationship(), or a class-bound descriptor + representing a relationship. + + """ + return _ORMJoin(left, right, onclause, True, join_to_left) + +def with_parent(instance, prop): + """Return criterion which selects instances with a given parent. + + instance + a parent instance, which should be persistent or detached. + + property + a class-attached descriptor, MapperProperty or string property name + attached to the parent instance. + + \**kwargs + all extra keyword arguments are propagated to the constructor of + Query. + + """ + if isinstance(prop, basestring): + mapper = object_mapper(instance) + prop = mapper.get_property(prop, resolve_synonyms=True) + elif isinstance(prop, attributes.QueryableAttribute): + prop = prop.property + + return prop.compare(operators.eq, instance, value_is_parent=True) + + +def _entity_info(entity, compile=True): + """Return mapping information given a class, mapper, or AliasedClass. + + Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this + is an aliased() construct. + + If the given entity is not a mapper, mapped class, or aliased construct, + returns None, the entity, False. This is typically used to allow + unmapped selectables through. + + """ + if isinstance(entity, AliasedClass): + return entity._AliasedClass__mapper, entity._AliasedClass__alias, True + + global mapperlib + if mapperlib is None: + from sqlalchemy.orm import mapperlib + + if isinstance(entity, mapperlib.Mapper): + mapper = entity + + elif isinstance(entity, type): + class_manager = attributes.manager_of_class(entity) + + if class_manager is None: + return None, entity, False + + mapper = class_manager.mapper + else: + return None, entity, False + + if compile: + mapper = mapper.compile() + return mapper, mapper._with_polymorphic_selectable, False + +def _entity_descriptor(entity, key): + """Return attribute/property information given an entity and string name. + + Returns a 2-tuple representing InstrumentedAttribute/MapperProperty. + + """ + if isinstance(entity, AliasedClass): + try: + desc = getattr(entity, key) + return desc, desc.property + except AttributeError: + raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key)) + + elif isinstance(entity, type): + try: + desc = attributes.manager_of_class(entity)[key] + return desc, desc.property + except KeyError: + raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key)) + + else: + try: + desc = entity.class_manager[key] + return desc, desc.property + except KeyError: + raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key)) + +def _orm_columns(entity): + mapper, selectable, is_aliased_class = _entity_info(entity) + if isinstance(selectable, expression.Selectable): + return [c for c in selectable.c] + else: + return [selectable] + +def _orm_selectable(entity): + mapper, selectable, is_aliased_class = _entity_info(entity) + return selectable + +def _is_aliased_class(entity): + return isinstance(entity, AliasedClass) + +def _state_mapper(state): + return state.manager.mapper + +def object_mapper(instance): + """Given an object, return the primary Mapper associated with the object instance. + + Raises UnmappedInstanceError if no mapping is configured. + + """ + try: + state = attributes.instance_state(instance) + if not state.manager.mapper: + raise exc.UnmappedInstanceError(instance) + return state.manager.mapper + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + +def class_mapper(class_, compile=True): + """Given a class, return the primary Mapper associated with the key. + + Raises UnmappedClassError if no mapping is configured. + + """ + try: + class_manager = attributes.manager_of_class(class_) + mapper = class_manager.mapper + + # HACK until [ticket:1142] is complete + if mapper is None: + raise AttributeError + + except exc.NO_STATE: + raise exc.UnmappedClassError(class_) + + if compile: + mapper = mapper.compile() + return mapper + +def _class_to_mapper(class_or_mapper, compile=True): + if _is_aliased_class(class_or_mapper): + return class_or_mapper._AliasedClass__mapper + elif isinstance(class_or_mapper, type): + return class_mapper(class_or_mapper, compile=compile) + elif hasattr(class_or_mapper, 'compile'): + if compile: + return class_or_mapper.compile() + else: + return class_or_mapper + else: + raise exc.UnmappedClassError(class_or_mapper) + +def has_identity(object): + state = attributes.instance_state(object) + return _state_has_identity(state) + +def _state_has_identity(state): + return bool(state.key) + +def _is_mapped_class(cls): + global mapperlib + if mapperlib is None: + from sqlalchemy.orm import mapperlib + if isinstance(cls, (AliasedClass, mapperlib.Mapper)): + return True + if isinstance(cls, expression.ClauseElement): + return False + if isinstance(cls, type): + manager = attributes.manager_of_class(cls) + return manager and _INSTRUMENTOR in manager.info + return False + +def instance_str(instance): + """Return a string describing an instance.""" + + return state_str(attributes.instance_state(instance)) + +def state_str(state): + """Return a string describing an instance via its InstanceState.""" + + if state is None: + return "None" + else: + return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj())) + +def attribute_str(instance, attribute): + return instance_str(instance) + "." + attribute + +def state_attribute_str(state, attribute): + return state_str(state) + "." + attribute + +def identity_equal(a, b): + if a is b: + return True + if a is None or b is None: + return False + try: + state_a = attributes.instance_state(a) + state_b = attributes.instance_state(b) + except exc.NO_STATE: + return False + if state_a.key is None or state_b.key is None: + return False + return state_a.key == state_b.key + + +# TODO: Avoid circular import. +attributes.identity_equal = identity_equal +attributes._is_aliased_class = _is_aliased_class +attributes._entity_info = _entity_info diff --git a/sqlalchemy/pool.py b/sqlalchemy/pool.py new file mode 100644 index 0000000..31ab7fa --- /dev/null +++ b/sqlalchemy/pool.py @@ -0,0 +1,913 @@ +# pool.py - Connection pooling for SQLAlchemy +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +"""Connection pooling for DB-API connections. + +Provides a number of connection pool implementations for a variety of +usage scenarios and thread behavior requirements imposed by the +application, DB-API or database itself. + +Also provides a DB-API 2.0 connection proxying mechanism allowing +regular DB-API connect() methods to be transparently managed by a +SQLAlchemy connection pool. +""" + +import weakref, time, threading + +from sqlalchemy import exc, log +from sqlalchemy import queue as sqla_queue +from sqlalchemy.util import threading, pickle, as_interface, memoized_property + +proxies = {} + +def manage(module, **params): + """Return a proxy for a DB-API module that automatically + pools connections. + + Given a DB-API 2.0 module and pool management parameters, returns + a proxy for the module that will automatically pool connections, + creating new connection pools for each distinct set of connection + arguments sent to the decorated module's connect() function. + + :param module: a DB-API 2.0 database module + + :param poolclass: the class used by the pool module to provide + pooling. Defaults to :class:`QueuePool`. + + :param \*\*params: will be passed through to *poolclass* + + """ + try: + return proxies[module] + except KeyError: + return proxies.setdefault(module, _DBProxy(module, **params)) + +def clear_managers(): + """Remove all current DB-API 2.0 managers. + + All pools and connections are disposed. + """ + + for manager in proxies.itervalues(): + manager.close() + proxies.clear() + +class Pool(log.Identified): + """Abstract base class for connection pools.""" + + def __init__(self, + creator, recycle=-1, echo=None, + use_threadlocal=False, + logging_name=None, + reset_on_return=True, listeners=None): + """ + Construct a Pool. + + :param creator: a callable function that returns a DB-API + connection object. The function will be called with + parameters. + + :param recycle: If set to non -1, number of seconds between + connection recycling, which means upon checkout, if this + timeout is surpassed the connection will be closed and + replaced with a newly opened connection. Defaults to -1. + + :param logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.pool" logger. Defaults to a hexstring of the object's + id. + + :param echo: If True, connections being pulled and retrieved + from the pool will be logged to the standard output, as well + as pool sizing information. Echoing can also be achieved by + enabling logging for the "sqlalchemy.pool" + namespace. Defaults to False. + + :param use_threadlocal: If set to True, repeated calls to + :meth:`connect` within the same application thread will be + guaranteed to return the same connection object, if one has + already been retrieved from the pool and has not been + returned yet. Offers a slight performance advantage at the + cost of individual transactions by default. The + :meth:`unique_connection` method is provided to bypass the + threadlocal behavior installed into :meth:`connect`. + + :param reset_on_return: If true, reset the database state of + connections returned to the pool. This is typically a + ROLLBACK to release locks and transaction resources. + Disable at your own peril. Defaults to True. + + :param listeners: A list of + :class:`~sqlalchemy.interfaces.PoolListener`-like objects or + dictionaries of callables that receive events when DB-API + connections are created, checked out and checked in to the + pool. + + """ + if logging_name: + self.logging_name = logging_name + self.logger = log.instance_logger(self, echoflag=echo) + self._threadconns = threading.local() + self._creator = creator + self._recycle = recycle + self._use_threadlocal = use_threadlocal + self._reset_on_return = reset_on_return + self.echo = echo + self.listeners = [] + self._on_connect = [] + self._on_first_connect = [] + self._on_checkout = [] + self._on_checkin = [] + + if listeners: + for l in listeners: + self.add_listener(l) + + def unique_connection(self): + return _ConnectionFairy(self).checkout() + + def create_connection(self): + return _ConnectionRecord(self) + + def recreate(self): + """Return a new instance with identical creation arguments.""" + + raise NotImplementedError() + + def dispose(self): + """Dispose of this pool. + + This method leaves the possibility of checked-out connections + remaining open, It is advised to not reuse the pool once dispose() + is called, and to instead use a new pool constructed by the + recreate() method. + """ + + raise NotImplementedError() + + def connect(self): + if not self._use_threadlocal: + return _ConnectionFairy(self).checkout() + + try: + rec = self._threadconns.current() + if rec: + return rec.checkout() + except AttributeError: + pass + + agent = _ConnectionFairy(self) + self._threadconns.current = weakref.ref(agent) + return agent.checkout() + + def return_conn(self, record): + if self._use_threadlocal and hasattr(self._threadconns, "current"): + del self._threadconns.current + self.do_return_conn(record) + + def get(self): + return self.do_get() + + def do_get(self): + raise NotImplementedError() + + def do_return_conn(self, conn): + raise NotImplementedError() + + def status(self): + raise NotImplementedError() + + def add_listener(self, listener): + """Add a ``PoolListener``-like object to this pool. + + ``listener`` may be an object that implements some or all of + PoolListener, or a dictionary of callables containing implementations + of some or all of the named methods in PoolListener. + + """ + + listener = as_interface(listener, + methods=('connect', 'first_connect', 'checkout', 'checkin')) + + self.listeners.append(listener) + if hasattr(listener, 'connect'): + self._on_connect.append(listener) + if hasattr(listener, 'first_connect'): + self._on_first_connect.append(listener) + if hasattr(listener, 'checkout'): + self._on_checkout.append(listener) + if hasattr(listener, 'checkin'): + self._on_checkin.append(listener) + +class _ConnectionRecord(object): + def __init__(self, pool): + self.__pool = pool + self.connection = self.__connect() + self.info = {} + ls = pool.__dict__.pop('_on_first_connect', None) + if ls is not None: + for l in ls: + l.first_connect(self.connection, self) + if pool._on_connect: + for l in pool._on_connect: + l.connect(self.connection, self) + + def close(self): + if self.connection is not None: + self.__pool.logger.debug("Closing connection %r", self.connection) + try: + self.connection.close() + except (SystemExit, KeyboardInterrupt): + raise + except: + self.__pool.logger.debug("Exception closing connection %r", + self.connection) + + def invalidate(self, e=None): + if e is not None: + self.__pool.logger.info("Invalidate connection %r (reason: %s:%s)", + self.connection, e.__class__.__name__, e) + else: + self.__pool.logger.info("Invalidate connection %r", self.connection) + self.__close() + self.connection = None + + def get_connection(self): + if self.connection is None: + self.connection = self.__connect() + self.info.clear() + if self.__pool._on_connect: + for l in self.__pool._on_connect: + l.connect(self.connection, self) + elif self.__pool._recycle > -1 and \ + time.time() - self.starttime > self.__pool._recycle: + self.__pool.logger.info("Connection %r exceeded timeout; recycling", + self.connection) + self.__close() + self.connection = self.__connect() + self.info.clear() + if self.__pool._on_connect: + for l in self.__pool._on_connect: + l.connect(self.connection, self) + return self.connection + + def __close(self): + try: + self.__pool.logger.debug("Closing connection %r", self.connection) + self.connection.close() + except (SystemExit, KeyboardInterrupt): + raise + except Exception, e: + self.__pool.logger.debug("Connection %r threw an error on close: %s", + self.connection, e) + + def __connect(self): + try: + self.starttime = time.time() + connection = self.__pool._creator() + self.__pool.logger.debug("Created new connection %r", connection) + return connection + except Exception, e: + self.__pool.logger.debug("Error on connect(): %s", e) + raise + + +def _finalize_fairy(connection, connection_record, pool, ref=None): + _refs.discard(connection_record) + + if ref is not None and (connection_record.fairy is not ref or isinstance(pool, AssertionPool)): + return + + if connection is not None: + try: + if pool._reset_on_return: + connection.rollback() + # Immediately close detached instances + if connection_record is None: + connection.close() + except Exception, e: + if connection_record is not None: + connection_record.invalidate(e=e) + if isinstance(e, (SystemExit, KeyboardInterrupt)): + raise + + if connection_record is not None: + connection_record.fairy = None + pool.logger.debug("Connection %r being returned to pool", connection) + if pool._on_checkin: + for l in pool._on_checkin: + l.checkin(connection, connection_record) + pool.return_conn(connection_record) + +_refs = set() + +class _ConnectionFairy(object): + """Proxies a DB-API connection and provides return-on-dereference support.""" + + __slots__ = '_pool', '__counter', 'connection', \ + '_connection_record', '__weakref__', '_detached_info' + + def __init__(self, pool): + self._pool = pool + self.__counter = 0 + try: + rec = self._connection_record = pool.get() + conn = self.connection = self._connection_record.get_connection() + rec.fairy = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref)) + _refs.add(rec) + except: + self.connection = None # helps with endless __getattr__ loops later on + self._connection_record = None + raise + self._pool.logger.debug("Connection %r checked out from pool" % + self.connection) + + @property + def _logger(self): + return self._pool.logger + + @property + def is_valid(self): + return self.connection is not None + + @property + def info(self): + """An info collection unique to this DB-API connection.""" + + try: + return self._connection_record.info + except AttributeError: + if self.connection is None: + raise exc.InvalidRequestError("This connection is closed") + try: + return self._detached_info + except AttributeError: + self._detached_info = value = {} + return value + + def invalidate(self, e=None): + """Mark this connection as invalidated. + + The connection will be immediately closed. The containing + ConnectionRecord will create a new connection when next used. + """ + + if self.connection is None: + raise exc.InvalidRequestError("This connection is closed") + if self._connection_record is not None: + self._connection_record.invalidate(e=e) + self.connection = None + self._close() + + def cursor(self, *args, **kwargs): + try: + c = self.connection.cursor(*args, **kwargs) + return _CursorFairy(self, c) + except Exception, e: + self.invalidate(e=e) + raise + + def __getattr__(self, key): + return getattr(self.connection, key) + + def checkout(self): + if self.connection is None: + raise exc.InvalidRequestError("This connection is closed") + self.__counter += 1 + + if not self._pool._on_checkout or self.__counter != 1: + return self + + # Pool listeners can trigger a reconnection on checkout + attempts = 2 + while attempts > 0: + try: + for l in self._pool._on_checkout: + l.checkout(self.connection, self._connection_record, self) + return self + except exc.DisconnectionError, e: + self._pool.logger.info( + "Disconnection detected on checkout: %s", e) + self._connection_record.invalidate(e) + self.connection = self._connection_record.get_connection() + attempts -= 1 + + self._pool.logger.info("Reconnection attempts exhausted on checkout") + self.invalidate() + raise exc.InvalidRequestError("This connection is closed") + + def detach(self): + """Separate this connection from its Pool. + + This means that the connection will no longer be returned to the + pool when closed, and will instead be literally closed. The + containing ConnectionRecord is separated from the DB-API connection, + and will create a new connection when next used. + + Note that any overall connection limiting constraints imposed by a + Pool implementation may be violated after a detach, as the detached + connection is removed from the pool's knowledge and control. + """ + + if self._connection_record is not None: + _refs.remove(self._connection_record) + self._connection_record.fairy = None + self._connection_record.connection = None + self._pool.do_return_conn(self._connection_record) + self._detached_info = \ + self._connection_record.info.copy() + self._connection_record = None + + def close(self): + self.__counter -= 1 + if self.__counter == 0: + self._close() + + def _close(self): + _finalize_fairy(self.connection, self._connection_record, self._pool) + self.connection = None + self._connection_record = None + +class _CursorFairy(object): + __slots__ = '_parent', 'cursor', 'execute' + + def __init__(self, parent, cursor): + self._parent = parent + self.cursor = cursor + self.execute = cursor.execute + + def invalidate(self, e=None): + self._parent.invalidate(e=e) + + def __iter__(self): + return iter(self.cursor) + + def close(self): + try: + self.cursor.close() + except Exception, e: + try: + ex_text = str(e) + except TypeError: + ex_text = repr(e) + self.__parent._logger.warn("Error closing cursor: %s", ex_text) + + if isinstance(e, (SystemExit, KeyboardInterrupt)): + raise + + def __setattr__(self, key, value): + if key in self.__slots__: + object.__setattr__(self, key, value) + else: + setattr(self.cursor, key, value) + + def __getattr__(self, key): + return getattr(self.cursor, key) + +class SingletonThreadPool(Pool): + """A Pool that maintains one connection per thread. + + Maintains one connection per each thread, never moving a connection to a + thread other than the one which it was created in. + + This is used for SQLite, which both does not handle multithreading by + default, and also requires a singleton connection if a :memory: database + is being used. + + Options are the same as those of :class:`Pool`, as well as: + + :param pool_size: The number of threads in which to maintain connections + at once. Defaults to five. + + """ + + def __init__(self, creator, pool_size=5, **kw): + kw['use_threadlocal'] = True + Pool.__init__(self, creator, **kw) + self._conn = threading.local() + self._all_conns = set() + self.size = pool_size + + def recreate(self): + self.logger.info("Pool recreating") + return SingletonThreadPool(self._creator, + pool_size=self.size, + recycle=self._recycle, + echo=self.echo, + use_threadlocal=self._use_threadlocal, + listeners=self.listeners) + + def dispose(self): + """Dispose of this pool.""" + + for conn in self._all_conns: + try: + conn.close() + except (SystemExit, KeyboardInterrupt): + raise + except: + # pysqlite won't even let you close a conn from a thread + # that didn't create it + pass + + self._all_conns.clear() + + def dispose_local(self): + if hasattr(self._conn, 'current'): + conn = self._conn.current() + self._all_conns.discard(conn) + del self._conn.current + + def cleanup(self): + while len(self._all_conns) > self.size: + self._all_conns.pop() + + def status(self): + return "SingletonThreadPool id:%d size: %d" % (id(self), len(self._all_conns)) + + def do_return_conn(self, conn): + pass + + def do_get(self): + try: + c = self._conn.current() + if c: + return c + except AttributeError: + pass + c = self.create_connection() + self._conn.current = weakref.ref(c) + self._all_conns.add(c) + if len(self._all_conns) > self.size: + self.cleanup() + return c + +class QueuePool(Pool): + """A Pool that imposes a limit on the number of open connections.""" + + def __init__(self, creator, pool_size=5, max_overflow=10, timeout=30, + **kw): + """ + Construct a QueuePool. + + :param creator: a callable function that returns a DB-API + connection object. The function will be called with + parameters. + + :param pool_size: The size of the pool to be maintained. This + is the largest number of connections that will be kept + persistently in the pool. Note that the pool begins with no + connections; once this number of connections is requested, + that number of connections will remain. Defaults to 5. + + :param max_overflow: The maximum overflow size of the + pool. When the number of checked-out connections reaches the + size set in pool_size, additional connections will be + returned up to this limit. When those additional connections + are returned to the pool, they are disconnected and + discarded. It follows then that the total number of + simultaneous connections the pool will allow is pool_size + + `max_overflow`, and the total number of "sleeping" + connections the pool will allow is pool_size. `max_overflow` + can be set to -1 to indicate no overflow limit; no limit + will be placed on the total number of concurrent + connections. Defaults to 10. + + :param timeout: The number of seconds to wait before giving up + on returning a connection. Defaults to 30. + + :param recycle: If set to non -1, number of seconds between + connection recycling, which means upon checkout, if this + timeout is surpassed the connection will be closed and + replaced with a newly opened connection. Defaults to -1. + + :param echo: If True, connections being pulled and retrieved + from the pool will be logged to the standard output, as well + as pool sizing information. Echoing can also be achieved by + enabling logging for the "sqlalchemy.pool" + namespace. Defaults to False. + + :param use_threadlocal: If set to True, repeated calls to + :meth:`connect` within the same application thread will be + guaranteed to return the same connection object, if one has + already been retrieved from the pool and has not been + returned yet. Offers a slight performance advantage at the + cost of individual transactions by default. The + :meth:`unique_connection` method is provided to bypass the + threadlocal behavior installed into :meth:`connect`. + + :param reset_on_return: If true, reset the database state of + connections returned to the pool. This is typically a + ROLLBACK to release locks and transaction resources. + Disable at your own peril. Defaults to True. + + :param listeners: A list of + :class:`~sqlalchemy.interfaces.PoolListener`-like objects or + dictionaries of callables that receive events when DB-API + connections are created, checked out and checked in to the + pool. + + """ + Pool.__init__(self, creator, **kw) + self._pool = sqla_queue.Queue(pool_size) + self._overflow = 0 - pool_size + self._max_overflow = max_overflow + self._timeout = timeout + self._overflow_lock = self._max_overflow > -1 and threading.Lock() or None + + def recreate(self): + self.logger.info("Pool recreating") + return QueuePool(self._creator, pool_size=self._pool.maxsize, + max_overflow=self._max_overflow, timeout=self._timeout, + recycle=self._recycle, echo=self.echo, + use_threadlocal=self._use_threadlocal, listeners=self.listeners) + + def do_return_conn(self, conn): + try: + self._pool.put(conn, False) + except sqla_queue.Full: + if self._overflow_lock is None: + self._overflow -= 1 + else: + self._overflow_lock.acquire() + try: + self._overflow -= 1 + finally: + self._overflow_lock.release() + + def do_get(self): + try: + wait = self._max_overflow > -1 and self._overflow >= self._max_overflow + return self._pool.get(wait, self._timeout) + except sqla_queue.Empty: + if self._max_overflow > -1 and self._overflow >= self._max_overflow: + if not wait: + return self.do_get() + else: + raise exc.TimeoutError( + "QueuePool limit of size %d overflow %d reached, " + "connection timed out, timeout %d" % + (self.size(), self.overflow(), self._timeout)) + + if self._overflow_lock is not None: + self._overflow_lock.acquire() + + if self._max_overflow > -1 and self._overflow >= self._max_overflow: + if self._overflow_lock is not None: + self._overflow_lock.release() + return self.do_get() + + try: + con = self.create_connection() + self._overflow += 1 + finally: + if self._overflow_lock is not None: + self._overflow_lock.release() + return con + + def dispose(self): + while True: + try: + conn = self._pool.get(False) + conn.close() + except sqla_queue.Empty: + break + + self._overflow = 0 - self.size() + self.logger.info("Pool disposed. %s", self.status()) + + def status(self): + return "Pool size: %d Connections in pool: %d "\ + "Current Overflow: %d Current Checked out "\ + "connections: %d" % (self.size(), + self.checkedin(), + self.overflow(), + self.checkedout()) + + def size(self): + return self._pool.maxsize + + def checkedin(self): + return self._pool.qsize() + + def overflow(self): + return self._overflow + + def checkedout(self): + return self._pool.maxsize - self._pool.qsize() + self._overflow + +class NullPool(Pool): + """A Pool which does not pool connections. + + Instead it literally opens and closes the underlying DB-API connection + per each connection open/close. + + Reconnect-related functions such as ``recycle`` and connection + invalidation are not supported by this Pool implementation, since + no connections are held persistently. + + """ + + def status(self): + return "NullPool" + + def do_return_conn(self, conn): + conn.close() + + def do_return_invalid(self, conn): + pass + + def do_get(self): + return self.create_connection() + + def recreate(self): + self.logger.info("Pool recreating") + + return NullPool(self._creator, + recycle=self._recycle, + echo=self.echo, + use_threadlocal=self._use_threadlocal, + listeners=self.listeners) + + def dispose(self): + pass + + +class StaticPool(Pool): + """A Pool of exactly one connection, used for all requests. + + Reconnect-related functions such as ``recycle`` and connection + invalidation (which is also used to support auto-reconnect) are not + currently supported by this Pool implementation but may be implemented + in a future release. + + """ + + @memoized_property + def _conn(self): + return self._creator() + + @memoized_property + def connection(self): + return _ConnectionRecord(self) + + def status(self): + return "StaticPool" + + def dispose(self): + if '_conn' in self.__dict__: + self._conn.close() + self._conn = None + + def recreate(self): + self.logger.info("Pool recreating") + return self.__class__(creator=self._creator, + recycle=self._recycle, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + echo=self.echo, + listeners=self.listeners) + + def create_connection(self): + return self._conn + + def do_return_conn(self, conn): + pass + + def do_return_invalid(self, conn): + pass + + def do_get(self): + return self.connection + +class AssertionPool(Pool): + """A Pool that allows at most one checked out connection at any given time. + + This will raise an exception if more than one connection is checked out + at a time. Useful for debugging code that is using more connections + than desired. + + """ + + def __init__(self, *args, **kw): + self._conn = None + self._checked_out = False + Pool.__init__(self, *args, **kw) + + def status(self): + return "AssertionPool" + + def do_return_conn(self, conn): + if not self._checked_out: + raise AssertionError("connection is not checked out") + self._checked_out = False + assert conn is self._conn + + def do_return_invalid(self, conn): + self._conn = None + self._checked_out = False + + def dispose(self): + self._checked_out = False + if self._conn: + self._conn.close() + + def recreate(self): + self.logger.info("Pool recreating") + return AssertionPool(self._creator, echo=self.echo, + listeners=self.listeners) + + def do_get(self): + if self._checked_out: + raise AssertionError("connection is already checked out") + + if not self._conn: + self._conn = self.create_connection() + + self._checked_out = True + return self._conn + +class _DBProxy(object): + """Layers connection pooling behavior on top of a standard DB-API module. + + Proxies a DB-API 2.0 connect() call to a connection pool keyed to the + specific connect parameters. Other functions and attributes are delegated + to the underlying DB-API module. + """ + + def __init__(self, module, poolclass=QueuePool, **kw): + """Initializes a new proxy. + + module + a DB-API 2.0 module + + poolclass + a Pool class, defaulting to QueuePool + + Other parameters are sent to the Pool object's constructor. + + """ + + self.module = module + self.kw = kw + self.poolclass = poolclass + self.pools = {} + self._create_pool_mutex = threading.Lock() + + def close(self): + for key in self.pools.keys(): + del self.pools[key] + + def __del__(self): + self.close() + + def __getattr__(self, key): + return getattr(self.module, key) + + def get_pool(self, *args, **kw): + key = self._serialize(*args, **kw) + try: + return self.pools[key] + except KeyError: + self._create_pool_mutex.acquire() + try: + if key not in self.pools: + pool = self.poolclass(lambda: self.module.connect(*args, **kw), **self.kw) + self.pools[key] = pool + return pool + else: + return self.pools[key] + finally: + self._create_pool_mutex.release() + + def connect(self, *args, **kw): + """Activate a connection to the database. + + Connect to the database using this DBProxy's module and the given + connect arguments. If the arguments match an existing pool, the + connection will be returned from the pool's current thread-local + connection instance, or if there is no thread-local connection + instance it will be checked out from the set of pooled connections. + + If the pool has no available connections and allows new connections + to be created, a new database connection will be made. + + """ + + return self.get_pool(*args, **kw).connect() + + def dispose(self, *args, **kw): + """Dispose the pool referenced by the given connect arguments.""" + + key = self._serialize(*args, **kw) + try: + del self.pools[key] + except KeyError: + pass + + def _serialize(self, *args, **kw): + return pickle.dumps([args, kw]) diff --git a/sqlalchemy/processors.py b/sqlalchemy/processors.py new file mode 100644 index 0000000..c99ca4c --- /dev/null +++ b/sqlalchemy/processors.py @@ -0,0 +1,101 @@ +# processors.py +# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""defines generic type conversion functions, as used in bind and result +processors. + +They all share one common characteristic: None is passed through unchanged. + +""" + +import codecs +import re +import datetime + +def str_to_datetime_processor_factory(regexp, type_): + rmatch = regexp.match + # Even on python2.6 datetime.strptime is both slower than this code + # and it does not support microseconds. + def process(value): + if value is None: + return None + else: + return type_(*map(int, rmatch(value).groups(0))) + return process + +try: + from sqlalchemy.cprocessors import UnicodeResultProcessor, \ + DecimalResultProcessor, \ + to_float, to_str, int_to_boolean, \ + str_to_datetime, str_to_time, \ + str_to_date + + def to_unicode_processor_factory(encoding, errors=None): + # this is cumbersome but it would be even more so on the C side + if errors is not None: + return UnicodeResultProcessor(encoding, errors).process + else: + return UnicodeResultProcessor(encoding).process + + def to_decimal_processor_factory(target_class, scale=10): + # Note that the scale argument is not taken into account for integer + # values in the C implementation while it is in the Python one. + # For example, the Python implementation might return + # Decimal('5.00000') whereas the C implementation will + # return Decimal('5'). These are equivalent of course. + return DecimalResultProcessor(target_class, "%%.%df" % scale).process + +except ImportError: + def to_unicode_processor_factory(encoding, errors=None): + decoder = codecs.getdecoder(encoding) + + def process(value): + if value is None: + return None + else: + # decoder returns a tuple: (value, len). Simply dropping the + # len part is safe: it is done that way in the normal + # 'xx'.decode(encoding) code path. + return decoder(value, errors)[0] + return process + + def to_decimal_processor_factory(target_class, scale=10): + fstring = "%%.%df" % scale + + def process(value): + if value is None: + return None + else: + return target_class(fstring % value) + return process + + def to_float(value): + if value is None: + return None + else: + return float(value) + + def to_str(value): + if value is None: + return None + else: + return str(value) + + def int_to_boolean(value): + if value is None: + return None + else: + return value and True or False + + DATETIME_RE = re.compile("(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?") + TIME_RE = re.compile("(\d+):(\d+):(\d+)(?:\.(\d+))?") + DATE_RE = re.compile("(\d+)-(\d+)-(\d+)") + + str_to_datetime = str_to_datetime_processor_factory(DATETIME_RE, + datetime.datetime) + str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time) + str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date) + diff --git a/sqlalchemy/queue.py b/sqlalchemy/queue.py new file mode 100644 index 0000000..2aaeea9 --- /dev/null +++ b/sqlalchemy/queue.py @@ -0,0 +1,183 @@ +"""An adaptation of Py2.3/2.4's Queue module which supports reentrant +behavior, using RLock instead of Lock for its mutex object. + +This is to support the connection pool's usage of weakref callbacks to return +connections to the underlying Queue, which can in extremely +rare cases be invoked within the ``get()`` method of the Queue itself, +producing a ``put()`` inside the ``get()`` and therefore a reentrant +condition.""" + +from collections import deque +from time import time as _time +from sqlalchemy.util import threading + +__all__ = ['Empty', 'Full', 'Queue'] + +class Empty(Exception): + "Exception raised by Queue.get(block=0)/get_nowait()." + + pass + +class Full(Exception): + "Exception raised by Queue.put(block=0)/put_nowait()." + + pass + +class Queue: + def __init__(self, maxsize=0): + """Initialize a queue object with a given maximum size. + + If `maxsize` is <= 0, the queue size is infinite. + """ + + self._init(maxsize) + # mutex must be held whenever the queue is mutating. All methods + # that acquire mutex must release it before returning. mutex + # is shared between the two conditions, so acquiring and + # releasing the conditions also acquires and releases mutex. + self.mutex = threading.RLock() + # Notify not_empty whenever an item is added to the queue; a + # thread waiting to get is notified then. + self.not_empty = threading.Condition(self.mutex) + # Notify not_full whenever an item is removed from the queue; + # a thread waiting to put is notified then. + self.not_full = threading.Condition(self.mutex) + + def qsize(self): + """Return the approximate size of the queue (not reliable!).""" + + self.mutex.acquire() + n = self._qsize() + self.mutex.release() + return n + + def empty(self): + """Return True if the queue is empty, False otherwise (not reliable!).""" + + self.mutex.acquire() + n = self._empty() + self.mutex.release() + return n + + def full(self): + """Return True if the queue is full, False otherwise (not reliable!).""" + + self.mutex.acquire() + n = self._full() + self.mutex.release() + return n + + def put(self, item, block=True, timeout=None): + """Put an item into the queue. + + If optional args `block` is True and `timeout` is None (the + default), block if necessary until a free slot is + available. If `timeout` is a positive number, it blocks at + most `timeout` seconds and raises the ``Full`` exception if no + free slot was available within that time. Otherwise (`block` + is false), put an item on the queue if a free slot is + immediately available, else raise the ``Full`` exception + (`timeout` is ignored in that case). + """ + + self.not_full.acquire() + try: + if not block: + if self._full(): + raise Full + elif timeout is None: + while self._full(): + self.not_full.wait() + else: + if timeout < 0: + raise ValueError("'timeout' must be a positive number") + endtime = _time() + timeout + while self._full(): + remaining = endtime - _time() + if remaining <= 0.0: + raise Full + self.not_full.wait(remaining) + self._put(item) + self.not_empty.notify() + finally: + self.not_full.release() + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + Only enqueue the item if a free slot is immediately available. + Otherwise raise the ``Full`` exception. + """ + return self.put(item, False) + + def get(self, block=True, timeout=None): + """Remove and return an item from the queue. + + If optional args `block` is True and `timeout` is None (the + default), block if necessary until an item is available. If + `timeout` is a positive number, it blocks at most `timeout` + seconds and raises the ``Empty`` exception if no item was + available within that time. Otherwise (`block` is false), + return an item if one is immediately available, else raise the + ``Empty`` exception (`timeout` is ignored in that case). + """ + + self.not_empty.acquire() + try: + if not block: + if self._empty(): + raise Empty + elif timeout is None: + while self._empty(): + self.not_empty.wait() + else: + if timeout < 0: + raise ValueError("'timeout' must be a positive number") + endtime = _time() + timeout + while self._empty(): + remaining = endtime - _time() + if remaining <= 0.0: + raise Empty + self.not_empty.wait(remaining) + item = self._get() + self.not_full.notify() + return item + finally: + self.not_empty.release() + + def get_nowait(self): + """Remove and return an item from the queue without blocking. + + Only get an item if one is immediately available. Otherwise + raise the ``Empty`` exception. + """ + + return self.get(False) + + # Override these methods to implement other queue organizations + # (e.g. stack or priority queue). + # These will only be called with appropriate locks held + + # Initialize the queue representation + def _init(self, maxsize): + self.maxsize = maxsize + self.queue = deque() + + def _qsize(self): + return len(self.queue) + + # Check whether the queue is empty + def _empty(self): + return not self.queue + + # Check whether the queue is full + def _full(self): + return self.maxsize > 0 and len(self.queue) == self.maxsize + + # Put a new item in the queue + def _put(self, item): + self.queue.append(item) + + # Get an item from the queue + def _get(self): + return self.queue.popleft() diff --git a/sqlalchemy/schema.py b/sqlalchemy/schema.py new file mode 100644 index 0000000..8ffb68a --- /dev/null +++ b/sqlalchemy/schema.py @@ -0,0 +1,2386 @@ +# schema.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""The schema module provides the building blocks for database metadata. + +Each element within this module describes a database entity which can be +created and dropped, or is otherwise part of such an entity. Examples include +tables, columns, sequences, and indexes. + +All entities are subclasses of :class:`~sqlalchemy.schema.SchemaItem`, and as defined +in this module they are intended to be agnostic of any vendor-specific +constructs. + +A collection of entities are grouped into a unit called +:class:`~sqlalchemy.schema.MetaData`. MetaData serves as a logical grouping of schema +elements, and can also be associated with an actual database connection such +that operations involving the contained elements can contact the database as +needed. + +Two of the elements here also build upon their "syntactic" counterparts, which +are defined in :class:`~sqlalchemy.sql.expression.`, specifically +:class:`~sqlalchemy.schema.Table` and :class:`~sqlalchemy.schema.Column`. Since these objects +are part of the SQL expression language, they are usable as components in SQL +expressions. + +""" +import re, inspect +from sqlalchemy import exc, util, dialects +from sqlalchemy.sql import expression, visitors + +URL = None + +__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', + 'ForeignKeyConstraint', 'PrimaryKeyConstraint', 'CheckConstraint', + 'UniqueConstraint', 'DefaultGenerator', 'Constraint', 'MetaData', + 'ThreadLocalMetaData', 'SchemaVisitor', 'PassiveDefault', + 'DefaultClause', 'FetchedValue', 'ColumnDefault', 'DDL', + 'CreateTable', 'DropTable', 'CreateSequence', 'DropSequence', + 'AddConstraint', 'DropConstraint', + ] +__all__.sort() + +RETAIN_SCHEMA = util.symbol('retain_schema') + +class SchemaItem(visitors.Visitable): + """Base class for items that define a database schema.""" + + __visit_name__ = 'schema_item' + quote = None + + def _init_items(self, *args): + """Initialize the list of child items for this SchemaItem.""" + + for item in args: + if item is not None: + item._set_parent(self) + + def _set_parent(self, parent): + """Associate with this SchemaItem's parent object.""" + + raise NotImplementedError() + + def get_children(self, **kwargs): + """used to allow SchemaVisitor access""" + return [] + + def __repr__(self): + return "%s()" % self.__class__.__name__ + + @util.memoized_property + def info(self): + return {} + +def _get_table_key(name, schema): + if schema is None: + return name + else: + return schema + "." + name + +class Table(SchemaItem, expression.TableClause): + """Represent a table in a database. + + e.g.:: + + mytable = Table("mytable", metadata, + Column('mytable_id', Integer, primary_key=True), + Column('value', String(50)) + ) + + The Table object constructs a unique instance of itself based on its + name within the given MetaData object. Constructor + arguments are as follows: + + :param name: The name of this table as represented in the database. + + This property, along with the *schema*, indicates the *singleton + identity* of this table in relation to its parent :class:`MetaData`. + Additional calls to :class:`Table` with the same name, metadata, + and schema name will return the same :class:`Table` object. + + Names which contain no upper case characters + will be treated as case insensitive names, and will not be quoted + unless they are a reserved word. Names with any number of upper + case characters will be quoted and sent exactly. Note that this + behavior applies even for databases which standardize upper + case names as case insensitive such as Oracle. + + :param metadata: a :class:`MetaData` object which will contain this + table. The metadata is used as a point of association of this table + with other tables which are referenced via foreign key. It also + may be used to associate this table with a particular + :class:`~sqlalchemy.engine.base.Connectable`. + + :param \*args: Additional positional arguments are used primarily + to add the list of :class:`Column` objects contained within this table. + Similar to the style of a CREATE TABLE statement, other :class:`SchemaItem` + constructs may be added here, including :class:`PrimaryKeyConstraint`, + and :class:`ForeignKeyConstraint`. + + :param autoload: Defaults to False: the Columns for this table should be reflected + from the database. Usually there will be no Column objects in the + constructor if this property is set. + + :param autoload_with: If autoload==True, this is an optional Engine or Connection + instance to be used for the table reflection. If ``None``, the + underlying MetaData's bound connectable will be used. + + :param implicit_returning: True by default - indicates that + RETURNING can be used by default to fetch newly inserted primary key + values, for backends which support this. Note that + create_engine() also provides an implicit_returning flag. + + :param include_columns: A list of strings indicating a subset of columns to be loaded via + the ``autoload`` operation; table columns who aren't present in + this list will not be represented on the resulting ``Table`` + object. Defaults to ``None`` which indicates all columns should + be reflected. + + :param info: A dictionary which defaults to ``{}``. A space to store application + specific data. This must be a dictionary. + + :param mustexist: When ``True``, indicates that this Table must already + be present in the given :class:`MetaData`` collection. + + :param prefixes: + A list of strings to insert after CREATE in the CREATE TABLE + statement. They will be separated by spaces. + + :param quote: Force quoting of this table's name on or off, corresponding + to ``True`` or ``False``. When left at its default of ``None``, + the column identifier will be quoted according to whether the name is + case sensitive (identifiers with at least one upper case character are + treated as case sensitive), or if it's a reserved word. This flag + is only needed to force quoting of a reserved word which is not known + by the SQLAlchemy dialect. + + :param quote_schema: same as 'quote' but applies to the schema identifier. + + :param schema: The *schema name* for this table, which is required if the table + resides in a schema other than the default selected schema for the + engine's database connection. Defaults to ``None``. + + :param useexisting: When ``True``, indicates that if this Table is already + present in the given :class:`MetaData`, apply further arguments within + the constructor to the existing :class:`Table`. If this flag is not + set, an error is raised when the parameters of an existing :class:`Table` + are overwritten. + + """ + + __visit_name__ = 'table' + + ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop') + + def __new__(cls, *args, **kw): + if not args: + # python3k pickle seems to call this + return object.__new__(cls) + + try: + name, metadata, args = args[0], args[1], args[2:] + except IndexError: + raise TypeError("Table() takes at least two arguments") + + schema = kw.get('schema', None) + useexisting = kw.pop('useexisting', False) + mustexist = kw.pop('mustexist', False) + key = _get_table_key(name, schema) + if key in metadata.tables: + if not useexisting and bool(args): + raise exc.InvalidRequestError( + "Table '%s' is already defined for this MetaData instance. " + "Specify 'useexisting=True' to redefine options and " + "columns on an existing Table object." % key) + table = metadata.tables[key] + table._init_existing(*args, **kw) + return table + else: + if mustexist: + raise exc.InvalidRequestError( + "Table '%s' not defined" % (key)) + metadata.tables[key] = table = object.__new__(cls) + try: + table._init(name, metadata, *args, **kw) + return table + except: + metadata.tables.pop(key) + raise + + def __init__(self, *args, **kw): + # __init__ is overridden to prevent __new__ from + # calling the superclass constructor. + pass + + def _init(self, name, metadata, *args, **kwargs): + super(Table, self).__init__(name) + self.metadata = metadata + self.schema = kwargs.pop('schema', None) + self.indexes = set() + self.constraints = set() + self._columns = expression.ColumnCollection() + self._set_primary_key(PrimaryKeyConstraint()) + self._foreign_keys = util.OrderedSet() + self.ddl_listeners = util.defaultdict(list) + self.kwargs = {} + if self.schema is not None: + self.fullname = "%s.%s" % (self.schema, self.name) + else: + self.fullname = self.name + + autoload = kwargs.pop('autoload', False) + autoload_with = kwargs.pop('autoload_with', None) + include_columns = kwargs.pop('include_columns', None) + + self.implicit_returning = kwargs.pop('implicit_returning', True) + self.quote = kwargs.pop('quote', None) + self.quote_schema = kwargs.pop('quote_schema', None) + if 'info' in kwargs: + self.info = kwargs.pop('info') + + self._prefixes = kwargs.pop('prefixes', []) + + self._extra_kwargs(**kwargs) + + # load column definitions from the database if 'autoload' is defined + # we do it after the table is in the singleton dictionary to support + # circular foreign keys + if autoload: + if autoload_with: + autoload_with.reflecttable(self, include_columns=include_columns) + else: + _bind_or_error(metadata, msg="No engine is bound to this Table's MetaData. " + "Pass an engine to the Table via " + "autoload_with=, " + "or associate the MetaData with an engine via " + "metadata.bind=").\ + reflecttable(self, include_columns=include_columns) + + # initialize all the column, etc. objects. done after reflection to + # allow user-overrides + self._init_items(*args) + + def _init_existing(self, *args, **kwargs): + autoload = kwargs.pop('autoload', False) + autoload_with = kwargs.pop('autoload_with', None) + schema = kwargs.pop('schema', None) + if schema and schema != self.schema: + raise exc.ArgumentError( + "Can't change schema of existing table from '%s' to '%s'", + (self.schema, schema)) + + include_columns = kwargs.pop('include_columns', None) + if include_columns: + for c in self.c: + if c.name not in include_columns: + self.c.remove(c) + + for key in ('quote', 'quote_schema'): + if key in kwargs: + setattr(self, key, kwargs.pop(key)) + + if 'info' in kwargs: + self.info = kwargs.pop('info') + + self._extra_kwargs(**kwargs) + self._init_items(*args) + + def _extra_kwargs(self, **kwargs): + # validate remaining kwargs that they all specify DB prefixes + if len([k for k in kwargs + if not re.match(r'^(?:%s)_' % '|'.join(dialects.__all__), k)]): + raise TypeError( + "Invalid argument(s) for Table: %r" % kwargs.keys()) + self.kwargs.update(kwargs) + + def _set_primary_key(self, pk): + if getattr(self, '_primary_key', None) in self.constraints: + self.constraints.remove(self._primary_key) + self._primary_key = pk + self.constraints.add(pk) + + for c in pk.columns: + c.primary_key = True + + @util.memoized_property + def _autoincrement_column(self): + for col in self.primary_key: + if col.autoincrement and \ + isinstance(col.type, types.Integer) and \ + not col.foreign_keys and \ + isinstance(col.default, (type(None), Sequence)): + + return col + + @property + def key(self): + return _get_table_key(self.name, self.schema) + + @property + def primary_key(self): + return self._primary_key + + def __repr__(self): + return "Table(%s)" % ', '.join( + [repr(self.name)] + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']]) + + def __str__(self): + return _get_table_key(self.description, self.schema) + + @property + def bind(self): + """Return the connectable associated with this Table.""" + + return self.metadata and self.metadata.bind or None + + def append_column(self, column): + """Append a ``Column`` to this ``Table``.""" + + column._set_parent(self) + + def append_constraint(self, constraint): + """Append a ``Constraint`` to this ``Table``.""" + + constraint._set_parent(self) + + def append_ddl_listener(self, event, listener): + """Append a DDL event listener to this ``Table``. + + The ``listener`` callable will be triggered when this ``Table`` is + created or dropped, either directly before or after the DDL is issued + to the database. The listener may modify the Table, but may not abort + the event itself. + + Arguments are: + + event + One of ``Table.ddl_events``; e.g. 'before-create', 'after-create', + 'before-drop' or 'after-drop'. + + listener + A callable, invoked with three positional arguments: + + event + The event currently being handled + target + The ``Table`` object being created or dropped + bind + The ``Connection`` bueing used for DDL execution. + + Listeners are added to the Table's ``ddl_listeners`` attribute. + """ + + if event not in self.ddl_events: + raise LookupError(event) + self.ddl_listeners[event].append(listener) + + def _set_parent(self, metadata): + metadata.tables[_get_table_key(self.name, self.schema)] = self + self.metadata = metadata + + def get_children(self, column_collections=True, schema_visitor=False, **kwargs): + if not schema_visitor: + return expression.TableClause.get_children( + self, column_collections=column_collections, **kwargs) + else: + if column_collections: + return list(self.columns) + else: + return [] + + def exists(self, bind=None): + """Return True if this table exists.""" + + if bind is None: + bind = _bind_or_error(self) + + return bind.run_callable(bind.dialect.has_table, self.name, schema=self.schema) + + def create(self, bind=None, checkfirst=False): + """Issue a ``CREATE`` statement for this table. + + See also ``metadata.create_all()``. + """ + self.metadata.create_all(bind=bind, checkfirst=checkfirst, tables=[self]) + + def drop(self, bind=None, checkfirst=False): + """Issue a ``DROP`` statement for this table. + + See also ``metadata.drop_all()``. + """ + self.metadata.drop_all(bind=bind, checkfirst=checkfirst, tables=[self]) + + def tometadata(self, metadata, schema=RETAIN_SCHEMA): + """Return a copy of this ``Table`` associated with a different ``MetaData``.""" + + try: + if schema is RETAIN_SCHEMA: + schema = self.schema + key = _get_table_key(self.name, schema) + return metadata.tables[key] + except KeyError: + args = [] + for c in self.columns: + args.append(c.copy(schema=schema)) + for c in self.constraints: + args.append(c.copy(schema=schema)) + return Table(self.name, metadata, schema=schema, *args) + +class Column(SchemaItem, expression.ColumnClause): + """Represents a column in a database table.""" + + __visit_name__ = 'column' + + def __init__(self, *args, **kwargs): + """ + Construct a new ``Column`` object. + + :param name: The name of this column as represented in the database. + This argument may be the first positional argument, or specified + via keyword. + + Names which contain no upper case characters + will be treated as case insensitive names, and will not be quoted + unless they are a reserved word. Names with any number of upper + case characters will be quoted and sent exactly. Note that this + behavior applies even for databases which standardize upper + case names as case insensitive such as Oracle. + + The name field may be omitted at construction time and applied + later, at any time before the Column is associated with a + :class:`Table`. This is to support convenient + usage within the :mod:`~sqlalchemy.ext.declarative` extension. + + :param type\_: The column's type, indicated using an instance which + subclasses :class:`~sqlalchemy.types.AbstractType`. If no arguments + are required for the type, the class of the type can be sent + as well, e.g.:: + + # use a type with arguments + Column('data', String(50)) + + # use no arguments + Column('level', Integer) + + The ``type`` argument may be the second positional argument + or specified by keyword. + + If this column also contains a :class:`ForeignKey`, + the type argument may be left as ``None`` in which case the + type assigned will be that of the referenced column. + + :param \*args: Additional positional arguments include various + :class:`SchemaItem` derived constructs which will be applied + as options to the column. These include instances of + :class:`Constraint`, :class:`ForeignKey`, :class:`ColumnDefault`, + and :class:`Sequence`. In some cases an equivalent keyword + argument is available such as ``server_default``, ``default`` + and ``unique``. + + :param autoincrement: This flag may be set to ``False`` to + indicate an integer primary key column that should not be + considered to be the "autoincrement" column, that is + the integer primary key column which generates values + implicitly upon INSERT and whose value is usually returned + via the DBAPI cursor.lastrowid attribute. It defaults + to ``True`` to satisfy the common use case of a table + with a single integer primary key column. If the table + has a composite primary key consisting of more than one + integer column, set this flag to True only on the + column that should be considered "autoincrement". + + The setting *only* has an effect for columns which are: + + * Integer derived (i.e. INT, SMALLINT, BIGINT) + + * Part of the primary key + + * Are not referenced by any foreign keys + + * have no server side or client side defaults (with the exception + of Postgresql SERIAL). + + The setting has these two effects on columns that meet the + above criteria: + + * DDL issued for the column will include database-specific + keywords intended to signify this column as an + "autoincrement" column, such as AUTO INCREMENT on MySQL, + SERIAL on Postgresql, and IDENTITY on MS-SQL. It does + *not* issue AUTOINCREMENT for SQLite since this is a + special SQLite flag that is not required for autoincrementing + behavior. See the SQLite dialect documentation for + information on SQLite's AUTOINCREMENT. + + * The column will be considered to be available as + cursor.lastrowid or equivalent, for those dialects which + "post fetch" newly inserted identifiers after a row has + been inserted (SQLite, MySQL, MS-SQL). It does not have + any effect in this regard for databases that use sequences + to generate primary key identifiers (i.e. Firebird, Postgresql, + Oracle). + + :param default: A scalar, Python callable, or + :class:`~sqlalchemy.sql.expression.ClauseElement` representing the + *default value* for this column, which will be invoked upon insert + if this column is otherwise not specified in the VALUES clause of + the insert. This is a shortcut to using :class:`ColumnDefault` as + a positional argument. + + Contrast this argument to ``server_default`` which creates a + default generator on the database side. + + :param key: An optional string identifier which will identify this + ``Column`` object on the :class:`Table`. When a key is provided, + this is the only identifier referencing the ``Column`` within the + application, including ORM attribute mapping; the ``name`` field + is used only when rendering SQL. + + :param index: When ``True``, indicates that the column is indexed. + This is a shortcut for using a :class:`Index` construct on the + table. To specify indexes with explicit names or indexes that + contain multiple columns, use the :class:`Index` construct + instead. + + :param info: A dictionary which defaults to ``{}``. A space to store + application specific data. This must be a dictionary. + + :param nullable: If set to the default of ``True``, indicates the + column will be rendered as allowing NULL, else it's rendered as + NOT NULL. This parameter is only used when issuing CREATE TABLE + statements. + + :param onupdate: A scalar, Python callable, or + :class:`~sqlalchemy.sql.expression.ClauseElement` representing a + default value to be applied to the column within UPDATE + statements, which wil be invoked upon update if this column is not + present in the SET clause of the update. This is a shortcut to + using :class:`ColumnDefault` as a positional argument with + ``for_update=True``. + + :param primary_key: If ``True``, marks this column as a primary key + column. Multiple columns can have this flag set to specify + composite primary keys. As an alternative, the primary key of a + :class:`Table` can be specified via an explicit + :class:`PrimaryKeyConstraint` object. + + :param server_default: A :class:`FetchedValue` instance, str, Unicode + or :func:`~sqlalchemy.sql.expression.text` construct representing + the DDL DEFAULT value for the column. + + String types will be emitted as-is, surrounded by single quotes:: + + Column('x', Text, server_default="val") + + x TEXT DEFAULT 'val' + + A :func:`~sqlalchemy.sql.expression.text` expression will be + rendered as-is, without quotes:: + + Column('y', DateTime, server_default=text('NOW()'))0 + + y DATETIME DEFAULT NOW() + + Strings and text() will be converted into a :class:`DefaultClause` + object upon initialization. + + Use :class:`FetchedValue` to indicate that an already-existing + column will generate a default value on the database side which + will be available to SQLAlchemy for post-fetch after inserts. This + construct does not specify any DDL and the implementation is left + to the database, such as via a trigger. + + :param server_onupdate: A :class:`FetchedValue` instance + representing a database-side default generation function. This + indicates to SQLAlchemy that a newly generated value will be + available after updates. This construct does not specify any DDL + and the implementation is left to the database, such as via a + trigger. + + :param quote: Force quoting of this column's name on or off, + corresponding to ``True`` or ``False``. When left at its default + of ``None``, the column identifier will be quoted according to + whether the name is case sensitive (identifiers with at least one + upper case character are treated as case sensitive), or if it's a + reserved word. This flag is only needed to force quoting of a + reserved word which is not known by the SQLAlchemy dialect. + + :param unique: When ``True``, indicates that this column contains a + unique constraint, or if ``index`` is ``True`` as well, indicates + that the :class:`Index` should be created with the unique flag. + To specify multiple columns in the constraint/index or to specify + an explicit name, use the :class:`UniqueConstraint` or + :class:`Index` constructs explicitly. + + """ + + name = kwargs.pop('name', None) + type_ = kwargs.pop('type_', None) + args = list(args) + if args: + if isinstance(args[0], basestring): + if name is not None: + raise exc.ArgumentError( + "May not pass name positionally and as a keyword.") + name = args.pop(0) + if args: + coltype = args[0] + + if (isinstance(coltype, types.AbstractType) or + (isinstance(coltype, type) and + issubclass(coltype, types.AbstractType))): + if type_ is not None: + raise exc.ArgumentError( + "May not pass type_ positionally and as a keyword.") + type_ = args.pop(0) + + no_type = type_ is None + + super(Column, self).__init__(name, None, type_) + self.key = kwargs.pop('key', name) + self.primary_key = kwargs.pop('primary_key', False) + self.nullable = kwargs.pop('nullable', not self.primary_key) + self.default = kwargs.pop('default', None) + self.server_default = kwargs.pop('server_default', None) + self.server_onupdate = kwargs.pop('server_onupdate', None) + self.index = kwargs.pop('index', None) + self.unique = kwargs.pop('unique', None) + self.quote = kwargs.pop('quote', None) + self.onupdate = kwargs.pop('onupdate', None) + self.autoincrement = kwargs.pop('autoincrement', True) + self.constraints = set() + self.foreign_keys = util.OrderedSet() + self._table_events = set() + + # check if this Column is proxying another column + if '_proxies' in kwargs: + self.proxies = kwargs.pop('_proxies') + # otherwise, add DDL-related events + elif isinstance(self.type, types.SchemaType): + self.type._set_parent(self) + + if self.default is not None: + if isinstance(self.default, (ColumnDefault, Sequence)): + args.append(self.default) + else: + args.append(ColumnDefault(self.default)) + + if self.server_default is not None: + if isinstance(self.server_default, FetchedValue): + args.append(self.server_default) + else: + args.append(DefaultClause(self.server_default)) + + if self.onupdate is not None: + if isinstance(self.onupdate, (ColumnDefault, Sequence)): + args.append(self.onupdate) + else: + args.append(ColumnDefault(self.onupdate, for_update=True)) + + if self.server_onupdate is not None: + if isinstance(self.server_onupdate, FetchedValue): + args.append(self.server_default) + else: + args.append(DefaultClause(self.server_onupdate, + for_update=True)) + self._init_items(*args) + + if not self.foreign_keys and no_type: + raise exc.ArgumentError("'type' is required on Column objects " + "which have no foreign keys.") + util.set_creation_order(self) + + if 'info' in kwargs: + self.info = kwargs.pop('info') + + if kwargs: + raise exc.ArgumentError( + "Unknown arguments passed to Column: " + repr(kwargs.keys())) + + def __str__(self): + if self.name is None: + return "(no name)" + elif self.table is not None: + if self.table.named_with_column: + return (self.table.description + "." + self.description) + else: + return self.description + else: + return self.description + + def references(self, column): + """Return True if this Column references the given column via foreign key.""" + for fk in self.foreign_keys: + if fk.references(column.table): + return True + else: + return False + + def append_foreign_key(self, fk): + fk._set_parent(self) + + def __repr__(self): + kwarg = [] + if self.key != self.name: + kwarg.append('key') + if self.primary_key: + kwarg.append('primary_key') + if not self.nullable: + kwarg.append('nullable') + if self.onupdate: + kwarg.append('onupdate') + if self.default: + kwarg.append('default') + if self.server_default: + kwarg.append('server_default') + return "Column(%s)" % ', '.join( + [repr(self.name)] + [repr(self.type)] + + [repr(x) for x in self.foreign_keys if x is not None] + + [repr(x) for x in self.constraints] + + [(self.table is not None and "table=<%s>" % self.table.description or "")] + + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]) + + def _set_parent(self, table): + if self.name is None: + raise exc.ArgumentError( + "Column must be constructed with a name or assign .name " + "before adding to a Table.") + if self.key is None: + self.key = self.name + + if getattr(self, 'table', None) is not None: + raise exc.ArgumentError("this Column already has a table!") + + if self.key in table._columns: + col = table._columns.get(self.key) + for fk in col.foreign_keys: + col.foreign_keys.remove(fk) + table.foreign_keys.remove(fk) + table.constraints.remove(fk.constraint) + + table._columns.replace(self) + + if self.primary_key: + table.primary_key._replace(self) + elif self.key in table.primary_key: + raise exc.ArgumentError( + "Trying to redefine primary-key column '%s' as a " + "non-primary-key column on table '%s'" % ( + self.key, table.fullname)) + self.table = table + + if self.index: + if isinstance(self.index, basestring): + raise exc.ArgumentError( + "The 'index' keyword argument on Column is boolean only. " + "To create indexes with a specific name, create an " + "explicit Index object external to the Table.") + Index('ix_%s' % self._label, self, unique=self.unique) + elif self.unique: + if isinstance(self.unique, basestring): + raise exc.ArgumentError( + "The 'unique' keyword argument on Column is boolean only. " + "To create unique constraints or indexes with a specific " + "name, append an explicit UniqueConstraint to the Table's " + "list of elements, or create an explicit Index object " + "external to the Table.") + table.append_constraint(UniqueConstraint(self.key)) + + for fn in self._table_events: + fn(table, self) + del self._table_events + + def _on_table_attach(self, fn): + if self.table is not None: + fn(self.table, self) + else: + self._table_events.add(fn) + + def copy(self, **kw): + """Create a copy of this ``Column``, unitialized. + + This is used in ``Table.tometadata``. + + """ + + # Constraint objects plus non-constraint-bound ForeignKey objects + args = \ + [c.copy(**kw) for c in self.constraints] + \ + [c.copy(**kw) for c in self.foreign_keys if not c.constraint] + + c = Column( + name=self.name, + type_=self.type, + key = self.key, + primary_key = self.primary_key, + nullable = self.nullable, + quote=self.quote, + index=self.index, + autoincrement=self.autoincrement, + default=self.default, + server_default=self.server_default, + onupdate=self.onupdate, + server_onupdate=self.server_onupdate, + *args + ) + if hasattr(self, '_table_events'): + c._table_events = list(self._table_events) + return c + + def _make_proxy(self, selectable, name=None): + """Create a *proxy* for this column. + + This is a copy of this ``Column`` referenced by a different parent + (such as an alias or select statement). The column should + be used only in select scenarios, as its full DDL/default + information is not transferred. + + """ + fk = [ForeignKey(f.column) for f in self.foreign_keys] + c = Column( + name or self.name, + self.type, + key = name or self.key, + primary_key = self.primary_key, + nullable = self.nullable, + quote=self.quote, _proxies=[self], *fk) + c.table = selectable + selectable.columns.add(c) + if self.primary_key: + selectable.primary_key.add(c) + for fn in c._table_events: + fn(selectable, c) + del c._table_events + return c + + def get_children(self, schema_visitor=False, **kwargs): + if schema_visitor: + return [x for x in (self.default, self.onupdate) if x is not None] + \ + list(self.foreign_keys) + list(self.constraints) + else: + return expression.ColumnClause.get_children(self, **kwargs) + + +class ForeignKey(SchemaItem): + """Defines a dependency between two columns. + + ``ForeignKey`` is specified as an argument to a :class:`Column` object, + e.g.:: + + t = Table("remote_table", metadata, + Column("remote_id", ForeignKey("main_table.id")) + ) + + Note that ``ForeignKey`` is only a marker object that defines + a dependency between two columns. The actual constraint + is in all cases represented by the :class:`ForeignKeyConstraint` + object. This object will be generated automatically when + a ``ForeignKey`` is associated with a :class:`Column` which + in turn is associated with a :class:`Table`. Conversely, + when :class:`ForeignKeyConstraint` is applied to a :class:`Table`, + ``ForeignKey`` markers are automatically generated to be + present on each associated :class:`Column`, which are also + associated with the constraint object. + + Note that you cannot define a "composite" foreign key constraint, + that is a constraint between a grouping of multiple parent/child + columns, using ``ForeignKey`` objects. To define this grouping, + the :class:`ForeignKeyConstraint` object must be used, and applied + to the :class:`Table`. The associated ``ForeignKey`` objects + are created automatically. + + The ``ForeignKey`` objects associated with an individual + :class:`Column` object are available in the `foreign_keys` collection + of that column. + + Further examples of foreign key configuration are in + :ref:`metadata_foreignkeys`. + + """ + + __visit_name__ = 'foreign_key' + + def __init__(self, column, _constraint=None, use_alter=False, name=None, + onupdate=None, ondelete=None, deferrable=None, + initially=None, link_to_name=False): + """ + Construct a column-level FOREIGN KEY. + + The :class:`ForeignKey` object when constructed generates a + :class:`ForeignKeyConstraint` which is associated with the parent + :class:`Table` object's collection of constraints. + + :param column: A single target column for the key relationship. A + :class:`Column` object or a column name as a string: + ``tablename.columnkey`` or ``schema.tablename.columnkey``. + ``columnkey`` is the ``key`` which has been assigned to the column + (defaults to the column name itself), unless ``link_to_name`` is + ``True`` in which case the rendered name of the column is used. + + :param name: Optional string. An in-database name for the key if + `constraint` is not provided. + + :param onupdate: Optional string. If set, emit ON UPDATE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + + :param ondelete: Optional string. If set, emit ON DELETE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + + :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT + DEFERRABLE when issuing DDL for this constraint. + + :param initially: Optional string. If set, emit INITIALLY when + issuing DDL for this constraint. + + :param link_to_name: if True, the string name given in ``column`` is + the rendered name of the referenced column, not its locally + assigned ``key``. + + :param use_alter: passed to the underlying + :class:`ForeignKeyConstraint` to indicate the constraint should be + generated/dropped externally from the CREATE TABLE/ DROP TABLE + statement. See that classes' constructor for details. + + """ + + self._colspec = column + + # the linked ForeignKeyConstraint. + # ForeignKey will create this when parent Column + # is attached to a Table, *or* ForeignKeyConstraint + # object passes itself in when creating ForeignKey + # markers. + self.constraint = _constraint + + + self.use_alter = use_alter + self.name = name + self.onupdate = onupdate + self.ondelete = ondelete + self.deferrable = deferrable + self.initially = initially + self.link_to_name = link_to_name + + def __repr__(self): + return "ForeignKey(%r)" % self._get_colspec() + + def copy(self, schema=None): + """Produce a copy of this ForeignKey object.""" + + return ForeignKey( + self._get_colspec(schema=schema), + use_alter=self.use_alter, + name=self.name, + onupdate=self.onupdate, + ondelete=self.ondelete, + deferrable=self.deferrable, + initially=self.initially, + link_to_name=self.link_to_name + ) + + def _get_colspec(self, schema=None): + if schema: + return schema + "." + self.column.table.name + "." + self.column.key + elif isinstance(self._colspec, basestring): + return self._colspec + elif hasattr(self._colspec, '__clause_element__'): + _column = self._colspec.__clause_element__() + else: + _column = self._colspec + + return "%s.%s" % (_column.table.fullname, _column.key) + + target_fullname = property(_get_colspec) + + def references(self, table): + """Return True if the given table is referenced by this ForeignKey.""" + return table.corresponding_column(self.column) is not None + + def get_referent(self, table): + """Return the column in the given table referenced by this ForeignKey. + + Returns None if this ``ForeignKey`` does not reference the given table. + + """ + + return table.corresponding_column(self.column) + + @util.memoized_property + def column(self): + # ForeignKey inits its remote column as late as possible, so tables + # can be defined without dependencies + if isinstance(self._colspec, basestring): + # locate the parent table this foreign key is attached to. we + # use the "original" column which our parent column represents + # (its a list of columns/other ColumnElements if the parent + # table is a UNION) + for c in self.parent.base_columns: + if isinstance(c, Column): + parenttable = c.table + break + else: + raise exc.ArgumentError( + "Parent column '%s' does not descend from a " + "table-attached Column" % str(self.parent)) + + m = self._colspec.split('.') + + if m is None: + raise exc.ArgumentError( + "Invalid foreign key column specification: %s" % + self._colspec) + + # A FK between column 'bar' and table 'foo' can be + # specified as 'foo', 'foo.bar', 'dbo.foo.bar', + # 'otherdb.dbo.foo.bar'. Once we have the column name and + # the table name, treat everything else as the schema + # name. Some databases (e.g. Sybase) support + # inter-database foreign keys. See tickets#1341 and -- + # indirectly related -- Ticket #594. This assumes that '.' + # will never appear *within* any component of the FK. + + (schema, tname, colname) = (None, None, None) + if (len(m) == 1): + tname = m.pop() + else: + colname = m.pop() + tname = m.pop() + + if (len(m) > 0): + schema = '.'.join(m) + + if _get_table_key(tname, schema) not in parenttable.metadata: + raise exc.NoReferencedTableError( + "Could not find table '%s' with which to generate a " + "foreign key" % tname) + table = Table(tname, parenttable.metadata, + mustexist=True, schema=schema) + + _column = None + if colname is None: + # colname is None in the case that ForeignKey argument + # was specified as table name only, in which case we + # match the column name to the same column on the + # parent. + key = self.parent + _column = table.c.get(self.parent.key, None) + elif self.link_to_name: + key = colname + for c in table.c: + if c.name == colname: + _column = c + else: + key = colname + _column = table.c.get(colname, None) + + if _column is None: + raise exc.NoReferencedColumnError( + "Could not create ForeignKey '%s' on table '%s': " + "table '%s' has no column named '%s'" % ( + self._colspec, parenttable.name, table.name, key)) + + elif hasattr(self._colspec, '__clause_element__'): + _column = self._colspec.__clause_element__() + else: + _column = self._colspec + + # propagate TypeEngine to parent if it didn't have one + if isinstance(self.parent.type, types.NullType): + self.parent.type = _column.type + return _column + + def _set_parent(self, column): + if hasattr(self, 'parent'): + if self.parent is column: + return + raise exc.InvalidRequestError("This ForeignKey already has a parent !") + self.parent = column + self.parent.foreign_keys.add(self) + self.parent._on_table_attach(self._set_table) + + def _set_table(self, table, column): + # standalone ForeignKey - create ForeignKeyConstraint + # on the hosting Table when attached to the Table. + if self.constraint is None and isinstance(table, Table): + self.constraint = ForeignKeyConstraint( + [], [], use_alter=self.use_alter, name=self.name, + onupdate=self.onupdate, ondelete=self.ondelete, + deferrable=self.deferrable, initially=self.initially, + ) + self.constraint._elements[self.parent] = self + self.constraint._set_parent(table) + table.foreign_keys.add(self) + +class DefaultGenerator(SchemaItem): + """Base class for column *default* values.""" + + __visit_name__ = 'default_generator' + + is_sequence = False + + def __init__(self, for_update=False): + self.for_update = for_update + + def _set_parent(self, column): + self.column = column + if self.for_update: + self.column.onupdate = self + else: + self.column.default = self + + def execute(self, bind=None, **kwargs): + if bind is None: + bind = _bind_or_error(self) + return bind._execute_default(self, **kwargs) + + @property + def bind(self): + """Return the connectable associated with this default.""" + if getattr(self, 'column', None) is not None: + return self.column.table.bind + else: + return None + + def __repr__(self): + return "DefaultGenerator()" + + +class ColumnDefault(DefaultGenerator): + """A plain default value on a column. + + This could correspond to a constant, a callable function, or a SQL clause. + """ + + def __init__(self, arg, **kwargs): + super(ColumnDefault, self).__init__(**kwargs) + if isinstance(arg, FetchedValue): + raise exc.ArgumentError( + "ColumnDefault may not be a server-side default type.") + if util.callable(arg): + arg = self._maybe_wrap_callable(arg) + self.arg = arg + + @util.memoized_property + def is_callable(self): + return util.callable(self.arg) + + @util.memoized_property + def is_clause_element(self): + return isinstance(self.arg, expression.ClauseElement) + + @util.memoized_property + def is_scalar(self): + return not self.is_callable and not self.is_clause_element and not self.is_sequence + + def _maybe_wrap_callable(self, fn): + """Backward compat: Wrap callables that don't accept a context.""" + + if inspect.isfunction(fn): + inspectable = fn + elif inspect.isclass(fn): + inspectable = fn.__init__ + elif hasattr(fn, '__call__'): + inspectable = fn.__call__ + else: + # probably not inspectable, try anyways. + inspectable = fn + try: + argspec = inspect.getargspec(inspectable) + except TypeError: + return lambda ctx: fn() + + positionals = len(argspec[0]) + + # Py3K compat - no unbound methods + if inspect.ismethod(inspectable) or inspect.isclass(fn): + positionals -= 1 + + if positionals == 0: + return lambda ctx: fn() + + defaulted = argspec[3] is not None and len(argspec[3]) or 0 + if positionals - defaulted > 1: + raise exc.ArgumentError( + "ColumnDefault Python function takes zero or one " + "positional arguments") + return fn + + def _visit_name(self): + if self.for_update: + return "column_onupdate" + else: + return "column_default" + __visit_name__ = property(_visit_name) + + def __repr__(self): + return "ColumnDefault(%r)" % self.arg + +class Sequence(DefaultGenerator): + """Represents a named database sequence.""" + + __visit_name__ = 'sequence' + + is_sequence = True + + def __init__(self, name, start=None, increment=None, schema=None, + optional=False, quote=None, metadata=None, for_update=False): + super(Sequence, self).__init__(for_update=for_update) + self.name = name + self.start = start + self.increment = increment + self.optional = optional + self.quote = quote + self.schema = schema + self.metadata = metadata + + @util.memoized_property + def is_callable(self): + return False + + @util.memoized_property + def is_clause_element(self): + return False + + def __repr__(self): + return "Sequence(%s)" % ', '.join( + [repr(self.name)] + + ["%s=%s" % (k, repr(getattr(self, k))) + for k in ['start', 'increment', 'optional']]) + + def _set_parent(self, column): + super(Sequence, self)._set_parent(column) + column._on_table_attach(self._set_table) + + def _set_table(self, table, column): + self.metadata = table.metadata + + @property + def bind(self): + if self.metadata: + return self.metadata.bind + else: + return None + + def create(self, bind=None, checkfirst=True): + """Creates this sequence in the database.""" + + if bind is None: + bind = _bind_or_error(self) + bind.create(self, checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=True): + """Drops this sequence from the database.""" + + if bind is None: + bind = _bind_or_error(self) + bind.drop(self, checkfirst=checkfirst) + + +class FetchedValue(object): + """A default that takes effect on the database side.""" + + def __init__(self, for_update=False): + self.for_update = for_update + + def _set_parent(self, column): + self.column = column + if self.for_update: + self.column.server_onupdate = self + else: + self.column.server_default = self + + def __repr__(self): + return 'FetchedValue(for_update=%r)' % self.for_update + + +class DefaultClause(FetchedValue): + """A DDL-specified DEFAULT column value.""" + + def __init__(self, arg, for_update=False): + util.assert_arg_type(arg, (basestring, + expression.ClauseElement, + expression._TextClause), 'arg') + super(DefaultClause, self).__init__(for_update) + self.arg = arg + + def __repr__(self): + return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update) + +class PassiveDefault(DefaultClause): + def __init__(self, *arg, **kw): + util.warn_deprecated("PassiveDefault is deprecated. Use DefaultClause.") + DefaultClause.__init__(self, *arg, **kw) + +class Constraint(SchemaItem): + """A table-level SQL constraint.""" + + __visit_name__ = 'constraint' + + def __init__(self, name=None, deferrable=None, initially=None, + _create_rule=None): + """Create a SQL constraint. + + name + Optional, the in-database name of this ``Constraint``. + + deferrable + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + initially + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + _create_rule + a callable which is passed the DDLCompiler object during + compilation. Returns True or False to signal inline generation of + this Constraint. + + The AddConstraint and DropConstraint DDL constructs provide + DDLElement's more comprehensive "conditional DDL" approach that is + passed a database connection when DDL is being issued. _create_rule + is instead called during any CREATE TABLE compilation, where there + may not be any transaction/connection in progress. However, it + allows conditional compilation of the constraint even for backends + which do not support addition of constraints through ALTER TABLE, + which currently includes SQLite. + + _create_rule is used by some types to create constraints. + Currently, its call signature is subject to change at any time. + + """ + + self.name = name + self.deferrable = deferrable + self.initially = initially + self._create_rule = _create_rule + + @property + def table(self): + try: + if isinstance(self.parent, Table): + return self.parent + except AttributeError: + pass + raise exc.InvalidRequestError("This constraint is not bound to a table. Did you mean to call table.add_constraint(constraint) ?") + + def _set_parent(self, parent): + self.parent = parent + parent.constraints.add(self) + + def copy(self, **kw): + raise NotImplementedError() + +class ColumnCollectionConstraint(Constraint): + """A constraint that proxies a ColumnCollection.""" + + def __init__(self, *columns, **kw): + """ + \*columns + A sequence of column names or Column objects. + + name + Optional, the in-database name of this constraint. + + deferrable + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + initially + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + """ + super(ColumnCollectionConstraint, self).__init__(**kw) + self.columns = expression.ColumnCollection() + self._pending_colargs = [_to_schema_column_or_string(c) for c in columns] + if self._pending_colargs and \ + isinstance(self._pending_colargs[0], Column) and \ + self._pending_colargs[0].table is not None: + self._set_parent(self._pending_colargs[0].table) + + def _set_parent(self, table): + super(ColumnCollectionConstraint, self)._set_parent(table) + for col in self._pending_colargs: + if isinstance(col, basestring): + col = table.c[col] + self.columns.add(col) + + def __contains__(self, x): + return x in self.columns + + def copy(self, **kw): + return self.__class__(name=self.name, deferrable=self.deferrable, + initially=self.initially, *self.columns.keys()) + + def contains_column(self, col): + return self.columns.contains_column(col) + + def __iter__(self): + return iter(self.columns) + + def __len__(self): + return len(self.columns) + + +class CheckConstraint(Constraint): + """A table- or column-level CHECK constraint. + + Can be included in the definition of a Table or Column. + """ + + def __init__(self, sqltext, name=None, deferrable=None, + initially=None, table=None, _create_rule=None): + """Construct a CHECK constraint. + + sqltext + A string containing the constraint definition, which will be used + verbatim, or a SQL expression construct. + + name + Optional, the in-database name of the constraint. + + deferrable + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + initially + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + """ + + super(CheckConstraint, self).__init__(name, deferrable, initially, _create_rule) + self.sqltext = expression._literal_as_text(sqltext) + if table is not None: + self._set_parent(table) + + def __visit_name__(self): + if isinstance(self.parent, Table): + return "check_constraint" + else: + return "column_check_constraint" + __visit_name__ = property(__visit_name__) + + def copy(self, **kw): + return CheckConstraint(self.sqltext, name=self.name) + +class ForeignKeyConstraint(Constraint): + """A table-level FOREIGN KEY constraint. + + Defines a single column or composite FOREIGN KEY ... REFERENCES + constraint. For a no-frills, single column foreign key, adding a + :class:`ForeignKey` to the definition of a :class:`Column` is a shorthand + equivalent for an unnamed, single column :class:`ForeignKeyConstraint`. + + Examples of foreign key configuration are in :ref:`metadata_foreignkeys`. + + """ + __visit_name__ = 'foreign_key_constraint' + + def __init__(self, columns, refcolumns, name=None, onupdate=None, + ondelete=None, deferrable=None, initially=None, use_alter=False, + link_to_name=False, table=None): + """Construct a composite-capable FOREIGN KEY. + + :param columns: A sequence of local column names. The named columns + must be defined and present in the parent Table. The names should + match the ``key`` given to each column (defaults to the name) unless + ``link_to_name`` is True. + + :param refcolumns: A sequence of foreign column names or Column + objects. The columns must all be located within the same Table. + + :param name: Optional, the in-database name of the key. + + :param onupdate: Optional string. If set, emit ON UPDATE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + + :param ondelete: Optional string. If set, emit ON DELETE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + + :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT + DEFERRABLE when issuing DDL for this constraint. + + :param initially: Optional string. If set, emit INITIALLY when + issuing DDL for this constraint. + + :param link_to_name: if True, the string name given in ``column`` is + the rendered name of the referenced column, not its locally assigned + ``key``. + + :param use_alter: If True, do not emit the DDL for this constraint as + part of the CREATE TABLE definition. Instead, generate it via an + ALTER TABLE statement issued after the full collection of tables + have been created, and drop it via an ALTER TABLE statement before + the full collection of tables are dropped. This is shorthand for the + usage of :class:`AddConstraint` and :class:`DropConstraint` applied + as "after-create" and "before-drop" events on the MetaData object. + This is normally used to generate/drop constraints on objects that + are mutually dependent on each other. + + """ + super(ForeignKeyConstraint, self).__init__(name, deferrable, initially) + + self.onupdate = onupdate + self.ondelete = ondelete + self.link_to_name = link_to_name + if self.name is None and use_alter: + raise exc.ArgumentError("Alterable Constraint requires a name") + self.use_alter = use_alter + + self._elements = util.OrderedDict() + + # standalone ForeignKeyConstraint - create + # associated ForeignKey objects which will be applied to hosted + # Column objects (in col.foreign_keys), either now or when attached + # to the Table for string-specified names + for col, refcol in zip(columns, refcolumns): + self._elements[col] = ForeignKey( + refcol, + _constraint=self, + name=self.name, + onupdate=self.onupdate, + ondelete=self.ondelete, + use_alter=self.use_alter, + link_to_name=self.link_to_name + ) + + if table: + self._set_parent(table) + + @property + def columns(self): + return self._elements.keys() + + @property + def elements(self): + return self._elements.values() + + def _set_parent(self, table): + super(ForeignKeyConstraint, self)._set_parent(table) + for col, fk in self._elements.iteritems(): + # string-specified column names now get + # resolved to Column objects + if isinstance(col, basestring): + col = table.c[col] + fk._set_parent(col) + + if self.use_alter: + def supports_alter(ddl, event, schema_item, bind, **kw): + return table in set(kw['tables']) and bind.dialect.supports_alter + AddConstraint(self, on=supports_alter).execute_at('after-create', table.metadata) + DropConstraint(self, on=supports_alter).execute_at('before-drop', table.metadata) + + def copy(self, **kw): + return ForeignKeyConstraint( + [x.parent.name for x in self._elements.values()], + [x._get_colspec(**kw) for x in self._elements.values()], + name=self.name, + onupdate=self.onupdate, + ondelete=self.ondelete, + use_alter=self.use_alter, + deferrable=self.deferrable, + initially=self.initially, + link_to_name=self.link_to_name + ) + +class PrimaryKeyConstraint(ColumnCollectionConstraint): + """A table-level PRIMARY KEY constraint. + + Defines a single column or composite PRIMARY KEY constraint. For a + no-frills primary key, adding ``primary_key=True`` to one or more + ``Column`` definitions is a shorthand equivalent for an unnamed single- or + multiple-column PrimaryKeyConstraint. + """ + + __visit_name__ = 'primary_key_constraint' + + def _set_parent(self, table): + super(PrimaryKeyConstraint, self)._set_parent(table) + table._set_primary_key(self) + + def _replace(self, col): + self.columns.replace(col) + +class UniqueConstraint(ColumnCollectionConstraint): + """A table-level UNIQUE constraint. + + Defines a single column or composite UNIQUE constraint. For a no-frills, + single column constraint, adding ``unique=True`` to the ``Column`` + definition is a shorthand equivalent for an unnamed, single column + UniqueConstraint. + """ + + __visit_name__ = 'unique_constraint' + +class Index(SchemaItem): + """A table-level INDEX. + + Defines a composite (one or more column) INDEX. For a no-frills, single + column index, adding ``index=True`` to the ``Column`` definition is + a shorthand equivalent for an unnamed, single column Index. + """ + + __visit_name__ = 'index' + + def __init__(self, name, *columns, **kwargs): + """Construct an index object. + + Arguments are: + + name + The name of the index + + \*columns + Columns to include in the index. All columns must belong to the same + table. + + \**kwargs + Keyword arguments include: + + unique + Defaults to False: create a unique index. + + postgresql_where + Defaults to None: create a partial index when using PostgreSQL + """ + + self.name = name + self.columns = expression.ColumnCollection() + self.table = None + self.unique = kwargs.pop('unique', False) + self.kwargs = kwargs + + for column in columns: + column = _to_schema_column(column) + if self.table is None: + self._set_parent(column.table) + elif column.table != self.table: + # all columns muse be from same table + raise exc.ArgumentError( + "All index columns must be from same table. " + "%s is from %s not %s" % (column, column.table, self.table)) + self.columns.add(column) + + def _set_parent(self, table): + self.table = table + table.indexes.add(self) + + @property + def bind(self): + """Return the connectable associated with this Index.""" + + return self.table.bind + + def create(self, bind=None): + if bind is None: + bind = _bind_or_error(self) + bind.create(self) + return self + + def drop(self, bind=None): + if bind is None: + bind = _bind_or_error(self) + bind.drop(self) + + def __repr__(self): + return 'Index("%s", %s%s)' % (self.name, + ', '.join(repr(c) for c in self.columns), + (self.unique and ', unique=True') or '') + +class MetaData(SchemaItem): + """A collection of Tables and their associated schema constructs. + + Holds a collection of Tables and an optional binding to an ``Engine`` or + ``Connection``. If bound, the :class:`~sqlalchemy.schema.Table` objects + in the collection and their columns may participate in implicit SQL + execution. + + The `Table` objects themselves are stored in the `metadata.tables` + dictionary. + + The ``bind`` property may be assigned to dynamically. A common pattern is + to start unbound and then bind later when an engine is available:: + + metadata = MetaData() + # define tables + Table('mytable', metadata, ...) + # connect to an engine later, perhaps after loading a URL from a + # configuration file + metadata.bind = an_engine + + MetaData is a thread-safe object after tables have been explicitly defined + or loaded via reflection. + + .. index:: + single: thread safety; MetaData + + """ + + __visit_name__ = 'metadata' + + ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop') + + def __init__(self, bind=None, reflect=False): + """Create a new MetaData object. + + bind + An Engine or Connection to bind to. May also be a string or URL + instance, these are passed to create_engine() and this MetaData will + be bound to the resulting engine. + + reflect + Optional, automatically load all tables from the bound database. + Defaults to False. ``bind`` is required when this option is set. + For finer control over loaded tables, use the ``reflect`` method of + ``MetaData``. + + """ + self.tables = {} + self.bind = bind + self.metadata = self + self.ddl_listeners = util.defaultdict(list) + if reflect: + if not bind: + raise exc.ArgumentError( + "A bind must be supplied in conjunction with reflect=True") + self.reflect() + + def __repr__(self): + return 'MetaData(%r)' % self.bind + + def __contains__(self, table_or_key): + if not isinstance(table_or_key, basestring): + table_or_key = table_or_key.key + return table_or_key in self.tables + + def __getstate__(self): + return {'tables': self.tables} + + def __setstate__(self, state): + self.tables = state['tables'] + self._bind = None + + def is_bound(self): + """True if this MetaData is bound to an Engine or Connection.""" + + return self._bind is not None + + def bind(self): + """An Engine or Connection to which this MetaData is bound. + + This property may be assigned an ``Engine`` or ``Connection``, or + assigned a string or URL to automatically create a basic ``Engine`` + for this bind with ``create_engine()``. + + """ + return self._bind + + def _bind_to(self, bind): + """Bind this MetaData to an Engine, Connection, string or URL.""" + + global URL + if URL is None: + from sqlalchemy.engine.url import URL + + if isinstance(bind, (basestring, URL)): + from sqlalchemy import create_engine + self._bind = create_engine(bind) + else: + self._bind = bind + bind = property(bind, _bind_to) + + def clear(self): + """Clear all Table objects from this MetaData.""" + # TODO: why have clear()/remove() but not all + # other accesors/mutators for the tables dict ? + self.tables.clear() + + def remove(self, table): + """Remove the given Table object from this MetaData.""" + + # TODO: scan all other tables and remove FK _column + del self.tables[table.key] + + @property + def sorted_tables(self): + """Returns a list of ``Table`` objects sorted in order of + dependency. + """ + from sqlalchemy.sql.util import sort_tables + return sort_tables(self.tables.itervalues()) + + def reflect(self, bind=None, schema=None, only=None): + """Load all available table definitions from the database. + + Automatically creates ``Table`` entries in this ``MetaData`` for any + table available in the database but not yet present in the + ``MetaData``. May be called multiple times to pick up tables recently + added to the database, however no special action is taken if a table + in this ``MetaData`` no longer exists in the database. + + bind + A :class:`~sqlalchemy.engine.base.Connectable` used to access the database; if None, uses the + existing bind on this ``MetaData``, if any. + + schema + Optional, query and reflect tables from an alterate schema. + + only + Optional. Load only a sub-set of available named tables. May be + specified as a sequence of names or a callable. + + If a sequence of names is provided, only those tables will be + reflected. An error is raised if a table is requested but not + available. Named tables already present in this ``MetaData`` are + ignored. + + If a callable is provided, it will be used as a boolean predicate to + filter the list of potential table names. The callable is called + with a table name and this ``MetaData`` instance as positional + arguments and should return a true value for any table to reflect. + + """ + reflect_opts = {'autoload': True} + if bind is None: + bind = _bind_or_error(self) + conn = None + else: + reflect_opts['autoload_with'] = bind + conn = bind.contextual_connect() + + if schema is not None: + reflect_opts['schema'] = schema + + available = util.OrderedSet(bind.engine.table_names(schema, + connection=conn)) + current = set(self.tables.iterkeys()) + + if only is None: + load = [name for name in available if name not in current] + elif util.callable(only): + load = [name for name in available + if name not in current and only(name, self)] + else: + missing = [name for name in only if name not in available] + if missing: + s = schema and (" schema '%s'" % schema) or '' + raise exc.InvalidRequestError( + 'Could not reflect: requested table(s) not available ' + 'in %s%s: (%s)' % (bind.engine.url, s, ', '.join(missing))) + load = [name for name in only if name not in current] + + for name in load: + Table(name, self, **reflect_opts) + + def append_ddl_listener(self, event, listener): + """Append a DDL event listener to this ``MetaData``. + + The ``listener`` callable will be triggered when this ``MetaData`` is + involved in DDL creates or drops, and will be invoked either before + all Table-related actions or after. + + Arguments are: + + event + One of ``MetaData.ddl_events``; 'before-create', 'after-create', + 'before-drop' or 'after-drop'. + listener + A callable, invoked with three positional arguments: + + event + The event currently being handled + target + The ``MetaData`` object being operated upon + bind + The ``Connection`` bueing used for DDL execution. + + Listeners are added to the MetaData's ``ddl_listeners`` attribute. + + Note: MetaData listeners are invoked even when ``Tables`` are created + in isolation. This may change in a future release. I.e.:: + + # triggers all MetaData and Table listeners: + metadata.create_all() + + # triggers MetaData listeners too: + some.table.create() + + """ + if event not in self.ddl_events: + raise LookupError(event) + self.ddl_listeners[event].append(listener) + + def create_all(self, bind=None, tables=None, checkfirst=True): + """Create all tables stored in this metadata. + + Conditional by default, will not attempt to recreate tables already + present in the target database. + + bind + A :class:`~sqlalchemy.engine.base.Connectable` used to access the database; if None, uses the + existing bind on this ``MetaData``, if any. + + tables + Optional list of ``Table`` objects, which is a subset of the total + tables in the ``MetaData`` (others are ignored). + + checkfirst + Defaults to True, don't issue CREATEs for tables already present + in the target database. + + """ + if bind is None: + bind = _bind_or_error(self) + bind.create(self, checkfirst=checkfirst, tables=tables) + + def drop_all(self, bind=None, tables=None, checkfirst=True): + """Drop all tables stored in this metadata. + + Conditional by default, will not attempt to drop tables not present in + the target database. + + bind + A :class:`~sqlalchemy.engine.base.Connectable` used to access the database; if None, uses + the existing bind on this ``MetaData``, if any. + + tables + Optional list of ``Table`` objects, which is a subset of the + total tables in the ``MetaData`` (others are ignored). + + checkfirst + Defaults to True, only issue DROPs for tables confirmed to be present + in the target database. + + """ + if bind is None: + bind = _bind_or_error(self) + bind.drop(self, checkfirst=checkfirst, tables=tables) + +class ThreadLocalMetaData(MetaData): + """A MetaData variant that presents a different ``bind`` in every thread. + + Makes the ``bind`` property of the MetaData a thread-local value, allowing + this collection of tables to be bound to different ``Engine`` + implementations or connections in each thread. + + The ThreadLocalMetaData starts off bound to None in each thread. Binds + must be made explicitly by assigning to the ``bind`` property or using + ``connect()``. You can also re-bind dynamically multiple times per + thread, just like a regular ``MetaData``. + + """ + + __visit_name__ = 'metadata' + + def __init__(self): + """Construct a ThreadLocalMetaData.""" + + self.context = util.threading.local() + self.__engines = {} + super(ThreadLocalMetaData, self).__init__() + + def bind(self): + """The bound Engine or Connection for this thread. + + This property may be assigned an Engine or Connection, or assigned a + string or URL to automatically create a basic Engine for this bind + with ``create_engine()``.""" + + return getattr(self.context, '_engine', None) + + def _bind_to(self, bind): + """Bind to a Connectable in the caller's thread.""" + + global URL + if URL is None: + from sqlalchemy.engine.url import URL + + if isinstance(bind, (basestring, URL)): + try: + self.context._engine = self.__engines[bind] + except KeyError: + from sqlalchemy import create_engine + e = create_engine(bind) + self.__engines[bind] = e + self.context._engine = e + else: + # TODO: this is squirrely. we shouldnt have to hold onto engines + # in a case like this + if bind not in self.__engines: + self.__engines[bind] = bind + self.context._engine = bind + + bind = property(bind, _bind_to) + + def is_bound(self): + """True if there is a bind for this thread.""" + return (hasattr(self.context, '_engine') and + self.context._engine is not None) + + def dispose(self): + """Dispose all bound engines, in all thread contexts.""" + + for e in self.__engines.itervalues(): + if hasattr(e, 'dispose'): + e.dispose() + +class SchemaVisitor(visitors.ClauseVisitor): + """Define the visiting for ``SchemaItem`` objects.""" + + __traverse_options__ = {'schema_visitor':True} + + +class DDLElement(expression.Executable, expression.ClauseElement): + """Base class for DDL expression constructs.""" + + _execution_options = expression.Executable.\ + _execution_options.union({'autocommit':True}) + + target = None + on = None + + def execute(self, bind=None, target=None): + """Execute this DDL immediately. + + Executes the DDL statement in isolation using the supplied + :class:`~sqlalchemy.engine.base.Connectable` or :class:`~sqlalchemy.engine.base.Connectable` assigned to the ``.bind`` property, + if not supplied. If the DDL has a conditional ``on`` criteria, it + will be invoked with None as the event. + + bind + Optional, an ``Engine`` or ``Connection``. If not supplied, a + valid :class:`~sqlalchemy.engine.base.Connectable` must be present in the ``.bind`` property. + + target + Optional, defaults to None. The target SchemaItem for the + execute call. Will be passed to the ``on`` callable if any, + and may also provide string expansion data for the + statement. See ``execute_at`` for more information. + """ + + if bind is None: + bind = _bind_or_error(self) + + if self._should_execute(None, target, bind): + return bind.execute(self.against(target)) + else: + bind.engine.logger.info("DDL execution skipped, criteria not met.") + + def execute_at(self, event, target): + """Link execution of this DDL to the DDL lifecycle of a SchemaItem. + + Links this ``DDLElement`` to a ``Table`` or ``MetaData`` instance, executing + it when that schema item is created or dropped. The DDL statement + will be executed using the same Connection and transactional context + as the Table create/drop itself. The ``.bind`` property of this + statement is ignored. + + event + One of the events defined in the schema item's ``.ddl_events``; + e.g. 'before-create', 'after-create', 'before-drop' or 'after-drop' + + target + The Table or MetaData instance for which this DDLElement will + be associated with. + + A DDLElement instance can be linked to any number of schema items. + + ``execute_at`` builds on the ``append_ddl_listener`` interface of + MetaDta and Table objects. + + Caveat: Creating or dropping a Table in isolation will also trigger + any DDL set to ``execute_at`` that Table's MetaData. This may change + in a future release. + """ + + if not hasattr(target, 'ddl_listeners'): + raise exc.ArgumentError( + "%s does not support DDL events" % type(target).__name__) + if event not in target.ddl_events: + raise exc.ArgumentError( + "Unknown event, expected one of (%s), got '%r'" % + (', '.join(target.ddl_events), event)) + target.ddl_listeners[event].append(self) + return self + + @expression._generative + def against(self, target): + """Return a copy of this DDL against a specific schema item.""" + + self.target = target + + def __call__(self, event, target, bind, **kw): + """Execute the DDL as a ddl_listener.""" + + if self._should_execute(event, target, bind, **kw): + return bind.execute(self.against(target)) + + def _check_ddl_on(self, on): + if (on is not None and + (not isinstance(on, (basestring, tuple, list, set)) and not util.callable(on))): + raise exc.ArgumentError( + "Expected the name of a database dialect, a tuple of names, or a callable for " + "'on' criteria, got type '%s'." % type(on).__name__) + + def _should_execute(self, event, target, bind, **kw): + if self.on is None: + return True + elif isinstance(self.on, basestring): + return self.on == bind.engine.name + elif isinstance(self.on, (tuple, list, set)): + return bind.engine.name in self.on + else: + return self.on(self, event, target, bind, **kw) + + def bind(self): + if self._bind: + return self._bind + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + return s + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a Dialect.""" + + return dialect.ddl_compiler(dialect, self, **kw) + +class DDL(DDLElement): + """A literal DDL statement. + + Specifies literal SQL DDL to be executed by the database. DDL objects can + be attached to ``Tables`` or ``MetaData`` instances, conditionally + executing SQL as part of the DDL lifecycle of those schema items. Basic + templating support allows a single DDL instance to handle repetitive tasks + for multiple tables. + + Examples:: + + tbl = Table('users', metadata, Column('uid', Integer)) # ... + DDL('DROP TRIGGER users_trigger').execute_at('before-create', tbl) + + spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE', on='somedb') + spow.execute_at('after-create', tbl) + + drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE') + connection.execute(drop_spow) + + When operating on Table events, the following ``statement`` + string substitions are available:: + + %(table)s - the Table name, with any required quoting applied + %(schema)s - the schema name, with any required quoting applied + %(fullname)s - the Table name including schema, quoted if needed + + The DDL's ``context``, if any, will be combined with the standard + substutions noted above. Keys present in the context will override + the standard substitutions. + + """ + + __visit_name__ = "ddl" + + def __init__(self, statement, on=None, context=None, bind=None): + """Create a DDL statement. + + statement + A string or unicode string to be executed. Statements will be + processed with Python's string formatting operator. See the + ``context`` argument and the ``execute_at`` method. + + A literal '%' in a statement must be escaped as '%%'. + + SQL bind parameters are not available in DDL statements. + + on + Optional filtering criteria. May be a string, tuple or a callable + predicate. If a string, it will be compared to the name of the + executing database dialect:: + + DDL('something', on='postgresql') + + If a tuple, specifies multiple dialect names:: + + DDL('something', on=('postgresql', 'mysql')) + + If a callable, it will be invoked with four positional arguments + as well as optional keyword arguments: + + ddl + This DDL element. + + event + The name of the event that has triggered this DDL, such as + 'after-create' Will be None if the DDL is executed explicitly. + + target + The ``Table`` or ``MetaData`` object which is the target of + this event. May be None if the DDL is executed explicitly. + + connection + The ``Connection`` being used for DDL execution + + \**kw + Keyword arguments which may be sent include: + tables - a list of Table objects which are to be created/ + dropped within a MetaData.create_all() or drop_all() method + call. + + If the callable returns a true value, the DDL statement will be + executed. + + context + Optional dictionary, defaults to None. These values will be + available for use in string substitutions on the DDL statement. + + bind + Optional. A :class:`~sqlalchemy.engine.base.Connectable`, used by default when ``execute()`` + is invoked without a bind argument. + + """ + + if not isinstance(statement, basestring): + raise exc.ArgumentError( + "Expected a string or unicode SQL statement, got '%r'" % + statement) + + self.statement = statement + self.context = context or {} + + self._check_ddl_on(on) + self.on = on + self._bind = bind + + + def __repr__(self): + return '<%s@%s; %s>' % ( + type(self).__name__, id(self), + ', '.join([repr(self.statement)] + + ['%s=%r' % (key, getattr(self, key)) + for key in ('on', 'context') + if getattr(self, key)])) + +def _to_schema_column(element): + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + if not isinstance(element, Column): + raise exc.ArgumentError("schema.Column object expected") + return element + +def _to_schema_column_or_string(element): + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + return element + +class _CreateDropBase(DDLElement): + """Base class for DDL constucts that represent CREATE and DROP or equivalents. + + The common theme of _CreateDropBase is a single + ``element`` attribute which refers to the element + to be created or dropped. + + """ + + def __init__(self, element, on=None, bind=None): + self.element = element + self._check_ddl_on(on) + self.on = on + self.bind = bind + + def _create_rule_disable(self, compiler): + """Allow disable of _create_rule using a callable. + + Pass to _create_rule using + util.portable_instancemethod(self._create_rule_disable) + to retain serializability. + + """ + return False + +class CreateTable(_CreateDropBase): + """Represent a CREATE TABLE statement.""" + + __visit_name__ = "create_table" + +class DropTable(_CreateDropBase): + """Represent a DROP TABLE statement.""" + + __visit_name__ = "drop_table" + +class CreateSequence(_CreateDropBase): + """Represent a CREATE SEQUENCE statement.""" + + __visit_name__ = "create_sequence" + +class DropSequence(_CreateDropBase): + """Represent a DROP SEQUENCE statement.""" + + __visit_name__ = "drop_sequence" + +class CreateIndex(_CreateDropBase): + """Represent a CREATE INDEX statement.""" + + __visit_name__ = "create_index" + +class DropIndex(_CreateDropBase): + """Represent a DROP INDEX statement.""" + + __visit_name__ = "drop_index" + +class AddConstraint(_CreateDropBase): + """Represent an ALTER TABLE ADD CONSTRAINT statement.""" + + __visit_name__ = "add_constraint" + + def __init__(self, element, *args, **kw): + super(AddConstraint, self).__init__(element, *args, **kw) + element._create_rule = util.portable_instancemethod(self._create_rule_disable) + +class DropConstraint(_CreateDropBase): + """Represent an ALTER TABLE DROP CONSTRAINT statement.""" + + __visit_name__ = "drop_constraint" + + def __init__(self, element, cascade=False, **kw): + self.cascade = cascade + super(DropConstraint, self).__init__(element, **kw) + element._create_rule = util.portable_instancemethod(self._create_rule_disable) + +def _bind_or_error(schemaitem, msg=None): + bind = schemaitem.bind + if not bind: + name = schemaitem.__class__.__name__ + label = getattr(schemaitem, 'fullname', + getattr(schemaitem, 'name', None)) + if label: + item = '%s %r' % (name, label) + else: + item = name + if isinstance(schemaitem, (MetaData, DDL)): + bindable = "the %s's .bind" % name + else: + bindable = "this %s's .metadata.bind" % name + + if msg is None: + msg = ('The %s is not bound to an Engine or Connection. ' + 'Execution can not proceed without a database to execute ' + 'against. Either execute with an explicit connection or ' + 'assign %s to enable implicit execution.') % (item, bindable) + raise exc.UnboundExecutionError(msg) + return bind + diff --git a/sqlalchemy/sql/__init__.py b/sqlalchemy/sql/__init__.py new file mode 100644 index 0000000..aa18eac --- /dev/null +++ b/sqlalchemy/sql/__init__.py @@ -0,0 +1,58 @@ +from sqlalchemy.sql.expression import ( + Alias, + ClauseElement, + ColumnCollection, + ColumnElement, + CompoundSelect, + Delete, + FromClause, + Insert, + Join, + Select, + Selectable, + TableClause, + Update, + alias, + and_, + asc, + between, + bindparam, + case, + cast, + collate, + column, + delete, + desc, + distinct, + except_, + except_all, + exists, + extract, + func, + insert, + intersect, + intersect_all, + join, + label, + literal, + literal_column, + modifier, + not_, + null, + or_, + outerjoin, + outparam, + select, + subquery, + table, + text, + tuple_, + union, + union_all, + update, + ) + +from sqlalchemy.sql.visitors import ClauseVisitor + +__tmp = locals().keys() +__all__ = sorted([i for i in __tmp if not i.startswith('__')]) diff --git a/sqlalchemy/sql/compiler.py b/sqlalchemy/sql/compiler.py new file mode 100644 index 0000000..78c6577 --- /dev/null +++ b/sqlalchemy/sql/compiler.py @@ -0,0 +1,1612 @@ +# compiler.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Base SQL and DDL compiler implementations. + +Classes provided include: + +:class:`~sqlalchemy.sql.compiler.SQLCompiler` - renders SQL +strings + +:class:`~sqlalchemy.sql.compiler.DDLCompiler` - renders DDL +(data definition language) strings + +:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders +type specification strings. + +To generate user-defined SQL strings, see +:module:`~sqlalchemy.ext.compiler`. + +""" + +import re +from sqlalchemy import schema, engine, util, exc +from sqlalchemy.sql import operators, functions, util as sql_util, visitors +from sqlalchemy.sql import expression as sql +import decimal + +RESERVED_WORDS = set([ + 'all', 'analyse', 'analyze', 'and', 'any', 'array', + 'as', 'asc', 'asymmetric', 'authorization', 'between', + 'binary', 'both', 'case', 'cast', 'check', 'collate', + 'column', 'constraint', 'create', 'cross', 'current_date', + 'current_role', 'current_time', 'current_timestamp', + 'current_user', 'default', 'deferrable', 'desc', + 'distinct', 'do', 'else', 'end', 'except', 'false', + 'for', 'foreign', 'freeze', 'from', 'full', 'grant', + 'group', 'having', 'ilike', 'in', 'initially', 'inner', + 'intersect', 'into', 'is', 'isnull', 'join', 'leading', + 'left', 'like', 'limit', 'localtime', 'localtimestamp', + 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', + 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', + 'placing', 'primary', 'references', 'right', 'select', + 'session_user', 'set', 'similar', 'some', 'symmetric', 'table', + 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', + 'using', 'verbose', 'when', 'where']) + +LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I) +ILLEGAL_INITIAL_CHARACTERS = set([str(x) for x in xrange(0, 10)]).union(['$']) + +BIND_PARAMS = re.compile(r'(? ', + operators.ge : ' >= ', + operators.eq : ' = ', + operators.concat_op : ' || ', + operators.between_op : ' BETWEEN ', + operators.match_op : ' MATCH ', + operators.in_op : ' IN ', + operators.notin_op : ' NOT IN ', + operators.comma_op : ', ', + operators.from_ : ' FROM ', + operators.as_ : ' AS ', + operators.is_ : ' IS ', + operators.isnot : ' IS NOT ', + operators.collate : ' COLLATE ', + + # unary + operators.exists : 'EXISTS ', + operators.distinct_op : 'DISTINCT ', + operators.inv : 'NOT ', + + # modifiers + operators.desc_op : ' DESC', + operators.asc_op : ' ASC', +} + +FUNCTIONS = { + functions.coalesce : 'coalesce%(expr)s', + functions.current_date: 'CURRENT_DATE', + functions.current_time: 'CURRENT_TIME', + functions.current_timestamp: 'CURRENT_TIMESTAMP', + functions.current_user: 'CURRENT_USER', + functions.localtime: 'LOCALTIME', + functions.localtimestamp: 'LOCALTIMESTAMP', + functions.random: 'random%(expr)s', + functions.sysdate: 'sysdate', + functions.session_user :'SESSION_USER', + functions.user: 'USER' +} + +EXTRACT_MAP = { + 'month': 'month', + 'day': 'day', + 'year': 'year', + 'second': 'second', + 'hour': 'hour', + 'doy': 'doy', + 'minute': 'minute', + 'quarter': 'quarter', + 'dow': 'dow', + 'week': 'week', + 'epoch': 'epoch', + 'milliseconds': 'milliseconds', + 'microseconds': 'microseconds', + 'timezone_hour': 'timezone_hour', + 'timezone_minute': 'timezone_minute' +} + +COMPOUND_KEYWORDS = { + sql.CompoundSelect.UNION : 'UNION', + sql.CompoundSelect.UNION_ALL : 'UNION ALL', + sql.CompoundSelect.EXCEPT : 'EXCEPT', + sql.CompoundSelect.EXCEPT_ALL : 'EXCEPT ALL', + sql.CompoundSelect.INTERSECT : 'INTERSECT', + sql.CompoundSelect.INTERSECT_ALL : 'INTERSECT ALL' +} + +class _CompileLabel(visitors.Visitable): + """lightweight label object which acts as an expression._Label.""" + + __visit_name__ = 'label' + __slots__ = 'element', 'name' + + def __init__(self, col, name): + self.element = col + self.name = name + + @property + def quote(self): + return self.element.quote + +class SQLCompiler(engine.Compiled): + """Default implementation of Compiled. + + Compiles ClauseElements into SQL strings. Uses a similar visit + paradigm as visitors.ClauseVisitor but implements its own traversal. + + """ + + extract_map = EXTRACT_MAP + + compound_keywords = COMPOUND_KEYWORDS + + # class-level defaults which can be set at the instance + # level to define if this Compiled instance represents + # INSERT/UPDATE/DELETE + isdelete = isinsert = isupdate = False + + # holds the "returning" collection of columns if + # the statement is CRUD and defines returning columns + # either implicitly or explicitly + returning = None + + # set to True classwide to generate RETURNING + # clauses before the VALUES or WHERE clause (i.e. MSSQL) + returning_precedes_values = False + + # SQL 92 doesn't allow bind parameters to be used + # in the columns clause of a SELECT, nor does it allow + # ambiguous expressions like "? = ?". A compiler + # subclass can set this flag to False if the target + # driver/DB enforces this + ansi_bind_rules = False + + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): + """Construct a new ``DefaultCompiler`` object. + + dialect + Dialect to be used + + statement + ClauseElement to be compiled + + column_keys + a list of column names to be compiled into an INSERT or UPDATE + statement. + + """ + engine.Compiled.__init__(self, dialect, statement, **kwargs) + + self.column_keys = column_keys + + # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) + self.inline = inline or getattr(statement, 'inline', False) + + # a dictionary of bind parameter keys to _BindParamClause instances. + self.binds = {} + + # a dictionary of _BindParamClause instances to "compiled" names that are + # actually present in the generated SQL + self.bind_names = util.column_dict() + + # stack which keeps track of nested SELECT statements + self.stack = [] + + # relates label names in the final SQL to + # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine. + # ResultProxy uses this for type processing and column targeting + self.result_map = {} + + # true if the paramstyle is positional + self.positional = self.dialect.positional + if self.positional: + self.positiontup = [] + + self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle] + + # an IdentifierPreparer that formats the quoting of identifiers + self.preparer = self.dialect.identifier_preparer + + self.label_length = self.dialect.label_length or self.dialect.max_identifier_length + + # a map which tracks "anonymous" identifiers that are + # created on the fly here + self.anon_map = util.PopulateDict(self._process_anon) + + # a map which tracks "truncated" names based on dialect.label_length + # or dialect.max_identifier_length + self.truncated_names = {} + + def is_subquery(self): + return len(self.stack) > 1 + + @property + def sql_compiler(self): + return self + + def construct_params(self, params=None, _group_number=None): + """return a dictionary of bind parameter keys and values""" + + if params: + pd = {} + for bindparam, name in self.bind_names.iteritems(): + for paramname in (bindparam.key, name): + if paramname in params: + pd[name] = params[paramname] + break + else: + if bindparam.required: + if _group_number: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" % + (bindparam.key, _group_number)) + else: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key) + elif util.callable(bindparam.value): + pd[name] = bindparam.value() + else: + pd[name] = bindparam.value + return pd + else: + pd = {} + for bindparam in self.bind_names: + if util.callable(bindparam.value): + pd[self.bind_names[bindparam]] = bindparam.value() + else: + pd[self.bind_names[bindparam]] = bindparam.value + return pd + + params = property(construct_params, doc=""" + Return the bind params for this compiled object. + + """) + + def default_from(self): + """Called when a SELECT statement has no froms, and no FROM clause is to be appended. + + Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. + + """ + return "" + + def visit_grouping(self, grouping, asfrom=False, **kwargs): + return "(" + self.process(grouping.element, **kwargs) + ")" + + def visit_label(self, label, result_map=None, + within_label_clause=False, + within_columns_clause=False, **kw): + # only render labels within the columns clause + # or ORDER BY clause of a select. dialect-specific compilers + # can modify this behavior. + if within_columns_clause and not within_label_clause: + labelname = isinstance(label.name, sql._generated_label) and \ + self._truncated_identifier("colident", label.name) or label.name + + if result_map is not None: + result_map[labelname.lower()] = \ + (label.name, (label, label.element, labelname), label.element.type) + + return self.process(label.element, + within_columns_clause=True, + within_label_clause=True, + **kw) + \ + OPERATORS[operators.as_] + \ + self.preparer.format_label(label, labelname) + else: + return self.process(label.element, + within_columns_clause=False, + **kw) + + def visit_column(self, column, result_map=None, **kwargs): + name = column.name + if not column.is_literal and isinstance(name, sql._generated_label): + name = self._truncated_identifier("colident", name) + + if result_map is not None: + result_map[name.lower()] = (name, (column, ), column.type) + + if column.is_literal: + name = self.escape_literal_column(name) + else: + name = self.preparer.quote(name, column.quote) + + if column.table is None or not column.table.named_with_column: + return name + else: + if column.table.schema: + schema_prefix = self.preparer.quote_schema( + column.table.schema, + column.table.quote_schema) + '.' + else: + schema_prefix = '' + tablename = column.table.name + tablename = isinstance(tablename, sql._generated_label) and \ + self._truncated_identifier("alias", tablename) or tablename + + return schema_prefix + \ + self.preparer.quote(tablename, column.table.quote) + "." + name + + def escape_literal_column(self, text): + """provide escaping for the literal_column() construct.""" + + # TODO: some dialects might need different behavior here + return text.replace('%', '%%') + + def visit_fromclause(self, fromclause, **kwargs): + return fromclause.name + + def visit_index(self, index, **kwargs): + return index.name + + def visit_typeclause(self, typeclause, **kwargs): + return self.dialect.type_compiler.process(typeclause.type) + + def post_process_text(self, text): + return text + + def visit_textclause(self, textclause, **kwargs): + if textclause.typemap is not None: + for colname, type_ in textclause.typemap.iteritems(): + self.result_map[colname.lower()] = (colname, None, type_) + + def do_bindparam(m): + name = m.group(1) + if name in textclause.bindparams: + return self.process(textclause.bindparams[name]) + else: + return self.bindparam_string(name) + + # un-escape any \:params + return BIND_PARAMS_ESC.sub(lambda m: m.group(1), + BIND_PARAMS.sub(do_bindparam, self.post_process_text(textclause.text)) + ) + + def visit_null(self, null, **kwargs): + return 'NULL' + + def visit_clauselist(self, clauselist, **kwargs): + sep = clauselist.operator + if sep is None: + sep = " " + else: + sep = OPERATORS[clauselist.operator] + return sep.join(s for s in (self.process(c, **kwargs) for c in clauselist.clauses) + if s is not None) + + def visit_case(self, clause, **kwargs): + x = "CASE " + if clause.value is not None: + x += self.process(clause.value, **kwargs) + " " + for cond, result in clause.whens: + x += "WHEN " + self.process(cond, **kwargs) + \ + " THEN " + self.process(result, **kwargs) + " " + if clause.else_ is not None: + x += "ELSE " + self.process(clause.else_, **kwargs) + " " + x += "END" + return x + + def visit_cast(self, cast, **kwargs): + return "CAST(%s AS %s)" % \ + (self.process(cast.clause, **kwargs), self.process(cast.typeclause, **kwargs)) + + def visit_extract(self, extract, **kwargs): + field = self.extract_map.get(extract.field, extract.field) + return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr, **kwargs)) + + def visit_function(self, func, result_map=None, **kwargs): + if result_map is not None: + result_map[func.name.lower()] = (func.name, None, func.type) + + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + if disp: + return disp(func, **kwargs) + else: + name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s") + return ".".join(func.packagenames + [name]) % \ + {'expr':self.function_argspec(func, **kwargs)} + + def function_argspec(self, func, **kwargs): + return self.process(func.clause_expr, **kwargs) + + def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=1, **kwargs): + entry = self.stack and self.stack[-1] or {} + self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) + + keyword = self.compound_keywords.get(cs.keyword) + + text = (" " + keyword + " ").join( + (self.process(c, asfrom=asfrom, parens=False, + compound_index=i, **kwargs) + for i, c in enumerate(cs.selects)) + ) + + group_by = self.process(cs._group_by_clause, asfrom=asfrom, **kwargs) + if group_by: + text += " GROUP BY " + group_by + + text += self.order_by_clause(cs, **kwargs) + text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or "" + + self.stack.pop(-1) + if asfrom and parens: + return "(" + text + ")" + else: + return text + + def visit_unary(self, unary, **kw): + s = self.process(unary.element, **kw) + if unary.operator: + s = OPERATORS[unary.operator] + s + if unary.modifier: + s = s + OPERATORS[unary.modifier] + return s + + def visit_binary(self, binary, **kw): + # don't allow "? = ?" to render + if self.ansi_bind_rules and \ + isinstance(binary.left, sql._BindParamClause) and \ + isinstance(binary.right, sql._BindParamClause): + kw['literal_binds'] = True + + return self._operator_dispatch(binary.operator, + binary, + lambda opstr: self.process(binary.left, **kw) + + opstr + + self.process(binary.right, **kw), + **kw + ) + + def visit_like_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s LIKE %s' % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_notlike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT LIKE %s' % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_ilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) LIKE lower(%s)' % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_notilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) NOT LIKE lower(%s)' % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def _operator_dispatch(self, operator, element, fn, **kw): + if util.callable(operator): + disp = getattr(self, "visit_%s" % operator.__name__, None) + if disp: + return disp(element, **kw) + else: + return fn(OPERATORS[operator]) + else: + return fn(" " + operator + " ") + + def visit_bindparam(self, bindparam, within_columns_clause=False, + literal_binds=False, **kwargs): + if literal_binds or \ + (within_columns_clause and \ + self.ansi_bind_rules): + if bindparam.value is None: + raise exc.CompileError("Bind parameter without a " + "renderable value not allowed here.") + return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs) + + name = self._truncate_bindparam(bindparam) + if name in self.binds: + existing = self.binds[name] + if existing is not bindparam: + if existing.unique or bindparam.unique: + raise exc.CompileError( + "Bind parameter '%s' conflicts with " + "unique bind parameter of the same name" % bindparam.key + ) + elif getattr(existing, '_is_crud', False): + raise exc.CompileError( + "Bind parameter name '%s' is reserved " + "for the VALUES or SET clause of this insert/update statement." + % bindparam.key + ) + + self.binds[bindparam.key] = self.binds[name] = bindparam + return self.bindparam_string(name) + + def render_literal_bindparam(self, bindparam, **kw): + value = bindparam.value + processor = bindparam.bind_processor(self.dialect) + if processor: + value = processor(value) + return self.render_literal_value(value, bindparam.type) + + def render_literal_value(self, value, type_): + """Render the value of a bind parameter as a quoted literal. + + This is used for statement sections that do not accept bind paramters + on the target driver/database. + + This should be implemented by subclasses using the quoting services + of the DBAPI. + + """ + if isinstance(value, basestring): + value = value.replace("'", "''") + return "'%s'" % value + elif value is None: + return "NULL" + elif isinstance(value, (float, int, long)): + return repr(value) + elif isinstance(value, decimal.Decimal): + return str(value) + else: + raise NotImplementedError("Don't know how to literal-quote value %r" % value) + + def _truncate_bindparam(self, bindparam): + if bindparam in self.bind_names: + return self.bind_names[bindparam] + + bind_name = bindparam.key + bind_name = isinstance(bind_name, sql._generated_label) and \ + self._truncated_identifier("bindparam", bind_name) or bind_name + # add to bind_names for translation + self.bind_names[bindparam] = bind_name + + return bind_name + + def _truncated_identifier(self, ident_class, name): + if (ident_class, name) in self.truncated_names: + return self.truncated_names[(ident_class, name)] + + anonname = name % self.anon_map + + if len(anonname) > self.label_length: + counter = self.truncated_names.get(ident_class, 1) + truncname = anonname[0:max(self.label_length - 6, 0)] + "_" + hex(counter)[2:] + self.truncated_names[ident_class] = counter + 1 + else: + truncname = anonname + self.truncated_names[(ident_class, name)] = truncname + return truncname + + def _anonymize(self, name): + return name % self.anon_map + + def _process_anon(self, key): + (ident, derived) = key.split(' ', 1) + anonymous_counter = self.anon_map.get(derived, 1) + self.anon_map[derived] = anonymous_counter + 1 + return derived + "_" + str(anonymous_counter) + + def bindparam_string(self, name): + if self.positional: + self.positiontup.append(name) + return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} + else: + return self.bindtemplate % {'name':name} + + def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs): + if asfrom or ashint: + alias_name = isinstance(alias.name, sql._generated_label) and \ + self._truncated_identifier("alias", alias.name) or alias.name + if ashint: + return self.preparer.format_alias(alias, alias_name) + elif asfrom: + ret = self.process(alias.original, asfrom=True, **kwargs) + " AS " + \ + self.preparer.format_alias(alias, alias_name) + + if fromhints and alias in fromhints: + hinttext = self.get_from_hint_text(alias, fromhints[alias]) + if hinttext: + ret += " " + hinttext + + return ret + else: + return self.process(alias.original, **kwargs) + + def label_select_column(self, select, column, asfrom): + """label columns present in a select().""" + + if isinstance(column, sql._Label): + return column + + if select is not None and select.use_labels and column._label: + return _CompileLabel(column, column._label) + + if \ + asfrom and \ + isinstance(column, sql.ColumnClause) and \ + not column.is_literal and \ + column.table is not None and \ + not isinstance(column.table, sql.Select): + return _CompileLabel(column, sql._generated_label(column.name)) + elif not isinstance(column, + (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ + and (not hasattr(column, 'name') or isinstance(column, sql.Function)): + return _CompileLabel(column, column.anon_label) + else: + return column + + def get_select_hint_text(self, byfroms): + return None + + def get_from_hint_text(self, table, text): + return None + + def visit_select(self, select, asfrom=False, parens=True, + iswrapper=False, fromhints=None, + compound_index=1, **kwargs): + + entry = self.stack and self.stack[-1] or {} + + existingfroms = entry.get('from', None) + + froms = select._get_display_froms(existingfroms) + + correlate_froms = set(sql._from_objects(*froms)) + + # TODO: might want to propagate existing froms for select(select(select)) + # where innermost select should correlate to outermost + # if existingfroms: + # correlate_froms = correlate_froms.union(existingfroms) + + self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper}) + + if compound_index==1 and not entry or entry.get('iswrapper', False): + column_clause_args = {'result_map':self.result_map} + else: + column_clause_args = {} + + # the actual list of columns to print in the SELECT column list. + inner_columns = [ + c for c in [ + self.process( + self.label_select_column(select, co, asfrom=asfrom), + within_columns_clause=True, + **column_clause_args) + for co in util.unique_list(select.inner_columns) + ] + if c is not None + ] + + text = "SELECT " # we're off to a good start ! + + if select._hints: + byfrom = dict([ + (from_, hinttext % {'name':self.process(from_, ashint=True)}) + for (from_, dialect), hinttext in + select._hints.iteritems() + if dialect in ('*', self.dialect.name) + ]) + hint_text = self.get_select_hint_text(byfrom) + if hint_text: + text += hint_text + " " + + if select._prefixes: + text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " " + text += self.get_select_precolumns(select) + text += ', '.join(inner_columns) + + if froms: + text += " \nFROM " + + if select._hints: + text += ', '.join([self.process(f, + asfrom=True, fromhints=byfrom, + **kwargs) + for f in froms]) + else: + text += ', '.join([self.process(f, + asfrom=True, **kwargs) + for f in froms]) + else: + text += self.default_from() + + if select._whereclause is not None: + t = self.process(select._whereclause, **kwargs) + if t: + text += " \nWHERE " + t + + if select._group_by_clause.clauses: + group_by = self.process(select._group_by_clause, **kwargs) + if group_by: + text += " GROUP BY " + group_by + + if select._having is not None: + t = self.process(select._having, **kwargs) + if t: + text += " \nHAVING " + t + + if select._order_by_clause.clauses: + text += self.order_by_clause(select, **kwargs) + if select._limit is not None or select._offset is not None: + text += self.limit_clause(select) + if select.for_update: + text += self.for_update_clause(select) + + self.stack.pop(-1) + + if asfrom and parens: + return "(" + text + ")" + else: + return text + + def get_select_precolumns(self, select): + """Called when building a ``SELECT`` statement, position is just before + column list. + + """ + return select._distinct and "DISTINCT " or "" + + def order_by_clause(self, select, **kw): + order_by = self.process(select._order_by_clause, **kw) + if order_by: + return " ORDER BY " + order_by + else: + return "" + + def for_update_clause(self, select): + if select.for_update: + return " FOR UPDATE" + else: + return "" + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT -1" + text += " OFFSET " + str(select._offset) + return text + + def visit_table(self, table, asfrom=False, ashint=False, fromhints=None, **kwargs): + if asfrom or ashint: + if getattr(table, "schema", None): + ret = self.preparer.quote_schema(table.schema, table.quote_schema) + \ + "." + self.preparer.quote(table.name, table.quote) + else: + ret = self.preparer.quote(table.name, table.quote) + if fromhints and table in fromhints: + hinttext = self.get_from_hint_text(table, fromhints[table]) + if hinttext: + ret += " " + hinttext + return ret + else: + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + return (self.process(join.left, asfrom=True, **kwargs) + \ + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ + self.process(join.right, asfrom=True, **kwargs) + " ON " + \ + self.process(join.onclause, **kwargs)) + + def visit_sequence(self, seq): + return None + + def visit_insert(self, insert_stmt): + self.isinsert = True + colparams = self._get_colparams(insert_stmt) + + if not colparams and \ + not self.dialect.supports_default_values and \ + not self.dialect.supports_empty_insert: + raise exc.CompileError("The version of %s you are using does " + "not support empty inserts." % + self.dialect.name) + + preparer = self.preparer + supports_default_values = self.dialect.supports_default_values + + text = "INSERT" + + prefixes = [self.process(x) for x in insert_stmt._prefixes] + if prefixes: + text += " " + " ".join(prefixes) + + text += " INTO " + preparer.format_table(insert_stmt.table) + + if colparams or not supports_default_values: + text += " (%s)" % ', '.join([preparer.format_column(c[0]) + for c in colparams]) + + if self.returning or insert_stmt._returning: + self.returning = self.returning or insert_stmt._returning + returning_clause = self.returning_clause(insert_stmt, self.returning) + + if self.returning_precedes_values: + text += " " + returning_clause + + if not colparams and supports_default_values: + text += " DEFAULT VALUES" + else: + text += " VALUES (%s)" % \ + ', '.join([c[1] for c in colparams]) + + if self.returning and not self.returning_precedes_values: + text += " " + returning_clause + + return text + + def visit_update(self, update_stmt): + self.stack.append({'from': set([update_stmt.table])}) + + self.isupdate = True + colparams = self._get_colparams(update_stmt) + + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + + text += ' SET ' + \ + ', '.join( + self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] + for c in colparams + ) + + if update_stmt._returning: + self.returning = update_stmt._returning + if self.returning_precedes_values: + text += " " + self.returning_clause(update_stmt, update_stmt._returning) + + if update_stmt._whereclause is not None: + text += " WHERE " + self.process(update_stmt._whereclause) + + if self.returning and not self.returning_precedes_values: + text += " " + self.returning_clause(update_stmt, update_stmt._returning) + + self.stack.pop(-1) + + return text + + def _create_crud_bind_param(self, col, value, required=False): + bindparam = sql.bindparam(col.key, value, type_=col.type, required=required) + bindparam._is_crud = True + if col.key in self.binds: + raise exc.CompileError( + "Bind parameter name '%s' is reserved " + "for the VALUES or SET clause of this insert/update statement." + % col.key + ) + + self.binds[col.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + + def _get_colparams(self, stmt): + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + Also generates the Compiled object's postfetch, prefetch, and returning + column collections, used for default handling and ultimately + populating the ResultProxy's prefetch_cols() and postfetch_cols() + collections. + + """ + + self.postfetch = [] + self.prefetch = [] + self.returning = [] + + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns + if self.column_keys is None and stmt.parameters is None: + return [ + (c, self._create_crud_bind_param(c, None, required=True)) + for c in stmt.table.columns + ] + + required = object() + + # if we have statement parameters - set defaults in the + # compiled params + if self.column_keys is None: + parameters = {} + else: + parameters = dict((sql._column_as_key(key), required) + for key in self.column_keys + if not stmt.parameters or key not in stmt.parameters) + + if stmt.parameters is not None: + for k, v in stmt.parameters.iteritems(): + parameters.setdefault(sql._column_as_key(k), v) + + # create a list of column assignment clauses as tuples + values = [] + + need_pks = self.isinsert and \ + not self.inline and \ + not stmt._returning + + implicit_returning = need_pks and \ + self.dialect.implicit_returning and \ + stmt.table.implicit_returning + + postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid + + # iterating through columns at the top to maintain ordering. + # otherwise we might iterate through individual sets of + # "defaults", "primary key cols", etc. + for c in stmt.table.columns: + if c.key in parameters: + value = parameters[c.key] + if sql._is_literal(value): + value = self._create_crud_bind_param(c, value, required=value is required) + else: + self.postfetch.append(c) + value = self.process(value.self_group()) + values.append((c, value)) + + elif self.isinsert: + if c.primary_key and \ + need_pks and \ + ( + implicit_returning or + not postfetch_lastrowid or + c is not stmt.table._autoincrement_column + ): + + if implicit_returning: + if c.default is not None: + if c.default.is_sequence: + proc = self.process(c.default) + if proc is not None: + values.append((c, proc)) + self.returning.append(c) + elif c.default.is_clause_element: + values.append((c, self.process(c.default.arg.self_group()))) + self.returning.append(c) + else: + values.append((c, self._create_crud_bind_param(c, None))) + self.prefetch.append(c) + else: + self.returning.append(c) + else: + if ( + c.default is not None and \ + ( + self.dialect.supports_sequences or + not c.default.is_sequence + ) + ) or self.dialect.preexecute_autoincrement_sequences: + + values.append((c, self._create_crud_bind_param(c, None))) + self.prefetch.append(c) + + elif c.default is not None: + if c.default.is_sequence: + proc = self.process(c.default) + if proc is not None: + values.append((c, proc)) + if not c.primary_key: + self.postfetch.append(c) + elif c.default.is_clause_element: + values.append((c, self.process(c.default.arg.self_group()))) + + if not c.primary_key: + # dont add primary key column to postfetch + self.postfetch.append(c) + else: + values.append((c, self._create_crud_bind_param(c, None))) + self.prefetch.append(c) + elif c.server_default is not None: + if not c.primary_key: + self.postfetch.append(c) + + elif self.isupdate: + if c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append((c, self.process(c.onupdate.arg.self_group()))) + self.postfetch.append(c) + else: + values.append((c, self._create_crud_bind_param(c, None))) + self.prefetch.append(c) + elif c.server_onupdate is not None: + self.postfetch.append(c) + return values + + def visit_delete(self, delete_stmt): + self.stack.append({'from': set([delete_stmt.table])}) + self.isdelete = True + + text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) + + if delete_stmt._returning: + self.returning = delete_stmt._returning + if self.returning_precedes_values: + text += " " + self.returning_clause(delete_stmt, delete_stmt._returning) + + if delete_stmt._whereclause is not None: + text += " WHERE " + self.process(delete_stmt._whereclause) + + if self.returning and not self.returning_precedes_values: + text += " " + self.returning_clause(delete_stmt, delete_stmt._returning) + + self.stack.pop(-1) + + return text + + def visit_savepoint(self, savepoint_stmt): + return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_release_savepoint(self, savepoint_stmt): + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + +class DDLCompiler(engine.Compiled): + + @util.memoized_property + def sql_compiler(self): + return self.dialect.statement_compiler(self.dialect, self.statement) + + @property + def preparer(self): + return self.dialect.identifier_preparer + + def construct_params(self, params=None): + return None + + def visit_ddl(self, ddl, **kwargs): + # table events can substitute table and schema name + context = ddl.context + if isinstance(ddl.target, schema.Table): + context = context.copy() + + preparer = self.dialect.identifier_preparer + path = preparer.format_table_seq(ddl.target) + if len(path) == 1: + table, sch = path[0], '' + else: + table, sch = path[-1], path[0] + + context.setdefault('table', table) + context.setdefault('schema', sch) + context.setdefault('fullname', preparer.format_table(ddl.target)) + + return ddl.statement % context + + def visit_create_table(self, create): + table = create.element + preparer = self.dialect.identifier_preparer + + text = "\n" + " ".join(['CREATE'] + \ + table._prefixes + \ + ['TABLE', + preparer.format_table(table), + "("]) + separator = "\n" + + # if only one primary key, specify it along with the column + first_pk = False + for column in table.columns: + text += separator + separator = ", \n" + text += "\t" + self.get_column_specification( + column, + first_pk=column.primary_key and not first_pk + ) + if column.primary_key: + first_pk = True + const = " ".join(self.process(constraint) for constraint in column.constraints) + if const: + text += " " + const + + const = self.create_table_constraints(table) + if const: + text += ", \n\t" + const + + text += "\n)%s\n\n" % self.post_create_table(table) + return text + + def create_table_constraints(self, table): + + # On some DB order is significant: visit PK first, then the + # other constraints (engine.ReflectionTest.testbasic failed on FB2) + constraints = [] + if table.primary_key: + constraints.append(table.primary_key) + + constraints.extend([c for c in table.constraints if c is not table.primary_key]) + + return ", \n\t".join(p for p in + (self.process(constraint) for constraint in constraints + if ( + constraint._create_rule is None or + constraint._create_rule(self)) + and ( + not self.dialect.supports_alter or + not getattr(constraint, 'use_alter', False) + )) if p is not None + ) + + def visit_drop_table(self, drop): + return "\nDROP TABLE " + self.preparer.format_table(drop.element) + + def visit_create_index(self, create): + index = create.element + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % (preparer.quote(self._validate_identifier(index.name, True), index.quote), + preparer.format_table(index.table), + ', '.join(preparer.quote(c.name, c.quote) + for c in index.columns)) + return text + + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX " + \ + self.preparer.quote(self._validate_identifier(index.name, False), index.quote) + + def visit_add_constraint(self, create): + preparer = self.preparer + return "ALTER TABLE %s ADD %s" % ( + self.preparer.format_table(create.element.table), + self.process(create.element) + ) + + def visit_create_sequence(self, create): + text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + if create.element.increment is not None: + text += " INCREMENT BY %d" % create.element.increment + if create.element.start is not None: + text += " START WITH %d" % create.element.start + return text + + def visit_drop_sequence(self, drop): + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + + def visit_drop_constraint(self, drop): + preparer = self.preparer + return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( + self.preparer.format_table(drop.element.table), + self.preparer.format_constraint(drop.element), + drop.cascade and " CASCADE" or "" + ) + + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + \ + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + return colspec + + def post_create_table(self, table): + return '' + + def _validate_identifier(self, ident, truncate): + if truncate: + if len(ident) > self.dialect.max_identifier_length: + counter = getattr(self, 'counter', 0) + self.counter = counter + 1 + return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] + else: + return ident + else: + self.dialect.validate_identifier(ident) + return ident + + def get_column_default_string(self, column): + if isinstance(column.server_default, schema.DefaultClause): + if isinstance(column.server_default.arg, basestring): + return "'%s'" % column.server_default.arg + else: + return self.sql_compiler.process(column.server_default.arg) + else: + return None + + def visit_check_constraint(self, constraint): + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + sqltext = sql_util.expression_as_ddl(constraint.sqltext) + text += "CHECK (%s)" % self.sql_compiler.process(sqltext) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_column_check_constraint(self, constraint): + text = " CHECK (%s)" % constraint.sqltext + text += self.define_constraint_deferrability(constraint) + return text + + def visit_primary_key_constraint(self, constraint): + if len(constraint) == 0: + return '' + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + text += "PRIMARY KEY " + text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) + for c in constraint) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_foreign_key_constraint(self, constraint): + preparer = self.dialect.identifier_preparer + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + preparer.format_constraint(constraint) + remote_table = list(constraint._elements.values())[0].column.table + text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + ', '.join(preparer.quote(f.parent.name, f.parent.quote) + for f in constraint._elements.values()), + preparer.format_table(remote_table), + ', '.join(preparer.quote(f.column.name, f.column.quote) + for f in constraint._elements.values()) + ) + text += self.define_constraint_cascades(constraint) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_unique_constraint(self, constraint): + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)) + text += self.define_constraint_deferrability(constraint) + return text + + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + if constraint.onupdate is not None: + text += " ON UPDATE %s" % constraint.onupdate + return text + + def define_constraint_deferrability(self, constraint): + text = "" + if constraint.deferrable is not None: + if constraint.deferrable: + text += " DEFERRABLE" + else: + text += " NOT DEFERRABLE" + if constraint.initially is not None: + text += " INITIALLY %s" % constraint.initially + return text + + +class GenericTypeCompiler(engine.TypeCompiler): + def visit_CHAR(self, type_): + return "CHAR" + (type_.length and "(%d)" % type_.length or "") + + def visit_NCHAR(self, type_): + return "NCHAR" + (type_.length and "(%d)" % type_.length or "") + + def visit_FLOAT(self, type_): + return "FLOAT" + + def visit_NUMERIC(self, type_): + if type_.precision is None: + return "NUMERIC" + elif type_.scale is None: + return "NUMERIC(%(precision)s)" % {'precision': type_.precision} + else: + return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale} + + def visit_DECIMAL(self, type_): + return "DECIMAL" + + def visit_INTEGER(self, type_): + return "INTEGER" + + def visit_SMALLINT(self, type_): + return "SMALLINT" + + def visit_BIGINT(self, type_): + return "BIGINT" + + def visit_TIMESTAMP(self, type_): + return 'TIMESTAMP' + + def visit_DATETIME(self, type_): + return "DATETIME" + + def visit_DATE(self, type_): + return "DATE" + + def visit_TIME(self, type_): + return "TIME" + + def visit_CLOB(self, type_): + return "CLOB" + + def visit_NCLOB(self, type_): + return "NCLOB" + + def visit_VARCHAR(self, type_): + return "VARCHAR" + (type_.length and "(%d)" % type_.length or "") + + def visit_NVARCHAR(self, type_): + return "NVARCHAR" + (type_.length and "(%d)" % type_.length or "") + + def visit_BLOB(self, type_): + return "BLOB" + + def visit_BINARY(self, type_): + return "BINARY" + (type_.length and "(%d)" % type_.length or "") + + def visit_VARBINARY(self, type_): + return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") + + def visit_BOOLEAN(self, type_): + return "BOOLEAN" + + def visit_TEXT(self, type_): + return "TEXT" + + def visit_large_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_boolean(self, type_): + return self.visit_BOOLEAN(type_) + + def visit_time(self, type_): + return self.visit_TIME(type_) + + def visit_datetime(self, type_): + return self.visit_DATETIME(type_) + + def visit_date(self, type_): + return self.visit_DATE(type_) + + def visit_big_integer(self, type_): + return self.visit_BIGINT(type_) + + def visit_small_integer(self, type_): + return self.visit_SMALLINT(type_) + + def visit_integer(self, type_): + return self.visit_INTEGER(type_) + + def visit_float(self, type_): + return self.visit_FLOAT(type_) + + def visit_numeric(self, type_): + return self.visit_NUMERIC(type_) + + def visit_string(self, type_): + return self.visit_VARCHAR(type_) + + def visit_unicode(self, type_): + return self.visit_VARCHAR(type_) + + def visit_text(self, type_): + return self.visit_TEXT(type_) + + def visit_unicode_text(self, type_): + return self.visit_TEXT(type_) + + def visit_enum(self, type_): + return self.visit_VARCHAR(type_) + + def visit_null(self, type_): + raise NotImplementedError("Can't generate DDL for the null type") + + def visit_type_decorator(self, type_): + return self.process(type_.type_engine(self.dialect)) + + def visit_user_defined(self, type_): + return type_.get_col_spec() + +class IdentifierPreparer(object): + """Handle quoting and case-folding of identifiers based on options.""" + + reserved_words = RESERVED_WORDS + + legal_characters = LEGAL_CHARACTERS + + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + + def __init__(self, dialect, initial_quote='"', + final_quote=None, escape_quote='"', omit_schema=False): + """Construct a new ``IdentifierPreparer`` object. + + initial_quote + Character that begins a delimited identifier. + + final_quote + Character that ends a delimited identifier. Defaults to `initial_quote`. + + omit_schema + Prevent prepending schema name. Useful for databases that do + not support schemae. + """ + + self.dialect = dialect + self.initial_quote = initial_quote + self.final_quote = final_quote or self.initial_quote + self.escape_quote = escape_quote + self.escape_to_quote = self.escape_quote * 2 + self.omit_schema = omit_schema + self._strings = {} + + def _escape_identifier(self, value): + """Escape an identifier. + + Subclasses should override this to provide database-dependent + escaping behavior. + """ + + return value.replace(self.escape_quote, self.escape_to_quote) + + def _unescape_identifier(self, value): + """Canonicalize an escaped identifier. + + Subclasses should override this to provide database-dependent + unescaping behavior that reverses _escape_identifier. + """ + + return value.replace(self.escape_to_quote, self.escape_quote) + + def quote_identifier(self, value): + """Quote an identifier. + + Subclasses should override this to provide database-dependent + quoting behavior. + """ + + return self.initial_quote + self._escape_identifier(value) + self.final_quote + + def _requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return (lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(unicode(value)) + or (lc_value != value)) + + def quote_schema(self, schema, force): + """Quote a schema. + + Subclasses should override this to provide database-dependent + quoting behavior. + """ + return self.quote(schema, force) + + def quote(self, ident, force): + if force is None: + if ident in self._strings: + return self._strings[ident] + else: + if self._requires_quotes(ident): + self._strings[ident] = self.quote_identifier(ident) + else: + self._strings[ident] = ident + return self._strings[ident] + elif force: + return self.quote_identifier(ident) + else: + return ident + + def format_sequence(self, sequence, use_schema=True): + name = self.quote(sequence.name, sequence.quote) + if not self.omit_schema and use_schema and sequence.schema is not None: + name = self.quote_schema(sequence.schema, sequence.quote) + "." + name + return name + + def format_label(self, label, name=None): + return self.quote(name or label.name, label.quote) + + def format_alias(self, alias, name=None): + return self.quote(name or alias.name, alias.quote) + + def format_savepoint(self, savepoint, name=None): + return self.quote(name or savepoint.ident, savepoint.quote) + + def format_constraint(self, constraint): + return self.quote(constraint.name, constraint.quote) + + def format_table(self, table, use_schema=True, name=None): + """Prepare a quoted table and schema name.""" + + if name is None: + name = table.name + result = self.quote(name, table.quote) + if not self.omit_schema and use_schema and getattr(table, "schema", None): + result = self.quote_schema(table.schema, table.quote_schema) + "." + result + return result + + def format_column(self, column, use_table=False, name=None, table_name=None): + """Prepare a quoted column name.""" + + if name is None: + name = column.name + if not getattr(column, 'is_literal', False): + if use_table: + return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote) + else: + return self.quote(name, column.quote) + else: + # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted + if use_table: + return self.format_table(column.table, use_schema=False, name=table_name) + "." + name + else: + return name + + def format_table_seq(self, table, use_schema=True): + """Format table name and schema as a tuple.""" + + # Dialects with more levels in their fully qualified references + # ('database', 'owner', etc.) could override this and return + # a longer sequence. + + if not self.omit_schema and use_schema and getattr(table, 'schema', None): + return (self.quote_schema(table.schema, table.quote_schema), + self.format_table(table, use_schema=False)) + else: + return (self.format_table(table, use_schema=False), ) + + @util.memoized_property + def _r_identifiers(self): + initial, final, escaped_final = \ + [re.escape(s) for s in + (self.initial_quote, self.final_quote, + self._escape_identifier(self.final_quote))] + r = re.compile( + r'(?:' + r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' + r'|([^\.]+))(?=\.|$))+' % + { 'initial': initial, + 'final': final, + 'escaped': escaped_final }) + return r + + def unformat_identifiers(self, identifiers): + """Unpack 'schema.table.column'-like strings into components.""" + + r = self._r_identifiers + return [self._unescape_identifier(i) + for i in [a or b for a, b in r.findall(identifiers)]] diff --git a/sqlalchemy/sql/expression.py b/sqlalchemy/sql/expression.py new file mode 100644 index 0000000..3aaa06f --- /dev/null +++ b/sqlalchemy/sql/expression.py @@ -0,0 +1,4258 @@ +# expression.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Defines the base components of SQL expression trees. + +All components are derived from a common base class +:class:`ClauseElement`. Common behaviors are organized +based on class hierarchies, in some cases via mixins. + +All object construction from this package occurs via functions which +in some cases will construct composite :class:`ClauseElement` structures +together, and in other cases simply return a single :class:`ClauseElement` +constructed directly. The function interface affords a more "DSL-ish" +feel to constructing SQL expressions and also allows future class +reorganizations. + +Even though classes are not constructed directly from the outside, +most classes which have additional public methods are considered to be +public (i.e. have no leading underscore). Other classes which are +"semi-public" are marked with a single leading underscore; these +classes usually have few or no public methods and are less guaranteed +to stay the same in future releases. + +""" + +import itertools, re +from operator import attrgetter + +from sqlalchemy import util, exc #, types as sqltypes +from sqlalchemy.sql import operators +from sqlalchemy.sql.visitors import Visitable, cloned_traverse +import operator + +functions, schema, sql_util, sqltypes = None, None, None, None +DefaultDialect, ClauseAdapter, Annotated = None, None, None + +__all__ = [ + 'Alias', 'ClauseElement', + 'ColumnCollection', 'ColumnElement', + 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', + 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', + 'between', 'bindparam', 'case', 'cast', 'column', 'delete', + 'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', + 'modifier', 'collate', + 'insert', 'intersect', 'intersect_all', 'join', 'label', 'literal', + 'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select', + 'subquery', 'table', 'text', 'tuple_', 'union', 'union_all', 'update', ] + +PARSE_AUTOCOMMIT = util._symbol('PARSE_AUTOCOMMIT') + +def desc(column): + """Return a descending ``ORDER BY`` clause element. + + e.g.:: + + order_by = [desc(table1.mycol)] + + """ + return _UnaryExpression(column, modifier=operators.desc_op) + +def asc(column): + """Return an ascending ``ORDER BY`` clause element. + + e.g.:: + + order_by = [asc(table1.mycol)] + + """ + return _UnaryExpression(column, modifier=operators.asc_op) + +def outerjoin(left, right, onclause=None): + """Return an ``OUTER JOIN`` clause element. + + The returned object is an instance of :class:`Join`. + + Similar functionality is also available via the :func:`outerjoin()` + method on any :class:`FromClause`. + + left + The left side of the join. + + right + The right side of the join. + + onclause + Optional criterion for the ``ON`` clause, is derived from + foreign key relationships established between left and right + otherwise. + + To chain joins together, use the :func:`join()` or :func:`outerjoin()` + methods on the resulting :class:`Join` object. + + """ + return Join(left, right, onclause, isouter=True) + +def join(left, right, onclause=None, isouter=False): + """Return a ``JOIN`` clause element (regular inner join). + + The returned object is an instance of :class:`Join`. + + Similar functionality is also available via the :func:`join()` method + on any :class:`FromClause`. + + left + The left side of the join. + + right + The right side of the join. + + onclause + Optional criterion for the ``ON`` clause, is derived from + foreign key relationships established between left and right + otherwise. + + To chain joins together, use the :func:`join()` or :func:`outerjoin()` + methods on the resulting :class:`Join` object. + + """ + return Join(left, right, onclause, isouter) + +def select(columns=None, whereclause=None, from_obj=[], **kwargs): + """Returns a ``SELECT`` clause element. + + Similar functionality is also available via the :func:`select()` + method on any :class:`FromClause`. + + The returned object is an instance of :class:`Select`. + + All arguments which accept :class:`ClauseElement` arguments also accept + string arguments, which will be converted as appropriate into + either :func:`text()` or :func:`literal_column()` constructs. + + columns + A list of :class:`ClauseElement` objects, typically :class:`ColumnElement` + objects or subclasses, which will form the columns clause of the + resulting statement. For all members which are instances of + :class:`Selectable`, the individual :class:`ColumnElement` members of the + :class:`Selectable` will be added individually to the columns clause. + For example, specifying a :class:`~sqlalchemy.schema.Table` instance will result in all + the contained :class:`~sqlalchemy.schema.Column` objects within to be added to the + columns clause. + + This argument is not present on the form of :func:`select()` + available on :class:`~sqlalchemy.schema.Table`. + + whereclause + A :class:`ClauseElement` expression which will be used to form the + ``WHERE`` clause. + + from_obj + A list of :class:`ClauseElement` objects which will be added to the + ``FROM`` clause of the resulting statement. Note that "from" + objects are automatically located within the columns and + whereclause ClauseElements. Use this parameter to explicitly + specify "from" objects which are not automatically locatable. + This could include :class:`~sqlalchemy.schema.Table` objects that aren't otherwise + present, or :class:`Join` objects whose presence will supercede that + of the :class:`~sqlalchemy.schema.Table` objects already located in the other clauses. + + \**kwargs + Additional parameters include: + + autocommit + Deprecated. Use .execution_options(autocommit=) + to set the autocommit option. + + prefixes + a list of strings or :class:`ClauseElement` objects to include + directly after the SELECT keyword in the generated statement, + for dialect-specific query features. + + distinct=False + when ``True``, applies a ``DISTINCT`` qualifier to the columns + clause of the resulting statement. + + use_labels=False + when ``True``, the statement will be generated using labels + for each column in the columns clause, which qualify each + column with its parent table's (or aliases) name so that name + conflicts between columns in different tables don't occur. + The format of the label is _. The "c" + collection of the resulting :class:`Select` object will use these + names as well for targeting column members. + + for_update=False + when ``True``, applies ``FOR UPDATE`` to the end of the + resulting statement. Certain database dialects also support + alternate values for this parameter, for example mysql + supports "read" which translates to ``LOCK IN SHARE MODE``, + and oracle supports "nowait" which translates to ``FOR UPDATE + NOWAIT``. + + correlate=True + indicates that this :class:`Select` object should have its + contained :class:`FromClause` elements "correlated" to an enclosing + :class:`Select` object. This means that any :class:`ClauseElement` + instance within the "froms" collection of this :class:`Select` + which is also present in the "froms" collection of an + enclosing select will not be rendered in the ``FROM`` clause + of this select statement. + + group_by + a list of :class:`ClauseElement` objects which will comprise the + ``GROUP BY`` clause of the resulting select. + + having + a :class:`ClauseElement` that will comprise the ``HAVING`` clause + of the resulting select when ``GROUP BY`` is used. + + order_by + a scalar or list of :class:`ClauseElement` objects which will + comprise the ``ORDER BY`` clause of the resulting select. + + limit=None + a numerical value which usually compiles to a ``LIMIT`` + expression in the resulting select. Databases that don't + support ``LIMIT`` will attempt to provide similar + functionality. + + offset=None + a numeric value which usually compiles to an ``OFFSET`` + expression in the resulting select. Databases that don't + support ``OFFSET`` will attempt to provide similar + functionality. + + bind=None + an ``Engine`` or ``Connection`` instance to which the + resulting ``Select ` object will be bound. The ``Select`` + object will otherwise automatically bind to whatever + ``Connectable`` instances can be located within its contained + :class:`ClauseElement` members. + + """ + return Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) + +def subquery(alias, *args, **kwargs): + """Return an :class:`Alias` object derived + from a :class:`Select`. + + name + alias name + + \*args, \**kwargs + + all other arguments are delivered to the + :func:`select` function. + + """ + return Select(*args, **kwargs).alias(alias) + +def insert(table, values=None, inline=False, **kwargs): + """Return an :class:`Insert` clause element. + + Similar functionality is available via the :func:`insert()` method on + :class:`~sqlalchemy.schema.Table`. + + :param table: The table to be inserted into. + + :param values: A dictionary which specifies the column specifications of the + ``INSERT``, and is optional. If left as None, the column + specifications are determined from the bind parameters used + during the compile phase of the ``INSERT`` statement. If the + bind parameters also are None during the compile phase, then the + column specifications will be generated from the full list of + table columns. Note that the :meth:`~Insert.values()` generative method + may also be used for this. + + :param prefixes: A list of modifier keywords to be inserted between INSERT + and INTO. Alternatively, the :meth:`~Insert.prefix_with` generative method + may be used. + + :param inline: if True, SQL defaults will be compiled 'inline' into the + statement and not pre-executed. + + If both `values` and compile-time bind parameters are present, the + compile-time bind parameters override the information specified + within `values` on a per-key basis. + + The keys within `values` can be either :class:`~sqlalchemy.schema.Column` objects or their + string identifiers. Each key may reference one of: + + * a literal data value (i.e. string, number, etc.); + * a Column object; + * a SELECT statement. + + If a ``SELECT`` statement is specified which references this + ``INSERT`` statement's table, the statement will be correlated + against the ``INSERT`` statement. + + """ + return Insert(table, values, inline=inline, **kwargs) + +def update(table, whereclause=None, values=None, inline=False, **kwargs): + """Return an :class:`Update` clause element. + + Similar functionality is available via the :func:`update()` method on + :class:`~sqlalchemy.schema.Table`. + + :param table: The table to be updated. + + :param whereclause: A :class:`ClauseElement` describing the ``WHERE`` condition + of the ``UPDATE`` statement. Note that the :meth:`~Update.where()` + generative method may also be used for this. + + :param values: + A dictionary which specifies the ``SET`` conditions of the + ``UPDATE``, and is optional. If left as None, the ``SET`` + conditions are determined from the bind parameters used during + the compile phase of the ``UPDATE`` statement. If the bind + parameters also are None during the compile phase, then the + ``SET`` conditions will be generated from the full list of table + columns. Note that the :meth:`~Update.values()` generative method may + also be used for this. + + :param inline: + if True, SQL defaults will be compiled 'inline' into the statement + and not pre-executed. + + If both `values` and compile-time bind parameters are present, the + compile-time bind parameters override the information specified + within `values` on a per-key basis. + + The keys within `values` can be either :class:`~sqlalchemy.schema.Column` objects or their + string identifiers. Each key may reference one of: + + * a literal data value (i.e. string, number, etc.); + * a Column object; + * a SELECT statement. + + If a ``SELECT`` statement is specified which references this + ``UPDATE`` statement's table, the statement will be correlated + against the ``UPDATE`` statement. + + """ + return Update( + table, + whereclause=whereclause, + values=values, + inline=inline, + **kwargs) + +def delete(table, whereclause = None, **kwargs): + """Return a :class:`Delete` clause element. + + Similar functionality is available via the :func:`delete()` method on + :class:`~sqlalchemy.schema.Table`. + + :param table: The table to be updated. + + :param whereclause: A :class:`ClauseElement` describing the ``WHERE`` + condition of the ``UPDATE`` statement. Note that the :meth:`~Delete.where()` + generative method may be used instead. + + """ + return Delete(table, whereclause, **kwargs) + +def and_(*clauses): + """Join a list of clauses together using the ``AND`` operator. + + The ``&`` operator is also overloaded on all + :class:`_CompareMixin` subclasses to produce the + same result. + + """ + if len(clauses) == 1: + return clauses[0] + return BooleanClauseList(operator=operators.and_, *clauses) + +def or_(*clauses): + """Join a list of clauses together using the ``OR`` operator. + + The ``|`` operator is also overloaded on all + :class:`_CompareMixin` subclasses to produce the + same result. + + """ + if len(clauses) == 1: + return clauses[0] + return BooleanClauseList(operator=operators.or_, *clauses) + +def not_(clause): + """Return a negation of the given clause, i.e. ``NOT(clause)``. + + The ``~`` operator is also overloaded on all + :class:`_CompareMixin` subclasses to produce the + same result. + + """ + return operators.inv(_literal_as_binds(clause)) + +def distinct(expr): + """Return a ``DISTINCT`` clause.""" + expr = _literal_as_binds(expr) + return _UnaryExpression(expr, operator=operators.distinct_op, type_=expr.type) + +def between(ctest, cleft, cright): + """Return a ``BETWEEN`` predicate clause. + + Equivalent of SQL ``clausetest BETWEEN clauseleft AND clauseright``. + + The :func:`between()` method on all + :class:`_CompareMixin` subclasses provides + similar functionality. + + """ + ctest = _literal_as_binds(ctest) + return ctest.between(cleft, cright) + + +def case(whens, value=None, else_=None): + """Produce a ``CASE`` statement. + + whens + A sequence of pairs, or alternatively a dict, + to be translated into "WHEN / THEN" clauses. + + value + Optional for simple case statements, produces + a column expression as in "CASE WHEN ..." + + else\_ + Optional as well, for case defaults produces + the "ELSE" portion of the "CASE" statement. + + The expressions used for THEN and ELSE, + when specified as strings, will be interpreted + as bound values. To specify textual SQL expressions + for these, use the literal_column() or + text() construct. + + The expressions used for the WHEN criterion + may only be literal strings when "value" is + present, i.e. CASE table.somecol WHEN "x" THEN "y". + Otherwise, literal strings are not accepted + in this position, and either the text() + or literal() constructs must be used to + interpret raw string values. + + Usage examples:: + + case([(orderline.c.qty > 100, item.c.specialprice), + (orderline.c.qty > 10, item.c.bulkprice) + ], else_=item.c.regularprice) + case(value=emp.c.type, whens={ + 'engineer': emp.c.salary * 1.1, + 'manager': emp.c.salary * 3, + }) + + Using :func:`literal_column()`, to allow for databases that + do not support bind parameters in the ``then`` clause. The type + can be specified which determines the type of the :func:`case()` construct + overall:: + + case([(orderline.c.qty > 100, literal_column("'greaterthan100'", String)), + (orderline.c.qty > 10, literal_column("'greaterthan10'", String)) + ], else_=literal_column("'lethan10'", String)) + + """ + + return _Case(whens, value=value, else_=else_) + +def cast(clause, totype, **kwargs): + """Return a ``CAST`` function. + + Equivalent of SQL ``CAST(clause AS totype)``. + + Use with a :class:`~sqlalchemy.types.TypeEngine` subclass, i.e:: + + cast(table.c.unit_price * table.c.qty, Numeric(10,4)) + + or:: + + cast(table.c.timestamp, DATE) + + """ + return _Cast(clause, totype, **kwargs) + +def extract(field, expr): + """Return the clause ``extract(field FROM expr)``.""" + + return _Extract(field, expr) + +def collate(expression, collation): + """Return the clause ``expression COLLATE collation``.""" + + expr = _literal_as_binds(expression) + return _BinaryExpression( + expr, + _literal_as_text(collation), + operators.collate, type_=expr.type) + +def exists(*args, **kwargs): + """Return an ``EXISTS`` clause as applied to a :class:`Select` object. + + Calling styles are of the following forms:: + + # use on an existing select() + s = select([table.c.col1]).where(table.c.col2==5) + s = exists(s) + + # construct a select() at once + exists(['*'], **select_arguments).where(criterion) + + # columns argument is optional, generates "EXISTS (SELECT *)" + # by default. + exists().where(table.c.col2==5) + + """ + return _Exists(*args, **kwargs) + +def union(*selects, **kwargs): + """Return a ``UNION`` of multiple selectables. + + The returned object is an instance of + :class:`CompoundSelect`. + + A similar :func:`union()` method is available on all + :class:`FromClause` subclasses. + + \*selects + a list of :class:`Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.UNION, *selects, **kwargs) + +def union_all(*selects, **kwargs): + """Return a ``UNION ALL`` of multiple selectables. + + The returned object is an instance of + :class:`CompoundSelect`. + + A similar :func:`union_all()` method is available on all + :class:`FromClause` subclasses. + + \*selects + a list of :class:`Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.UNION_ALL, *selects, **kwargs) + +def except_(*selects, **kwargs): + """Return an ``EXCEPT`` of multiple selectables. + + The returned object is an instance of + :class:`CompoundSelect`. + + \*selects + a list of :class:`Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.EXCEPT, *selects, **kwargs) + +def except_all(*selects, **kwargs): + """Return an ``EXCEPT ALL`` of multiple selectables. + + The returned object is an instance of + :class:`CompoundSelect`. + + \*selects + a list of :class:`Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.EXCEPT_ALL, *selects, **kwargs) + +def intersect(*selects, **kwargs): + """Return an ``INTERSECT`` of multiple selectables. + + The returned object is an instance of + :class:`CompoundSelect`. + + \*selects + a list of :class:`Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.INTERSECT, *selects, **kwargs) + +def intersect_all(*selects, **kwargs): + """Return an ``INTERSECT ALL`` of multiple selectables. + + The returned object is an instance of + :class:`CompoundSelect`. + + \*selects + a list of :class:`Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs) + +def alias(selectable, alias=None): + """Return an :class:`Alias` object. + + An :class:`Alias` represents any :class:`FromClause` + with an alternate name assigned within SQL, typically using the ``AS`` + clause when generated, e.g. ``SELECT * FROM table AS aliasname``. + + Similar functionality is available via the :func:`alias()` method + available on all :class:`FromClause` subclasses. + + selectable + any :class:`FromClause` subclass, such as a table, select + statement, etc.. + + alias + string name to be assigned as the alias. If ``None``, a + random name will be generated. + + """ + return Alias(selectable, alias=alias) + + +def literal(value, type_=None): + """Return a literal clause, bound to a bind parameter. + + Literal clauses are created automatically when non- :class:`ClauseElement` + objects (such as strings, ints, dates, etc.) are used in a comparison + operation with a :class:`_CompareMixin` + subclass, such as a :class:`~sqlalchemy.schema.Column` object. Use this function to force the + generation of a literal clause, which will be created as a + :class:`_BindParamClause` with a bound value. + + :param value: the value to be bound. Can be any Python object supported by + the underlying DB-API, or is translatable via the given type argument. + + :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine` which + will provide bind-parameter translation for this literal. + + """ + return _BindParamClause(None, value, type_=type_, unique=True) + +def tuple_(*expr): + """Return a SQL tuple. + + Main usage is to produce a composite IN construct:: + + tuple_(table.c.col1, table.c.col2).in_( + [(1, 2), (5, 12), (10, 19)] + ) + + """ + return _Tuple(*expr) + +def label(name, obj): + """Return a :class:`_Label` object for the + given :class:`ColumnElement`. + + A label changes the name of an element in the columns clause of a + ``SELECT`` statement, typically via the ``AS`` SQL keyword. + + This functionality is more conveniently available via the + :func:`label()` method on :class:`ColumnElement`. + + name + label name + + obj + a :class:`ColumnElement`. + + """ + return _Label(name, obj) + +def column(text, type_=None): + """Return a textual column clause, as would be in the columns clause of a + ``SELECT`` statement. + + The object returned is an instance of + :class:`ColumnClause`, which represents the + "syntactical" portion of the schema-level + :class:`~sqlalchemy.schema.Column` object. + + text + the name of the column. Quoting rules will be applied to the + clause like any other column name. For textual column + constructs that are not to be quoted, use the + :func:`literal_column` function. + + type\_ + an optional :class:`~sqlalchemy.types.TypeEngine` object which will + provide result-set translation for this column. + + """ + return ColumnClause(text, type_=type_) + +def literal_column(text, type_=None): + """Return a textual column expression, as would be in the columns + clause of a ``SELECT`` statement. + + The object returned supports further expressions in the same way as any + other column object, including comparison, math and string operations. + The type\_ parameter is important to determine proper expression behavior + (such as, '+' means string concatenation or numerical addition based on + the type). + + text + the text of the expression; can be any SQL expression. Quoting rules + will not be applied. To specify a column-name expression which should + be subject to quoting rules, use the + :func:`column` function. + + type\_ + an optional :class:`~sqlalchemy.types.TypeEngine` object which will + provide result-set translation and additional expression semantics for + this column. If left as None the type will be NullType. + + """ + return ColumnClause(text, type_=type_, is_literal=True) + +def table(name, *columns): + """Return a :class:`TableClause` object. + + This is a primitive version of the :class:`~sqlalchemy.schema.Table` object, + which is a subclass of this object. + + """ + return TableClause(name, *columns) + +def bindparam(key, value=None, type_=None, unique=False, required=False): + """Create a bind parameter clause with the given key. + + value + a default value for this bind parameter. a bindparam with a + value is called a ``value-based bindparam``. + + type\_ + a sqlalchemy.types.TypeEngine object indicating the type of this + bind param, will invoke type-specific bind parameter processing + + unique + if True, bind params sharing the same name will have their + underlying ``key`` modified to a uniquely generated name. + mostly useful with value-based bind params. + + required + A value is required at execution time. + + """ + if isinstance(key, ColumnClause): + return _BindParamClause(key.name, value, type_=key.type, + unique=unique, required=required) + else: + return _BindParamClause(key, value, type_=type_, + unique=unique, required=required) + +def outparam(key, type_=None): + """Create an 'OUT' parameter for usage in functions (stored procedures), + for databases which support them. + + The ``outparam`` can be used like a regular function parameter. + The "output" value will be available from the + :class:`~sqlalchemy.engine.ResultProxy` object via its ``out_parameters`` + attribute, which returns a dictionary containing the values. + + """ + return _BindParamClause( + key, None, type_=type_, unique=False, isoutparam=True) + +def text(text, bind=None, *args, **kwargs): + """Create literal text to be inserted into a query. + + When constructing a query from a :func:`select()`, :func:`update()`, + :func:`insert()` or :func:`delete()`, using plain strings for argument + values will usually result in text objects being created + automatically. Use this function when creating textual clauses + outside of other :class:`ClauseElement` objects, or optionally wherever + plain text is to be used. + + text + the text of the SQL statement to be created. use ``:`` + to specify bind parameters; they will be compiled to their + engine-specific format. + + bind + an optional connection or engine to be used for this text query. + + autocommit=True + Deprecated. Use .execution_options(autocommit=) + to set the autocommit option. + + bindparams + a list of :func:`bindparam()` instances which can be used to define + the types and/or initial values for the bind parameters within + the textual statement; the keynames of the bindparams must match + those within the text of the statement. The types will be used + for pre-processing on bind values. + + typemap + a dictionary mapping the names of columns represented in the + ``SELECT`` clause of the textual statement to type objects, + which will be used to perform post-processing on columns within + the result set (for textual statements that produce result + sets). + + """ + return _TextClause(text, bind=bind, *args, **kwargs) + +def null(): + """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql + statement. + + """ + return _Null() + +class _FunctionGenerator(object): + """Generate :class:`Function` objects based on getattr calls.""" + + def __init__(self, **opts): + self.__names = [] + self.opts = opts + + def __getattr__(self, name): + # passthru __ attributes; fixes pydoc + if name.startswith('__'): + try: + return self.__dict__[name] + except KeyError: + raise AttributeError(name) + + elif name.endswith('_'): + name = name[0:-1] + f = _FunctionGenerator(**self.opts) + f.__names = list(self.__names) + [name] + return f + + def __call__(self, *c, **kwargs): + o = self.opts.copy() + o.update(kwargs) + if len(self.__names) == 1: + global functions + if functions is None: + from sqlalchemy.sql import functions + func = getattr(functions, self.__names[-1].lower(), None) + if func is not None: + return func(*c, **o) + + return Function( + self.__names[-1], packagenames=self.__names[0:-1], *c, **o) + +# "func" global - i.e. func.count() +func = _FunctionGenerator() + +# "modifier" global - i.e. modifier.distinct +# TODO: use UnaryExpression for this instead ? +modifier = _FunctionGenerator(group=False) + +class _generated_label(unicode): + """A unicode subclass used to identify dynamically generated names.""" + +def _escape_for_generated(x): + if isinstance(x, _generated_label): + return x + else: + return x.replace('%', '%%') + +def _clone(element): + return element._clone() + +def _expand_cloned(elements): + """expand the given set of ClauseElements to be the set of all 'cloned' + predecessors. + + """ + return itertools.chain(*[x._cloned_set for x in elements]) + +def _select_iterables(elements): + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain(*[c._select_iterable for c in elements]) + +def _cloned_intersection(a, b): + """return the intersection of sets a and b, counting + any overlap between 'cloned' predecessors. + + The returned set is in terms of the enties present within 'a'. + + """ + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set(elem for elem in a if all_overlap.intersection(elem._cloned_set)) + + +def _is_literal(element): + return not isinstance(element, Visitable) and \ + not hasattr(element, '__clause_element__') + +def _from_objects(*elements): + return itertools.chain(*[element._from_objects for element in elements]) + +def _labeled(element): + if not hasattr(element, 'name'): + return element.label(None) + else: + return element + +def _column_as_key(element): + if isinstance(element, basestring): + return element + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + return element.key + +def _literal_as_text(element): + if hasattr(element, '__clause_element__'): + return element.__clause_element__() + elif not isinstance(element, Visitable): + return _TextClause(unicode(element)) + else: + return element + +def _clause_element_as_expr(element): + if hasattr(element, '__clause_element__'): + return element.__clause_element__() + else: + return element + +def _literal_as_column(element): + if hasattr(element, '__clause_element__'): + return element.__clause_element__() + elif not isinstance(element, Visitable): + return literal_column(str(element)) + else: + return element + +def _literal_as_binds(element, name=None, type_=None): + if hasattr(element, '__clause_element__'): + return element.__clause_element__() + elif not isinstance(element, Visitable): + if element is None: + return null() + else: + return _BindParamClause(name, element, type_=type_, unique=True) + else: + return element + +def _type_from_args(args): + for a in args: + if not isinstance(a.type, sqltypes.NullType): + return a.type + else: + return sqltypes.NullType + +def _no_literals(element): + if hasattr(element, '__clause_element__'): + return element.__clause_element__() + elif not isinstance(element, Visitable): + raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' " + "function to indicate a SQL expression " + "literal, or 'literal()' to indicate a " + "bound value." % element) + else: + return element + +def _corresponding_column_or_error(fromclause, column, require_embedded=False): + c = fromclause.corresponding_column(column, + require_embedded=require_embedded) + if c is None: + raise exc.InvalidRequestError( + "Given column '%s', attached to table '%s', " + "failed to locate a corresponding column from table '%s'" + % + (column, + getattr(column, 'table', None),fromclause.description) + ) + return c + +@util.decorator +def _generative(fn, *args, **kw): + """Mark a method as generative.""" + + self = args[0]._generate() + fn(self, *args[1:], **kw) + return self + + +def is_column(col): + """True if ``col`` is an instance of :class:`ColumnElement`.""" + + return isinstance(col, ColumnElement) + + +class ClauseElement(Visitable): + """Base class for elements of a programmatically constructed SQL + expression. + + """ + __visit_name__ = 'clause' + + _annotations = {} + supports_execution = False + _from_objects = [] + _bind = None + + def _clone(self): + """Create a shallow copy of this ClauseElement. + + This method may be used by a generative API. Its also used as + part of the "deep" copy afforded by a traversal that combines + the _copy_internals() method. + + """ + c = self.__class__.__new__(self.__class__) + c.__dict__ = self.__dict__.copy() + c.__dict__.pop('_cloned_set', None) + + # this is a marker that helps to "equate" clauses to each other + # when a Select returns its list of FROM clauses. the cloning + # process leaves around a lot of remnants of the previous clause + # typically in the form of column expressions still attached to the + # old table. + c._is_clone_of = self + + return c + + @util.memoized_property + def _cloned_set(self): + """Return the set consisting all cloned anscestors of this + ClauseElement. + + Includes this ClauseElement. This accessor tends to be used for + FromClause objects to identify 'equivalent' FROM clauses, regardless + of transformative operations. + + """ + s = util.column_set() + f = self + while f is not None: + s.add(f) + f = getattr(f, '_is_clone_of', None) + return s + + def __getstate__(self): + d = self.__dict__.copy() + d.pop('_is_clone_of', None) + return d + + if util.jython: + def __hash__(self): + """Return a distinct hash code. + + ClauseElements may have special equality comparisons which + makes us rely on them having unique hash codes for use in + hash-based collections. Stock __hash__ doesn't guarantee + unique values on platforms with moving GCs. + """ + return id(self) + + def _annotate(self, values): + """return a copy of this ClauseElement with the given annotations + dictionary. + + """ + global Annotated + if Annotated is None: + from sqlalchemy.sql.util import Annotated + return Annotated(self, values) + + def _deannotate(self): + """return a copy of this ClauseElement with an empty annotations + dictionary. + + """ + return self._clone() + + def unique_params(self, *optionaldict, **kwargs): + """Return a copy with :func:`bindparam()` elments replaced. + + Same functionality as ``params()``, except adds `unique=True` + to affected bind parameters so that multiple statements can be + used. + + """ + return self._params(True, optionaldict, kwargs) + + def params(self, *optionaldict, **kwargs): + """Return a copy with :func:`bindparam()` elments replaced. + + Returns a copy of this ClauseElement with :func:`bindparam()` + elements replaced with values taken from the given dictionary:: + + >>> clause = column('x') + bindparam('foo') + >>> print clause.compile().params + {'foo':None} + >>> print clause.params({'foo':7}).compile().params + {'foo':7} + + """ + return self._params(False, optionaldict, kwargs) + + def _params(self, unique, optionaldict, kwargs): + if len(optionaldict) == 1: + kwargs.update(optionaldict[0]) + elif len(optionaldict) > 1: + raise exc.ArgumentError( + "params() takes zero or one positional dictionary argument") + + def visit_bindparam(bind): + if bind.key in kwargs: + bind.value = kwargs[bind.key] + if unique: + bind._convert_to_unique() + return cloned_traverse(self, {}, {'bindparam':visit_bindparam}) + + def compare(self, other, **kw): + """Compare this ClauseElement to the given ClauseElement. + + Subclasses should override the default behavior, which is a + straight identity comparison. + + \**kw are arguments consumed by subclass compare() methods and + may be used to modify the criteria for comparison. + (see :class:`ColumnElement`) + + """ + return self is other + + def _copy_internals(self, clone=_clone): + """Reassign internal elements to be clones of themselves. + + Called during a copy-and-traverse operation on newly + shallow-copied elements to create a deep copy. + + """ + pass + + def get_children(self, **kwargs): + """Return immediate child elements of this :class:`ClauseElement`. + + This is used for visit traversal. + + \**kwargs may contain flags that change the collection that is + returned, for example to return a subset of items in order to + cut down on larger traversals, or to return child items from a + different context (such as schema-level collections instead of + clause-level). + + """ + return [] + + def self_group(self, against=None): + return self + + # TODO: remove .bind as a method from the root ClauseElement. + # we should only be deriving binds from FromClause elements + # and certain SchemaItem subclasses. + # the "search_for_bind" functionality can still be used by + # execute(), however. + @property + def bind(self): + """Returns the Engine or Connection to which this ClauseElement is + bound, or None if none found. + + """ + if self._bind is not None: + return self._bind + + for f in _from_objects(self): + if f is self: + continue + engine = f.bind + if engine is not None: + return engine + else: + return None + + def execute(self, *multiparams, **params): + """Compile and execute this :class:`ClauseElement`.""" + + e = self.bind + if e is None: + label = getattr(self, 'description', self.__class__.__name__) + msg = ('This %s is not bound and does not support direct ' + 'execution. Supply this statement to a Connection or ' + 'Engine for execution. Or, assign a bind to the statement ' + 'or the Metadata of its underlying tables to enable ' + 'implicit execution via this method.' % label) + raise exc.UnboundExecutionError(msg) + return e._execute_clauseelement(self, multiparams, params) + + def scalar(self, *multiparams, **params): + """Compile and execute this :class:`ClauseElement`, returning the result's + scalar representation. + + """ + return self.execute(*multiparams, **params).scalar() + + def compile(self, bind=None, dialect=None, **kw): + """Compile this SQL expression. + + The return value is a :class:`~sqlalchemy.engine.Compiled` object. + Calling ``str()`` or ``unicode()`` on the returned value will yield a + string representation of the result. The + :class:`~sqlalchemy.engine.Compiled` object also can return a + dictionary of bind parameter names and values + using the ``params`` accessor. + + :param bind: An ``Engine`` or ``Connection`` from which a + ``Compiled`` will be acquired. This argument takes precedence over + this :class:`ClauseElement`'s bound engine, if any. + + :param column_keys: Used for INSERT and UPDATE statements, a list of + column names which should be present in the VALUES clause of the + compiled statement. If ``None``, all columns from the target table + object are rendered. + + :param dialect: A ``Dialect`` instance frmo which a ``Compiled`` + will be acquired. This argument takes precedence over the `bind` + argument as well as this :class:`ClauseElement`'s bound engine, if any. + + :param inline: Used for INSERT statements, for a dialect which does + not support inline retrieval of newly generated primary key + columns, will force the expression used to create the new primary + key value to be rendered inline within the INSERT statement's + VALUES clause. This typically refers to Sequence execution but may + also refer to any server-side default generation function + associated with a primary key `Column`. + + """ + + if not dialect: + if bind: + dialect = bind.dialect + elif self.bind: + dialect = self.bind.dialect + bind = self.bind + else: + global DefaultDialect + if DefaultDialect is None: + from sqlalchemy.engine.default import DefaultDialect + dialect = DefaultDialect() + compiler = self._compiler(dialect, bind=bind, **kw) + compiler.compile() + return compiler + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a Dialect.""" + + return dialect.statement_compiler(dialect, self, **kw) + + def __str__(self): + # Py3K + #return unicode(self.compile()) + # Py2K + return unicode(self.compile()).encode('ascii', 'backslashreplace') + # end Py2K + + def __and__(self, other): + return and_(self, other) + + def __or__(self, other): + return or_(self, other) + + def __invert__(self): + return self._negate() + + def __nonzero__(self): + raise TypeError("Boolean value of this clause is not defined") + + def _negate(self): + if hasattr(self, 'negation_clause'): + return self.negation_clause + else: + return _UnaryExpression( + self.self_group(against=operators.inv), + operator=operators.inv, + negate=None) + + def __repr__(self): + friendly = getattr(self, 'description', None) + if friendly is None: + return object.__repr__(self) + else: + return '<%s.%s at 0x%x; %s>' % ( + self.__module__, self.__class__.__name__, id(self), friendly) + + +class _Immutable(object): + """mark a ClauseElement as 'immutable' when expressions are cloned.""" + + def unique_params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def _clone(self): + return self + +class Operators(object): + def __and__(self, other): + return self.operate(operators.and_, other) + + def __or__(self, other): + return self.operate(operators.or_, other) + + def __invert__(self): + return self.operate(operators.inv) + + def op(self, opstring): + def op(b): + return self.operate(operators.op, opstring, b) + return op + + def operate(self, op, *other, **kwargs): + raise NotImplementedError(str(op)) + + def reverse_operate(self, op, other, **kwargs): + raise NotImplementedError(str(op)) + +class ColumnOperators(Operators): + """Defines comparison and math operations.""" + + timetuple = None + """Hack, allows datetime objects to be compared on the LHS.""" + + def __lt__(self, other): + return self.operate(operators.lt, other) + + def __le__(self, other): + return self.operate(operators.le, other) + + __hash__ = Operators.__hash__ + + def __eq__(self, other): + return self.operate(operators.eq, other) + + def __ne__(self, other): + return self.operate(operators.ne, other) + + def __gt__(self, other): + return self.operate(operators.gt, other) + + def __ge__(self, other): + return self.operate(operators.ge, other) + + def __neg__(self): + return self.operate(operators.neg) + + def concat(self, other): + return self.operate(operators.concat_op, other) + + def like(self, other, escape=None): + return self.operate(operators.like_op, other, escape=escape) + + def ilike(self, other, escape=None): + return self.operate(operators.ilike_op, other, escape=escape) + + def in_(self, other): + return self.operate(operators.in_op, other) + + def startswith(self, other, **kwargs): + return self.operate(operators.startswith_op, other, **kwargs) + + def endswith(self, other, **kwargs): + return self.operate(operators.endswith_op, other, **kwargs) + + def contains(self, other, **kwargs): + return self.operate(operators.contains_op, other, **kwargs) + + def match(self, other, **kwargs): + return self.operate(operators.match_op, other, **kwargs) + + def desc(self): + return self.operate(operators.desc_op) + + def asc(self): + return self.operate(operators.asc_op) + + def collate(self, collation): + return self.operate(operators.collate, collation) + + def __radd__(self, other): + return self.reverse_operate(operators.add, other) + + def __rsub__(self, other): + return self.reverse_operate(operators.sub, other) + + def __rmul__(self, other): + return self.reverse_operate(operators.mul, other) + + def __rdiv__(self, other): + return self.reverse_operate(operators.div, other) + + def between(self, cleft, cright): + return self.operate(operators.between_op, cleft, cright) + + def distinct(self): + return self.operate(operators.distinct_op) + + def __add__(self, other): + return self.operate(operators.add, other) + + def __sub__(self, other): + return self.operate(operators.sub, other) + + def __mul__(self, other): + return self.operate(operators.mul, other) + + def __div__(self, other): + return self.operate(operators.div, other) + + def __mod__(self, other): + return self.operate(operators.mod, other) + + def __truediv__(self, other): + return self.operate(operators.truediv, other) + + def __rtruediv__(self, other): + return self.reverse_operate(operators.truediv, other) + +class _CompareMixin(ColumnOperators): + """Defines comparison and math operations for :class:`ClauseElement` instances.""" + + def __compare(self, op, obj, negate=None, reverse=False, **kwargs): + if obj is None or isinstance(obj, _Null): + if op == operators.eq: + return _BinaryExpression(self, null(), operators.is_, negate=operators.isnot) + elif op == operators.ne: + return _BinaryExpression(self, null(), operators.isnot, negate=operators.is_) + else: + raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL") + else: + obj = self._check_literal(op, obj) + + if reverse: + return _BinaryExpression(obj, + self, + op, + type_=sqltypes.BOOLEANTYPE, + negate=negate, modifiers=kwargs) + else: + return _BinaryExpression(self, + obj, + op, + type_=sqltypes.BOOLEANTYPE, + negate=negate, modifiers=kwargs) + + def __operate(self, op, obj, reverse=False): + obj = self._check_literal(op, obj) + + if reverse: + left, right = obj, self + else: + left, right = self, obj + + if left.type is None: + op, result_type = sqltypes.NULLTYPE._adapt_expression(op, right.type) + elif right.type is None: + op, result_type = left.type._adapt_expression(op, sqltypes.NULLTYPE) + else: + op, result_type = left.type._adapt_expression(op, right.type) + + return _BinaryExpression(left, right, op, type_=result_type) + + + # a mapping of operators with the method they use, along with their negated + # operator for comparison operators + operators = { + operators.add : (__operate,), + operators.mul : (__operate,), + operators.sub : (__operate,), + # Py2K + operators.div : (__operate,), + # end Py2K + operators.mod : (__operate,), + operators.truediv : (__operate,), + operators.lt : (__compare, operators.ge), + operators.le : (__compare, operators.gt), + operators.ne : (__compare, operators.eq), + operators.gt : (__compare, operators.le), + operators.ge : (__compare, operators.lt), + operators.eq : (__compare, operators.ne), + operators.like_op : (__compare, operators.notlike_op), + operators.ilike_op : (__compare, operators.notilike_op), + } + + def operate(self, op, *other, **kwargs): + o = _CompareMixin.operators[op] + return o[0](self, op, other[0], *o[1:], **kwargs) + + def reverse_operate(self, op, other, **kwargs): + o = _CompareMixin.operators[op] + return o[0](self, op, other, reverse=True, *o[1:], **kwargs) + + def in_(self, other): + return self._in_impl(operators.in_op, operators.notin_op, other) + + def _in_impl(self, op, negate_op, seq_or_selectable): + seq_or_selectable = _clause_element_as_expr(seq_or_selectable) + + if isinstance(seq_or_selectable, _ScalarSelect): + return self.__compare( op, seq_or_selectable, negate=negate_op) + + elif isinstance(seq_or_selectable, _SelectBaseMixin): + # TODO: if we ever want to support (x, y, z) IN (select x, y, z from table), + # we would need a multi-column version of as_scalar() to produce a multi- + # column selectable that does not export itself as a FROM clause + return self.__compare( op, seq_or_selectable.as_scalar(), negate=negate_op) + + elif isinstance(seq_or_selectable, Selectable): + return self.__compare( op, seq_or_selectable, negate=negate_op) + + # Handle non selectable arguments as sequences + args = [] + for o in seq_or_selectable: + if not _is_literal(o): + if not isinstance( o, _CompareMixin): + raise exc.InvalidRequestError( + "in() function accepts either a list of non-selectable values, " + "or a selectable: %r" % o) + else: + o = self._bind_param(op, o) + args.append(o) + + if len(args) == 0: + # Special case handling for empty IN's, behave like comparison + # against zero row selectable. We use != to build the + # contradiction as it handles NULL values appropriately, i.e. + # "not (x IN ())" should not return NULL values for x. + util.warn("The IN-predicate on \"%s\" was invoked with an empty sequence. " + "This results in a contradiction, which nonetheless can be " + "expensive to evaluate. Consider alternative strategies for " + "improved performance." % self) + + return self != self + + return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op) + + def __neg__(self): + return _UnaryExpression(self, operator=operators.neg) + + def startswith(self, other, escape=None): + """Produce the clause ``LIKE '%'``""" + + # use __radd__ to force string concat behavior + return self.__compare( + operators.like_op, + literal_column("'%'", type_=sqltypes.String).__radd__( + self._check_literal(operators.like_op, other) + ), + escape=escape) + + def endswith(self, other, escape=None): + """Produce the clause ``LIKE '%'``""" + + return self.__compare( + operators.like_op, + literal_column("'%'", type_=sqltypes.String) + + self._check_literal(operators.like_op, other), + escape=escape) + + def contains(self, other, escape=None): + """Produce the clause ``LIKE '%%'``""" + + return self.__compare( + operators.like_op, + literal_column("'%'", type_=sqltypes.String) + + self._check_literal(operators.like_op, other) + + literal_column("'%'", type_=sqltypes.String), + escape=escape) + + def match(self, other): + """Produce a MATCH clause, i.e. ``MATCH ''`` + + The allowed contents of ``other`` are database backend specific. + + """ + return self.__compare(operators.match_op, self._check_literal(operators.match_op, other)) + + def label(self, name): + """Produce a column label, i.e. `` AS ``. + + if 'name' is None, an anonymous label name will be generated. + + """ + return _Label(name, self, self.type) + + def desc(self): + """Produce a DESC clause, i.e. `` DESC``""" + + return desc(self) + + def asc(self): + """Produce a ASC clause, i.e. `` ASC``""" + + return asc(self) + + def distinct(self): + """Produce a DISTINCT clause, i.e. ``DISTINCT ``""" + return _UnaryExpression(self, operator=operators.distinct_op, type_=self.type) + + def between(self, cleft, cright): + """Produce a BETWEEN clause, i.e. `` BETWEEN AND ``""" + + return _BinaryExpression( + self, + ClauseList( + self._check_literal(operators.and_, cleft), + self._check_literal(operators.and_, cright), + operator=operators.and_, + group=False), + operators.between_op) + + def collate(self, collation): + """Produce a COLLATE clause, i.e. `` COLLATE utf8_bin``""" + + return collate(self, collation) + + def op(self, operator): + """produce a generic operator function. + + e.g.:: + + somecolumn.op("*")(5) + + produces:: + + somecolumn * 5 + + + :param operator: a string which will be output as the infix operator between + this :class:`ClauseElement` and the expression passed to the + generated function. + + This function can also be used to make bitwise operators explicit. For example:: + + somecolumn.op('&')(0xff) + + is a bitwise AND of the value in somecolumn. + + """ + return lambda other: self.__operate(operator, other) + + def _bind_param(self, operator, obj): + return _BindParamClause(None, obj, + _compared_to_operator=operator, + _compared_to_type=self.type, unique=True) + + def _check_literal(self, operator, other): + if isinstance(other, _BindParamClause) and \ + isinstance(other.type, sqltypes.NullType): + # TODO: perhaps we should not mutate the incoming bindparam() + # here and instead make a copy of it. this might + # be the only place that we're mutating an incoming construct. + other.type = self.type + return other + elif hasattr(other, '__clause_element__'): + return other.__clause_element__() + elif not isinstance(other, ClauseElement): + return self._bind_param(operator, other) + elif isinstance(other, (_SelectBaseMixin, Alias)): + return other.as_scalar() + else: + return other + + +class ColumnElement(ClauseElement, _CompareMixin): + """Represent an element that is usable within the "column clause" portion of a ``SELECT`` statement. + + This includes columns associated with tables, aliases, and + subqueries, expressions, function calls, SQL keywords such as + ``NULL``, literals, etc. :class:`ColumnElement` is the ultimate base + class for all such elements. + + :class:`ColumnElement` supports the ability to be a *proxy* element, + which indicates that the :class:`ColumnElement` may be associated with + a :class:`Selectable` which was derived from another :class:`Selectable`. + An example of a "derived" :class:`Selectable` is an :class:`Alias` of a + :class:`~sqlalchemy.schema.Table`. + + A :class:`ColumnElement`, by subclassing the :class:`_CompareMixin` mixin + class, provides the ability to generate new :class:`ClauseElement` + objects using Python expressions. See the :class:`_CompareMixin` + docstring for more details. + + """ + + __visit_name__ = 'column' + primary_key = False + foreign_keys = [] + quote = None + _label = None + + @property + def _select_iterable(self): + return (self, ) + + @util.memoized_property + def base_columns(self): + return util.column_set(c for c in self.proxy_set + if not hasattr(c, 'proxies')) + + @util.memoized_property + def proxy_set(self): + s = util.column_set([self]) + if hasattr(self, 'proxies'): + for c in self.proxies: + s.update(c.proxy_set) + return s + + def shares_lineage(self, othercolumn): + """Return True if the given :class:`ColumnElement` + has a common ancestor to this :class:`ColumnElement`.""" + + return bool(self.proxy_set.intersection(othercolumn.proxy_set)) + + def _make_proxy(self, selectable, name=None): + """Create a new :class:`ColumnElement` representing this + :class:`ColumnElement` as it appears in the select list of a + descending selectable. + + """ + + if name: + co = ColumnClause(name, selectable, type_=getattr(self, 'type', None)) + else: + name = str(self) + co = ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None)) + + co.proxies = [self] + selectable.columns[name] = co + return co + + def compare(self, other, use_proxies=False, equivalents=None, **kw): + """Compare this ColumnElement to another. + + Special arguments understood: + + :param use_proxies: when True, consider two columns that + share a common base column as equivalent (i.e. shares_lineage()) + + :param equivalents: a dictionary of columns as keys mapped to sets + of columns. If the given "other" column is present in this dictionary, + if any of the columns in the correponding set() pass the comparison + test, the result is True. This is used to expand the comparison to + other columns that may be known to be equivalent to this one via + foreign key or other criterion. + + """ + to_compare = (other, ) + if equivalents and other in equivalents: + to_compare = equivalents[other].union(to_compare) + + for oth in to_compare: + if use_proxies and self.shares_lineage(oth): + return True + elif oth is self: + return True + else: + return False + + @util.memoized_property + def anon_label(self): + """provides a constant 'anonymous label' for this ColumnElement. + + This is a label() expression which will be named at compile time. + The same label() is returned each time anon_label is called so + that expressions can reference anon_label multiple times, producing + the same label name at compile time. + + the compiler uses this function automatically at compile time + for expressions that are known to be 'unnamed' like binary + expressions and function calls. + + """ + return _generated_label("%%(%d %s)s" % (id(self), getattr(self, 'name', 'anon'))) + +class ColumnCollection(util.OrderedProperties): + """An ordered dictionary that stores a list of ColumnElement + instances. + + Overrides the ``__eq__()`` method to produce SQL clauses between + sets of correlated columns. + + """ + + def __init__(self, *cols): + super(ColumnCollection, self).__init__() + self.update((c.key, c) for c in cols) + + def __str__(self): + return repr([str(c) for c in self]) + + def replace(self, column): + """add the given column to this collection, removing unaliased + versions of this column as well as existing columns with the + same key. + + e.g.:: + + t = Table('sometable', metadata, Column('col1', Integer)) + t.columns.replace(Column('col1', Integer, key='columnone')) + + will remove the original 'col1' from the collection, and add + the new column under the name 'columnname'. + + Used by schema.Column to override columns during table reflection. + + """ + if column.name in self and column.key != column.name: + other = self[column.name] + if other.name == other.key: + del self[other.name] + util.OrderedProperties.__setitem__(self, column.key, column) + + def add(self, column): + """Add a column to this collection. + + The key attribute of the column will be used as the hash key + for this dictionary. + + """ + self[column.key] = column + + def __setitem__(self, key, value): + if key in self: + # this warning is primarily to catch select() statements which + # have conflicting column names in their exported columns collection + existing = self[key] + if not existing.shares_lineage(value): + util.warn(("Column %r on table %r being replaced by another " + "column with the same key. Consider use_labels " + "for select() statements.") % (key, getattr(existing, 'table', None))) + util.OrderedProperties.__setitem__(self, key, value) + + def remove(self, column): + del self[column.key] + + def extend(self, iter): + for c in iter: + self.add(c) + + __hash__ = None + + def __eq__(self, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c==local) + return and_(*l) + + def __contains__(self, other): + if not isinstance(other, basestring): + raise exc.ArgumentError("__contains__ requires a string argument") + return util.OrderedProperties.__contains__(self, other) + + def contains_column(self, col): + # have to use a Set here, because it will compare the identity + # of the column, not just using "==" for comparison which will always return a + # "True" value (i.e. a BinaryClause...) + return col in util.column_set(self) + +class ColumnSet(util.ordered_column_set): + def contains_column(self, col): + return col in self + + def extend(self, cols): + for col in cols: + self.add(col) + + def __add__(self, other): + return list(self) + list(other) + + def __eq__(self, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c==local) + return and_(*l) + + def __hash__(self): + return hash(tuple(x for x in self)) + +class Selectable(ClauseElement): + """mark a class as being selectable""" + __visit_name__ = 'selectable' + +class FromClause(Selectable): + """Represent an element that can be used within the ``FROM`` + clause of a ``SELECT`` statement. + + """ + __visit_name__ = 'fromclause' + named_with_column = False + _hide_froms = [] + quote = None + schema = None + + def count(self, whereclause=None, **params): + """return a SELECT COUNT generated against this :class:`FromClause`.""" + + if self.primary_key: + col = list(self.primary_key)[0] + else: + col = list(self.columns)[0] + return select( + [func.count(col).label('tbl_row_count')], + whereclause, + from_obj=[self], + **params) + + def select(self, whereclause=None, **params): + """return a SELECT of this :class:`FromClause`.""" + + return select([self], whereclause, **params) + + def join(self, right, onclause=None, isouter=False): + """return a join of this :class:`FromClause` against another :class:`FromClause`.""" + + return Join(self, right, onclause, isouter) + + def outerjoin(self, right, onclause=None): + """return an outer join of this :class:`FromClause` against another :class:`FromClause`.""" + + return Join(self, right, onclause, True) + + def alias(self, name=None): + """return an alias of this :class:`FromClause`. + + For table objects, this has the effect of the table being rendered + as ``tablename AS aliasname`` in a SELECT statement. + For select objects, the effect is that of creating a named + subquery, i.e. ``(select ...) AS aliasname``. + The :func:`alias()` method is the general way to create + a "subquery" out of an existing SELECT. + + The ``name`` parameter is optional, and if left blank an + "anonymous" name will be generated at compile time, guaranteed + to be unique against other anonymous constructs used in the + same statement. + + """ + + return Alias(self, name) + + def is_derived_from(self, fromclause): + """Return True if this FromClause is 'derived' from the given FromClause. + + An example would be an Alias of a Table is derived from that Table. + + """ + return fromclause in self._cloned_set + + def replace_selectable(self, old, alias): + """replace all occurences of FromClause 'old' with the given Alias + object, returning a copy of this :class:`FromClause`. + + """ + global ClauseAdapter + if ClauseAdapter is None: + from sqlalchemy.sql.util import ClauseAdapter + return ClauseAdapter(alias).traverse(self) + + def correspond_on_equivalents(self, column, equivalents): + """Return corresponding_column for the given column, or if None + search for a match in the given dictionary. + + """ + col = self.corresponding_column(column, require_embedded=True) + if col is None and col in equivalents: + for equiv in equivalents[col]: + nc = self.corresponding_column(equiv, require_embedded=True) + if nc: + return nc + return col + + def corresponding_column(self, column, require_embedded=False): + """Given a :class:`ColumnElement`, return the exported :class:`ColumnElement` + object from this :class:`Selectable` which corresponds to that + original :class:`~sqlalchemy.schema.Column` via a common anscestor column. + + :param column: the target :class:`ColumnElement` to be matched + + :param require_embedded: only return corresponding columns for the given + :class:`ColumnElement`, if the given :class:`ColumnElement` is + actually present within a sub-element of this + :class:`FromClause`. Normally the column will match if it merely + shares a common anscestor with one of the exported columns + of this :class:`FromClause`. + + """ + # dont dig around if the column is locally present + if self.c.contains_column(column): + return column + + col, intersect = None, None + target_set = column.proxy_set + cols = self.c + for c in cols: + i = target_set.intersection(itertools.chain(*[p._cloned_set for p in c.proxy_set])) + + if i and \ + (not require_embedded or c.proxy_set.issuperset(target_set)): + + if col is None: + # no corresponding column yet, pick this one. + col, intersect = c, i + elif len(i) > len(intersect): + # 'c' has a larger field of correspondence than 'col'. + # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches + # a1.c.x->table.c.x better than + # selectable.c.x->table.c.x does. + col, intersect = c, i + elif i == intersect: + # they have the same field of correspondence. + # see which proxy_set has fewer columns in it, which indicates + # a closer relationship with the root column. Also take into + # account the "weight" attribute which CompoundSelect() uses to + # give higher precedence to columns based on vertical position + # in the compound statement, and discard columns that have no + # reference to the target column (also occurs with + # CompoundSelect) + col_distance = util.reduce(operator.add, + [sc._annotations.get('weight', 1) + for sc in col.proxy_set + if sc.shares_lineage(column)] + ) + c_distance = util.reduce(operator.add, + [sc._annotations.get('weight', 1) + for sc in c.proxy_set + if sc.shares_lineage(column)] + ) + if c_distance < col_distance: + col, intersect = c, i + return col + + @property + def description(self): + """a brief description of this FromClause. + + Used primarily for error message formatting. + + """ + return getattr(self, 'name', self.__class__.__name__ + " object") + + def _reset_exported(self): + """delete memoized collections when a FromClause is cloned.""" + + for attr in ('_columns', '_primary_key' '_foreign_keys', 'locate_all_froms'): + self.__dict__.pop(attr, None) + + @util.memoized_property + def _columns(self): + """Return the collection of Column objects contained by this FromClause.""" + + self._export_columns() + return self._columns + + @util.memoized_property + def _primary_key(self): + """Return the collection of Column objects which comprise the + primary key of this FromClause.""" + + self._export_columns() + return self._primary_key + + @util.memoized_property + def _foreign_keys(self): + """Return the collection of ForeignKey objects which this + FromClause references.""" + + self._export_columns() + return self._foreign_keys + + columns = property(attrgetter('_columns'), doc=_columns.__doc__) + primary_key = property( + attrgetter('_primary_key'), + doc=_primary_key.__doc__) + foreign_keys = property( + attrgetter('_foreign_keys'), + doc=_foreign_keys.__doc__) + + # synonyms for 'columns' + c = _select_iterable = property(attrgetter('columns'), doc=_columns.__doc__) + + def _export_columns(self): + """Initialize column collections.""" + + self._columns = ColumnCollection() + self._primary_key = ColumnSet() + self._foreign_keys = set() + self._populate_column_collection() + + def _populate_column_collection(self): + pass + +class _BindParamClause(ColumnElement): + """Represent a bind parameter. + + Public constructor is the :func:`bindparam()` function. + + """ + + __visit_name__ = 'bindparam' + quote = None + + def __init__(self, key, value, type_=None, unique=False, + isoutparam=False, required=False, + _compared_to_operator=None, + _compared_to_type=None): + """Construct a _BindParamClause. + + key + the key for this bind param. Will be used in the generated + SQL statement for dialects that use named parameters. This + value may be modified when part of a compilation operation, + if other :class:`_BindParamClause` objects exist with the same + key, or if its length is too long and truncation is + required. + + value + Initial value for this bind param. This value may be + overridden by the dictionary of parameters sent to statement + compilation/execution. + + type\_ + A ``TypeEngine`` object that will be used to pre-process the + value corresponding to this :class:`_BindParamClause` at + execution time. + + unique + if True, the key name of this BindParamClause will be + modified if another :class:`_BindParamClause` of the same name + already has been located within the containing + :class:`ClauseElement`. + + required + a value is required at execution time. + + isoutparam + if True, the parameter should be treated like a stored procedure "OUT" + parameter. + + """ + if unique: + self.key = _generated_label("%%(%d %s)s" % (id(self), key or 'param')) + else: + self.key = key or _generated_label("%%(%d param)s" % id(self)) + self._orig_key = key or 'param' + self.unique = unique + self.value = value + self.isoutparam = isoutparam + self.required = required + + if type_ is None: + if _compared_to_type is not None: + self.type = _compared_to_type._coerce_compared_value(_compared_to_operator, value) + else: + self.type = sqltypes.type_map.get(type(value), sqltypes.NULLTYPE) + elif isinstance(type_, type): + self.type = type_() + else: + self.type = type_ + + def _clone(self): + c = ClauseElement._clone(self) + if self.unique: + c.key = _generated_label("%%(%d %s)s" % (id(c), c._orig_key or 'param')) + return c + + def _convert_to_unique(self): + if not self.unique: + self.unique = True + self.key = _generated_label("%%(%d %s)s" % (id(self), + self._orig_key or 'param')) + + def bind_processor(self, dialect): + return self.type.dialect_impl(dialect).bind_processor(dialect) + + def compare(self, other, **kw): + """Compare this :class:`_BindParamClause` to the given clause.""" + + return isinstance(other, _BindParamClause) and \ + self.type._compare_type_affinity(other.type) and \ + self.value == other.value + + def __getstate__(self): + """execute a deferred value for serialization purposes.""" + + d = self.__dict__.copy() + v = self.value + if util.callable(v): + v = v() + d['value'] = v + return d + + def __repr__(self): + return "_BindParamClause(%r, %r, type_=%r)" % ( + self.key, self.value, self.type + ) + +class _TypeClause(ClauseElement): + """Handle a type keyword in a SQL statement. + + Used by the ``Case`` statement. + + """ + + __visit_name__ = 'typeclause' + + def __init__(self, type): + self.type = type + + +class _Generative(object): + """Allow a ClauseElement to generate itself via the + @_generative decorator. + + """ + + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + return s + + +class Executable(_Generative): + """Mark a ClauseElement as supporting execution. + + :class:`Executable` is a superclass for all "statement" types + of objects, including :func:`select`, :func:`delete`, :func:`update`, + :func:`insert`, :func:`text`. + + """ + + supports_execution = True + _execution_options = util.frozendict() + + @_generative + def execution_options(self, **kw): + """ Set non-SQL options for the statement which take effect during execution. + + Current options include: + + * autocommit - when True, a COMMIT will be invoked after execution + when executed in 'autocommit' mode, i.e. when an explicit transaction + is not begun on the connection. Note that DBAPI connections by + default are always in a transaction - SQLAlchemy uses rules applied + to different kinds of statements to determine if COMMIT will be invoked + in order to provide its "autocommit" feature. Typically, all + INSERT/UPDATE/DELETE statements as well as CREATE/DROP statements + have autocommit behavior enabled; SELECT constructs do not. Use this + option when invokving a SELECT or other specific SQL construct + where COMMIT is desired (typically when calling stored procedures + and such). + + * stream_results - indicate to the dialect that results should be + "streamed" and not pre-buffered, if possible. This is a limitation + of many DBAPIs. The flag is currently understood only by the + psycopg2 dialect. + + See also: + + :meth:`sqlalchemy.engine.base.Connection.execution_options()` + + :meth:`sqlalchemy.orm.query.Query.execution_options()` + + """ + self._execution_options = self._execution_options.union(kw) + +# legacy, some outside users may be calling this +_Executable = Executable + +class _TextClause(Executable, ClauseElement): + """Represent a literal SQL text fragment. + + Public constructor is the :func:`text()` function. + + """ + + __visit_name__ = 'textclause' + + _bind_params_regex = re.compile(r'(? RIGHT``.""" + + __visit_name__ = 'binary' + + def __init__(self, left, right, operator, type_=None, negate=None, modifiers=None): + self.left = _literal_as_text(left).self_group(against=operator) + self.right = _literal_as_text(right).self_group(against=operator) + self.operator = operator + self.type = sqltypes.to_instance(type_) + self.negate = negate + if modifiers is None: + self.modifiers = {} + else: + self.modifiers = modifiers + + def __nonzero__(self): + try: + return self.operator(hash(self.left), hash(self.right)) + except: + raise TypeError("Boolean value of this clause is not defined") + + @property + def _from_objects(self): + return self.left._from_objects + self.right._from_objects + + def _copy_internals(self, clone=_clone): + self.left = clone(self.left) + self.right = clone(self.right) + + def get_children(self, **kwargs): + return self.left, self.right + + def compare(self, other, **kw): + """Compare this :class:`_BinaryExpression` against the + given :class:`_BinaryExpression`.""" + + return ( + isinstance(other, _BinaryExpression) and + self.operator == other.operator and + ( + self.left.compare(other.left, **kw) and + self.right.compare(other.right, **kw) or + ( + operators.is_commutative(self.operator) and + self.left.compare(other.right, **kw) and + self.right.compare(other.left, **kw) + ) + ) + ) + + def self_group(self, against=None): + # use small/large defaults for comparison so that unknown + # operators are always parenthesized + if self.operator is not against and operators.is_precedent(self.operator, against): + return _Grouping(self) + else: + return self + + def _negate(self): + if self.negate is not None: + return _BinaryExpression( + self.left, + self.right, + self.negate, + negate=self.operator, + type_=sqltypes.BOOLEANTYPE, + modifiers=self.modifiers) + else: + return super(_BinaryExpression, self)._negate() + +class _Exists(_UnaryExpression): + __visit_name__ = _UnaryExpression.__visit_name__ + _from_objects = [] + + def __init__(self, *args, **kwargs): + if args and isinstance(args[0], (_SelectBaseMixin, _ScalarSelect)): + s = args[0] + else: + if not args: + args = ([literal_column('*')],) + s = select(*args, **kwargs).as_scalar().self_group() + + _UnaryExpression.__init__(self, s, operator=operators.exists, type_=sqltypes.Boolean) + + def select(self, whereclause=None, **params): + return select([self], whereclause, **params) + + def correlate(self, fromclause): + e = self._clone() + e.element = self.element.correlate(fromclause).self_group() + return e + + def select_from(self, clause): + """return a new exists() construct with the given expression set as its FROM + clause. + + """ + e = self._clone() + e.element = self.element.select_from(clause).self_group() + return e + + def where(self, clause): + """return a new exists() construct with the given expression added to its WHERE + clause, joined to the existing clause via AND, if any. + + """ + e = self._clone() + e.element = self.element.where(clause).self_group() + return e + +class Join(FromClause): + """represent a ``JOIN`` construct between two :class:`FromClause` elements. + + The public constructor function for :class:`Join` is the module-level + :func:`join()` function, as well as the :func:`join()` method available + off all :class:`FromClause` subclasses. + + """ + __visit_name__ = 'join' + + def __init__(self, left, right, onclause=None, isouter=False): + self.left = _literal_as_text(left) + self.right = _literal_as_text(right).self_group() + + if onclause is None: + self.onclause = self._match_primaries(self.left, self.right) + else: + self.onclause = onclause + + self.isouter = isouter + self.__folded_equivalents = None + + @property + def description(self): + return "Join object on %s(%d) and %s(%d)" % ( + self.left.description, + id(self.left), + self.right.description, + id(self.right)) + + def is_derived_from(self, fromclause): + return fromclause is self or \ + self.left.is_derived_from(fromclause) or\ + self.right.is_derived_from(fromclause) + + def self_group(self, against=None): + return _FromGrouping(self) + + def _populate_column_collection(self): + columns = [c for c in self.left.columns] + [c for c in self.right.columns] + + global sql_util + if not sql_util: + from sqlalchemy.sql import util as sql_util + self._primary_key.extend(sql_util.reduce_columns( + (c for c in columns if c.primary_key), self.onclause)) + self._columns.update((col._label, col) for col in columns) + self._foreign_keys.update(itertools.chain(*[col.foreign_keys for col in columns])) + + def _copy_internals(self, clone=_clone): + self._reset_exported() + self.left = clone(self.left) + self.right = clone(self.right) + self.onclause = clone(self.onclause) + self.__folded_equivalents = None + + def get_children(self, **kwargs): + return self.left, self.right, self.onclause + + def _match_primaries(self, left, right): + global sql_util + if not sql_util: + from sqlalchemy.sql import util as sql_util + if isinstance(left, Join): + left_right = left.right + else: + left_right = None + return sql_util.join_condition(left, right, a_subset=left_right) + + def select(self, whereclause=None, fold_equivalents=False, **kwargs): + """Create a :class:`Select` from this :class:`Join`. + + :param whereclause: the WHERE criterion that will be sent to + the :func:`select()` function + + :param fold_equivalents: based on the join criterion of this + :class:`Join`, do not include + repeat column names in the column list of the resulting + select, for columns that are calculated to be "equivalent" + based on the join criterion of this :class:`Join`. This will + recursively apply to any joins directly nested by this one + as well. This flag is specific to a particular use case + by the ORM and is deprecated as of 0.6. + + :param \**kwargs: all other kwargs are sent to the + underlying :func:`select()` function. + + """ + if fold_equivalents: + global sql_util + if not sql_util: + from sqlalchemy.sql import util as sql_util + util.warn_deprecated("fold_equivalents is deprecated.") + collist = sql_util.folded_equivalents(self) + else: + collist = [self.left, self.right] + + return select(collist, whereclause, from_obj=[self], **kwargs) + + @property + def bind(self): + return self.left.bind or self.right.bind + + def alias(self, name=None): + """Create a :class:`Select` out of this :class:`Join` clause and return an :class:`Alias` of it. + + The :class:`Select` is not correlating. + + """ + return self.select(use_labels=True, correlate=False).alias(name) + + @property + def _hide_froms(self): + return itertools.chain(*[_from_objects(x.left, x.right) for x in self._cloned_set]) + + @property + def _from_objects(self): + return [self] + \ + self.onclause._from_objects + \ + self.left._from_objects + \ + self.right._from_objects + +class Alias(FromClause): + """Represents an table or selectable alias (AS). + + Represents an alias, as typically applied to any table or + sub-select within a SQL statement using the ``AS`` keyword (or + without the keyword on certain databases such as Oracle). + + This object is constructed from the :func:`alias()` module level + function as well as the :func:`alias()` method available on all + :class:`FromClause` subclasses. + + """ + + __visit_name__ = 'alias' + named_with_column = True + + def __init__(self, selectable, alias=None): + baseselectable = selectable + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.element + self.original = baseselectable + self.supports_execution = baseselectable.supports_execution + if self.supports_execution: + self._execution_options = baseselectable._execution_options + self.element = selectable + if alias is None: + if self.original.named_with_column: + alias = getattr(self.original, 'name', None) + alias = _generated_label('%%(%d %s)s' % (id(self), alias or 'anon')) + self.name = alias + + @property + def description(self): + # Py3K + #return self.name + # Py2K + return self.name.encode('ascii', 'backslashreplace') + # end Py2K + + def as_scalar(self): + try: + return self.element.as_scalar() + except AttributeError: + raise AttributeError("Element %s does not support 'as_scalar()'" % self.element) + + def is_derived_from(self, fromclause): + if fromclause in self._cloned_set: + return True + return self.element.is_derived_from(fromclause) + + def _populate_column_collection(self): + for col in self.element.columns: + col._make_proxy(self) + + def _copy_internals(self, clone=_clone): + self._reset_exported() + self.element = _clone(self.element) + baseselectable = self.element + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.element + self.original = baseselectable + + def get_children(self, column_collections=True, aliased_selectables=True, **kwargs): + if column_collections: + for c in self.c: + yield c + if aliased_selectables: + yield self.element + + @property + def _from_objects(self): + return [self] + + @property + def bind(self): + return self.element.bind + + +class _Grouping(ColumnElement): + """Represent a grouping within a column expression""" + + __visit_name__ = 'grouping' + + def __init__(self, element): + self.element = element + self.type = getattr(element, 'type', None) + + @property + def _label(self): + return getattr(self.element, '_label', None) or self.anon_label + + def _copy_internals(self, clone=_clone): + self.element = clone(self.element) + + def get_children(self, **kwargs): + return self.element, + + @property + def _from_objects(self): + return self.element._from_objects + + def __getattr__(self, attr): + return getattr(self.element, attr) + + def __getstate__(self): + return {'element':self.element, 'type':self.type} + + def __setstate__(self, state): + self.element = state['element'] + self.type = state['type'] + +class _FromGrouping(FromClause): + """Represent a grouping of a FROM clause""" + __visit_name__ = 'grouping' + + def __init__(self, element): + self.element = element + + @property + def columns(self): + return self.element.columns + + @property + def _hide_froms(self): + return self.element._hide_froms + + def get_children(self, **kwargs): + return self.element, + + def _copy_internals(self, clone=_clone): + self.element = clone(self.element) + + @property + def _from_objects(self): + return self.element._from_objects + + def __getattr__(self, attr): + return getattr(self.element, attr) + + def __getstate__(self): + return {'element':self.element} + + def __setstate__(self, state): + self.element = state['element'] + +class _Label(ColumnElement): + """Represents a column label (AS). + + Represent a label, as typically applied to any column-level + element using the ``AS`` sql keyword. + + This object is constructed from the :func:`label()` module level + function as well as the :func:`label()` method available on all + :class:`ColumnElement` subclasses. + + """ + + __visit_name__ = 'label' + + def __init__(self, name, element, type_=None): + while isinstance(element, _Label): + element = element.element + self.name = self.key = self._label = name or \ + _generated_label("%%(%d %s)s" % ( + id(self), getattr(element, 'name', 'anon')) + ) + self._element = element + self._type = type_ + self.quote = element.quote + + @util.memoized_property + def type(self): + return sqltypes.to_instance( + self._type or getattr(self._element, 'type', None) + ) + + @util.memoized_property + def element(self): + return self._element.self_group(against=operators.as_) + + def _proxy_attr(name): + get = attrgetter(name) + def attr(self): + return get(self.element) + return property(attr) + + proxies = _proxy_attr('proxies') + base_columns = _proxy_attr('base_columns') + proxy_set = _proxy_attr('proxy_set') + primary_key = _proxy_attr('primary_key') + foreign_keys = _proxy_attr('foreign_keys') + + def get_children(self, **kwargs): + return self.element, + + def _copy_internals(self, clone=_clone): + self.element = clone(self.element) + + @property + def _from_objects(self): + return self.element._from_objects + + def _make_proxy(self, selectable, name = None): + if isinstance(self.element, (Selectable, ColumnElement)): + e = self.element._make_proxy(selectable, name=self.name) + else: + e = column(self.name)._make_proxy(selectable=selectable) + e.proxies.append(self) + return e + +class ColumnClause(_Immutable, ColumnElement): + """Represents a generic column expression from any textual string. + + This includes columns associated with tables, aliases and select + statements, but also any arbitrary text. May or may not be bound + to an underlying :class:`Selectable`. :class:`ColumnClause` is usually + created publically via the :func:`column()` function or the + :func:`literal_column()` function. + + text + the text of the element. + + selectable + parent selectable. + + type + ``TypeEngine`` object which can associate this :class:`ColumnClause` + with a type. + + is_literal + if True, the :class:`ColumnClause` is assumed to be an exact + expression that will be delivered to the output with no quoting + rules applied regardless of case sensitive settings. the + :func:`literal_column()` function is usually used to create such a + :class:`ColumnClause`. + + """ + __visit_name__ = 'column' + + onupdate = default = server_default = server_onupdate = None + + def __init__(self, text, selectable=None, type_=None, is_literal=False): + self.key = self.name = text + self.table = selectable + self.type = sqltypes.to_instance(type_) + self.is_literal = is_literal + + @util.memoized_property + def description(self): + # Py3K + #return self.name + # Py2K + return self.name.encode('ascii', 'backslashreplace') + # end Py2K + + @util.memoized_property + def _label(self): + if self.is_literal: + return None + + elif self.table is not None and self.table.named_with_column: + if getattr(self.table, 'schema', None): + label = self.table.schema.replace('.', '_') + "_" + \ + _escape_for_generated(self.table.name) + "_" + \ + _escape_for_generated(self.name) + else: + label = _escape_for_generated(self.table.name) + "_" + \ + _escape_for_generated(self.name) + + return _generated_label(label) + + else: + return self.name + + def label(self, name): + if name is None: + return self + else: + return super(ColumnClause, self).label(name) + + @property + def _from_objects(self): + if self.table is not None: + return [self.table] + else: + return [] + + def _bind_param(self, operator, obj): + return _BindParamClause(self.name, obj, _compared_to_operator=operator, + _compared_to_type=self.type, unique=True) + + def _make_proxy(self, selectable, name=None, attach=True): + # propagate the "is_literal" flag only if we are keeping our name, + # otherwise its considered to be a label + is_literal = self.is_literal and (name is None or name == self.name) + c = ColumnClause( + name or self.name, + selectable=selectable, + type_=self.type, + is_literal=is_literal + ) + c.proxies = [self] + if attach: + selectable.columns[c.name] = c + return c + +class TableClause(_Immutable, FromClause): + """Represents a "table" construct. + + Note that this represents tables only as another syntactical + construct within SQL expressions; it does not provide schema-level + functionality. + + """ + + __visit_name__ = 'table' + + named_with_column = True + + def __init__(self, name, *columns): + super(TableClause, self).__init__() + self.name = self.fullname = name + self._columns = ColumnCollection() + self._primary_key = ColumnSet() + self._foreign_keys = set() + for c in columns: + self.append_column(c) + + def _export_columns(self): + raise NotImplementedError() + + @util.memoized_property + def description(self): + # Py3K + #return self.name + # Py2K + return self.name.encode('ascii', 'backslashreplace') + # end Py2K + + def append_column(self, c): + self._columns[c.name] = c + c.table = self + + def get_children(self, column_collections=True, **kwargs): + if column_collections: + return [c for c in self.c] + else: + return [] + + def count(self, whereclause=None, **params): + """return a SELECT COUNT generated against this :class:`TableClause`.""" + + if self.primary_key: + col = list(self.primary_key)[0] + else: + col = list(self.columns)[0] + return select( + [func.count(col).label('tbl_row_count')], + whereclause, + from_obj=[self], + **params) + + def insert(self, values=None, inline=False, **kwargs): + """Generate an :func:`insert()` construct.""" + + return insert(self, values=values, inline=inline, **kwargs) + + def update(self, whereclause=None, values=None, inline=False, **kwargs): + """Generate an :func:`update()` construct.""" + + return update(self, whereclause=whereclause, + values=values, inline=inline, **kwargs) + + def delete(self, whereclause=None, **kwargs): + """Generate a :func:`delete()` construct.""" + + return delete(self, whereclause, **kwargs) + + @property + def _from_objects(self): + return [self] + +class _SelectBaseMixin(Executable): + """Base class for :class:`Select` and ``CompoundSelects``.""" + + def __init__(self, + use_labels=False, + for_update=False, + limit=None, + offset=None, + order_by=None, + group_by=None, + bind=None, + autocommit=None): + self.use_labels = use_labels + self.for_update = for_update + if autocommit is not None: + util.warn_deprecated("autocommit on select() is deprecated. " + "Use .execution_options(autocommit=True)") + self._execution_options = self._execution_options.union({'autocommit':autocommit}) + self._limit = limit + self._offset = offset + self._bind = bind + + self._order_by_clause = ClauseList(*util.to_list(order_by) or []) + self._group_by_clause = ClauseList(*util.to_list(group_by) or []) + + def as_scalar(self): + """return a 'scalar' representation of this selectable, which can be + used as a column expression. + + Typically, a select statement which has only one column in its columns + clause is eligible to be used as a scalar expression. + + The returned object is an instance of + :class:`_ScalarSelect`. + + """ + return _ScalarSelect(self) + + @_generative + def apply_labels(self): + """return a new selectable with the 'use_labels' flag set to True. + + This will result in column expressions being generated using labels + against their table name, such as "SELECT somecolumn AS + tablename_somecolumn". This allows selectables which contain multiple + FROM clauses to produce a unique set of column names regardless of + name conflicts among the individual FROM clauses. + + """ + self.use_labels = True + + def label(self, name): + """return a 'scalar' representation of this selectable, embedded as a + subquery with a label. + + See also ``as_scalar()``. + + """ + return self.as_scalar().label(name) + + @_generative + @util.deprecated(message="autocommit() is deprecated. " + "Use .execution_options(autocommit=True)") + def autocommit(self): + """return a new selectable with the 'autocommit' flag set to True.""" + + self._execution_options = self._execution_options.union({'autocommit':True}) + + def _generate(self): + """Override the default _generate() method to also clear out exported collections.""" + + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + s._reset_exported() + return s + + @_generative + def limit(self, limit): + """return a new selectable with the given LIMIT criterion applied.""" + + self._limit = limit + + @_generative + def offset(self, offset): + """return a new selectable with the given OFFSET criterion applied.""" + + self._offset = offset + + @_generative + def order_by(self, *clauses): + """return a new selectable with the given list of ORDER BY criterion applied. + + The criterion will be appended to any pre-existing ORDER BY criterion. + + """ + self.append_order_by(*clauses) + + @_generative + def group_by(self, *clauses): + """return a new selectable with the given list of GROUP BY criterion applied. + + The criterion will be appended to any pre-existing GROUP BY criterion. + + """ + self.append_group_by(*clauses) + + def append_order_by(self, *clauses): + """Append the given ORDER BY criterion applied to this selectable. + + The criterion will be appended to any pre-existing ORDER BY criterion. + + """ + if len(clauses) == 1 and clauses[0] is None: + self._order_by_clause = ClauseList() + else: + if getattr(self, '_order_by_clause', None) is not None: + clauses = list(self._order_by_clause) + list(clauses) + self._order_by_clause = ClauseList(*clauses) + + def append_group_by(self, *clauses): + """Append the given GROUP BY criterion applied to this selectable. + + The criterion will be appended to any pre-existing GROUP BY criterion. + + """ + if len(clauses) == 1 and clauses[0] is None: + self._group_by_clause = ClauseList() + else: + if getattr(self, '_group_by_clause', None) is not None: + clauses = list(self._group_by_clause) + list(clauses) + self._group_by_clause = ClauseList(*clauses) + + @property + def _from_objects(self): + return [self] + + +class _ScalarSelect(_Grouping): + _from_objects = [] + + def __init__(self, element): + self.element = element + cols = list(element.c) + self.type = cols[0].type + + @property + def columns(self): + raise exc.InvalidRequestError("Scalar Select expression has no columns; " + "use this object directly within a column-level expression.") + c = columns + + def self_group(self, **kwargs): + return self + + def _make_proxy(self, selectable, name): + return list(self.inner_columns)[0]._make_proxy(selectable, name) + +class CompoundSelect(_SelectBaseMixin, FromClause): + """Forms the basis of ``UNION``, ``UNION ALL``, and other + SELECT-based set operations.""" + + __visit_name__ = 'compound_select' + + UNION = util.symbol('UNION') + UNION_ALL = util.symbol('UNION ALL') + EXCEPT = util.symbol('EXCEPT') + EXCEPT_ALL = util.symbol('EXCEPT ALL') + INTERSECT = util.symbol('INTERSECT') + INTERSECT_ALL = util.symbol('INTERSECT ALL') + + def __init__(self, keyword, *selects, **kwargs): + self._should_correlate = kwargs.pop('correlate', False) + self.keyword = keyword + self.selects = [] + + numcols = None + + # some DBs do not like ORDER BY in the inner queries of a UNION, etc. + for n, s in enumerate(selects): + s = _clause_element_as_expr(s) + + if not numcols: + numcols = len(s.c) + elif len(s.c) != numcols: + raise exc.ArgumentError( + "All selectables passed to CompoundSelect must " + "have identical numbers of columns; select #%d has %d columns," + " select #%d has %d" % + (1, len(self.selects[0].c), n+1, len(s.c)) + ) + + self.selects.append(s.self_group(self)) + + _SelectBaseMixin.__init__(self, **kwargs) + + def self_group(self, against=None): + return _FromGrouping(self) + + def is_derived_from(self, fromclause): + for s in self.selects: + if s.is_derived_from(fromclause): + return True + return False + + def _populate_column_collection(self): + for cols in zip(*[s.c for s in self.selects]): + # this is a slightly hacky thing - the union exports a column that + # resembles just that of the *first* selectable. to get at a "composite" column, + # particularly foreign keys, you have to dig through the proxies collection + # which we generate below. We may want to improve upon this, + # such as perhaps _make_proxy can accept a list of other columns that + # are "shared" - schema.column can then copy all the ForeignKeys in. + # this would allow the union() to have all those fks too. + proxy = cols[0]._make_proxy( + self, name=self.use_labels and cols[0]._label or None) + + # hand-construct the "proxies" collection to include all derived columns + # place a 'weight' annotation corresponding to how low in the list of + # select()s the column occurs, so that the corresponding_column() operation + # can resolve conflicts + proxy.proxies = [c._annotate({'weight':i + 1}) for i, c in enumerate(cols)] + + def _copy_internals(self, clone=_clone): + self._reset_exported() + self.selects = [clone(s) for s in self.selects] + if hasattr(self, '_col_map'): + del self._col_map + for attr in ('_order_by_clause', '_group_by_clause'): + if getattr(self, attr) is not None: + setattr(self, attr, clone(getattr(self, attr))) + + def get_children(self, column_collections=True, **kwargs): + return (column_collections and list(self.c) or []) + \ + [self._order_by_clause, self._group_by_clause] + list(self.selects) + + def bind(self): + if self._bind: + return self._bind + for s in self.selects: + e = s.bind + if e: + return e + else: + return None + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + +class Select(_SelectBaseMixin, FromClause): + """Represents a ``SELECT`` statement. + + Select statements support appendable clauses, as well as the + ability to execute themselves and return a result set. + + """ + + __visit_name__ = 'select' + + _prefixes = () + _hints = util.frozendict() + + def __init__(self, + columns, + whereclause=None, + from_obj=None, + distinct=False, + having=None, + correlate=True, + prefixes=None, + **kwargs): + """Construct a Select object. + + The public constructor for Select is the + :func:`select` function; see that function for + argument descriptions. + + Additional generative and mutator methods are available on the + :class:`_SelectBaseMixin` superclass. + + """ + self._should_correlate = correlate + self._distinct = distinct + + self._correlate = set() + self._froms = util.OrderedSet() + + try: + cols_present = bool(columns) + except TypeError: + raise exc.ArgumentError("columns argument to select() must " + "be a Python list or other iterable") + + if cols_present: + self._raw_columns = [] + for c in columns: + c = _literal_as_column(c) + if isinstance(c, _ScalarSelect): + c = c.self_group(against=operators.comma_op) + self._raw_columns.append(c) + + self._froms.update(_from_objects(*self._raw_columns)) + else: + self._raw_columns = [] + + if whereclause is not None: + self._whereclause = _literal_as_text(whereclause) + self._froms.update(_from_objects(self._whereclause)) + else: + self._whereclause = None + + if from_obj is not None: + for f in util.to_list(from_obj): + if _is_literal(f): + self._froms.add(_TextClause(f)) + else: + self._froms.add(f) + + if having is not None: + self._having = _literal_as_text(having) + else: + self._having = None + + if prefixes: + self._prefixes = tuple([_literal_as_text(p) for p in prefixes]) + + _SelectBaseMixin.__init__(self, **kwargs) + + def _get_display_froms(self, existing_froms=None): + """Return the full list of 'from' clauses to be displayed. + + Takes into account a set of existing froms which may be + rendered in the FROM clause of enclosing selects; this Select + may want to leave those absent if it is automatically + correlating. + + """ + froms = self._froms + + toremove = itertools.chain(*[f._hide_froms for f in froms]) + if toremove: + froms = froms.difference(toremove) + + if len(froms) > 1 or self._correlate: + if self._correlate: + froms = froms.difference(_cloned_intersection(froms, self._correlate)) + + if self._should_correlate and existing_froms: + froms = froms.difference(_cloned_intersection(froms, existing_froms)) + + if not len(froms): + raise exc.InvalidRequestError( + "Select statement '%s' returned no FROM clauses " + "due to auto-correlation; specify correlate() " + "to control correlation manually." % self) + + return froms + + @property + def froms(self): + """Return the displayed list of FromClause elements.""" + + return self._get_display_froms() + + @_generative + def with_hint(self, selectable, text, dialect_name=None): + """Add an indexing hint for the given selectable to this :class:`Select`. + + The text of the hint is written specific to a specific backend, and + typically uses Python string substitution syntax to render the name + of the table or alias, such as for Oracle:: + + select([mytable]).with_hint(mytable, "+ index(%(name)s ix_mytable)") + + Would render SQL as:: + + select /*+ index(mytable ix_mytable) */ ... from mytable + + The ``dialect_name`` option will limit the rendering of a particular hint + to a particular backend. Such as, to add hints for both Oracle and + Sybase simultaneously:: + + select([mytable]).\ + with_hint(mytable, "+ index(%(name)s ix_mytable)", 'oracle').\ + with_hint(mytable, "WITH INDEX ix_mytable", 'sybase') + + """ + if not dialect_name: + dialect_name = '*' + self._hints = self._hints.union({(selectable, dialect_name):text}) + + @property + def type(self): + raise exc.InvalidRequestError("Select objects don't have a type. " + "Call as_scalar() on this Select object " + "to return a 'scalar' version of this Select.") + + @util.memoized_instancemethod + def locate_all_froms(self): + """return a Set of all FromClause elements referenced by this Select. + + This set is a superset of that returned by the ``froms`` property, which + is specifically for those FromClause elements that would actually be rendered. + + """ + return self._froms.union(_from_objects(*list(self._froms))) + + @property + def inner_columns(self): + """an iterator of all ColumnElement expressions which would + be rendered into the columns clause of the resulting SELECT statement. + + """ + return _select_iterables(self._raw_columns) + + def is_derived_from(self, fromclause): + if self in fromclause._cloned_set: + return True + + for f in self.locate_all_froms(): + if f.is_derived_from(fromclause): + return True + return False + + def _copy_internals(self, clone=_clone): + self._reset_exported() + from_cloned = dict((f, clone(f)) + for f in self._froms.union(self._correlate)) + self._froms = util.OrderedSet(from_cloned[f] for f in self._froms) + self._correlate = set(from_cloned[f] for f in self._correlate) + self._raw_columns = [clone(c) for c in self._raw_columns] + for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): + if getattr(self, attr) is not None: + setattr(self, attr, clone(getattr(self, attr))) + + def get_children(self, column_collections=True, **kwargs): + """return child elements as per the ClauseElement specification.""" + + return (column_collections and list(self.columns) or []) + \ + self._raw_columns + list(self._froms) + \ + [x for x in + (self._whereclause, self._having, + self._order_by_clause, self._group_by_clause) + if x is not None] + + @_generative + def column(self, column): + """return a new select() construct with the given column expression + added to its columns clause. + + """ + + column = _literal_as_column(column) + + if isinstance(column, _ScalarSelect): + column = column.self_group(against=operators.comma_op) + + self._raw_columns = self._raw_columns + [column] + self._froms = self._froms.union(_from_objects(column)) + + @_generative + def with_only_columns(self, columns): + """return a new select() construct with its columns clause replaced + with the given columns. + + """ + + self._raw_columns = [ + isinstance(c, _ScalarSelect) and + c.self_group(against=operators.comma_op) or c + for c in [_literal_as_column(c) for c in columns] + ] + + @_generative + def where(self, whereclause): + """return a new select() construct with the given expression added to its + WHERE clause, joined to the existing clause via AND, if any. + + """ + + self.append_whereclause(whereclause) + + @_generative + def having(self, having): + """return a new select() construct with the given expression added to its HAVING + clause, joined to the existing clause via AND, if any. + + """ + self.append_having(having) + + @_generative + def distinct(self): + """return a new select() construct which will apply DISTINCT to its columns + clause. + + """ + self._distinct = True + + @_generative + def prefix_with(self, clause): + """return a new select() construct which will apply the given expression to the + start of its columns clause, not using any commas. + + """ + clause = _literal_as_text(clause) + self._prefixes = self._prefixes + (clause,) + + @_generative + def select_from(self, fromclause): + """return a new select() construct with the given FROM expression applied to its + list of FROM objects. + + """ + fromclause = _literal_as_text(fromclause) + self._froms = self._froms.union([fromclause]) + + @_generative + def correlate(self, *fromclauses): + """return a new select() construct which will correlate the given FROM clauses to + that of an enclosing select(), if a match is found. + + By "match", the given fromclause must be present in this select's list of FROM + objects and also present in an enclosing select's list of FROM objects. + + Calling this method turns off the select's default behavior of + "auto-correlation". Normally, select() auto-correlates all of its FROM clauses to + those of an embedded select when compiled. + + If the fromclause is None, correlation is disabled for the returned select(). + + """ + self._should_correlate = False + if fromclauses == (None,): + self._correlate = set() + else: + self._correlate = self._correlate.union(fromclauses) + + def append_correlation(self, fromclause): + """append the given correlation expression to this select() construct.""" + + self._should_correlate = False + self._correlate = self._correlate.union([fromclause]) + + def append_column(self, column): + """append the given column expression to the columns clause of this select() + construct. + + """ + column = _literal_as_column(column) + + if isinstance(column, _ScalarSelect): + column = column.self_group(against=operators.comma_op) + + self._raw_columns = self._raw_columns + [column] + self._froms = self._froms.union(_from_objects(column)) + self._reset_exported() + + def append_prefix(self, clause): + """append the given columns clause prefix expression to this select() + construct. + + """ + clause = _literal_as_text(clause) + self._prefixes = self._prefixes + (clause,) + + def append_whereclause(self, whereclause): + """append the given expression to this select() construct's WHERE criterion. + + The expression will be joined to existing WHERE criterion via AND. + + """ + whereclause = _literal_as_text(whereclause) + self._froms = self._froms.union(_from_objects(whereclause)) + + if self._whereclause is not None: + self._whereclause = and_(self._whereclause, whereclause) + else: + self._whereclause = whereclause + + def append_having(self, having): + """append the given expression to this select() construct's HAVING criterion. + + The expression will be joined to existing HAVING criterion via AND. + + """ + if self._having is not None: + self._having = and_(self._having, _literal_as_text(having)) + else: + self._having = _literal_as_text(having) + + def append_from(self, fromclause): + """append the given FromClause expression to this select() construct's FROM + clause. + + """ + if _is_literal(fromclause): + fromclause = _TextClause(fromclause) + + self._froms = self._froms.union([fromclause]) + + def __exportable_columns(self): + for column in self._raw_columns: + if isinstance(column, Selectable): + for co in column.columns: + yield co + elif isinstance(column, ColumnElement): + yield column + else: + continue + + def _populate_column_collection(self): + for c in self.__exportable_columns(): + c._make_proxy(self, name=self.use_labels and c._label or None) + + def self_group(self, against=None): + """return a 'grouping' construct as per the ClauseElement specification. + + This produces an element that can be embedded in an expression. Note that + this method is called automatically as needed when constructing expressions. + + """ + if isinstance(against, CompoundSelect): + return self + return _FromGrouping(self) + + def union(self, other, **kwargs): + """return a SQL UNION of this select() construct against the given selectable.""" + + return union(self, other, **kwargs) + + def union_all(self, other, **kwargs): + """return a SQL UNION ALL of this select() construct against the given + selectable. + + """ + return union_all(self, other, **kwargs) + + def except_(self, other, **kwargs): + """return a SQL EXCEPT of this select() construct against the given selectable.""" + + return except_(self, other, **kwargs) + + def except_all(self, other, **kwargs): + """return a SQL EXCEPT ALL of this select() construct against the given + selectable. + + """ + return except_all(self, other, **kwargs) + + def intersect(self, other, **kwargs): + """return a SQL INTERSECT of this select() construct against the given + selectable. + + """ + return intersect(self, other, **kwargs) + + def intersect_all(self, other, **kwargs): + """return a SQL INTERSECT ALL of this select() construct against the given + selectable. + + """ + return intersect_all(self, other, **kwargs) + + def bind(self): + if self._bind: + return self._bind + if not self._froms: + for c in self._raw_columns: + e = c.bind + if e: + self._bind = e + return e + else: + e = list(self._froms)[0].bind + if e: + self._bind = e + return e + + return None + + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + +class _UpdateBase(Executable, ClauseElement): + """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" + + __visit_name__ = 'update_base' + + _execution_options = Executable._execution_options.union({'autocommit':True}) + kwargs = util.frozendict() + + def _process_colparams(self, parameters): + if isinstance(parameters, (list, tuple)): + pp = {} + for i, c in enumerate(self.table.c): + pp[c.key] = parameters[i] + return pp + else: + return parameters + + def params(self, *arg, **kw): + raise NotImplementedError( + "params() is not supported for INSERT/UPDATE/DELETE statements." + " To set the values for an INSERT or UPDATE statement, use" + " stmt.values(**parameters).") + + def bind(self): + return self._bind or self.table.bind + + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + + _returning_re = re.compile(r'(?:firebird|postgres(?:ql)?)_returning') + def _process_deprecated_kw(self, kwargs): + for k in list(kwargs): + m = self._returning_re.match(k) + if m: + self._returning = kwargs.pop(k) + util.warn_deprecated( + "The %r argument is deprecated. Please " + "use statement.returning(col1, col2, ...)" % k + ) + return kwargs + + @_generative + def returning(self, *cols): + """Add a RETURNING or equivalent clause to this statement. + + The given list of columns represent columns within the table + that is the target of the INSERT, UPDATE, or DELETE. Each + element can be any column expression. :class:`~sqlalchemy.schema.Table` + objects will be expanded into their individual columns. + + Upon compilation, a RETURNING clause, or database equivalent, + will be rendered within the statement. For INSERT and UPDATE, + the values are the newly inserted/updated values. For DELETE, + the values are those of the rows which were deleted. + + Upon execution, the values of the columns to be returned + are made available via the result set and can be iterated + using ``fetchone()`` and similar. For DBAPIs which do not + natively support returning values (i.e. cx_oracle), + SQLAlchemy will approximate this behavior at the result level + so that a reasonable amount of behavioral neutrality is + provided. + + Note that not all databases/DBAPIs + support RETURNING. For those backends with no support, + an exception is raised upon compilation and/or execution. + For those who do support it, the functionality across backends + varies greatly, including restrictions on executemany() + and other statements which return multiple rows. Please + read the documentation notes for the database in use in + order to determine the availability of RETURNING. + + """ + self._returning = cols + +class _ValuesBase(_UpdateBase): + + __visit_name__ = 'values_base' + + def __init__(self, table, values): + self.table = table + self.parameters = self._process_colparams(values) + + @_generative + def values(self, *args, **kwargs): + """specify the VALUES clause for an INSERT statement, or the SET clause for an + UPDATE. + + \**kwargs + key= arguments + + \*args + A single dictionary can be sent as the first positional argument. This + allows non-string based keys, such as Column objects, to be used. + + """ + if args: + v = args[0] + else: + v = {} + + if self.parameters is None: + self.parameters = self._process_colparams(v) + self.parameters.update(kwargs) + else: + self.parameters = self.parameters.copy() + self.parameters.update(self._process_colparams(v)) + self.parameters.update(kwargs) + +class Insert(_ValuesBase): + """Represent an INSERT construct. + + The :class:`Insert` object is created using the :func:`insert()` function. + + """ + __visit_name__ = 'insert' + + _prefixes = () + + def __init__(self, + table, + values=None, + inline=False, + bind=None, + prefixes=None, + returning=None, + **kwargs): + _ValuesBase.__init__(self, table, values) + self._bind = bind + self.select = None + self.inline = inline + self._returning = returning + if prefixes: + self._prefixes = tuple([_literal_as_text(p) for p in prefixes]) + + if kwargs: + self.kwargs = self._process_deprecated_kw(kwargs) + + def get_children(self, **kwargs): + if self.select is not None: + return self.select, + else: + return () + + def _copy_internals(self, clone=_clone): + # TODO: coverage + self.parameters = self.parameters.copy() + + @_generative + def prefix_with(self, clause): + """Add a word or expression between INSERT and INTO. Generative. + + If multiple prefixes are supplied, they will be separated with + spaces. + + """ + clause = _literal_as_text(clause) + self._prefixes = self._prefixes + (clause,) + +class Update(_ValuesBase): + """Represent an Update construct. + + The :class:`Update` object is created using the :func:`update()` function. + + """ + __visit_name__ = 'update' + + def __init__(self, + table, + whereclause, + values=None, + inline=False, + bind=None, + returning=None, + **kwargs): + _ValuesBase.__init__(self, table, values) + self._bind = bind + self._returning = returning + if whereclause is not None: + self._whereclause = _literal_as_text(whereclause) + else: + self._whereclause = None + self.inline = inline + + if kwargs: + self.kwargs = self._process_deprecated_kw(kwargs) + + def get_children(self, **kwargs): + if self._whereclause is not None: + return self._whereclause, + else: + return () + + def _copy_internals(self, clone=_clone): + # TODO: coverage + self._whereclause = clone(self._whereclause) + self.parameters = self.parameters.copy() + + @_generative + def where(self, whereclause): + """return a new update() construct with the given expression added to its WHERE + clause, joined to the existing clause via AND, if any. + + """ + if self._whereclause is not None: + self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) + else: + self._whereclause = _literal_as_text(whereclause) + + +class Delete(_UpdateBase): + """Represent a DELETE construct. + + The :class:`Delete` object is created using the :func:`delete()` function. + + """ + + __visit_name__ = 'delete' + + def __init__(self, + table, + whereclause, + bind=None, + returning =None, + **kwargs): + self._bind = bind + self.table = table + self._returning = returning + + if whereclause is not None: + self._whereclause = _literal_as_text(whereclause) + else: + self._whereclause = None + + if kwargs: + self.kwargs = self._process_deprecated_kw(kwargs) + + def get_children(self, **kwargs): + if self._whereclause is not None: + return self._whereclause, + else: + return () + + @_generative + def where(self, whereclause): + """Add the given WHERE clause to a newly returned delete construct.""" + + if self._whereclause is not None: + self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) + else: + self._whereclause = _literal_as_text(whereclause) + + def _copy_internals(self, clone=_clone): + # TODO: coverage + self._whereclause = clone(self._whereclause) + +class _IdentifiedClause(Executable, ClauseElement): + __visit_name__ = 'identified' + _execution_options = Executable._execution_options.union({'autocommit':False}) + quote = None + + def __init__(self, ident): + self.ident = ident + +class SavepointClause(_IdentifiedClause): + __visit_name__ = 'savepoint' + +class RollbackToSavepointClause(_IdentifiedClause): + __visit_name__ = 'rollback_to_savepoint' + +class ReleaseSavepointClause(_IdentifiedClause): + __visit_name__ = 'release_savepoint' + + diff --git a/sqlalchemy/sql/functions.py b/sqlalchemy/sql/functions.py new file mode 100644 index 0000000..212f81a --- /dev/null +++ b/sqlalchemy/sql/functions.py @@ -0,0 +1,104 @@ +from sqlalchemy import types as sqltypes +from sqlalchemy.sql.expression import ( + ClauseList, Function, _literal_as_binds, text, _type_from_args + ) +from sqlalchemy.sql import operators +from sqlalchemy.sql.visitors import VisitableType + +class _GenericMeta(VisitableType): + def __call__(self, *args, **kwargs): + args = [_literal_as_binds(c) for c in args] + return type.__call__(self, *args, **kwargs) + +class GenericFunction(Function): + __metaclass__ = _GenericMeta + + def __init__(self, type_=None, args=(), **kwargs): + self.packagenames = [] + self.name = self.__class__.__name__ + self._bind = kwargs.get('bind', None) + self.clause_expr = ClauseList( + operator=operators.comma_op, + group_contents=True, *args).self_group() + self.type = sqltypes.to_instance( + type_ or getattr(self, '__return_type__', None)) + +class AnsiFunction(GenericFunction): + def __init__(self, **kwargs): + GenericFunction.__init__(self, **kwargs) + +class ReturnTypeFromArgs(GenericFunction): + """Define a function whose return type is the same as its arguments.""" + + def __init__(self, *args, **kwargs): + kwargs.setdefault('type_', _type_from_args(args)) + GenericFunction.__init__(self, args=args, **kwargs) + +class coalesce(ReturnTypeFromArgs): + pass + +class max(ReturnTypeFromArgs): + pass + +class min(ReturnTypeFromArgs): + pass + +class sum(ReturnTypeFromArgs): + pass + +class now(GenericFunction): + __return_type__ = sqltypes.DateTime + +class concat(GenericFunction): + __return_type__ = sqltypes.String + def __init__(self, *args, **kwargs): + GenericFunction.__init__(self, args=args, **kwargs) + +class char_length(GenericFunction): + __return_type__ = sqltypes.Integer + + def __init__(self, arg, **kwargs): + GenericFunction.__init__(self, args=[arg], **kwargs) + +class random(GenericFunction): + def __init__(self, *args, **kwargs): + kwargs.setdefault('type_', None) + GenericFunction.__init__(self, args=args, **kwargs) + +class count(GenericFunction): + """The ANSI COUNT aggregate function. With no arguments, emits COUNT \*.""" + + __return_type__ = sqltypes.Integer + + def __init__(self, expression=None, **kwargs): + if expression is None: + expression = text('*') + GenericFunction.__init__(self, args=(expression,), **kwargs) + +class current_date(AnsiFunction): + __return_type__ = sqltypes.Date + +class current_time(AnsiFunction): + __return_type__ = sqltypes.Time + +class current_timestamp(AnsiFunction): + __return_type__ = sqltypes.DateTime + +class current_user(AnsiFunction): + __return_type__ = sqltypes.String + +class localtime(AnsiFunction): + __return_type__ = sqltypes.DateTime + +class localtimestamp(AnsiFunction): + __return_type__ = sqltypes.DateTime + +class session_user(AnsiFunction): + __return_type__ = sqltypes.String + +class sysdate(AnsiFunction): + __return_type__ = sqltypes.DateTime + +class user(AnsiFunction): + __return_type__ = sqltypes.String + diff --git a/sqlalchemy/sql/operators.py b/sqlalchemy/sql/operators.py new file mode 100644 index 0000000..6f70b17 --- /dev/null +++ b/sqlalchemy/sql/operators.py @@ -0,0 +1,135 @@ +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Defines operators used in SQL expressions.""" + +from operator import ( + and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq, neg + ) + +# Py2K +from operator import (div,) +# end Py2K + +from sqlalchemy.util import symbol + + +def from_(): + raise NotImplementedError() + +def as_(): + raise NotImplementedError() + +def exists(): + raise NotImplementedError() + +def is_(): + raise NotImplementedError() + +def isnot(): + raise NotImplementedError() + +def collate(): + raise NotImplementedError() + +def op(a, opstring, b): + return a.op(opstring)(b) + +def like_op(a, b, escape=None): + return a.like(b, escape=escape) + +def notlike_op(a, b, escape=None): + raise NotImplementedError() + +def ilike_op(a, b, escape=None): + return a.ilike(b, escape=escape) + +def notilike_op(a, b, escape=None): + raise NotImplementedError() + +def between_op(a, b, c): + return a.between(b, c) + +def in_op(a, b): + return a.in_(b) + +def notin_op(a, b): + raise NotImplementedError() + +def distinct_op(a): + return a.distinct() + +def startswith_op(a, b, escape=None): + return a.startswith(b, escape=escape) + +def endswith_op(a, b, escape=None): + return a.endswith(b, escape=escape) + +def contains_op(a, b, escape=None): + return a.contains(b, escape=escape) + +def match_op(a, b): + return a.match(b) + +def comma_op(a, b): + raise NotImplementedError() + +def concat_op(a, b): + return a.concat(b) + +def desc_op(a): + return a.desc() + +def asc_op(a): + return a.asc() + +_commutative = set([eq, ne, add, mul]) +def is_commutative(op): + return op in _commutative + +_smallest = symbol('_smallest') +_largest = symbol('_largest') + +_PRECEDENCE = { + from_: 15, + mul: 7, + truediv: 7, + # Py2K + div: 7, + # end Py2K + mod: 7, + neg: 7, + add: 6, + sub: 6, + concat_op: 6, + match_op: 6, + ilike_op: 5, + notilike_op: 5, + like_op: 5, + notlike_op: 5, + in_op: 5, + notin_op: 5, + is_: 5, + isnot: 5, + eq: 5, + ne: 5, + gt: 5, + lt: 5, + ge: 5, + le: 5, + between_op: 5, + distinct_op: 5, + inv: 5, + and_: 3, + or_: 2, + comma_op: -1, + collate: 7, + as_: -1, + exists: 0, + _smallest: -1000, + _largest: 1000 +} + +def is_precedent(operator, against): + return (_PRECEDENCE.get(operator, _PRECEDENCE[_smallest]) <= + _PRECEDENCE.get(against, _PRECEDENCE[_largest])) diff --git a/sqlalchemy/sql/util.py b/sqlalchemy/sql/util.py new file mode 100644 index 0000000..d5575e0 --- /dev/null +++ b/sqlalchemy/sql/util.py @@ -0,0 +1,651 @@ +from sqlalchemy import exc, schema, topological, util, sql, types as sqltypes +from sqlalchemy.sql import expression, operators, visitors +from itertools import chain + +"""Utility functions that build upon SQL and Schema constructs.""" + +def sort_tables(tables): + """sort a collection of Table objects in order of their foreign-key dependency.""" + + tables = list(tables) + tuples = [] + def visit_foreign_key(fkey): + if fkey.use_alter: + return + parent_table = fkey.column.table + if parent_table in tables: + child_table = fkey.parent.table + tuples.append( ( parent_table, child_table ) ) + + for table in tables: + visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key}) + return topological.sort(tuples, tables) + +def find_join_source(clauses, join_to): + """Given a list of FROM clauses and a selectable, + return the first index and element from the list of + clauses which can be joined against the selectable. returns + None, None if no match is found. + + e.g.:: + + clause1 = table1.join(table2) + clause2 = table4.join(table5) + + join_to = table2.join(table3) + + find_join_source([clause1, clause2], join_to) == clause1 + + """ + + selectables = list(expression._from_objects(join_to)) + for i, f in enumerate(clauses): + for s in selectables: + if f.is_derived_from(s): + return i, f + else: + return None, None + + + +def find_tables(clause, check_columns=False, + include_aliases=False, include_joins=False, + include_selects=False, include_crud=False): + """locate Table objects within the given expression.""" + + tables = [] + _visitors = {} + + if include_selects: + _visitors['select'] = _visitors['compound_select'] = tables.append + + if include_joins: + _visitors['join'] = tables.append + + if include_aliases: + _visitors['alias'] = tables.append + + if include_crud: + _visitors['insert'] = _visitors['update'] = \ + _visitors['delete'] = lambda ent: tables.append(ent.table) + + if check_columns: + def visit_column(column): + tables.append(column.table) + _visitors['column'] = visit_column + + _visitors['table'] = tables.append + + visitors.traverse(clause, {'column_collections':False}, _visitors) + return tables + +def find_columns(clause): + """locate Column objects within the given expression.""" + + cols = util.column_set() + visitors.traverse(clause, {}, {'column':cols.add}) + return cols + +def _quote_ddl_expr(element): + if isinstance(element, basestring): + element = element.replace("'", "''") + return "'%s'" % element + else: + return repr(element) + +def expression_as_ddl(clause): + """Given a SQL expression, convert for usage in DDL, such as + CREATE INDEX and CHECK CONSTRAINT. + + Converts bind params into quoted literals, column identifiers + into detached column constructs so that the parent table + identifier is not included. + + """ + def repl(element): + if isinstance(element, expression._BindParamClause): + return expression.literal_column(_quote_ddl_expr(element.value)) + elif isinstance(element, expression.ColumnClause) and \ + element.table is not None: + return expression.column(element.name) + else: + return None + + return visitors.replacement_traverse(clause, {}, repl) + +def adapt_criterion_to_null(crit, nulls): + """given criterion containing bind params, convert selected elements to IS NULL.""" + + def visit_binary(binary): + if isinstance(binary.left, expression._BindParamClause) and binary.left.key in nulls: + # reverse order if the NULL is on the left side + binary.left = binary.right + binary.right = expression.null() + binary.operator = operators.is_ + binary.negate = operators.isnot + elif isinstance(binary.right, expression._BindParamClause) and binary.right.key in nulls: + binary.right = expression.null() + binary.operator = operators.is_ + binary.negate = operators.isnot + + return visitors.cloned_traverse(crit, {}, {'binary':visit_binary}) + + +def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): + """create a join condition between two tables or selectables. + + e.g.:: + + join_condition(tablea, tableb) + + would produce an expression along the lines of:: + + tablea.c.id==tableb.c.tablea_id + + The join is determined based on the foreign key relationships + between the two selectables. If there are multiple ways + to join, or no way to join, an error is raised. + + :param ignore_nonexistent_tables: This flag will cause the + function to silently skip over foreign key resolution errors + due to nonexistent tables - the assumption is that these + tables have not yet been defined within an initialization process + and are not significant to the operation. + + :param a_subset: An optional expression that is a sub-component + of ``a``. An attempt will be made to join to just this sub-component + first before looking at the full ``a`` construct, and if found + will be successful even if there are other ways to join to ``a``. + This allows the "right side" of a join to be passed thereby + providing a "natural join". + + """ + crit = [] + constraints = set() + + for left in (a_subset, a): + if left is None: + continue + for fk in b.foreign_keys: + try: + col = fk.get_referent(left) + except exc.NoReferencedTableError: + if ignore_nonexistent_tables: + continue + else: + raise + + if col is not None: + crit.append(col == fk.parent) + constraints.add(fk.constraint) + if left is not b: + for fk in left.foreign_keys: + try: + col = fk.get_referent(b) + except exc.NoReferencedTableError: + if ignore_nonexistent_tables: + continue + else: + raise + + if col is not None: + crit.append(col == fk.parent) + constraints.add(fk.constraint) + if crit: + break + + if len(crit) == 0: + if isinstance(b, expression._FromGrouping): + hint = " Perhaps you meant to convert the right side to a subquery using alias()?" + else: + hint = "" + raise exc.ArgumentError( + "Can't find any foreign key relationships " + "between '%s' and '%s'.%s" % (a.description, b.description, hint)) + elif len(constraints) > 1: + raise exc.ArgumentError( + "Can't determine join between '%s' and '%s'; " + "tables have more than one foreign key " + "constraint relationship between them. " + "Please specify the 'onclause' of this " + "join explicitly." % (a.description, b.description)) + elif len(crit) == 1: + return (crit[0]) + else: + return sql.and_(*crit) + + +class Annotated(object): + """clones a ClauseElement and applies an 'annotations' dictionary. + + Unlike regular clones, this clone also mimics __hash__() and + __cmp__() of the original element so that it takes its place + in hashed collections. + + A reference to the original element is maintained, for the important + reason of keeping its hash value current. When GC'ed, the + hash value may be reused, causing conflicts. + + """ + + def __new__(cls, *args): + if not args: + # clone constructor + return object.__new__(cls) + else: + element, values = args + # pull appropriate subclass from registry of annotated + # classes + try: + cls = annotated_classes[element.__class__] + except KeyError: + cls = annotated_classes[element.__class__] = type.__new__(type, + "Annotated%s" % element.__class__.__name__, + (Annotated, element.__class__), {}) + return object.__new__(cls) + + def __init__(self, element, values): + # force FromClause to generate their internal + # collections into __dict__ + if isinstance(element, expression.FromClause): + element.c + + self.__dict__ = element.__dict__.copy() + self.__element = element + self._annotations = values + + def _annotate(self, values): + _values = self._annotations.copy() + _values.update(values) + clone = self.__class__.__new__(self.__class__) + clone.__dict__ = self.__dict__.copy() + clone._annotations = _values + return clone + + def _deannotate(self): + return self.__element + + def _clone(self): + clone = self.__element._clone() + if clone is self.__element: + # detect immutable, don't change anything + return self + else: + # update the clone with any changes that have occured + # to this object's __dict__. + clone.__dict__.update(self.__dict__) + return Annotated(clone, self._annotations) + + def __hash__(self): + return hash(self.__element) + + def __cmp__(self, other): + return cmp(hash(self.__element), hash(other)) + +# hard-generate Annotated subclasses. this technique +# is used instead of on-the-fly types (i.e. type.__new__()) +# so that the resulting objects are pickleable. +annotated_classes = {} + +from sqlalchemy.sql import expression +for cls in expression.__dict__.values() + [schema.Column, schema.Table]: + if isinstance(cls, type) and issubclass(cls, expression.ClauseElement): + exec "class Annotated%s(Annotated, cls):\n" \ + " __visit_name__ = cls.__visit_name__\n"\ + " pass" % (cls.__name__, ) in locals() + exec "annotated_classes[cls] = Annotated%s" % (cls.__name__) + +def _deep_annotate(element, annotations, exclude=None): + """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. + + Elements within the exclude collection will be cloned but not annotated. + + """ + def clone(elem): + # check if element is present in the exclude list. + # take into account proxying relationships. + if exclude and \ + hasattr(elem, 'proxy_set') and \ + elem.proxy_set.intersection(exclude): + elem = elem._clone() + elif annotations != elem._annotations: + elem = elem._annotate(annotations.copy()) + elem._copy_internals(clone=clone) + return elem + + if element is not None: + element = clone(element) + return element + +def _deep_deannotate(element): + """Deep copy the given element, removing all annotations.""" + + def clone(elem): + elem = elem._deannotate() + elem._copy_internals(clone=clone) + return elem + + if element is not None: + element = clone(element) + return element + + +def splice_joins(left, right, stop_on=None): + if left is None: + return right + + stack = [(right, None)] + + adapter = ClauseAdapter(left) + ret = None + while stack: + (right, prevright) = stack.pop() + if isinstance(right, expression.Join) and right is not stop_on: + right = right._clone() + right._reset_exported() + right.onclause = adapter.traverse(right.onclause) + stack.append((right.left, right)) + else: + right = adapter.traverse(right) + if prevright is not None: + prevright.left = right + if ret is None: + ret = right + + return ret + +def reduce_columns(columns, *clauses, **kw): + """given a list of columns, return a 'reduced' set based on natural equivalents. + + the set is reduced to the smallest list of columns which have no natural + equivalent present in the list. A "natural equivalent" means that two columns + will ultimately represent the same value because they are related by a foreign key. + + \*clauses is an optional list of join clauses which will be traversed + to further identify columns that are "equivalent". + + \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys + whose tables are not yet configured. + + This function is primarily used to determine the most minimal "primary key" + from a selectable, by reducing the set of primary key columns present + in the the selectable to just those that are not repeated. + + """ + ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False) + + columns = util.ordered_column_set(columns) + + omit = util.column_set() + for col in columns: + for fk in chain(*[c.foreign_keys for c in col.proxy_set]): + for c in columns: + if c is col: + continue + try: + fk_col = fk.column + except exc.NoReferencedTableError: + if ignore_nonexistent_tables: + continue + else: + raise + if fk_col.shares_lineage(c): + omit.add(col) + break + + if clauses: + def visit_binary(binary): + if binary.operator == operators.eq: + cols = util.column_set(chain(*[c.proxy_set for c in columns.difference(omit)])) + if binary.left in cols and binary.right in cols: + for c in columns: + if c.shares_lineage(binary.right): + omit.add(c) + break + for clause in clauses: + visitors.traverse(clause, {}, {'binary':visit_binary}) + + return expression.ColumnSet(columns.difference(omit)) + +def criterion_as_pairs(expression, consider_as_foreign_keys=None, + consider_as_referenced_keys=None, any_operator=False): + """traverse an expression and locate binary criterion pairs.""" + + if consider_as_foreign_keys and consider_as_referenced_keys: + raise exc.ArgumentError("Can only specify one of " + "'consider_as_foreign_keys' or " + "'consider_as_referenced_keys'") + + def visit_binary(binary): + if not any_operator and binary.operator is not operators.eq: + return + if not isinstance(binary.left, sql.ColumnElement) or \ + not isinstance(binary.right, sql.ColumnElement): + return + + if consider_as_foreign_keys: + if binary.left in consider_as_foreign_keys and \ + (binary.right is binary.left or + binary.right not in consider_as_foreign_keys): + pairs.append((binary.right, binary.left)) + elif binary.right in consider_as_foreign_keys and \ + (binary.left is binary.right or + binary.left not in consider_as_foreign_keys): + pairs.append((binary.left, binary.right)) + elif consider_as_referenced_keys: + if binary.left in consider_as_referenced_keys and \ + (binary.right is binary.left or + binary.right not in consider_as_referenced_keys): + pairs.append((binary.left, binary.right)) + elif binary.right in consider_as_referenced_keys and \ + (binary.left is binary.right or + binary.left not in consider_as_referenced_keys): + pairs.append((binary.right, binary.left)) + else: + if isinstance(binary.left, schema.Column) and \ + isinstance(binary.right, schema.Column): + if binary.left.references(binary.right): + pairs.append((binary.right, binary.left)) + elif binary.right.references(binary.left): + pairs.append((binary.left, binary.right)) + pairs = [] + visitors.traverse(expression, {}, {'binary':visit_binary}) + return pairs + +def folded_equivalents(join, equivs=None): + """Return a list of uniquely named columns. + + The column list of the given Join will be narrowed + down to a list of all equivalently-named, + equated columns folded into one column, where 'equated' means they are + equated to each other in the ON clause of this join. + + This function is used by Join.select(fold_equivalents=True). + + Deprecated. This function is used for a certain kind of + "polymorphic_union" which is designed to achieve joined + table inheritance where the base table has no "discriminator" + column; [ticket:1131] will provide a better way to + achieve this. + + """ + if equivs is None: + equivs = set() + def visit_binary(binary): + if binary.operator == operators.eq and binary.left.name == binary.right.name: + equivs.add(binary.right) + equivs.add(binary.left) + visitors.traverse(join.onclause, {}, {'binary':visit_binary}) + collist = [] + if isinstance(join.left, expression.Join): + left = folded_equivalents(join.left, equivs) + else: + left = list(join.left.columns) + if isinstance(join.right, expression.Join): + right = folded_equivalents(join.right, equivs) + else: + right = list(join.right.columns) + used = set() + for c in left + right: + if c in equivs: + if c.name not in used: + collist.append(c) + used.add(c.name) + else: + collist.append(c) + return collist + +class AliasedRow(object): + """Wrap a RowProxy with a translation map. + + This object allows a set of keys to be translated + to those present in a RowProxy. + + """ + def __init__(self, row, map): + # AliasedRow objects don't nest, so un-nest + # if another AliasedRow was passed + if isinstance(row, AliasedRow): + self.row = row.row + else: + self.row = row + self.map = map + + def __contains__(self, key): + return self.map[key] in self.row + + def has_key(self, key): + return key in self + + def __getitem__(self, key): + return self.row[self.map[key]] + + def keys(self): + return self.row.keys() + + +class ClauseAdapter(visitors.ReplacingCloningVisitor): + """Clones and modifies clauses based on column correspondence. + + E.g.:: + + table1 = Table('sometable', metadata, + Column('col1', Integer), + Column('col2', Integer) + ) + table2 = Table('someothertable', metadata, + Column('col1', Integer), + Column('col2', Integer) + ) + + condition = table1.c.col1 == table2.c.col1 + + make an alias of table1:: + + s = table1.alias('foo') + + calling ``ClauseAdapter(s).traverse(condition)`` converts + condition to read:: + + s.c.col1 == table2.c.col1 + + """ + def __init__(self, selectable, equivalents=None, include=None, exclude=None): + self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]} + self.selectable = selectable + self.include = include + self.exclude = exclude + self.equivalents = util.column_dict(equivalents or {}) + + def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET): + newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded) + + if newcol is None and col in self.equivalents and col not in _seen: + for equiv in self.equivalents[col]: + newcol = self._corresponding_column(equiv, require_embedded=require_embedded, _seen=_seen.union([col])) + if newcol is not None: + return newcol + return newcol + + def replace(self, col): + if isinstance(col, expression.FromClause): + if self.selectable.is_derived_from(col): + return self.selectable + + if not isinstance(col, expression.ColumnElement): + return None + + if self.include and col not in self.include: + return None + elif self.exclude and col in self.exclude: + return None + + return self._corresponding_column(col, True) + +class ColumnAdapter(ClauseAdapter): + """Extends ClauseAdapter with extra utility functions. + + Provides the ability to "wrap" this ClauseAdapter + around another, a columns dictionary which returns + adapted elements given an original, and an + adapted_row() factory. + + """ + def __init__(self, selectable, equivalents=None, + chain_to=None, include=None, + exclude=None, adapt_required=False): + ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) + if chain_to: + self.chain(chain_to) + self.columns = util.populate_column_dict(self._locate_col) + self.adapt_required = adapt_required + + def wrap(self, adapter): + ac = self.__class__.__new__(self.__class__) + ac.__dict__ = self.__dict__.copy() + ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col) + ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause) + ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list) + ac.columns = util.populate_column_dict(ac._locate_col) + return ac + + adapt_clause = ClauseAdapter.traverse + adapt_list = ClauseAdapter.copy_and_process + + def _wrap(self, local, wrapped): + def locate(col): + col = local(col) + return wrapped(col) + return locate + + def _locate_col(self, col): + c = self._corresponding_column(col, True) + if c is None: + c = self.adapt_clause(col) + + # anonymize labels in case they have a hardcoded name + if isinstance(c, expression._Label): + c = c.label(None) + + # adapt_required indicates that if we got the same column + # back which we put in (i.e. it passed through), + # it's not correct. this is used by eagerloading which + # knows that all columns and expressions need to be adapted + # to a result row, and a "passthrough" is definitely targeting + # the wrong column. + if self.adapt_required and c is col: + return None + + return c + + def adapted_row(self, row): + return AliasedRow(row, self.columns) + + def __getstate__(self): + d = self.__dict__.copy() + del d['columns'] + return d + + def __setstate__(self, state): + self.__dict__.update(state) + self.columns = util.PopulateDict(self._locate_col) diff --git a/sqlalchemy/sql/visitors.py b/sqlalchemy/sql/visitors.py new file mode 100644 index 0000000..4a54375 --- /dev/null +++ b/sqlalchemy/sql/visitors.py @@ -0,0 +1,256 @@ +"""Visitor/traversal interface and library functions. + +SQLAlchemy schema and expression constructs rely on a Python-centric +version of the classic "visitor" pattern as the primary way in which +they apply functionality. The most common use of this pattern +is statement compilation, where individual expression classes match +up to rendering methods that produce a string result. Beyond this, +the visitor system is also used to inspect expressions for various +information and patterns, as well as for usage in +some kinds of expression transformation. Other kinds of transformation +use a non-visitor traversal system. + +For many examples of how the visit system is used, see the +sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules. +For an introduction to clause adaption, see +http://techspot.zzzeek.org/?p=19 . + +""" + +from collections import deque +import re +from sqlalchemy import util +import operator + +__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', + 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate', + 'iterate_depthfirst', 'traverse_using', 'traverse', + 'cloned_traverse', 'replacement_traverse'] + +class VisitableType(type): + """Metaclass which checks for a `__visit_name__` attribute and + applies `_compiler_dispatch` method to classes. + + """ + + def __init__(cls, clsname, bases, clsdict): + if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'): + super(VisitableType, cls).__init__(clsname, bases, clsdict) + return + + # set up an optimized visit dispatch function + # for use by the compiler + visit_name = cls.__visit_name__ + if isinstance(visit_name, str): + getter = operator.attrgetter("visit_%s" % visit_name) + def _compiler_dispatch(self, visitor, **kw): + return getter(visitor)(self, **kw) + else: + def _compiler_dispatch(self, visitor, **kw): + return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) + + cls._compiler_dispatch = _compiler_dispatch + + super(VisitableType, cls).__init__(clsname, bases, clsdict) + +class Visitable(object): + """Base class for visitable objects, applies the + ``VisitableType`` metaclass. + + """ + + __metaclass__ = VisitableType + +class ClauseVisitor(object): + """Base class for visitor objects which can traverse using + the traverse() function. + + """ + + __traverse_options__ = {} + + def traverse_single(self, obj): + for v in self._visitor_iterator: + meth = getattr(v, "visit_%s" % obj.__visit_name__, None) + if meth: + return meth(obj) + + def iterate(self, obj): + """traverse the given expression structure, returning an iterator of all elements.""" + + return iterate(obj, self.__traverse_options__) + + def traverse(self, obj): + """traverse and visit the given expression structure.""" + + return traverse(obj, self.__traverse_options__, self._visitor_dict) + + @util.memoized_property + def _visitor_dict(self): + visitors = {} + + for name in dir(self): + if name.startswith('visit_'): + visitors[name[6:]] = getattr(self, name) + return visitors + + @property + def _visitor_iterator(self): + """iterate through this visitor and each 'chained' visitor.""" + + v = self + while v: + yield v + v = getattr(v, '_next', None) + + def chain(self, visitor): + """'chain' an additional ClauseVisitor onto this ClauseVisitor. + + the chained visitor will receive all visit events after this one. + + """ + tail = list(self._visitor_iterator)[-1] + tail._next = visitor + return self + +class CloningVisitor(ClauseVisitor): + """Base class for visitor objects which can traverse using + the cloned_traverse() function. + + """ + + def copy_and_process(self, list_): + """Apply cloned traversal to the given list of elements, and return the new list.""" + + return [self.traverse(x) for x in list_] + + def traverse(self, obj): + """traverse and visit the given expression structure.""" + + return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict) + +class ReplacingCloningVisitor(CloningVisitor): + """Base class for visitor objects which can traverse using + the replacement_traverse() function. + + """ + + def replace(self, elem): + """receive pre-copied elements during a cloning traversal. + + If the method returns a new element, the element is used + instead of creating a simple copy of the element. Traversal + will halt on the newly returned element if it is re-encountered. + """ + return None + + def traverse(self, obj): + """traverse and visit the given expression structure.""" + + def replace(elem): + for v in self._visitor_iterator: + e = v.replace(elem) + if e is not None: + return e + return replacement_traverse(obj, self.__traverse_options__, replace) + +def iterate(obj, opts): + """traverse the given expression structure, returning an iterator. + + traversal is configured to be breadth-first. + + """ + stack = deque([obj]) + while stack: + t = stack.popleft() + yield t + for c in t.get_children(**opts): + stack.append(c) + +def iterate_depthfirst(obj, opts): + """traverse the given expression structure, returning an iterator. + + traversal is configured to be depth-first. + + """ + stack = deque([obj]) + traversal = deque() + while stack: + t = stack.pop() + traversal.appendleft(t) + for c in t.get_children(**opts): + stack.append(c) + return iter(traversal) + +def traverse_using(iterator, obj, visitors): + """visit the given expression structure using the given iterator of objects.""" + + for target in iterator: + meth = visitors.get(target.__visit_name__, None) + if meth: + meth(target) + return obj + +def traverse(obj, opts, visitors): + """traverse and visit the given expression structure using the default iterator.""" + + return traverse_using(iterate(obj, opts), obj, visitors) + +def traverse_depthfirst(obj, opts, visitors): + """traverse and visit the given expression structure using the depth-first iterator.""" + + return traverse_using(iterate_depthfirst(obj, opts), obj, visitors) + +def cloned_traverse(obj, opts, visitors): + """clone the given expression structure, allowing modifications by visitors.""" + + cloned = util.column_dict() + + def clone(element): + if element not in cloned: + cloned[element] = element._clone() + return cloned[element] + + obj = clone(obj) + stack = [obj] + + while stack: + t = stack.pop() + if t in cloned: + continue + t._copy_internals(clone=clone) + + meth = visitors.get(t.__visit_name__, None) + if meth: + meth(t) + + for c in t.get_children(**opts): + stack.append(c) + return obj + +def replacement_traverse(obj, opts, replace): + """clone the given expression structure, allowing element replacement by a given replacement function.""" + + cloned = util.column_dict() + stop_on = util.column_set(opts.get('stop_on', [])) + + def clone(element): + newelem = replace(element) + if newelem is not None: + stop_on.add(newelem) + return newelem + + if element not in cloned: + cloned[element] = element._clone() + return cloned[element] + + obj = clone(obj) + stack = [obj] + while stack: + t = stack.pop() + if t in stop_on: + continue + t._copy_internals(clone=clone) + for c in t.get_children(**opts): + stack.append(c) + return obj diff --git a/sqlalchemy/test/__init__.py b/sqlalchemy/test/__init__.py new file mode 100644 index 0000000..d69cede --- /dev/null +++ b/sqlalchemy/test/__init__.py @@ -0,0 +1,26 @@ +"""Testing environment and utilities. + +This package contains base classes and routines used by +the unit tests. Tests are based on Nose and bootstrapped +by noseplugin.NoseSQLAlchemy. + +""" + +from sqlalchemy.test import testing, engines, requires, profiling, pickleable, config +from sqlalchemy.test.schema import Column, Table +from sqlalchemy.test.testing import \ + AssertsCompiledSQL, \ + AssertsExecutionResults, \ + ComparesTables, \ + TestBase, \ + rowset + + +__all__ = ('testing', + 'Column', 'Table', + 'rowset', + 'TestBase', 'AssertsExecutionResults', + 'AssertsCompiledSQL', 'ComparesTables', + 'engines', 'profiling', 'pickleable') + + diff --git a/sqlalchemy/test/assertsql.py b/sqlalchemy/test/assertsql.py new file mode 100644 index 0000000..1417c2e --- /dev/null +++ b/sqlalchemy/test/assertsql.py @@ -0,0 +1,285 @@ + +from sqlalchemy.interfaces import ConnectionProxy +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.engine.base import Connection +from sqlalchemy import util +import re + +class AssertRule(object): + def process_execute(self, clauseelement, *multiparams, **params): + pass + + def process_cursor_execute(self, statement, parameters, context, executemany): + pass + + def is_consumed(self): + """Return True if this rule has been consumed, False if not. + + Should raise an AssertionError if this rule's condition has definitely failed. + + """ + raise NotImplementedError() + + def rule_passed(self): + """Return True if the last test of this rule passed, False if failed, None if no test was applied.""" + + raise NotImplementedError() + + def consume_final(self): + """Return True if this rule has been consumed. + + Should raise an AssertionError if this rule's condition has not been consumed or has failed. + + """ + + if self._result is None: + assert False, "Rule has not been consumed" + + return self.is_consumed() + +class SQLMatchRule(AssertRule): + def __init__(self): + self._result = None + self._errmsg = "" + + def rule_passed(self): + return self._result + + def is_consumed(self): + if self._result is None: + return False + + assert self._result, self._errmsg + + return True + +class ExactSQL(SQLMatchRule): + def __init__(self, sql, params=None): + SQLMatchRule.__init__(self) + self.sql = sql + self.params = params + + def process_cursor_execute(self, statement, parameters, context, executemany): + if not context: + return + + _received_statement = _process_engine_statement(context.unicode_statement, context) + _received_parameters = context.compiled_parameters + + # TODO: remove this step once all unit tests + # are migrated, as ExactSQL should really be *exact* SQL + sql = _process_assertion_statement(self.sql, context) + + equivalent = _received_statement == sql + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + + if not isinstance(params, list): + params = [params] + equivalent = equivalent and params == context.compiled_parameters + else: + params = {} + + + self._result = equivalent + if not self._result: + self._errmsg = "Testing for exact statement %r exact params %r, " \ + "received %r with params %r" % (sql, params, _received_statement, _received_parameters) + + +class RegexSQL(SQLMatchRule): + def __init__(self, regex, params=None): + SQLMatchRule.__init__(self) + self.regex = re.compile(regex) + self.orig_regex = regex + self.params = params + + def process_cursor_execute(self, statement, parameters, context, executemany): + if not context: + return + + _received_statement = _process_engine_statement(context.unicode_statement, context) + _received_parameters = context.compiled_parameters + + equivalent = bool(self.regex.match(_received_statement)) + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + + if not isinstance(params, list): + params = [params] + + # do a positive compare only + for param, received in zip(params, _received_parameters): + for k, v in param.iteritems(): + if k not in received or received[k] != v: + equivalent = False + break + else: + params = {} + + self._result = equivalent + if not self._result: + self._errmsg = "Testing for regex %r partial params %r, "\ + "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters) + +class CompiledSQL(SQLMatchRule): + def __init__(self, statement, params): + SQLMatchRule.__init__(self) + self.statement = statement + self.params = params + + def process_cursor_execute(self, statement, parameters, context, executemany): + if not context: + return + + _received_parameters = context.compiled_parameters + + # recompile from the context, using the default dialect + compiled = context.compiled.statement.\ + compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys) + + _received_statement = re.sub(r'\n', '', str(compiled)) + + equivalent = self.statement == _received_statement + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + + if not isinstance(params, list): + params = [params] + + # do a positive compare only + for param, received in zip(params, _received_parameters): + for k, v in param.iteritems(): + if k not in received or received[k] != v: + equivalent = False + break + else: + params = {} + + self._result = equivalent + if not self._result: + self._errmsg = "Testing for compiled statement %r partial params %r, " \ + "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters) + + +class CountStatements(AssertRule): + def __init__(self, count): + self.count = count + self._statement_count = 0 + + def process_execute(self, clauseelement, *multiparams, **params): + self._statement_count += 1 + + def process_cursor_execute(self, statement, parameters, context, executemany): + pass + + def is_consumed(self): + return False + + def consume_final(self): + assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count) + return True + +class AllOf(AssertRule): + def __init__(self, *rules): + self.rules = set(rules) + + def process_execute(self, clauseelement, *multiparams, **params): + for rule in self.rules: + rule.process_execute(clauseelement, *multiparams, **params) + + def process_cursor_execute(self, statement, parameters, context, executemany): + for rule in self.rules: + rule.process_cursor_execute(statement, parameters, context, executemany) + + def is_consumed(self): + if not self.rules: + return True + + for rule in list(self.rules): + if rule.rule_passed(): # a rule passed, move on + self.rules.remove(rule) + return len(self.rules) == 0 + + assert False, "No assertion rules were satisfied for statement" + + def consume_final(self): + return len(self.rules) == 0 + +def _process_engine_statement(query, context): + if util.jython: + # oracle+zxjdbc passes a PyStatement when returning into + query = unicode(query) + if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'): + query = query[:-25] + + query = re.sub(r'\n', '', query) + + return query + +def _process_assertion_statement(query, context): + paramstyle = context.dialect.paramstyle + if paramstyle == 'named': + pass + elif paramstyle =='pyformat': + query = re.sub(r':([\w_]+)', r"%(\1)s", query) + else: + # positional params + repl = None + if paramstyle=='qmark': + repl = "?" + elif paramstyle=='format': + repl = r"%s" + elif paramstyle=='numeric': + repl = None + query = re.sub(r':([\w_]+)', repl, query) + + return query + +class SQLAssert(ConnectionProxy): + rules = None + + def add_rules(self, rules): + self.rules = list(rules) + + def statement_complete(self): + for rule in self.rules: + if not rule.consume_final(): + assert False, "All statements are complete, but pending assertion rules remain" + + def clear_rules(self): + del self.rules + + def execute(self, conn, execute, clauseelement, *multiparams, **params): + result = execute(clauseelement, *multiparams, **params) + + if self.rules is not None: + if not self.rules: + assert False, "All rules have been exhausted, but further statements remain" + rule = self.rules[0] + rule.process_execute(clauseelement, *multiparams, **params) + if rule.is_consumed(): + self.rules.pop(0) + + return result + + def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): + result = execute(cursor, statement, parameters, context) + + if self.rules: + rule = self.rules[0] + rule.process_cursor_execute(statement, parameters, context, executemany) + + return result + +asserter = SQLAssert() + diff --git a/sqlalchemy/test/config.py b/sqlalchemy/test/config.py new file mode 100644 index 0000000..efbe00f --- /dev/null +++ b/sqlalchemy/test/config.py @@ -0,0 +1,180 @@ +import optparse, os, sys, re, ConfigParser, time, warnings + + +# 2to3 +import StringIO + +logging = None + +__all__ = 'parser', 'configure', 'options', + +db = None +db_label, db_url, db_opts = None, None, {} + +options = None +file_config = None + +base_config = """ +[db] +sqlite=sqlite:///:memory: +sqlite_file=sqlite:///querytest.db +postgresql=postgresql://scott:tiger@127.0.0.1:5432/test +postgres=postgresql://scott:tiger@127.0.0.1:5432/test +pg8000=postgresql+pg8000://scott:tiger@127.0.0.1:5432/test +postgresql_jython=postgresql+zxjdbc://scott:tiger@127.0.0.1:5432/test +mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test +mysql=mysql://scott:tiger@127.0.0.1:3306/test +oracle=oracle://scott:tiger@127.0.0.1:1521 +oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 +mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test +firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb +maxdb=maxdb://MONA:RED@/maxdb1 +""" + +def _log(option, opt_str, value, parser): + global logging + if not logging: + import logging + logging.basicConfig() + + if opt_str.endswith('-info'): + logging.getLogger(value).setLevel(logging.INFO) + elif opt_str.endswith('-debug'): + logging.getLogger(value).setLevel(logging.DEBUG) + + +def _list_dbs(*args): + print "Available --db options (use --dburi to override)" + for macro in sorted(file_config.options('db')): + print "%20s\t%s" % (macro, file_config.get('db', macro)) + sys.exit(0) + +def _server_side_cursors(options, opt_str, value, parser): + db_opts['server_side_cursors'] = True + +def _engine_strategy(options, opt_str, value, parser): + if value: + db_opts['strategy'] = value + +class _ordered_map(object): + def __init__(self): + self._keys = list() + self._data = dict() + + def __setitem__(self, key, value): + if key not in self._keys: + self._keys.append(key) + self._data[key] = value + + def __iter__(self): + for key in self._keys: + yield self._data[key] + +# at one point in refactoring, modules were injecting into the config +# process. this could probably just become a list now. +post_configure = _ordered_map() + +def _engine_uri(options, file_config): + global db_label, db_url + db_label = 'sqlite' + if options.dburi: + db_url = options.dburi + db_label = db_url[:db_url.index(':')] + elif options.db: + db_label = options.db + db_url = None + + if db_url is None: + if db_label not in file_config.options('db'): + raise RuntimeError( + "Unknown engine. Specify --dbs for known engines.") + db_url = file_config.get('db', db_label) +post_configure['engine_uri'] = _engine_uri + +def _require(options, file_config): + if not(options.require or + (file_config.has_section('require') and + file_config.items('require'))): + return + + try: + import pkg_resources + except ImportError: + raise RuntimeError("setuptools is required for version requirements") + + cmdline = [] + for requirement in options.require: + pkg_resources.require(requirement) + cmdline.append(re.split('\s*(=)', requirement, 1)[0]) + + if file_config.has_section('require'): + for label, requirement in file_config.items('require'): + if not label == db_label or label.startswith('%s.' % db_label): + continue + seen = [c for c in cmdline if requirement.startswith(c)] + if seen: + continue + pkg_resources.require(requirement) +post_configure['require'] = _require + +def _engine_pool(options, file_config): + if options.mockpool: + from sqlalchemy import pool + db_opts['poolclass'] = pool.AssertionPool +post_configure['engine_pool'] = _engine_pool + +def _create_testing_engine(options, file_config): + from sqlalchemy.test import engines, testing + global db + db = engines.testing_engine(db_url, db_opts) + testing.db = db +post_configure['create_engine'] = _create_testing_engine + +def _prep_testing_database(options, file_config): + from sqlalchemy.test import engines + from sqlalchemy import schema + + # also create alt schemas etc. here? + if options.dropfirst: + e = engines.utf8_engine() + existing = e.table_names() + if existing: + print "Dropping existing tables in database: " + db_url + try: + print "Tables: %s" % ', '.join(existing) + except: + pass + print "Abort within 5 seconds..." + time.sleep(5) + md = schema.MetaData(e, reflect=True) + md.drop_all() + e.dispose() + +post_configure['prep_db'] = _prep_testing_database + +def _set_table_options(options, file_config): + from sqlalchemy.test import schema + + table_options = schema.table_options + for spec in options.tableopts: + key, value = spec.split('=') + table_options[key] = value + + if options.mysql_engine: + table_options['mysql_engine'] = options.mysql_engine +post_configure['table_options'] = _set_table_options + +def _reverse_topological(options, file_config): + if options.reversetop: + from sqlalchemy.orm import unitofwork + from sqlalchemy import topological + class RevQueueDepSort(topological.QueueDependencySorter): + def __init__(self, tuples, allitems): + self.tuples = list(tuples) + self.allitems = list(allitems) + self.tuples.reverse() + self.allitems.reverse() + topological.QueueDependencySorter = RevQueueDepSort + unitofwork.DependencySorter = RevQueueDepSort +post_configure['topological'] = _reverse_topological + diff --git a/sqlalchemy/test/engines.py b/sqlalchemy/test/engines.py new file mode 100644 index 0000000..0cfd58d --- /dev/null +++ b/sqlalchemy/test/engines.py @@ -0,0 +1,300 @@ +import sys, types, weakref +from collections import deque +import config +from sqlalchemy.util import function_named, callable +import re +import warnings + +class ConnectionKiller(object): + def __init__(self): + self.proxy_refs = weakref.WeakKeyDictionary() + + def checkout(self, dbapi_con, con_record, con_proxy): + self.proxy_refs[con_proxy] = True + + def _apply_all(self, methods): + # must copy keys atomically + for rec in self.proxy_refs.keys(): + if rec is not None and rec.is_valid: + try: + for name in methods: + if callable(name): + name(rec) + else: + getattr(rec, name)() + except (SystemExit, KeyboardInterrupt): + raise + except Exception, e: + warnings.warn("testing_reaper couldn't close connection: %s" % e) + + def rollback_all(self): + self._apply_all(('rollback',)) + + def close_all(self): + self._apply_all(('rollback', 'close')) + + def assert_all_closed(self): + for rec in self.proxy_refs: + if rec.is_valid: + assert False + +testing_reaper = ConnectionKiller() + +def drop_all_tables(metadata): + testing_reaper.close_all() + metadata.drop_all() + +def assert_conns_closed(fn): + def decorated(*args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.assert_all_closed() + return function_named(decorated, fn.__name__) + +def rollback_open_connections(fn): + """Decorator that rolls back all open connections after fn execution.""" + + def decorated(*args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.rollback_all() + return function_named(decorated, fn.__name__) + +def close_first(fn): + """Decorator that closes all connections before fn execution.""" + def decorated(*args, **kw): + testing_reaper.close_all() + fn(*args, **kw) + return function_named(decorated, fn.__name__) + + +def close_open_connections(fn): + """Decorator that closes all connections after fn execution.""" + + def decorated(*args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.close_all() + return function_named(decorated, fn.__name__) + +def all_dialects(exclude=None): + import sqlalchemy.databases as d + for name in d.__all__: + # TEMPORARY + if exclude and name in exclude: + continue + mod = getattr(d, name, None) + if not mod: + mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) + yield mod.dialect() + +class ReconnectFixture(object): + def __init__(self, dbapi): + self.dbapi = dbapi + self.connections = [] + + def __getattr__(self, key): + return getattr(self.dbapi, key) + + def connect(self, *args, **kwargs): + conn = self.dbapi.connect(*args, **kwargs) + self.connections.append(conn) + return conn + + def shutdown(self): + for c in list(self.connections): + c.close() + self.connections = [] + +def reconnecting_engine(url=None, options=None): + url = url or config.db_url + dbapi = config.db.dialect.dbapi + if not options: + options = {} + options['module'] = ReconnectFixture(dbapi) + engine = testing_engine(url, options) + engine.test_shutdown = engine.dialect.dbapi.shutdown + return engine + +def testing_engine(url=None, options=None): + """Produce an engine configured by --options with optional overrides.""" + + from sqlalchemy import create_engine + from sqlalchemy.test.assertsql import asserter + + url = url or config.db_url + options = options or config.db_opts + + options.setdefault('proxy', asserter) + + listeners = options.setdefault('listeners', []) + listeners.append(testing_reaper) + + engine = create_engine(url, **options) + + # may want to call this, results + # in first-connect initializers + #engine.connect() + + return engine + +def utf8_engine(url=None, options=None): + """Hook for dialects or drivers that don't handle utf8 by default.""" + + from sqlalchemy.engine import url as engine_url + + if config.db.driver == 'mysqldb': + dbapi_ver = config.db.dialect.dbapi.version_info + if (dbapi_ver < (1, 2, 1) or + dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2), + (1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))): + raise RuntimeError('Character set support unavailable with this ' + 'driver version: %s' % repr(dbapi_ver)) + else: + url = url or config.db_url + url = engine_url.make_url(url) + url.query['charset'] = 'utf8' + url.query['use_unicode'] = '0' + url = str(url) + + return testing_engine(url, options) + +def mock_engine(dialect_name=None): + """Provides a mocking engine based on the current testing.db. + + This is normally used to test DDL generation flow as emitted + by an Engine. + + It should not be used in other cases, as assert_compile() and + assert_sql_execution() are much better choices with fewer + moving parts. + + """ + + from sqlalchemy import create_engine + + if not dialect_name: + dialect_name = config.db.name + + buffer = [] + def executor(sql, *a, **kw): + buffer.append(sql) + def assert_sql(stmts): + recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] + assert recv == stmts, recv + + engine = create_engine(dialect_name + '://', + strategy='mock', executor=executor) + assert not hasattr(engine, 'mock') + engine.mock = buffer + engine.assert_sql = assert_sql + return engine + +class ReplayableSession(object): + """A simple record/playback tool. + + This is *not* a mock testing class. It only records a session for later + playback and makes no assertions on call consistency whatsoever. It's + unlikely to be suitable for anything other than DB-API recording. + + """ + + Callable = object() + NoAttribute = object() + Natives = set([getattr(types, t) + for t in dir(types) if not t.startswith('_')]). \ + difference([getattr(types, t) + # Py3K + #for t in ('FunctionType', 'BuiltinFunctionType', + # 'MethodType', 'BuiltinMethodType', + # 'LambdaType', )]) + + # Py2K + for t in ('FunctionType', 'BuiltinFunctionType', + 'MethodType', 'BuiltinMethodType', + 'LambdaType', 'UnboundMethodType',)]) + # end Py2K + def __init__(self): + self.buffer = deque() + + def recorder(self, base): + return self.Recorder(self.buffer, base) + + def player(self): + return self.Player(self.buffer) + + class Recorder(object): + def __init__(self, buffer, subject): + self._buffer = buffer + self._subject = subject + + def __call__(self, *args, **kw): + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + + result = subject(*args, **kw) + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + @property + def _sqla_unwrap(self): + return self._subject + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + try: + result = type(subject).__getattribute__(subject, key) + except AttributeError: + buffer.append(ReplayableSession.NoAttribute) + raise + else: + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + class Player(object): + def __init__(self, buffer): + self._buffer = buffer + + def __call__(self, *args, **kw): + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + else: + return result + + @property + def _sqla_unwrap(self): + return None + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + elif result is ReplayableSession.NoAttribute: + raise AttributeError(key) + else: + return result + diff --git a/sqlalchemy/test/entities.py b/sqlalchemy/test/entities.py new file mode 100644 index 0000000..0ec677e --- /dev/null +++ b/sqlalchemy/test/entities.py @@ -0,0 +1,83 @@ +import sqlalchemy as sa +from sqlalchemy import exc as sa_exc + +_repr_stack = set() +class BasicEntity(object): + def __init__(self, **kw): + for key, value in kw.iteritems(): + setattr(self, key, value) + + def __repr__(self): + if id(self) in _repr_stack: + return object.__repr__(self) + _repr_stack.add(id(self)) + try: + return "%s(%s)" % ( + (self.__class__.__name__), + ', '.join(["%s=%r" % (key, getattr(self, key)) + for key in sorted(self.__dict__.keys()) + if not key.startswith('_')])) + finally: + _repr_stack.remove(id(self)) + +_recursion_stack = set() +class ComparableEntity(BasicEntity): + def __hash__(self): + return hash(self.__class__) + + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + """'Deep, sparse compare. + + Deeply compare two entities, following the non-None attributes of the + non-persisted object, if possible. + + """ + if other is self: + return True + elif not self.__class__ == other.__class__: + return False + + if id(self) in _recursion_stack: + return True + _recursion_stack.add(id(self)) + + try: + # pick the entity thats not SA persisted as the source + try: + self_key = sa.orm.attributes.instance_state(self).key + except sa.orm.exc.NO_STATE: + self_key = None + + if other is None: + a = self + b = other + elif self_key is not None: + a = other + b = self + else: + a = self + b = other + + for attr in a.__dict__.keys(): + if attr.startswith('_'): + continue + value = getattr(a, attr) + + try: + # handle lazy loader errors + battr = getattr(b, attr) + except (AttributeError, sa_exc.UnboundExecutionError): + return False + + if hasattr(value, '__iter__'): + if list(value) != list(battr): + return False + else: + if value is not None and value != battr: + return False + return True + finally: + _recursion_stack.remove(id(self)) diff --git a/sqlalchemy/test/noseplugin.py b/sqlalchemy/test/noseplugin.py new file mode 100644 index 0000000..5e8e21e --- /dev/null +++ b/sqlalchemy/test/noseplugin.py @@ -0,0 +1,162 @@ +import logging +import os +import re +import sys +import time +import warnings +import ConfigParser +import StringIO + +import nose.case +from nose.plugins import Plugin + +from sqlalchemy import util, log as sqla_log +from sqlalchemy.test import testing, config, requires +from sqlalchemy.test.config import ( + _create_testing_engine, _engine_pool, _engine_strategy, _engine_uri, _list_dbs, _log, + _prep_testing_database, _require, _reverse_topological, _server_side_cursors, + _set_table_options, base_config, db, db_label, db_url, file_config, post_configure) + +log = logging.getLogger('nose.plugins.sqlalchemy') + +class NoseSQLAlchemy(Plugin): + """ + Handles the setup and extra properties required for testing SQLAlchemy + """ + enabled = True + name = 'sqlalchemy' + score = 100 + + def options(self, parser, env=os.environ): + Plugin.options(self, parser, env) + opt = parser.add_option + opt("--log-info", action="callback", type="string", callback=_log, + help="turn on info logging for (multiple OK)") + opt("--log-debug", action="callback", type="string", callback=_log, + help="turn on debug logging for (multiple OK)") + opt("--require", action="append", dest="require", default=[], + help="require a particular driver or module version (multiple OK)") + opt("--db", action="store", dest="db", default="sqlite", + help="Use prefab database uri") + opt('--dbs', action='callback', callback=_list_dbs, + help="List available prefab dbs") + opt("--dburi", action="store", dest="dburi", + help="Database uri (overrides --db)") + opt("--dropfirst", action="store_true", dest="dropfirst", + help="Drop all tables in the target database first (use with caution on Oracle, " + "MS-SQL)") + opt("--mockpool", action="store_true", dest="mockpool", + help="Use mock pool (asserts only one connection used)") + opt("--enginestrategy", action="callback", type="string", + callback=_engine_strategy, + help="Engine strategy (plain or threadlocal, defaults to plain)") + opt("--reversetop", action="store_true", dest="reversetop", default=False, + help="Reverse the collection ordering for topological sorts (helps " + "reveal dependency issues)") + opt("--unhashable", action="store_true", dest="unhashable", default=False, + help="Disallow SQLAlchemy from performing a hash() on mapped test objects.") + opt("--noncomparable", action="store_true", dest="noncomparable", default=False, + help="Disallow SQLAlchemy from performing == on mapped test objects.") + opt("--truthless", action="store_true", dest="truthless", default=False, + help="Disallow SQLAlchemy from truth-evaluating mapped test objects.") + opt("--serverside", action="callback", callback=_server_side_cursors, + help="Turn on server side cursors for PG") + opt("--mysql-engine", action="store", dest="mysql_engine", default=None, + help="Use the specified MySQL storage engine for all tables, default is " + "a db-default/InnoDB combo.") + opt("--table-option", action="append", dest="tableopts", default=[], + help="Add a dialect-specific table option, key=value") + + global file_config + file_config = ConfigParser.ConfigParser() + file_config.readfp(StringIO.StringIO(base_config)) + file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')]) + config.file_config = file_config + + def configure(self, options, conf): + Plugin.configure(self, options, conf) + self.options = options + + def begin(self): + testing.db = db + testing.requires = requires + + # Lazy setup of other options (post coverage) + for fn in post_configure: + fn(self.options, file_config) + + def describeTest(self, test): + return "" + + def wantClass(self, cls): + """Return true if you want the main test selector to collect + tests from this class, false if you don't, and None if you don't + care. + + :Parameters: + cls : class + The class being examined by the selector + + """ + + if not issubclass(cls, testing.TestBase): + return False + else: + if (hasattr(cls, '__whitelist__') and testing.db.name in cls.__whitelist__): + return True + else: + return not self.__should_skip_for(cls) + + def __should_skip_for(self, cls): + if hasattr(cls, '__requires__'): + def test_suite(): return 'ok' + test_suite.__name__ = cls.__name__ + for requirement in cls.__requires__: + check = getattr(requires, requirement) + if check(test_suite)() != 'ok': + # The requirement will perform messaging. + return True + + if cls.__unsupported_on__: + spec = testing.db_spec(*cls.__unsupported_on__) + if spec(testing.db): + print "'%s' unsupported on DB implementation '%s'" % ( + cls.__class__.__name__, testing.db.name) + return True + + if getattr(cls, '__only_on__', None): + spec = testing.db_spec(*util.to_list(cls.__only_on__)) + if not spec(testing.db): + print "'%s' unsupported on DB implementation '%s'" % ( + cls.__class__.__name__, testing.db.name) + return True + + if getattr(cls, '__skip_if__', False): + for c in getattr(cls, '__skip_if__'): + if c(): + print "'%s' skipped by %s" % ( + cls.__class__.__name__, c.__name__) + return True + + for rule in getattr(cls, '__excluded_on__', ()): + if testing._is_excluded(*rule): + print "'%s' unsupported on DB %s version %s" % ( + cls.__class__.__name__, testing.db.name, + _server_version()) + return True + return False + + def beforeTest(self, test): + testing.resetwarnings() + + def afterTest(self, test): + testing.resetwarnings() + + def afterContext(self): + testing.global_cleanup_assertions() + + #def handleError(self, test, err): + #pass + + #def finalize(self, result=None): + #pass diff --git a/sqlalchemy/test/orm.py b/sqlalchemy/test/orm.py new file mode 100644 index 0000000..7ec13c5 --- /dev/null +++ b/sqlalchemy/test/orm.py @@ -0,0 +1,111 @@ +import inspect, re +import config, testing +from sqlalchemy import orm + +__all__ = 'mapper', + + +_whitespace = re.compile(r'^(\s+)') + +def _find_pragma(lines, current): + m = _whitespace.match(lines[current]) + basis = m and m.group() or '' + + for line in reversed(lines[0:current]): + if 'testlib.pragma' in line: + return line + m = _whitespace.match(line) + indent = m and m.group() or '' + + # simplistic detection: + + # >> # testlib.pragma foo + # >> center_line() + if indent == basis: + break + # >> # testlib.pragma foo + # >> if fleem: + # >> center_line() + if line.endswith(':'): + break + return None + +def _make_blocker(method_name, fallback): + """Creates tripwired variant of a method, raising when called. + + To excempt an invocation from blockage, there are two options. + + 1) add a pragma in a comment:: + + # testlib.pragma exempt:methodname + offending_line() + + 2) add a magic cookie to the function's namespace:: + __sa_baremethodname_exempt__ = True + ... + offending_line() + another_offending_lines() + + The second is useful for testing and development. + """ + + if method_name.startswith('__') and method_name.endswith('__'): + frame_marker = '__sa_%s_exempt__' % method_name[2:-2] + else: + frame_marker = '__sa_%s_exempt__' % method_name + pragma_marker = 'exempt:' + method_name + + def method(self, *args, **kw): + frame_r = None + try: + frame = inspect.stack()[1][0] + frame_r = inspect.getframeinfo(frame, 9) + + module = frame.f_globals.get('__name__', '') + + type_ = type(self) + + pragma = _find_pragma(*frame_r[3:5]) + + exempt = ( + (not module.startswith('sqlalchemy')) or + (pragma and pragma_marker in pragma) or + (frame_marker in frame.f_locals) or + ('self' in frame.f_locals and + getattr(frame.f_locals['self'], frame_marker, False))) + + if exempt: + supermeth = getattr(super(type_, self), method_name, None) + if (supermeth is None or + getattr(supermeth, 'im_func', None) is method): + return fallback(self, *args, **kw) + else: + return supermeth(*args, **kw) + else: + raise AssertionError( + "%s.%s called in %s, line %s in %s" % ( + type_.__name__, method_name, module, frame_r[1], frame_r[2])) + finally: + del frame + method.__name__ = method_name + return method + +def mapper(type_, *args, **kw): + forbidden = [ + ('__hash__', 'unhashable', lambda s: id(s)), + ('__eq__', 'noncomparable', lambda s, o: s is o), + ('__ne__', 'noncomparable', lambda s, o: s is not o), + ('__cmp__', 'noncomparable', lambda s, o: object.__cmp__(s, o)), + ('__le__', 'noncomparable', lambda s, o: object.__le__(s, o)), + ('__lt__', 'noncomparable', lambda s, o: object.__lt__(s, o)), + ('__ge__', 'noncomparable', lambda s, o: object.__ge__(s, o)), + ('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)), + ('__nonzero__', 'truthless', lambda s: 1), ] + + if isinstance(type_, type) and type_.__bases__ == (object,): + for method_name, option, fallback in forbidden: + if (getattr(config.options, option, False) and + method_name not in type_.__dict__): + setattr(type_, method_name, _make_blocker(method_name, fallback)) + + return orm.mapper(type_, *args, **kw) diff --git a/sqlalchemy/test/pickleable.py b/sqlalchemy/test/pickleable.py new file mode 100644 index 0000000..9794e42 --- /dev/null +++ b/sqlalchemy/test/pickleable.py @@ -0,0 +1,75 @@ +""" + +some objects used for pickle tests, declared in their own module so that they +are easily pickleable. + +""" + + +class Foo(object): + def __init__(self, moredata): + self.data = 'im data' + self.stuff = 'im stuff' + self.moredata = moredata + __hash__ = object.__hash__ + def __eq__(self, other): + return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata + + +class Bar(object): + def __init__(self, x, y): + self.x = x + self.y = y + __hash__ = object.__hash__ + def __eq__(self, other): + return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + +class OldSchool: + def __init__(self, x, y): + self.x = x + self.y = y + def __eq__(self, other): + return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y + +class OldSchoolWithoutCompare: + def __init__(self, x, y): + self.x = x + self.y = y + +class BarWithoutCompare(object): + def __init__(self, x, y): + self.x = x + self.y = y + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + + +class NotComparable(object): + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return NotImplemented + + def __ne__(self, other): + return NotImplemented + + +class BrokenComparable(object): + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + raise NotImplementedError + diff --git a/sqlalchemy/test/profiling.py b/sqlalchemy/test/profiling.py new file mode 100644 index 0000000..c5256af --- /dev/null +++ b/sqlalchemy/test/profiling.py @@ -0,0 +1,222 @@ +"""Profiling support for unit and performance tests. + +These are special purpose profiling methods which operate +in a more fine-grained way than nose's profiling plugin. + +""" + +import os, sys +from sqlalchemy.test import config +from sqlalchemy.test.util import function_named, gc_collect +from nose import SkipTest + +__all__ = 'profiled', 'function_call_count', 'conditional_call_count' + +all_targets = set() +profile_config = { 'targets': set(), + 'report': True, + 'sort': ('time', 'calls'), + 'limit': None } +profiler = None + +def profiled(target=None, **target_opts): + """Optional function profiling. + + @profiled('label') + or + @profiled('label', report=True, sort=('calls',), limit=20) + + Enables profiling for a function when 'label' is targetted for + profiling. Report options can be supplied, and override the global + configuration and command-line options. + """ + + # manual or automatic namespacing by module would remove conflict issues + if target is None: + target = 'anonymous_target' + elif target in all_targets: + print "Warning: redefining profile target '%s'" % target + all_targets.add(target) + + filename = "%s.prof" % target + + def decorator(fn): + def profiled(*args, **kw): + if (target not in profile_config['targets'] and + not target_opts.get('always', None)): + return fn(*args, **kw) + + elapsed, load_stats, result = _profile( + filename, fn, *args, **kw) + + report = target_opts.get('report', profile_config['report']) + if report: + sort_ = target_opts.get('sort', profile_config['sort']) + limit = target_opts.get('limit', profile_config['limit']) + print "Profile report for target '%s' (%s)" % ( + target, filename) + + stats = load_stats() + stats.sort_stats(*sort_) + if limit: + stats.print_stats(limit) + else: + stats.print_stats() + #stats.print_callers() + os.unlink(filename) + return result + return function_named(profiled, fn.__name__) + return decorator + +def function_call_count(count=None, versions={}, variance=0.05): + """Assert a target for a test case's function call count. + + count + Optional, general target function call count. + + versions + Optional, a dictionary of Python version strings to counts, + for example:: + + { '2.5.1': 110, + '2.5': 100, + '2.4': 150 } + + The best match for the current running python will be used. + If none match, 'count' will be used as the fallback. + + variance + An +/- deviation percentage, defaults to 5%. + """ + + # this could easily dump the profile report if --verbose is in effect + + version_info = list(sys.version_info) + py_version = '.'.join([str(v) for v in sys.version_info]) + try: + from sqlalchemy.cprocessors import to_float + cextension = True + except ImportError: + cextension = False + + while version_info: + version = '.'.join([str(v) for v in version_info]) + if cextension: + version += "+cextension" + if version in versions: + count = versions[version] + break + version_info.pop() + + if count is None: + return lambda fn: fn + + def decorator(fn): + def counted(*args, **kw): + try: + filename = "%s.prof" % fn.__name__ + + elapsed, stat_loader, result = _profile( + filename, fn, *args, **kw) + + stats = stat_loader() + calls = stats.total_calls + + stats.sort_stats('calls', 'cumulative') + stats.print_stats() + #stats.print_callers() + deviance = int(count * variance) + if (calls < (count - deviance) or + calls > (count + deviance)): + raise AssertionError( + "Function call count %s not within %s%% " + "of expected %s. (Python version %s)" % ( + calls, (variance * 100), count, py_version)) + + return result + finally: + if os.path.exists(filename): + os.unlink(filename) + return function_named(counted, fn.__name__) + return decorator + +def conditional_call_count(discriminator, categories): + """Apply a function call count conditionally at runtime. + + Takes two arguments, a callable that returns a key value, and a dict + mapping key values to a tuple of arguments to function_call_count. + + The callable is not evaluated until the decorated function is actually + invoked. If the `discriminator` returns a key not present in the + `categories` dictionary, no call count assertion is applied. + + Useful for integration tests, where running a named test in isolation may + have a function count penalty not seen in the full suite, due to lazy + initialization in the DB-API, SA, etc. + """ + + def decorator(fn): + def at_runtime(*args, **kw): + criteria = categories.get(discriminator(), None) + if criteria is None: + return fn(*args, **kw) + + rewrapped = function_call_count(*criteria)(fn) + return rewrapped(*args, **kw) + return function_named(at_runtime, fn.__name__) + return decorator + + +def _profile(filename, fn, *args, **kw): + global profiler + if not profiler: + if sys.version_info > (2, 5): + try: + import cProfile + profiler = 'cProfile' + except ImportError: + pass + if not profiler: + try: + import hotshot + profiler = 'hotshot' + except ImportError: + profiler = 'skip' + + if profiler == 'skip': + raise SkipTest('Profiling not supported on this platform') + elif profiler == 'cProfile': + return _profile_cProfile(filename, fn, *args, **kw) + else: + return _profile_hotshot(filename, fn, *args, **kw) + +def _profile_cProfile(filename, fn, *args, **kw): + import cProfile, gc, pstats, time + + load_stats = lambda: pstats.Stats(filename) + gc_collect() + + began = time.time() + cProfile.runctx('result = fn(*args, **kw)', globals(), locals(), + filename=filename) + ended = time.time() + + return ended - began, load_stats, locals()['result'] + +def _profile_hotshot(filename, fn, *args, **kw): + import gc, hotshot, hotshot.stats, time + load_stats = lambda: hotshot.stats.load(filename) + + gc_collect() + prof = hotshot.Profile(filename) + began = time.time() + prof.start() + try: + result = fn(*args, **kw) + finally: + prof.stop() + ended = time.time() + prof.close() + + return ended - began, load_stats, result + diff --git a/sqlalchemy/test/requires.py b/sqlalchemy/test/requires.py new file mode 100644 index 0000000..73b2120 --- /dev/null +++ b/sqlalchemy/test/requires.py @@ -0,0 +1,259 @@ +"""Global database feature support policy. + +Provides decorators to mark tests requiring specific feature support from the +target database. + +""" + +from testing import \ + _block_unconditionally as no_support, \ + _chain_decorators_on, \ + exclude, \ + emits_warning_on,\ + skip_if,\ + fails_on + +import testing +import sys + +def deferrable_constraints(fn): + """Target database must support derferable constraints.""" + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('mysql', 'not supported by database'), + no_support('mssql', 'not supported by database'), + ) + +def foreign_keys(fn): + """Target database must support foreign keys.""" + return _chain_decorators_on( + fn, + no_support('sqlite', 'not supported by database'), + ) + + +def unbounded_varchar(fn): + """Target database must support VARCHAR with no length""" + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('oracle', 'not supported by database'), + no_support('mysql', 'not supported by database'), + ) + +def boolean_col_expressions(fn): + """Target database must support boolean expressions as columns""" + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('oracle', 'not supported by database'), + no_support('mssql', 'not supported by database'), + no_support('sybase', 'not supported by database'), + no_support('maxdb', 'FIXME: verify not supported by database'), + ) + +def identity(fn): + """Target database must support GENERATED AS IDENTITY or a facsimile. + + Includes GENERATED AS IDENTITY, AUTOINCREMENT, AUTO_INCREMENT, or other + column DDL feature that fills in a DB-generated identifier at INSERT-time + without requiring pre-execution of a SEQUENCE or other artifact. + + """ + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('oracle', 'not supported by database'), + no_support('postgresql', 'not supported by database'), + no_support('sybase', 'not supported by database'), + ) + +def independent_cursors(fn): + """Target must support simultaneous, independent database cursors on a single connection.""" + + return _chain_decorators_on( + fn, + no_support('mssql+pyodbc', 'no driver support'), + no_support('mssql+mxodbc', 'no driver support'), + ) + +def independent_connections(fn): + """Target must support simultaneous, independent database connections.""" + + # This is also true of some configurations of UnixODBC and probably win32 + # ODBC as well. + return _chain_decorators_on( + fn, + no_support('sqlite', 'no driver support'), + exclude('mssql', '<', (9, 0, 0), + 'SQL Server 2005+ is required for independent connections'), + ) + +def row_triggers(fn): + """Target must support standard statement-running EACH ROW triggers.""" + return _chain_decorators_on( + fn, + # no access to same table + no_support('mysql', 'requires SUPER priv'), + exclude('mysql', '<', (5, 0, 10), 'not supported by database'), + + # huh? TODO: implement triggers for PG tests, remove this + no_support('postgresql', 'PG triggers need to be implemented for tests'), + ) + +def correlated_outer_joins(fn): + """Target must support an outer join to a subquery which correlates to the parent.""" + + return _chain_decorators_on( + fn, + no_support('oracle', 'Raises "ORA-01799: a column may not be outer-joined to a subquery"') + ) + +def savepoints(fn): + """Target database must support savepoints.""" + return _chain_decorators_on( + fn, + emits_warning_on('mssql', 'Savepoint support in mssql is experimental and may lead to data loss.'), + no_support('access', 'not supported by database'), + no_support('sqlite', 'not supported by database'), + no_support('sybase', 'FIXME: guessing, needs confirmation'), + exclude('mysql', '<', (5, 0, 3), 'not supported by database'), + ) + +def denormalized_names(fn): + """Target database must have 'denormalized', i.e. UPPERCASE as case insensitive names.""" + + return skip_if( + lambda: not testing.db.dialect.requires_name_normalize, + "Backend does not require denomralized names." + )(fn) + +def schemas(fn): + """Target database must support external schemas, and have one named 'test_schema'.""" + + return _chain_decorators_on( + fn, + no_support('sqlite', 'no schema support'), + no_support('firebird', 'no schema support') + ) + +def sequences(fn): + """Target database must support SEQUENCEs.""" + return _chain_decorators_on( + fn, + no_support('access', 'no SEQUENCE support'), + no_support('mssql', 'no SEQUENCE support'), + no_support('mysql', 'no SEQUENCE support'), + no_support('sqlite', 'no SEQUENCE support'), + no_support('sybase', 'no SEQUENCE support'), + ) + +def subqueries(fn): + """Target database must support subqueries.""" + return _chain_decorators_on( + fn, + exclude('mysql', '<', (4, 1, 1), 'no subquery support'), + ) + +def intersect(fn): + """Target database must support INTERSECT or equivlaent.""" + return _chain_decorators_on( + fn, + fails_on('firebird', 'no support for INTERSECT'), + fails_on('mysql', 'no support for INTERSECT'), + fails_on('sybase', 'no support for INTERSECT'), + ) + +def except_(fn): + """Target database must support EXCEPT or equivlaent (i.e. MINUS).""" + return _chain_decorators_on( + fn, + fails_on('firebird', 'no support for EXCEPT'), + fails_on('mysql', 'no support for EXCEPT'), + fails_on('sybase', 'no support for EXCEPT'), + ) + +def offset(fn): + """Target database must support some method of adding OFFSET or equivalent to a result set.""" + return _chain_decorators_on( + fn, + fails_on('sybase', 'no support for OFFSET or equivalent'), + ) + +def returning(fn): + return _chain_decorators_on( + fn, + no_support('access', 'not supported by database'), + no_support('sqlite', 'not supported by database'), + no_support('mysql', 'not supported by database'), + no_support('maxdb', 'not supported by database'), + no_support('sybase', 'not supported by database'), + no_support('informix', 'not supported by database'), + ) + +def two_phase_transactions(fn): + """Target database must support two-phase transactions.""" + return _chain_decorators_on( + fn, + no_support('access', 'not supported by database'), + no_support('firebird', 'no SA implementation'), + no_support('maxdb', 'not supported by database'), + no_support('mssql', 'FIXME: guessing, needs confirmation'), + no_support('oracle', 'no SA implementation'), + no_support('sqlite', 'not supported by database'), + no_support('sybase', 'FIXME: guessing, needs confirmation'), + no_support('postgresql+zxjdbc', 'FIXME: JDBC driver confuses the transaction state, may ' + 'need separate XA implementation'), + exclude('mysql', '<', (5, 0, 3), 'not supported by database'), + ) + +def unicode_connections(fn): + """Target driver must support some encoding of Unicode across the wire.""" + # TODO: expand to exclude MySQLdb versions w/ broken unicode + return _chain_decorators_on( + fn, + exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'), + ) + +def unicode_ddl(fn): + """Target driver must support some encoding of Unicode across the wire.""" + # TODO: expand to exclude MySQLdb versions w/ broken unicode + return _chain_decorators_on( + fn, + no_support('maxdb', 'database support flakey'), + no_support('oracle', 'FIXME: no support in database?'), + no_support('sybase', 'FIXME: guessing, needs confirmation'), + no_support('mssql+pymssql', 'no FreeTDS support'), + exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'), + ) + +def sane_rowcount(fn): + return _chain_decorators_on( + fn, + skip_if(lambda: not testing.db.dialect.supports_sane_rowcount) + ) + +def python2(fn): + return _chain_decorators_on( + fn, + skip_if( + lambda: sys.version_info >= (3,), + "Python version 2.xx is required." + ) + ) + +def _has_sqlite(): + from sqlalchemy import create_engine + try: + e = create_engine('sqlite://') + return True + except ImportError: + return False + +def sqlite(fn): + return _chain_decorators_on( + fn, + skip_if(lambda: not _has_sqlite()) + ) + diff --git a/sqlalchemy/test/schema.py b/sqlalchemy/test/schema.py new file mode 100644 index 0000000..d33d75e --- /dev/null +++ b/sqlalchemy/test/schema.py @@ -0,0 +1,79 @@ +"""Enhanced versions of schema.Table and schema.Column which establish +desired state for different backends. +""" + +from sqlalchemy.test import testing +from sqlalchemy import schema + +__all__ = 'Table', 'Column', + +table_options = {} + +def Table(*args, **kw): + """A schema.Table wrapper/hook for dialect-specific tweaks.""" + + test_opts = dict([(k,kw.pop(k)) for k in kw.keys() + if k.startswith('test_')]) + + kw.update(table_options) + + if testing.against('mysql'): + if 'mysql_engine' not in kw and 'mysql_type' not in kw: + if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: + kw['mysql_engine'] = 'InnoDB' + + # Apply some default cascading rules for self-referential foreign keys. + # MySQL InnoDB has some issues around seleting self-refs too. + if testing.against('firebird'): + table_name = args[0] + unpack = (testing.config.db.dialect. + identifier_preparer.unformat_identifiers) + + # Only going after ForeignKeys in Columns. May need to + # expand to ForeignKeyConstraint too. + fks = [fk + for col in args if isinstance(col, schema.Column) + for fk in col.foreign_keys] + + for fk in fks: + # root around in raw spec + ref = fk._colspec + if isinstance(ref, schema.Column): + name = ref.table.name + else: + # take just the table name: on FB there cannot be + # a schema, so the first element is always the + # table name, possibly followed by the field name + name = unpack(ref)[0] + if name == table_name: + if fk.ondelete is None: + fk.ondelete = 'CASCADE' + if fk.onupdate is None: + fk.onupdate = 'CASCADE' + + return schema.Table(*args, **kw) + + +def Column(*args, **kw): + """A schema.Column wrapper/hook for dialect-specific tweaks.""" + + test_opts = dict([(k,kw.pop(k)) for k in kw.keys() + if k.startswith('test_')]) + + col = schema.Column(*args, **kw) + if 'test_needs_autoincrement' in test_opts and \ + kw.get('primary_key', False) and \ + testing.against('firebird', 'oracle'): + def add_seq(tbl, c): + c._init_items( + schema.Sequence(_truncate_name(testing.db.dialect, tbl.name + '_' + c.name + '_seq'), optional=True) + ) + col._on_table_attach(add_seq) + return col + +def _truncate_name(dialect, name): + if len(name) > dialect.max_identifier_length: + return name[0:max(dialect.max_identifier_length - 6, 0)] + "_" + hex(hash(name) % 64)[2:] + else: + return name + diff --git a/sqlalchemy/test/testing.py b/sqlalchemy/test/testing.py new file mode 100644 index 0000000..771b8c9 --- /dev/null +++ b/sqlalchemy/test/testing.py @@ -0,0 +1,779 @@ +"""TestCase and TestSuite artifacts and testing decorators.""" + +import itertools +import operator +import re +import sys +import types +import warnings +from cStringIO import StringIO + +from sqlalchemy.test import config, assertsql, util as testutil +from sqlalchemy.util import function_named, py3k +from engines import drop_all_tables + +from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, pool, orm +from sqlalchemy.engine import default +from nose import SkipTest + + +_ops = { '<': operator.lt, + '>': operator.gt, + '==': operator.eq, + '!=': operator.ne, + '<=': operator.le, + '>=': operator.ge, + 'in': operator.contains, + 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + } + +# sugar ('testing.db'); set here by config() at runtime +db = None + +# more sugar, installed by __init__ +requires = None + +def fails_if(callable_, reason=None): + """Mark a test as expected to fail if callable_ returns True. + + If the callable returns false, the test is run and reported as normal. + However if the callable returns true, the test is expected to fail and the + unit test logic is inverted: if the test fails, a success is reported. If + the test succeeds, a failure is reported. + """ + + docstring = getattr(callable_, '__doc__', None) or callable_.__name__ + description = docstring.split('\n')[0] + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if not callable_(): + return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected (condition: %s): %s " % ( + fn_name, description, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for '%s' (condition: %s)" % + (fn_name, description)) + return function_named(maybe, fn_name) + return decorate + + +def future(fn): + """Mark a test as expected to unconditionally fail. + + Takes no arguments, omit parens when using as a decorator. + """ + + fn_name = fn.__name__ + def decorated(*args, **kw): + try: + fn(*args, **kw) + except Exception, ex: + print ("Future test '%s' failed as expected: %s " % ( + fn_name, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for future test '%s'" % fn_name) + return function_named(decorated, fn_name) + +def db_spec(*dbs): + dialects = set([x for x in dbs if '+' not in x]) + drivers = set([x[1:] for x in dbs if x.startswith('+')]) + specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers]) + + def check(engine): + return engine.name in dialects or \ + engine.driver in drivers or \ + (engine.name, engine.driver) in specs + + return check + + +def fails_on(dbs, reason): + """Mark a test as expected to fail on the specified database + implementation. + + Unlike ``crashes``, tests marked as ``fails_on`` will be run + for the named databases. The test is expected to fail and the unit test + logic is inverted: if the test fails, a success is reported. If the test + succeeds, a failure is reported. + """ + + spec = db_spec(dbs) + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if not spec(config.db): + return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected on DB implementation " + "'%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason)) + return True + else: + raise AssertionError( + "Unexpected success for '%s' on DB implementation '%s+%s'" % + (fn_name, config.db.name, config.db.driver)) + return function_named(maybe, fn_name) + return decorate + +def fails_on_everything_except(*dbs): + """Mark a test as expected to fail on most database implementations. + + Like ``fails_on``, except failure is the expected outcome on all + databases except those listed. + """ + + spec = db_spec(*dbs) + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if spec(config.db): + return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected on DB implementation " + "'%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for '%s' on DB implementation '%s+%s'" % + (fn_name, config.db.name, config.db.driver)) + return function_named(maybe, fn_name) + return decorate + +def crashes(db, reason): + """Mark a test as unsupported by a database implementation. + + ``crashes`` tests will be skipped unconditionally. Use for feature tests + that cause deadlocks or other fatal problems. + + """ + carp = _should_carp_about_exclusion(reason) + spec = db_spec(db) + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if spec(config.db): + msg = "'%s' unsupported on DB implementation '%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason) + print msg + if carp: + print >> sys.stderr, msg + return True + else: + return fn(*args, **kw) + return function_named(maybe, fn_name) + return decorate + +def _block_unconditionally(db, reason): + """Mark a test as unsupported by a database implementation. + + Will never run the test against any version of the given database, ever, + no matter what. Use when your assumptions are infallible; past, present + and future. + + """ + carp = _should_carp_about_exclusion(reason) + spec = db_spec(db) + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if spec(config.db): + msg = "'%s' unsupported on DB implementation '%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason) + print msg + if carp: + print >> sys.stderr, msg + return True + else: + return fn(*args, **kw) + return function_named(maybe, fn_name) + return decorate + +def only_on(db, reason): + carp = _should_carp_about_exclusion(reason) + spec = db_spec(db) + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if spec(config.db): + return fn(*args, **kw) + else: + msg = "'%s' unsupported on DB implementation '%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason) + print msg + if carp: + print >> sys.stderr, msg + return True + return function_named(maybe, fn_name) + return decorate + +def exclude(db, op, spec, reason): + """Mark a test as unsupported by specific database server versions. + + Stackable, both with other excludes and other decorators. Examples:: + + # Not supported by mydb versions less than 1, 0 + @exclude('mydb', '<', (1,0)) + # Other operators work too + @exclude('bigdb', '==', (9,0,9)) + @exclude('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) + + """ + carp = _should_carp_about_exclusion(reason) + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if _is_excluded(db, op, spec): + msg = "'%s' unsupported on DB %s version '%s': %s" % ( + fn_name, config.db.name, _server_version(), reason) + print msg + if carp: + print >> sys.stderr, msg + return True + else: + return fn(*args, **kw) + return function_named(maybe, fn_name) + return decorate + +def _should_carp_about_exclusion(reason): + """Guard against forgotten exclusions.""" + assert reason + for _ in ('todo', 'fixme', 'xxx'): + if _ in reason.lower(): + return True + else: + if len(reason) < 4: + return True + +def _is_excluded(db, op, spec): + """Return True if the configured db matches an exclusion specification. + + db: + A dialect name + op: + An operator or stringified operator, such as '==' + spec: + A value that will be compared to the dialect's server_version_info + using the supplied operator. + + Examples:: + # Not supported by mydb versions less than 1, 0 + _is_excluded('mydb', '<', (1,0)) + # Other operators work too + _is_excluded('bigdb', '==', (9,0,9)) + _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) + """ + + vendor_spec = db_spec(db) + + if not vendor_spec(config.db): + return False + + version = _server_version() + + oper = hasattr(op, '__call__') and op or _ops[op] + return oper(version, spec) + +def _server_version(bind=None): + """Return a server_version_info tuple.""" + + if bind is None: + bind = config.db + + # force metadata to be retrieved + conn = bind.connect() + version = getattr(bind.dialect, 'server_version_info', ()) + conn.close() + return version + +def skip_if(predicate, reason=None): + """Skip a test if predicate is true.""" + reason = reason or predicate.__name__ + carp = _should_carp_about_exclusion(reason) + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if predicate(): + msg = "'%s' skipped on DB %s version '%s': %s" % ( + fn_name, config.db.name, _server_version(), reason) + print msg + if carp: + print >> sys.stderr, msg + return True + else: + return fn(*args, **kw) + return function_named(maybe, fn_name) + return decorate + +def emits_warning(*messages): + """Mark a test as emitting a warning. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + + # TODO: it would be nice to assert that a named warning was + # emitted. should work with some monkeypatching of warnings, + # and may work on non-CPython if they keep to the spirit of + # warnings.showwarning's docstring. + # - update: jython looks ok, it uses cpython's module + def decorate(fn): + def safe(*args, **kw): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SAWarning)) + else: + filters.extend(dict(action='ignore', + message=message, + category=sa_exc.SAWarning) + for message in messages) + for f in filters: + warnings.filterwarnings(**f) + try: + return fn(*args, **kw) + finally: + resetwarnings() + return function_named(safe, fn.__name__) + return decorate + +def emits_warning_on(db, *warnings): + """Mark a test as emitting a warning on a specific dialect. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + spec = db_spec(db) + + def decorate(fn): + def maybe(*args, **kw): + if isinstance(db, basestring): + if not spec(config.db): + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + else: + if not _is_excluded(*db): + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + return function_named(maybe, fn.__name__) + return decorate + +def uses_deprecated(*messages): + """Mark a test as immune from fatal deprecation warnings. + + With no arguments, squelches all SADeprecationWarning failures. + Or pass one or more strings; these will be matched to the root + of the warning description by warnings.filterwarnings(). + + As a special case, you may pass a function name prefixed with // + and it will be re-written as needed to match the standard warning + verbiage emitted by the sqlalchemy.util.deprecated decorator. + """ + + def decorate(fn): + def safe(*args, **kw): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SADeprecationWarning)) + else: + filters.extend( + [dict(action='ignore', + message=message, + category=sa_exc.SADeprecationWarning) + for message in + [ (m.startswith('//') and + ('Call to deprecated function ' + m[2:]) or m) + for m in messages] ]) + + for f in filters: + warnings.filterwarnings(**f) + try: + return fn(*args, **kw) + finally: + resetwarnings() + return function_named(safe, fn.__name__) + return decorate + +def resetwarnings(): + """Reset warning behavior to testing defaults.""" + + warnings.filterwarnings('ignore', + category=sa_exc.SAPendingDeprecationWarning) + warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) + warnings.filterwarnings('error', category=sa_exc.SAWarning) + +# warnings.simplefilter('error') + + if sys.version_info < (2, 4): + warnings.filterwarnings('ignore', category=FutureWarning) + +def global_cleanup_assertions(): + """Check things that have to be finalized at the end of a test suite. + + Hardcoded at the moment, a modular system can be built here + to support things like PG prepared transactions, tables all + dropped, etc. + + """ + + testutil.lazy_gc() + assert not pool._refs + + + +def against(*queries): + """Boolean predicate, compares to testing database configuration. + + Given one or more dialect names, returns True if one is the configured + database engine. + + Also supports comparison to database version when provided with one or + more 3-tuples of dialect name, operator, and version specification:: + + testing.against('mysql', 'postgresql') + testing.against(('mysql', '>=', (5, 0, 0)) + """ + + for query in queries: + if isinstance(query, basestring): + if db_spec(query)(config.db): + return True + else: + name, op, spec = query + if not db_spec(name)(config.db): + continue + + have = _server_version() + + oper = hasattr(op, '__call__') and op or _ops[op] + if oper(have, spec): + return True + return False + +def _chain_decorators_on(fn, *decorators): + """Apply a series of decorators to fn, returning a decorated function.""" + for decorator in reversed(decorators): + fn = decorator(fn) + return fn + +def rowset(results): + """Converts the results of sql execution into a plain set of column tuples. + + Useful for asserting the results of an unordered query. + """ + + return set([tuple(row) for row in results]) + + +def eq_(a, b, msg=None): + """Assert a == b, with repr messaging on failure.""" + assert a == b, msg or "%r != %r" % (a, b) + +def ne_(a, b, msg=None): + """Assert a != b, with repr messaging on failure.""" + assert a != b, msg or "%r == %r" % (a, b) + +def is_(a, b, msg=None): + """Assert a is b, with repr messaging on failure.""" + assert a is b, msg or "%r is not %r" % (a, b) + +def is_not_(a, b, msg=None): + """Assert a is not b, with repr messaging on failure.""" + assert a is not b, msg or "%r is %r" % (a, b) + +def startswith_(a, fragment, msg=None): + """Assert a.startswith(fragment), with repr messaging on failure.""" + assert a.startswith(fragment), msg or "%r does not start with %r" % ( + a, fragment) + +def assert_raises(except_cls, callable_, *args, **kw): + try: + callable_(*args, **kw) + success = False + except except_cls, e: + success = True + + # assert outside the block so it works for AssertionError too ! + assert success, "Callable did not raise an exception" + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + try: + callable_(*args, **kwargs) + assert False, "Callable did not raise an exception" + except except_cls, e: + assert re.search(msg, str(e)), "%r !~ %s" % (msg, e) + +def fail(msg): + assert False, msg + +def fixture(table, columns, *rows): + """Insert data into table after creation.""" + def onload(event, schema_item, connection): + insert = table.insert() + column_names = [col.key for col in columns] + connection.execute(insert, [dict(zip(column_names, column_values)) + for column_values in rows]) + table.append_ddl_listener('after-create', onload) + +def resolve_artifact_names(fn): + """Decorator, augment function globals with tables and classes. + + Swaps out the function's globals at execution time. The 'global' statement + will not work as expected inside a decorated function. + + """ + # This could be automatically applied to framework and test_ methods in + # the MappedTest-derived test suites but... *some* explicitness for this + # magic is probably good. Especially as 'global' won't work- these + # rebound functions aren't regular Python.. + # + # Also: it's lame that CPython accepts a dict-subclass for globals, but + # only calls dict methods. That would allow 'global' to pass through to + # the func_globals. + def resolved(*args, **kwargs): + self = args[0] + context = dict(fn.func_globals) + for source in self._artifact_registries: + context.update(getattr(self, source)) + # jython bug #1034 + rebound = types.FunctionType( + fn.func_code, context, fn.func_name, fn.func_defaults, + fn.func_closure) + return rebound(*args, **kwargs) + return function_named(resolved, fn.func_name) + +class adict(dict): + """Dict keys available as attributes. Shadows.""" + def __getattribute__(self, key): + try: + return self[key] + except KeyError: + return dict.__getattribute__(self, key) + + def get_all(self, *keys): + return tuple([self[key] for key in keys]) + + +class TestBase(object): + # A sequence of database names to always run, regardless of the + # constraints below. + __whitelist__ = () + + # A sequence of requirement names matching testing.requires decorators + __requires__ = () + + # A sequence of dialect names to exclude from the test class. + __unsupported_on__ = () + + # If present, test class is only runnable for the *single* specified + # dialect. If you need multiple, use __unsupported_on__ and invert. + __only_on__ = None + + # A sequence of no-arg callables. If any are True, the entire testcase is + # skipped. + __skip_if__ = None + + _artifact_registries = () + + def assert_(self, val, msg=None): + assert val, msg + +class AssertsCompiledSQL(object): + def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, use_default_dialect=False): + if use_default_dialect: + dialect = default.DefaultDialect() + + if dialect is None: + dialect = getattr(self, '__dialect__', None) + + kw = {} + if params is not None: + kw['column_keys'] = params.keys() + + if isinstance(clause, orm.Query): + context = clause._compile_context() + context.statement.use_labels = True + clause = context.statement + + c = clause.compile(dialect=dialect, **kw) + + param_str = repr(getattr(c, 'params', {})) + # Py3K + #param_str = param_str.encode('utf-8').decode('ascii', 'ignore') + + print "\nSQL String:\n" + str(c) + param_str + + cc = re.sub(r'[\n\t]', '', str(c)) + + eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) + + if checkparams is not None: + eq_(c.construct_params(params), checkparams) + +class ComparesTables(object): + def assert_tables_equal(self, table, reflected_table, strict_types=False): + assert len(table.c) == len(reflected_table.c) + for c, reflected_c in zip(table.c, reflected_table.c): + eq_(c.name, reflected_c.name) + assert reflected_c is reflected_table.c[c.name] + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) + + if strict_types: + assert type(reflected_c.type) is type(c.type), \ + "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) + else: + self.assert_types_base(reflected_c, c) + + if isinstance(c.type, sqltypes.String): + eq_(c.type.length, reflected_c.type.length) + + eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) + if c.server_default: + assert isinstance(reflected_c.server_default, + schema.FetchedValue) + + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] is not None + + def assert_types_base(self, c1, c2): + assert c1.type._compare_type_affinity(c2.type),\ + "On column %r, type '%s' doesn't correspond to type '%s'" % \ + (c1.name, c1.type, c2.type) + +class AssertsExecutionResults(object): + def assert_result(self, result, class_, *objects): + result = list(result) + print repr(result) + self.assert_list(result, class_, objects) + + def assert_list(self, result, class_, list): + self.assert_(len(result) == len(list), + "result list is not the same size as test list, " + + "for class " + class_.__name__) + for i in range(0, len(list)): + self.assert_row(class_, result[i], list[i]) + + def assert_row(self, class_, rowobj, desc): + self.assert_(rowobj.__class__ is class_, + "item class is not " + repr(class_)) + for key, value in desc.iteritems(): + if isinstance(value, tuple): + if isinstance(value[1], list): + self.assert_list(getattr(rowobj, key), value[0], value[1]) + else: + self.assert_row(value[0], getattr(rowobj, key), value[1]) + else: + self.assert_(getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" % ( + key, getattr(rowobj, key), value)) + + def assert_unordered_result(self, result, cls, *expected): + """As assert_result, but the order of objects is not considered. + + The algorithm is very expensive but not a big deal for the small + numbers of rows that the test suite manipulates. + """ + + class frozendict(dict): + def __hash__(self): + return id(self) + + found = util.IdentitySet(result) + expected = set([frozendict(e) for e in expected]) + + for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found): + fail('Unexpected type "%s", expected "%s"' % ( + type(wrong).__name__, cls.__name__)) + + if len(found) != len(expected): + fail('Unexpected object count "%s", expected "%s"' % ( + len(found), len(expected))) + + NOVALUE = object() + def _compare_item(obj, spec): + for key, value in spec.iteritems(): + if isinstance(value, tuple): + try: + self.assert_unordered_result( + getattr(obj, key), value[0], *value[1]) + except AssertionError: + return False + else: + if getattr(obj, key, NOVALUE) != value: + return False + return True + + for expected_item in expected: + for found_item in found: + if _compare_item(found_item, expected_item): + found.remove(found_item) + break + else: + fail( + "Expected %s instance with attributes %s not found." % ( + cls.__name__, repr(expected_item))) + return True + + def assert_sql_execution(self, db, callable_, *rules): + assertsql.asserter.add_rules(rules) + try: + callable_() + assertsql.asserter.statement_complete() + finally: + assertsql.asserter.clear_rules() + + def assert_sql(self, db, callable_, list_, with_sequences=None): + if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'): + rules = with_sequences + else: + rules = list_ + + newrules = [] + for rule in rules: + if isinstance(rule, dict): + newrule = assertsql.AllOf(*[ + assertsql.ExactSQL(k, v) for k, v in rule.iteritems() + ]) + else: + newrule = assertsql.ExactSQL(*rule) + newrules.append(newrule) + + self.assert_sql_execution(db, callable_, *newrules) + + def assert_sql_count(self, db, callable_, count): + self.assert_sql_execution(db, callable_, assertsql.CountStatements(count)) + + diff --git a/sqlalchemy/test/util.py b/sqlalchemy/test/util.py new file mode 100644 index 0000000..8a3a0e7 --- /dev/null +++ b/sqlalchemy/test/util.py @@ -0,0 +1,53 @@ +from sqlalchemy.util import jython, function_named + +import gc +import time + +if jython: + def gc_collect(*args): + """aggressive gc.collect for tests.""" + gc.collect() + time.sleep(0.1) + gc.collect() + gc.collect() + return 0 + + # "lazy" gc, for VM's that don't GC on refcount == 0 + lazy_gc = gc_collect + +else: + # assume CPython - straight gc.collect, lazy_gc() is a pass + gc_collect = gc.collect + def lazy_gc(): + pass + + + +def picklers(): + picklers = set() + # Py2K + try: + import cPickle + picklers.add(cPickle) + except ImportError: + pass + # end Py2K + import pickle + picklers.add(pickle) + + # yes, this thing needs this much testing + for pickle in picklers: + for protocol in -1, 0, 1, 2: + yield pickle.loads, lambda d:pickle.dumps(d, protocol) + + +def round_decimal(value, prec): + if isinstance(value, float): + return round(value, prec) + + import decimal + + # can also use shift() here but that is 2.6 only + return (value * decimal.Decimal("1" + "0" * prec)).to_integral(decimal.ROUND_FLOOR) / \ + pow(10, prec) + \ No newline at end of file diff --git a/sqlalchemy/topological.py b/sqlalchemy/topological.py new file mode 100644 index 0000000..d35213f --- /dev/null +++ b/sqlalchemy/topological.py @@ -0,0 +1,297 @@ +# topological.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Topological sorting algorithms. + +The topological sort is an algorithm that receives this list of +dependencies as a *partial ordering*, that is a list of pairs which +might say, *X is dependent on Y*, *Q is dependent on Z*, but does not +necessarily tell you anything about Q being dependent on X. Therefore, +its not a straight sort where every element can be compared to +another... only some of the elements have any sorting preference, and +then only towards just some of the other elements. For a particular +partial ordering, there can be many possible sorts that satisfy the +conditions. + +""" + +from sqlalchemy.exc import CircularDependencyError +from sqlalchemy import util + +__all__ = ['sort', 'sort_with_cycles', 'sort_as_tree'] + +def sort(tuples, allitems): + """sort the given list of items by dependency. + + 'tuples' is a list of tuples representing a partial ordering. + """ + + return [n.item for n in _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=True)] + +def sort_with_cycles(tuples, allitems): + """sort the given list of items by dependency, cutting out cycles. + + returns results as an iterable of 2-tuples, containing the item, + and a list containing items involved in a cycle with this item, if any. + + 'tuples' is a list of tuples representing a partial ordering. + """ + + return [(n.item, [n.item for n in n.cycles or []]) for n in _sort(tuples, allitems, allow_cycles=True)] + +def sort_as_tree(tuples, allitems, with_cycles=False): + """sort the given list of items by dependency, and return results + as a hierarchical tree structure. + + returns results as an iterable of 3-tuples, containing the item, + a list containing items involved in a cycle with this item, if any, + and a list of child tuples. + + if with_cycles is False, the returned structure is of the same form + but the second element of each tuple, i.e. the 'cycles', is an empty list. + + 'tuples' is a list of tuples representing a partial ordering. + """ + + return _organize_as_tree(_sort(tuples, allitems, allow_cycles=with_cycles)) + + +class _Node(object): + """Represent each item in the sort.""" + + def __init__(self, item): + self.item = item + self.dependencies = set() + self.children = [] + self.cycles = None + + def __str__(self): + return self.safestr() + + def safestr(self, indent=0): + return (' ' * indent * 2) + \ + str(self.item) + \ + (self.cycles is not None and (" (cycles: " + repr([x for x in self.cycles]) + ")") or "") + \ + "\n" + \ + ''.join(str(n) for n in self.children) + + def __repr__(self): + return str(self.item) + + def all_deps(self): + """Return a set of dependencies for this node and all its cycles.""" + + deps = set(self.dependencies) + if self.cycles is not None: + for c in self.cycles: + deps.update(c.dependencies) + return deps + +class _EdgeCollection(object): + """A collection of directed edges.""" + + def __init__(self): + self.parent_to_children = util.defaultdict(set) + self.child_to_parents = util.defaultdict(set) + + def add(self, edge): + """Add an edge to this collection.""" + + parentnode, childnode = edge + self.parent_to_children[parentnode].add(childnode) + self.child_to_parents[childnode].add(parentnode) + parentnode.dependencies.add(childnode) + + def remove(self, edge): + """Remove an edge from this collection. + + Return the childnode if it has no other parents. + """ + + (parentnode, childnode) = edge + self.parent_to_children[parentnode].remove(childnode) + self.child_to_parents[childnode].remove(parentnode) + if not self.child_to_parents[childnode]: + return childnode + else: + return None + + def has_parents(self, node): + return node in self.child_to_parents and bool(self.child_to_parents[node]) + + def edges_by_parent(self, node): + if node in self.parent_to_children: + return [(node, child) for child in self.parent_to_children[node]] + else: + return [] + + def get_parents(self): + return self.parent_to_children.keys() + + def pop_node(self, node): + """Remove all edges where the given node is a parent. + + Return the collection of all nodes which were children of the + given node, and have no further parents. + """ + + children = self.parent_to_children.pop(node, None) + if children is not None: + for child in children: + self.child_to_parents[child].remove(node) + if not self.child_to_parents[child]: + yield child + + def __len__(self): + return sum(len(x) for x in self.parent_to_children.values()) + + def __iter__(self): + for parent, children in self.parent_to_children.iteritems(): + for child in children: + yield (parent, child) + + def __repr__(self): + return repr(list(self)) + +def _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=False): + nodes = {} + edges = _EdgeCollection() + + for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]: + item_id = id(item) + if item_id not in nodes: + nodes[item_id] = _Node(item) + + for t in tuples: + id0, id1 = id(t[0]), id(t[1]) + if t[0] is t[1]: + if allow_cycles: + n = nodes[id0] + n.cycles = set([n]) + elif not ignore_self_cycles: + raise CircularDependencyError("Self-referential dependency detected: %r" % t) + continue + childnode = nodes[id1] + parentnode = nodes[id0] + edges.add((parentnode, childnode)) + + queue = [] + for n in nodes.values(): + if not edges.has_parents(n): + queue.append(n) + + output = [] + while nodes: + if not queue: + # edges remain but no edgeless nodes to remove; this indicates + # a cycle + if allow_cycles: + for cycle in _find_cycles(edges): + lead = cycle[0][0] + lead.cycles = set() + for edge in cycle: + n = edges.remove(edge) + lead.cycles.add(edge[0]) + lead.cycles.add(edge[1]) + if n is not None: + queue.append(n) + for n in lead.cycles: + if n is not lead: + n._cyclical = True + for (n, k) in list(edges.edges_by_parent(n)): + edges.add((lead, k)) + edges.remove((n, k)) + continue + else: + # long cycles not allowed + raise CircularDependencyError("Circular dependency detected: %r %r " % (edges, queue)) + node = queue.pop() + if not hasattr(node, '_cyclical'): + output.append(node) + del nodes[id(node.item)] + for childnode in edges.pop_node(node): + queue.append(childnode) + return output + +def _organize_as_tree(nodes): + """Given a list of nodes from a topological sort, organize the + nodes into a tree structure, with as many non-dependent nodes + set as siblings to each other as possible. + + returns nodes as 3-tuples (item, cycles, children). + """ + + if not nodes: + return None + # a list of all currently independent subtrees as a tuple of + # (root_node, set_of_all_tree_nodes, set_of_all_cycle_nodes_in_tree) + # order of the list has no semantics for the algorithmic + independents = [] + # in reverse topological order + for node in reversed(nodes): + # nodes subtree and cycles contain the node itself + subtree = set([node]) + if node.cycles is not None: + cycles = set(node.cycles) + else: + cycles = set() + # get a set of dependent nodes of node and its cycles + nodealldeps = node.all_deps() + if nodealldeps: + # iterate over independent node indexes in reverse order so we can efficiently remove them + for index in xrange(len(independents) - 1, -1, -1): + child, childsubtree, childcycles = independents[index] + # if there is a dependency between this node and an independent node + if (childsubtree.intersection(nodealldeps) or childcycles.intersection(node.dependencies)): + # prepend child to nodes children + # (append should be fine, but previous implemetation used prepend) + node.children[0:0] = [(child.item, [n.item for n in child.cycles or []], child.children)] + # merge childs subtree and cycles + subtree.update(childsubtree) + cycles.update(childcycles) + # remove the child from list of independent subtrees + independents[index:index+1] = [] + # add node as a new independent subtree + independents.append((node, subtree, cycles)) + # choose an arbitrary node from list of all independent subtrees + head = independents.pop()[0] + # add all other independent subtrees as a child of the chosen root + # used prepend [0:0] instead of extend to maintain exact behaviour of previous implementation + head.children[0:0] = [(i[0].item, [n.item for n in i[0].cycles or []], i[0].children) for i in independents] + return (head.item, [n.item for n in head.cycles or []], head.children) + +def _find_cycles(edges): + cycles = {} + + def traverse(node, cycle, goal): + for (n, key) in edges.edges_by_parent(node): + if key in cycle: + continue + cycle.add(key) + if key is goal: + cycset = set(cycle) + for x in cycle: + if x in cycles: + existing_set = cycles[x] + existing_set.update(cycset) + for y in existing_set: + cycles[y] = existing_set + cycset = existing_set + else: + cycles[x] = cycset + else: + traverse(key, cycle, goal) + cycle.pop() + + for parent in edges.get_parents(): + traverse(parent, set(), parent) + + unique_cycles = set(tuple(s) for s in cycles.values()) + + for cycle in unique_cycles: + edgecollection = [edge for edge in edges + if edge[0] in cycle and edge[1] in cycle] + yield edgecollection diff --git a/sqlalchemy/types.py b/sqlalchemy/types.py new file mode 100644 index 0000000..16cd57f --- /dev/null +++ b/sqlalchemy/types.py @@ -0,0 +1,1742 @@ +# types.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""defines genericized SQL types, each represented by a subclass of +:class:`~sqlalchemy.types.AbstractType`. Dialects define further subclasses of these +types. + +For more information see the SQLAlchemy documentation on types. + +""" +__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType', + 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text', + 'FLOAT', 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', + 'BLOB', 'BOOLEAN', 'SMALLINT', 'INTEGER', 'DATE', 'TIME', + 'String', 'Integer', 'SmallInteger', 'BigInteger', 'Numeric', + 'Float', 'DateTime', 'Date', 'Time', 'LargeBinary', 'Binary', 'Boolean', + 'Unicode', 'MutableType', 'Concatenable', 'UnicodeText', + 'PickleType', 'Interval', 'type_map', 'Enum' ] + +import inspect +import datetime as dt +from decimal import Decimal as _python_Decimal +import codecs + +from sqlalchemy import exc, schema +from sqlalchemy.sql import expression, operators +import sys +schema.types = expression.sqltypes =sys.modules['sqlalchemy.types'] +from sqlalchemy.util import pickle +from sqlalchemy.sql.visitors import Visitable +from sqlalchemy import util +from sqlalchemy import processors +import collections + +NoneType = type(None) +if util.jython: + import array + +class AbstractType(Visitable): + + def __init__(self, *args, **kwargs): + pass + + def compile(self, dialect): + return dialect.type_compiler.process(self) + + def copy_value(self, value): + return value + + def bind_processor(self, dialect): + """Defines a bind parameter processing function. + + :param dialect: Dialect instance in use. + + """ + + return None + + def result_processor(self, dialect, coltype): + """Defines a result-column processing function. + + :param dialect: Dialect instance in use. + + :param coltype: DBAPI coltype argument received in cursor.description. + + """ + + return None + + def compare_values(self, x, y): + """Compare two values for equality.""" + + return x == y + + def is_mutable(self): + """Return True if the target Python type is 'mutable'. + + This allows systems like the ORM to know if a column value can + be considered 'not changed' by comparing the identity of + objects alone. + + Use the :class:`MutableType` mixin or override this method to + return True in custom types that hold mutable values such as + ``dict``, ``list`` and custom objects. + + """ + return False + + def get_dbapi_type(self, dbapi): + """Return the corresponding type object from the underlying DB-API, if + any. + + This can be useful for calling ``setinputsizes()``, for example. + + """ + return None + + def _adapt_expression(self, op, othertype): + """evaluate the return type of , + and apply any adaptations to the given operator. + + """ + return op, self + + @util.memoized_property + def _type_affinity(self): + """Return a rudimental 'affinity' value expressing the general class of type.""" + + typ = None + for t in self.__class__.__mro__: + if t is TypeEngine or t is UserDefinedType: + return typ + elif issubclass(t, TypeEngine): + typ = t + else: + return self.__class__ + + def _coerce_compared_value(self, op, value): + _coerced_type = type_map.get(type(value), NULLTYPE) + if _coerced_type is NULLTYPE or _coerced_type._type_affinity is self._type_affinity: + return self + else: + return _coerced_type + + def _compare_type_affinity(self, other): + return self._type_affinity is other._type_affinity + + def __repr__(self): + return "%s(%s)" % ( + self.__class__.__name__, + ", ".join("%s=%r" % (k, getattr(self, k, None)) + for k in inspect.getargspec(self.__init__)[0][1:])) + +class TypeEngine(AbstractType): + """Base for built-in types.""" + + @util.memoized_property + def _impl_dict(self): + return {} + + def dialect_impl(self, dialect, **kwargs): + key = (dialect.__class__, dialect.server_version_info) + + try: + return self._impl_dict[key] + except KeyError: + return self._impl_dict.setdefault(key, dialect.type_descriptor(self)) + + def __getstate__(self): + d = self.__dict__.copy() + d.pop('_impl_dict', None) + return d + + def bind_processor(self, dialect): + """Return a conversion function for processing bind values. + + Returns a callable which will receive a bind parameter value + as the sole positional argument and will return a value to + send to the DB-API. + + If processing is not necessary, the method should return ``None``. + + """ + return None + + def result_processor(self, dialect, coltype): + """Return a conversion function for processing result row values. + + Returns a callable which will receive a result row column + value as the sole positional argument and will return a value + to return to the user. + + If processing is not necessary, the method should return ``None``. + + """ + return None + + def adapt(self, cls): + return cls() + +class UserDefinedType(TypeEngine): + """Base for user defined types. + + This should be the base of new types. Note that + for most cases, :class:`TypeDecorator` is probably + more appropriate:: + + import sqlalchemy.types as types + + class MyType(types.UserDefinedType): + def __init__(self, precision = 8): + self.precision = precision + + def get_col_spec(self): + return "MYTYPE(%s)" % self.precision + + def bind_processor(self, dialect): + def process(value): + return value + return process + + def result_processor(self, dialect, coltype): + def process(value): + return value + return process + + Once the type is made, it's immediately usable:: + + table = Table('foo', meta, + Column('id', Integer, primary_key=True), + Column('data', MyType(16)) + ) + + """ + __visit_name__ = "user_defined" + + def _adapt_expression(self, op, othertype): + """evaluate the return type of , + and apply any adaptations to the given operator. + + """ + return self.adapt_operator(op), self + + def adapt_operator(self, op): + """A hook which allows the given operator to be adapted + to something new. + + See also UserDefinedType._adapt_expression(), an as-yet- + semi-public method with greater capability in this regard. + + """ + return op + +class TypeDecorator(AbstractType): + """Allows the creation of types which add additional functionality + to an existing type. + + This method is preferred to direct subclassing of SQLAlchemy's + built-in types as it ensures that all required functionality of + the underlying type is kept in place. + + Typical usage:: + + import sqlalchemy.types as types + + class MyType(types.TypeDecorator): + '''Prefixes Unicode values with "PREFIX:" on the way in and + strips it off on the way out. + ''' + + impl = types.Unicode + + def process_bind_param(self, value, dialect): + return "PREFIX:" + value + + def process_result_value(self, value, dialect): + return value[7:] + + def copy(self): + return MyType(self.impl.length) + + The class-level "impl" variable is required, and can reference any + TypeEngine class. Alternatively, the load_dialect_impl() method + can be used to provide different type classes based on the dialect + given; in this case, the "impl" variable can reference + ``TypeEngine`` as a placeholder. + + Types that receive a Python type that isn't similar to the + ultimate type used may want to define the :meth:`TypeDecorator.coerce_compared_value` + method. This is used to give the expression system a hint + when coercing Python objects into bind parameters within expressions. + Consider this expression:: + + mytable.c.somecol + datetime.date(2009, 5, 15) + + Above, if "somecol" is an ``Integer`` variant, it makes sense that + we're doing date arithmetic, where above is usually interpreted + by databases as adding a number of days to the given date. + The expression system does the right thing by not attempting to + coerce the "date()" value into an integer-oriented bind parameter. + + However, in the case of ``TypeDecorator``, we are usually changing + an incoming Python type to something new - ``TypeDecorator`` by + default will "coerce" the non-typed side to be the same type as itself. + Such as below, we define an "epoch" type that stores a date value as an integer:: + + class MyEpochType(types.TypeDecorator): + impl = types.Integer + + epoch = datetime.date(1970, 1, 1) + + def process_bind_param(self, value, dialect): + return (value - self.epoch).days + + def process_result_value(self, value, dialect): + return self.epoch + timedelta(days=value) + + Our expression of ``somecol + date`` with the above type will coerce the + "date" on the right side to also be treated as ``MyEpochType``. + + This behavior can be overridden via the :meth:`~TypeDecorator.coerce_compared_value` + method, which returns a type that should be used for the value of the expression. + Below we set it such that an integer value will be treated as an ``Integer``, + and any other value is assumed to be a date and will be treated as a ``MyEpochType``:: + + def coerce_compared_value(self, op, value): + if isinstance(value, int): + return Integer() + else: + return self + + """ + + __visit_name__ = "type_decorator" + + def __init__(self, *args, **kwargs): + if not hasattr(self.__class__, 'impl'): + raise AssertionError("TypeDecorator implementations require a class-level " + "variable 'impl' which refers to the class of type being decorated") + self.impl = self.__class__.impl(*args, **kwargs) + + def adapt(self, cls): + return cls() + + def dialect_impl(self, dialect): + key = (dialect.__class__, dialect.server_version_info) + try: + return self._impl_dict[key] + except KeyError: + pass + + # adapt the TypeDecorator first, in + # the case that the dialect maps the TD + # to one of its native types (i.e. PGInterval) + adapted = dialect.type_descriptor(self) + if adapted is not self: + self._impl_dict[key] = adapted + return adapted + + # otherwise adapt the impl type, link + # to a copy of this TypeDecorator and return + # that. + typedesc = self.load_dialect_impl(dialect) + tt = self.copy() + if not isinstance(tt, self.__class__): + raise AssertionError("Type object %s does not properly implement the copy() " + "method, it must return an object of type %s" % (self, self.__class__)) + tt.impl = typedesc + self._impl_dict[key] = tt + return tt + + @util.memoized_property + def _type_affinity(self): + return self.impl._type_affinity + + def type_engine(self, dialect): + impl = self.dialect_impl(dialect) + if not isinstance(impl, TypeDecorator): + return impl + else: + return impl.impl + + def load_dialect_impl(self, dialect): + """Loads the dialect-specific implementation of this type. + + by default calls dialect.type_descriptor(self.impl), but + can be overridden to provide different behavior. + + """ + if isinstance(self.impl, TypeDecorator): + return self.impl.dialect_impl(dialect) + else: + return dialect.type_descriptor(self.impl) + + def __getattr__(self, key): + """Proxy all other undefined accessors to the underlying implementation.""" + + return getattr(self.impl, key) + + def process_bind_param(self, value, dialect): + raise NotImplementedError() + + def process_result_value(self, value, dialect): + raise NotImplementedError() + + def bind_processor(self, dialect): + if self.__class__.process_bind_param.func_code is not TypeDecorator.process_bind_param.func_code: + process_param = self.process_bind_param + impl_processor = self.impl.bind_processor(dialect) + if impl_processor: + def process(value): + return impl_processor(process_param(value, dialect)) + else: + def process(value): + return process_param(value, dialect) + return process + else: + return self.impl.bind_processor(dialect) + + def result_processor(self, dialect, coltype): + if self.__class__.process_result_value.func_code is not TypeDecorator.process_result_value.func_code: + process_value = self.process_result_value + impl_processor = self.impl.result_processor(dialect, coltype) + if impl_processor: + def process(value): + return process_value(impl_processor(value), dialect) + else: + def process(value): + return process_value(value, dialect) + return process + else: + return self.impl.result_processor(dialect, coltype) + + def coerce_compared_value(self, op, value): + """Suggest a type for a 'coerced' Python value in an expression. + + By default, returns self. This method is called by + the expression system when an object using this type is + on the left or right side of an expression against a plain Python + object which does not yet have a SQLAlchemy type assigned:: + + expr = table.c.somecolumn + 35 + + Where above, if ``somecolumn`` uses this type, this method will + be called with the value ``operator.add`` + and ``35``. The return value is whatever SQLAlchemy type should + be used for ``35`` for this particular operation. + + """ + return self + + def _coerce_compared_value(self, op, value): + return self.coerce_compared_value(op, value) + + def copy(self): + instance = self.__class__.__new__(self.__class__) + instance.__dict__.update(self.__dict__) + instance._impl_dict = {} + return instance + + def get_dbapi_type(self, dbapi): + return self.impl.get_dbapi_type(dbapi) + + def copy_value(self, value): + return self.impl.copy_value(value) + + def compare_values(self, x, y): + return self.impl.compare_values(x, y) + + def is_mutable(self): + return self.impl.is_mutable() + + def _adapt_expression(self, op, othertype): + return self.impl._adapt_expression(op, othertype) + + + +class MutableType(object): + """A mixin that marks a Type as holding a mutable object. + + :meth:`copy_value` and :meth:`compare_values` should be customized + as needed to match the needs of the object. + + """ + + def is_mutable(self): + """Return True, mutable.""" + return True + + def copy_value(self, value): + """Unimplemented.""" + raise NotImplementedError() + + def compare_values(self, x, y): + """Compare *x* == *y*.""" + return x == y + +def to_instance(typeobj): + if typeobj is None: + return NULLTYPE + + if util.callable(typeobj): + return typeobj() + else: + return typeobj + +def adapt_type(typeobj, colspecs): + if isinstance(typeobj, type): + typeobj = typeobj() + for t in typeobj.__class__.__mro__[0:-1]: + try: + impltype = colspecs[t] + break + except KeyError: + pass + else: + # couldnt adapt - so just return the type itself + # (it may be a user-defined type) + return typeobj + # if we adapted the given generic type to a database-specific type, + # but it turns out the originally given "generic" type + # is actually a subclass of our resulting type, then we were already + # given a more specific type than that required; so use that. + if (issubclass(typeobj.__class__, impltype)): + return typeobj + return typeobj.adapt(impltype) + +class NullType(TypeEngine): + """An unknown type. + + NullTypes will stand in if :class:`~sqlalchemy.Table` reflection + encounters a column data type unknown to SQLAlchemy. The + resulting columns are nearly fully usable: the DB-API adapter will + handle all translation to and from the database data type. + + NullType does not have sufficient information to particpate in a + ``CREATE TABLE`` statement and will raise an exception if + encountered during a :meth:`~sqlalchemy.Table.create` operation. + + """ + __visit_name__ = 'null' + + def _adapt_expression(self, op, othertype): + if othertype is NullType or not operators.is_commutative(op): + return op, self + else: + return othertype._adapt_expression(op, self) + +NullTypeEngine = NullType + +class Concatenable(object): + """A mixin that marks a type as supporting 'concatenation', typically strings.""" + + def _adapt_expression(self, op, othertype): + if op is operators.add and issubclass(othertype._type_affinity, (Concatenable, NullType)): + return operators.concat_op, self + else: + return op, self + +class _DateAffinity(object): + """Mixin date/time specific expression adaptations. + + Rules are implemented within Date,Time,Interval,DateTime, Numeric, Integer. + Based on http://www.postgresql.org/docs/current/static/functions-datetime.html. + + """ + + @property + def _expression_adaptations(self): + raise NotImplementedError() + + _blank_dict = util.frozendict() + def _adapt_expression(self, op, othertype): + othertype = othertype._type_affinity + return op, \ + self._expression_adaptations.get(op, self._blank_dict).\ + get(othertype, NULLTYPE) + +class String(Concatenable, TypeEngine): + """The base for all string and character types. + + In SQL, corresponds to VARCHAR. Can also take Python unicode objects + and encode to the database's encoding in bind params (and the reverse for + result sets.) + + The `length` field is usually required when the `String` type is + used within a CREATE TABLE statement, as VARCHAR requires a length + on most databases. + + """ + + __visit_name__ = 'string' + + def __init__(self, length=None, convert_unicode=False, + assert_unicode=None, unicode_error=None, + _warn_on_bytestring=False + ): + """ + Create a string-holding type. + + :param length: optional, a length for the column for use in + DDL statements. May be safely omitted if no ``CREATE + TABLE`` will be issued. Certain databases may require a + *length* for use in DDL, and will raise an exception when + the ``CREATE TABLE`` DDL is issued. Whether the value is + interpreted as bytes or characters is database specific. + + :param convert_unicode: defaults to False. If True, the + type will do what is necessary in order to accept + Python Unicode objects as bind parameters, and to return + Python Unicode objects in result rows. This may + require SQLAlchemy to explicitly coerce incoming Python + unicodes into an encoding, and from an encoding + back to Unicode, or it may not require any interaction + from SQLAlchemy at all, depending on the DBAPI in use. + + When SQLAlchemy performs the encoding/decoding, + the encoding used is configured via + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which + defaults to `utf-8`. + + The "convert_unicode" behavior can also be turned on + for all String types by setting + :attr:`sqlalchemy.engine.base.Dialect.convert_unicode` + on create_engine(). + + To instruct SQLAlchemy to perform Unicode encoding/decoding + even on a platform that already handles Unicode natively, + set convert_unicode='force'. This will incur significant + performance overhead when fetching unicode result columns. + + :param assert_unicode: Deprecated. A warning is raised in all cases when a non-Unicode + object is passed when SQLAlchemy would coerce into an encoding + (note: but **not** when the DBAPI handles unicode objects natively). + To suppress or raise this warning to an + error, use the Python warnings filter documented at: + http://docs.python.org/library/warnings.html + + :param unicode_error: Optional, a method to use to handle Unicode + conversion errors. Behaves like the 'errors' keyword argument to + the standard library's string.decode() functions. This flag + requires that `convert_unicode` is set to `"force"` - otherwise, + SQLAlchemy is not guaranteed to handle the task of unicode + conversion. Note that this flag adds significant performance + overhead to row-fetching operations for backends that already + return unicode objects natively (which most DBAPIs do). This + flag should only be used as an absolute last resort for reading + strings from a column with varied or corrupted encodings, + which only applies to databases that accept invalid encodings + in the first place (i.e. MySQL. *not* PG, Sqlite, etc.) + + """ + if unicode_error is not None and convert_unicode != 'force': + raise exc.ArgumentError("convert_unicode must be 'force' " + "when unicode_error is set.") + + if assert_unicode: + util.warn_deprecated("assert_unicode is deprecated. " + "SQLAlchemy emits a warning in all cases where it " + "would otherwise like to encode a Python unicode object " + "into a specific encoding but a plain bytestring is received. " + "This does *not* apply to DBAPIs that coerce Unicode natively." + ) + self.length = length + self.convert_unicode = convert_unicode + self.unicode_error = unicode_error + self._warn_on_bytestring = _warn_on_bytestring + + def adapt(self, impltype): + return impltype( + length=self.length, + convert_unicode=self.convert_unicode, + unicode_error=self.unicode_error, + _warn_on_bytestring=True, + ) + + def bind_processor(self, dialect): + if self.convert_unicode or dialect.convert_unicode: + if dialect.supports_unicode_binds and self.convert_unicode != 'force': + if self._warn_on_bytestring: + def process(value): + # Py3K + #if isinstance(value, bytes): + # Py2K + if isinstance(value, str): + # end Py2K + util.warn("Unicode type received non-unicode bind " + "param value %r" % value) + return value + return process + else: + return None + else: + encoder = codecs.getencoder(dialect.encoding) + def process(value): + if isinstance(value, unicode): + return encoder(value, self.unicode_error)[0] + elif value is not None: + util.warn("Unicode type received non-unicode bind " + "param value %r" % value) + return value + return process + else: + return None + + def result_processor(self, dialect, coltype): + wants_unicode = self.convert_unicode or dialect.convert_unicode + needs_convert = wants_unicode and \ + (dialect.returns_unicode_strings is not True or + self.convert_unicode == 'force') + + if needs_convert: + to_unicode = processors.to_unicode_processor_factory( + dialect.encoding, self.unicode_error) + + if dialect.returns_unicode_strings: + # we wouldn't be here unless convert_unicode='force' + # was specified, or the driver has erratic unicode-returning + # habits. since we will be getting back unicode + # in most cases, we check for it (decode will fail). + def process(value): + if isinstance(value, unicode): + return value + else: + return to_unicode(value) + return process + else: + # here, we assume that the object is not unicode, + # avoiding expensive isinstance() check. + return to_unicode + else: + return None + + def get_dbapi_type(self, dbapi): + return dbapi.STRING + +class Text(String): + """A variably sized string type. + + In SQL, usually corresponds to CLOB or TEXT. Can also take Python + unicode objects and encode to the database's encoding in bind + params (and the reverse for result sets.) + + """ + __visit_name__ = 'text' + +class Unicode(String): + """A variable length Unicode string. + + The ``Unicode`` type is a :class:`String` which converts Python + ``unicode`` objects (i.e., strings that are defined as + ``u'somevalue'``) into encoded bytestrings when passing the value + to the database driver, and similarly decodes values from the + database back into Python ``unicode`` objects. + + It's roughly equivalent to using a ``String`` object with + ``convert_unicode=True``, however + the type has other significances in that it implies the usage + of a unicode-capable type being used on the backend, such as NVARCHAR. + This may affect what type is emitted when issuing CREATE TABLE + and also may effect some DBAPI-specific details, such as type + information passed along to ``setinputsizes()``. + + When using the ``Unicode`` type, it is only appropriate to pass + Python ``unicode`` objects, and not plain ``str``. If a + bytestring (``str``) is passed, a runtime warning is issued. If + you notice your application raising these warnings but you're not + sure where, the Python ``warnings`` filter can be used to turn + these warnings into exceptions which will illustrate a stack + trace:: + + import warnings + warnings.simplefilter('error') + + Bytestrings sent to and received from the database are encoded + using the dialect's + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which defaults + to `utf-8`. + + """ + + __visit_name__ = 'unicode' + + def __init__(self, length=None, **kwargs): + """ + Create a Unicode-converting String type. + + :param length: optional, a length for the column for use in + DDL statements. May be safely omitted if no ``CREATE + TABLE`` will be issued. Certain databases may require a + *length* for use in DDL, and will raise an exception when + the ``CREATE TABLE`` DDL is issued. Whether the value is + interpreted as bytes or characters is database specific. + + :param \**kwargs: passed through to the underlying ``String`` + type. + + """ + kwargs.setdefault('convert_unicode', True) + kwargs.setdefault('_warn_on_bytestring', True) + super(Unicode, self).__init__(length=length, **kwargs) + +class UnicodeText(Text): + """An unbounded-length Unicode string. + + See :class:`Unicode` for details on the unicode + behavior of this object. + + Like ``Unicode``, usage the ``UnicodeText`` type implies a + unicode-capable type being used on the backend, such as NCLOB. + + """ + + __visit_name__ = 'unicode_text' + + def __init__(self, length=None, **kwargs): + """ + Create a Unicode-converting Text type. + + :param length: optional, a length for the column for use in + DDL statements. May be safely omitted if no ``CREATE + TABLE`` will be issued. Certain databases may require a + *length* for use in DDL, and will raise an exception when + the ``CREATE TABLE`` DDL is issued. Whether the value is + interpreted as bytes or characters is database specific. + + """ + kwargs.setdefault('convert_unicode', True) + kwargs.setdefault('_warn_on_bytestring', True) + super(UnicodeText, self).__init__(length=length, **kwargs) + + +class Integer(_DateAffinity, TypeEngine): + """A type for ``int`` integers.""" + + __visit_name__ = 'integer' + + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + @util.memoized_property + def _expression_adaptations(self): + # TODO: need a dictionary object that will + # handle operators generically here, this is incomplete + return { + operators.add:{ + Date:Date, + Integer:Integer, + Numeric:Numeric, + }, + operators.mul:{ + Interval:Interval, + Integer:Integer, + Numeric:Numeric, + }, + # Py2K + operators.div:{ + Integer:Integer, + Numeric:Numeric, + }, + # end Py2K + operators.truediv:{ + Integer:Integer, + Numeric:Numeric, + }, + operators.sub:{ + Integer:Integer, + Numeric:Numeric, + }, + } + +class SmallInteger(Integer): + """A type for smaller ``int`` integers. + + Typically generates a ``SMALLINT`` in DDL, and otherwise acts like + a normal :class:`Integer` on the Python side. + + """ + + __visit_name__ = 'small_integer' + +class BigInteger(Integer): + """A type for bigger ``int`` integers. + + Typically generates a ``BIGINT`` in DDL, and otherwise acts like + a normal :class:`Integer` on the Python side. + + """ + + __visit_name__ = 'big_integer' + +class Numeric(_DateAffinity, TypeEngine): + """A type for fixed precision numbers. + + Typically generates DECIMAL or NUMERIC. Returns + ``decimal.Decimal`` objects by default, applying + conversion as needed. + + """ + + __visit_name__ = 'numeric' + + def __init__(self, precision=None, scale=None, asdecimal=True): + """ + Construct a Numeric. + + :param precision: the numeric precision for use in DDL ``CREATE TABLE``. + + :param scale: the numeric scale for use in DDL ``CREATE TABLE``. + + :param asdecimal: default True. Return whether or not + values should be sent as Python Decimal objects, or + as floats. Different DBAPIs send one or the other based on + datatypes - the Numeric type will ensure that return values + are one or the other across DBAPIs consistently. + + When using the ``Numeric`` type, care should be taken to ensure + that the asdecimal setting is apppropriate for the DBAPI in use - + when Numeric applies a conversion from Decimal->float or float-> + Decimal, this conversion incurs an additional performance overhead + for all result columns received. + + DBAPIs that return Decimal natively (e.g. psycopg2) will have + better accuracy and higher performance with a setting of ``True``, + as the native translation to Decimal reduces the amount of floating- + point issues at play, and the Numeric type itself doesn't need + to apply any further conversions. However, another DBAPI which + returns floats natively *will* incur an additional conversion + overhead, and is still subject to floating point data loss - in + which case ``asdecimal=False`` will at least remove the extra + conversion overhead. + + """ + self.precision = precision + self.scale = scale + self.asdecimal = asdecimal + + def adapt(self, impltype): + return impltype( + precision=self.precision, + scale=self.scale, + asdecimal=self.asdecimal) + + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + def bind_processor(self, dialect): + if dialect.supports_native_decimal: + return None + else: + return processors.to_float + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if dialect.supports_native_decimal: + # we're a "numeric", DBAPI will give us Decimal directly + return None + else: + # we're a "numeric", DBAPI returns floats, convert. + if self.scale is not None: + return processors.to_decimal_processor_factory(_python_Decimal, self.scale) + else: + return processors.to_decimal_processor_factory(_python_Decimal) + else: + if dialect.supports_native_decimal: + return processors.to_float + else: + return None + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.mul:{ + Interval:Interval + }, + } + + +class Float(Numeric): + """A type for ``float`` numbers. + + Returns Python ``float`` objects by default, applying + conversion as needed. + + """ + + __visit_name__ = 'float' + + def __init__(self, precision=None, asdecimal=False, **kwargs): + """ + Construct a Float. + + :param precision: the numeric precision for use in DDL ``CREATE TABLE``. + + :param asdecimal: the same flag as that of :class:`Numeric`, but + defaults to ``False``. + + """ + self.precision = precision + self.asdecimal = asdecimal + + def adapt(self, impltype): + return impltype(precision=self.precision, asdecimal=self.asdecimal) + + def result_processor(self, dialect, coltype): + if self.asdecimal: + return processors.to_decimal_processor_factory(_python_Decimal) + else: + return None + + +class DateTime(_DateAffinity, TypeEngine): + """A type for ``datetime.datetime()`` objects. + + Date and time types return objects from the Python ``datetime`` + module. Most DBAPIs have built in support for the datetime + module, with the noted exception of SQLite. In the case of + SQLite, date and time types are stored as strings which are then + converted back to datetime objects when rows are returned. + + """ + + __visit_name__ = 'datetime' + + def __init__(self, timezone=False): + self.timezone = timezone + + def adapt(self, impltype): + return impltype(timezone=self.timezone) + + def get_dbapi_type(self, dbapi): + return dbapi.DATETIME + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add:{ + Interval:DateTime, + }, + operators.sub:{ + Interval:DateTime, + DateTime:Interval, + }, + } + + +class Date(_DateAffinity,TypeEngine): + """A type for ``datetime.date()`` objects.""" + + __visit_name__ = 'date' + + def get_dbapi_type(self, dbapi): + return dbapi.DATETIME + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add:{ + Integer:Date, + Interval:DateTime, + Time:DateTime, + }, + operators.sub:{ + # date - integer = date + Integer:Date, + + # date - date = integer. + Date:Integer, + + Interval:DateTime, + + # date - datetime = interval, + # this one is not in the PG docs + # but works + DateTime:Interval, + }, + } + + +class Time(_DateAffinity,TypeEngine): + """A type for ``datetime.time()`` objects.""" + + __visit_name__ = 'time' + + def __init__(self, timezone=False): + self.timezone = timezone + + def adapt(self, impltype): + return impltype(timezone=self.timezone) + + def get_dbapi_type(self, dbapi): + return dbapi.DATETIME + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add:{ + Date:DateTime, + Interval:Time + }, + operators.sub:{ + Time:Interval, + Interval:Time, + }, + } + + +class _Binary(TypeEngine): + """Define base behavior for binary types.""" + + def __init__(self, length=None): + self.length = length + + # Python 3 - sqlite3 doesn't need the `Binary` conversion + # here, though pg8000 does to indicate "bytea" + def bind_processor(self, dialect): + DBAPIBinary = dialect.dbapi.Binary + def process(value): + if value is not None: + return DBAPIBinary(value) + else: + return None + return process + + # Python 3 has native bytes() type + # both sqlite3 and pg8000 seem to return it + # (i.e. and not 'memoryview') + # Py2K + def result_processor(self, dialect, coltype): + if util.jython: + def process(value): + if value is not None: + if isinstance(value, array.array): + return value.tostring() + return str(value) + else: + return None + else: + process = processors.to_str + return process + # end Py2K + + def adapt(self, impltype): + return impltype(length=self.length) + + def get_dbapi_type(self, dbapi): + return dbapi.BINARY + +class LargeBinary(_Binary): + """A type for large binary byte data. + + The Binary type generates BLOB or BYTEA when tables are created, + and also converts incoming values using the ``Binary`` callable + provided by each DB-API. + + """ + + __visit_name__ = 'large_binary' + + def __init__(self, length=None): + """ + Construct a LargeBinary type. + + :param length: optional, a length for the column for use in + DDL statements, for those BLOB types that accept a length + (i.e. MySQL). It does *not* produce a small BINARY/VARBINARY + type - use the BINARY/VARBINARY types specifically for those. + May be safely omitted if no ``CREATE + TABLE`` will be issued. Certain databases may require a + *length* for use in DDL, and will raise an exception when + the ``CREATE TABLE`` DDL is issued. + + """ + _Binary.__init__(self, length=length) + +class Binary(LargeBinary): + """Deprecated. Renamed to LargeBinary.""" + + def __init__(self, *arg, **kw): + util.warn_deprecated("The Binary type has been renamed to LargeBinary.") + LargeBinary.__init__(self, *arg, **kw) + +class SchemaType(object): + """Mark a type as possibly requiring schema-level DDL for usage. + + Supports types that must be explicitly created/dropped (i.e. PG ENUM type) + as well as types that are complimented by table or schema level + constraints, triggers, and other rules. + + """ + + def __init__(self, **kw): + self.name = kw.pop('name', None) + self.quote = kw.pop('quote', None) + self.schema = kw.pop('schema', None) + self.metadata = kw.pop('metadata', None) + if self.metadata: + self.metadata.append_ddl_listener( + 'before-create', + util.portable_instancemethod(self._on_metadata_create) + ) + self.metadata.append_ddl_listener( + 'after-drop', + util.portable_instancemethod(self._on_metadata_drop) + ) + + def _set_parent(self, column): + column._on_table_attach(util.portable_instancemethod(self._set_table)) + + def _set_table(self, table, column): + table.append_ddl_listener( + 'before-create', + util.portable_instancemethod(self._on_table_create) + ) + table.append_ddl_listener( + 'after-drop', + util.portable_instancemethod(self._on_table_drop) + ) + if self.metadata is None: + table.metadata.append_ddl_listener( + 'before-create', + util.portable_instancemethod(self._on_metadata_create) + ) + table.metadata.append_ddl_listener( + 'after-drop', + util.portable_instancemethod(self._on_metadata_drop) + ) + + @property + def bind(self): + return self.metadata and self.metadata.bind or None + + def create(self, bind=None, checkfirst=False): + """Issue CREATE ddl for this type, if applicable.""" + + from sqlalchemy.schema import _bind_or_error + if bind is None: + bind = _bind_or_error(self) + t = self.dialect_impl(bind.dialect) + if t is not self and isinstance(t, SchemaType): + t.create(bind=bind, checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=False): + """Issue DROP ddl for this type, if applicable.""" + + from sqlalchemy.schema import _bind_or_error + if bind is None: + bind = _bind_or_error(self) + t = self.dialect_impl(bind.dialect) + if t is not self and isinstance(t, SchemaType): + t.drop(bind=bind, checkfirst=checkfirst) + + def _on_table_create(self, event, target, bind, **kw): + t = self.dialect_impl(bind.dialect) + if t is not self and isinstance(t, SchemaType): + t._on_table_create(event, target, bind, **kw) + + def _on_table_drop(self, event, target, bind, **kw): + t = self.dialect_impl(bind.dialect) + if t is not self and isinstance(t, SchemaType): + t._on_table_drop(event, target, bind, **kw) + + def _on_metadata_create(self, event, target, bind, **kw): + t = self.dialect_impl(bind.dialect) + if t is not self and isinstance(t, SchemaType): + t._on_metadata_create(event, target, bind, **kw) + + def _on_metadata_drop(self, event, target, bind, **kw): + t = self.dialect_impl(bind.dialect) + if t is not self and isinstance(t, SchemaType): + t._on_metadata_drop(event, target, bind, **kw) + +class Enum(String, SchemaType): + """Generic Enum Type. + + The Enum type provides a set of possible string values which the + column is constrained towards. + + By default, uses the backend's native ENUM type if available, + else uses VARCHAR + a CHECK constraint. + """ + + __visit_name__ = 'enum' + + def __init__(self, *enums, **kw): + """Construct an enum. + + Keyword arguments which don't apply to a specific backend are ignored + by that backend. + + :param \*enums: string or unicode enumeration labels. If unicode labels + are present, the `convert_unicode` flag is auto-enabled. + + :param convert_unicode: Enable unicode-aware bind parameter and result-set + processing for this Enum's data. This is set automatically based on + the presence of unicode label strings. + + :param metadata: Associate this type directly with a ``MetaData`` object. + For types that exist on the target database as an independent schema + construct (Postgresql), this type will be created and dropped within + ``create_all()`` and ``drop_all()`` operations. If the type is not + associated with any ``MetaData`` object, it will associate itself with + each ``Table`` in which it is used, and will be created when any of + those individual tables are created, after a check is performed for + it's existence. The type is only dropped when ``drop_all()`` is called + for that ``Table`` object's metadata, however. + + :param name: The name of this type. This is required for Postgresql and + any future supported database which requires an explicitly named type, + or an explicitly named constraint in order to generate the type and/or + a table that uses it. + + :param native_enum: Use the database's native ENUM type when available. + Defaults to True. When False, uses VARCHAR + check constraint + for all backends. + + :param schema: Schemaname of this type. For types that exist on the target + database as an independent schema construct (Postgresql), this + parameter specifies the named schema in which the type is present. + + :param quote: Force quoting to be on or off on the type's name. If left as + the default of `None`, the usual schema-level "case + sensitive"/"reserved name" rules are used to determine if this type's + name should be quoted. + + """ + self.enums = enums + self.native_enum = kw.pop('native_enum', True) + convert_unicode= kw.pop('convert_unicode', None) + if convert_unicode is None: + for e in enums: + if isinstance(e, unicode): + convert_unicode = True + break + else: + convert_unicode = False + + if self.enums: + length =max(len(x) for x in self.enums) + else: + length = 0 + String.__init__(self, + length =length, + convert_unicode=convert_unicode, + ) + SchemaType.__init__(self, **kw) + + def _should_create_constraint(self, compiler): + return not self.native_enum or \ + not compiler.dialect.supports_native_enum + + def _set_table(self, table, column): + if self.native_enum: + SchemaType._set_table(self, table, column) + + + e = schema.CheckConstraint( + column.in_(self.enums), + name=self.name, + _create_rule=util.portable_instancemethod(self._should_create_constraint) + ) + table.append_constraint(e) + + def adapt(self, impltype): + if issubclass(impltype, Enum): + return impltype(name=self.name, + quote=self.quote, + schema=self.schema, + metadata=self.metadata, + convert_unicode=self.convert_unicode, + *self.enums + ) + else: + return super(Enum, self).adapt(impltype) + +class PickleType(MutableType, TypeDecorator): + """Holds Python objects. + + PickleType builds upon the Binary type to apply Python's + ``pickle.dumps()`` to incoming objects, and ``pickle.loads()`` on + the way out, allowing any pickleable Python object to be stored as + a serialized binary field. + + """ + + impl = LargeBinary + + def __init__(self, protocol=pickle.HIGHEST_PROTOCOL, pickler=None, mutable=True, comparator=None): + """ + Construct a PickleType. + + :param protocol: defaults to ``pickle.HIGHEST_PROTOCOL``. + + :param pickler: defaults to cPickle.pickle or pickle.pickle if + cPickle is not available. May be any object with + pickle-compatible ``dumps` and ``loads`` methods. + + :param mutable: defaults to True; implements + :meth:`AbstractType.is_mutable`. When ``True``, incoming + objects should provide an ``__eq__()`` method which + performs the desired deep comparison of members, or the + ``comparator`` argument must be present. + + :param comparator: optional. a 2-arg callable predicate used + to compare values of this type. Otherwise, + the == operator is used to compare values. + + """ + self.protocol = protocol + self.pickler = pickler or pickle + self.mutable = mutable + self.comparator = comparator + super(PickleType, self).__init__() + + def bind_processor(self, dialect): + impl_processor = self.impl.bind_processor(dialect) + dumps = self.pickler.dumps + protocol = self.protocol + if impl_processor: + def process(value): + if value is not None: + value = dumps(value, protocol) + return impl_processor(value) + else: + def process(value): + if value is not None: + value = dumps(value, protocol) + return value + return process + + def result_processor(self, dialect, coltype): + impl_processor = self.impl.result_processor(dialect, coltype) + loads = self.pickler.loads + if impl_processor: + def process(value): + value = impl_processor(value) + if value is None: + return None + return loads(value) + else: + def process(value): + if value is None: + return None + return loads(value) + return process + + def copy_value(self, value): + if self.mutable: + return self.pickler.loads(self.pickler.dumps(value, self.protocol)) + else: + return value + + def compare_values(self, x, y): + if self.comparator: + return self.comparator(x, y) + else: + return x == y + + def is_mutable(self): + return self.mutable + + +class Boolean(TypeEngine, SchemaType): + """A bool datatype. + + Boolean typically uses BOOLEAN or SMALLINT on the DDL side, and on + the Python side deals in ``True`` or ``False``. + + """ + + __visit_name__ = 'boolean' + + def __init__(self, create_constraint=True, name=None): + """Construct a Boolean. + + :param create_constraint: defaults to True. If the boolean + is generated as an int/smallint, also create a CHECK constraint + on the table that ensures 1 or 0 as a value. + + :param name: if a CHECK constraint is generated, specify + the name of the constraint. + + """ + self.create_constraint = create_constraint + self.name = name + + def _should_create_constraint(self, compiler): + return not compiler.dialect.supports_native_boolean + + def _set_table(self, table, column): + if not self.create_constraint: + return + + e = schema.CheckConstraint( + column.in_([0, 1]), + name=self.name, + _create_rule=util.portable_instancemethod(self._should_create_constraint) + ) + table.append_constraint(e) + + def result_processor(self, dialect, coltype): + if dialect.supports_native_boolean: + return None + else: + return processors.int_to_boolean + +class Interval(_DateAffinity, TypeDecorator): + """A type for ``datetime.timedelta()`` objects. + + The Interval type deals with ``datetime.timedelta`` objects. In + PostgreSQL, the native ``INTERVAL`` type is used; for others, the + value is stored as a date which is relative to the "epoch" + (Jan. 1, 1970). + + Note that the ``Interval`` type does not currently provide + date arithmetic operations on platforms which do not support + interval types natively. Such operations usually require + transformation of both sides of the expression (such as, conversion + of both sides into integer epoch values first) which currently + is a manual procedure (such as via :attr:`~sqlalchemy.sql.expression.func`). + + """ + + impl = DateTime + epoch = dt.datetime.utcfromtimestamp(0) + + def __init__(self, native=True, + second_precision=None, + day_precision=None): + """Construct an Interval object. + + :param native: when True, use the actual + INTERVAL type provided by the database, if + supported (currently Postgresql, Oracle). + Otherwise, represent the interval data as + an epoch value regardless. + + :param second_precision: For native interval types + which support a "fractional seconds precision" parameter, + i.e. Oracle and Postgresql + + :param day_precision: for native interval types which + support a "day precision" parameter, i.e. Oracle. + + """ + super(Interval, self).__init__() + self.native = native + self.second_precision = second_precision + self.day_precision = day_precision + + def adapt(self, cls): + if self.native: + return cls._adapt_from_generic_interval(self) + else: + return self + + def bind_processor(self, dialect): + impl_processor = self.impl.bind_processor(dialect) + epoch = self.epoch + if impl_processor: + def process(value): + if value is not None: + value = epoch + value + return impl_processor(value) + else: + def process(value): + if value is not None: + value = epoch + value + return value + return process + + def result_processor(self, dialect, coltype): + impl_processor = self.impl.result_processor(dialect, coltype) + epoch = self.epoch + if impl_processor: + def process(value): + value = impl_processor(value) + if value is None: + return None + return value - epoch + else: + def process(value): + if value is None: + return None + return value - epoch + return process + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add:{ + Date:DateTime, + Interval:Interval, + DateTime:DateTime, + Time:Time, + }, + operators.sub:{ + Interval:Interval + }, + operators.mul:{ + Numeric:Interval + }, + operators.truediv: { + Numeric:Interval + }, + # Py2K + operators.div: { + Numeric:Interval + } + # end Py2K + } + + @property + def _type_affinity(self): + return Interval + + def _coerce_compared_value(self, op, value): + return self.impl._coerce_compared_value(op, value) + + +class FLOAT(Float): + """The SQL FLOAT type.""" + + __visit_name__ = 'FLOAT' + +class NUMERIC(Numeric): + """The SQL NUMERIC type.""" + + __visit_name__ = 'NUMERIC' + + +class DECIMAL(Numeric): + """The SQL DECIMAL type.""" + + __visit_name__ = 'DECIMAL' + + +class INTEGER(Integer): + """The SQL INT or INTEGER type.""" + + __visit_name__ = 'INTEGER' +INT = INTEGER + + +class SMALLINT(SmallInteger): + """The SQL SMALLINT type.""" + + __visit_name__ = 'SMALLINT' + + +class BIGINT(BigInteger): + """The SQL BIGINT type.""" + + __visit_name__ = 'BIGINT' + +class TIMESTAMP(DateTime): + """The SQL TIMESTAMP type.""" + + __visit_name__ = 'TIMESTAMP' + + def get_dbapi_type(self, dbapi): + return dbapi.TIMESTAMP + +class DATETIME(DateTime): + """The SQL DATETIME type.""" + + __visit_name__ = 'DATETIME' + + +class DATE(Date): + """The SQL DATE type.""" + + __visit_name__ = 'DATE' + + +class TIME(Time): + """The SQL TIME type.""" + + __visit_name__ = 'TIME' + +class TEXT(Text): + """The SQL TEXT type.""" + + __visit_name__ = 'TEXT' + +class CLOB(Text): + """The CLOB type. + + This type is found in Oracle and Informix. + """ + + __visit_name__ = 'CLOB' + +class VARCHAR(String): + """The SQL VARCHAR type.""" + + __visit_name__ = 'VARCHAR' + +class NVARCHAR(Unicode): + """The SQL NVARCHAR type.""" + + __visit_name__ = 'NVARCHAR' + +class CHAR(String): + """The SQL CHAR type.""" + + __visit_name__ = 'CHAR' + + +class NCHAR(Unicode): + """The SQL NCHAR type.""" + + __visit_name__ = 'NCHAR' + + +class BLOB(LargeBinary): + """The SQL BLOB type.""" + + __visit_name__ = 'BLOB' + +class BINARY(_Binary): + """The SQL BINARY type.""" + + __visit_name__ = 'BINARY' + +class VARBINARY(_Binary): + """The SQL VARBINARY type.""" + + __visit_name__ = 'VARBINARY' + + +class BOOLEAN(Boolean): + """The SQL BOOLEAN type.""" + + __visit_name__ = 'BOOLEAN' + +NULLTYPE = NullType() +BOOLEANTYPE = Boolean() + +# using VARCHAR/NCHAR so that we dont get the genericized "String" +# type which usually resolves to TEXT/CLOB +type_map = { + str: String(), + # Py3K + #bytes : LargeBinary(), + # Py2K + unicode : Unicode(), + # end Py2K + int : Integer(), + float : Numeric(), + bool: BOOLEANTYPE, + _python_Decimal : Numeric(), + dt.date : Date(), + dt.datetime : DateTime(), + dt.time : Time(), + dt.timedelta : Interval(), + NoneType: NULLTYPE +} + diff --git a/sqlalchemy/util.py b/sqlalchemy/util.py new file mode 100644 index 0000000..9727000 --- /dev/null +++ b/sqlalchemy/util.py @@ -0,0 +1,1651 @@ +# util.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import inspect, itertools, operator, sys, warnings, weakref, gc +# Py2K +import __builtin__ +# end Py2K +types = __import__('types') + +from sqlalchemy import exc + +try: + import threading +except ImportError: + import dummy_threading as threading + +py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0) +jython = sys.platform.startswith('java') + +if py3k: + set_types = set +elif sys.version_info < (2, 6): + import sets + set_types = set, sets.Set +else: + # 2.6 deprecates sets.Set, but we still need to be able to detect them + # in user code and as return values from DB-APIs + ignore = ('ignore', None, DeprecationWarning, None, 0) + try: + warnings.filters.insert(0, ignore) + except Exception: + import sets + else: + import sets + warnings.filters.remove(ignore) + + set_types = set, sets.Set + +EMPTY_SET = frozenset() + +NoneType = type(None) + +if py3k: + import pickle +else: + try: + import cPickle as pickle + except ImportError: + import pickle + +# Py2K +# a controversial feature, required by MySQLdb currently +def buffer(x): + return x + +buffer = getattr(__builtin__, 'buffer', buffer) +# end Py2K + +if sys.version_info >= (2, 5): + class PopulateDict(dict): + """A dict which populates missing values via a creation function. + + Note the creation function takes a key, unlike + collections.defaultdict. + + """ + + def __init__(self, creator): + self.creator = creator + + def __missing__(self, key): + self[key] = val = self.creator(key) + return val +else: + class PopulateDict(dict): + """A dict which populates missing values via a creation function.""" + + def __init__(self, creator): + self.creator = creator + + def __getitem__(self, key): + try: + return dict.__getitem__(self, key) + except KeyError: + self[key] = value = self.creator(key) + return value + +if py3k: + def callable(fn): + return hasattr(fn, '__call__') + def cmp(a, b): + return (a > b) - (a < b) + + from functools import reduce +else: + callable = __builtin__.callable + cmp = __builtin__.cmp + reduce = __builtin__.reduce + +try: + from collections import defaultdict +except ImportError: + class defaultdict(dict): + def __init__(self, default_factory=None, *a, **kw): + if (default_factory is not None and + not hasattr(default_factory, '__call__')): + raise TypeError('first argument must be callable') + dict.__init__(self, *a, **kw) + self.default_factory = default_factory + def __getitem__(self, key): + try: + return dict.__getitem__(self, key) + except KeyError: + return self.__missing__(key) + def __missing__(self, key): + if self.default_factory is None: + raise KeyError(key) + self[key] = value = self.default_factory() + return value + def __reduce__(self): + if self.default_factory is None: + args = tuple() + else: + args = self.default_factory, + return type(self), args, None, None, self.iteritems() + def copy(self): + return self.__copy__() + def __copy__(self): + return type(self)(self.default_factory, self) + def __deepcopy__(self, memo): + import copy + return type(self)(self.default_factory, + copy.deepcopy(self.items())) + def __repr__(self): + return 'defaultdict(%s, %s)' % (self.default_factory, + dict.__repr__(self)) + +class frozendict(dict): + def _blocked_attribute(obj): + raise AttributeError, "A frozendict cannot be modified." + _blocked_attribute = property(_blocked_attribute) + + __delitem__ = __setitem__ = clear = _blocked_attribute + pop = popitem = setdefault = update = _blocked_attribute + + def __new__(cls, *args): + new = dict.__new__(cls) + dict.__init__(new, *args) + return new + + def __init__(self, *args): + pass + + def __reduce__(self): + return frozendict, (dict(self), ) + + def union(self, d): + if not self: + return frozendict(d) + else: + d2 = self.copy() + d2.update(d) + return frozendict(d2) + + def __repr__(self): + return "frozendict(%s)" % dict.__repr__(self) + +def to_list(x, default=None): + if x is None: + return default + if not isinstance(x, (list, tuple)): + return [x] + else: + return x + +def to_set(x): + if x is None: + return set() + if not isinstance(x, set): + return set(to_list(x)) + else: + return x + +def to_column_set(x): + if x is None: + return column_set() + if not isinstance(x, column_set): + return column_set(to_list(x)) + else: + return x + + +try: + from functools import update_wrapper +except ImportError: + def update_wrapper(wrapper, wrapped, + assigned=('__doc__', '__module__', '__name__'), + updated=('__dict__',)): + for attr in assigned: + setattr(wrapper, attr, getattr(wrapped, attr)) + for attr in updated: + getattr(wrapper, attr).update(getattr(wrapped, attr, ())) + return wrapper + +try: + from functools import partial +except: + def partial(func, *args, **keywords): + def newfunc(*fargs, **fkeywords): + newkeywords = keywords.copy() + newkeywords.update(fkeywords) + return func(*(args + fargs), **newkeywords) + return newfunc + + +def accepts_a_list_as_starargs(list_deprecation=None): + def decorate(fn): + + spec = inspect.getargspec(fn) + assert spec[1], 'Decorated function does not accept *args' + + def _deprecate(): + if list_deprecation: + if list_deprecation == 'pending': + warning_type = exc.SAPendingDeprecationWarning + else: + warning_type = exc.SADeprecationWarning + msg = ( + "%s%s now accepts multiple %s arguments as a " + "variable argument list. Supplying %s as a single " + "list is deprecated and support will be removed " + "in a future release." % ( + fn.func_name, + inspect.formatargspec(*spec), + spec[1], spec[1])) + warnings.warn(msg, warning_type, stacklevel=3) + + def go(fn, *args, **kw): + if isinstance(args[-1], list): + _deprecate() + return fn(*(list(args[0:-1]) + args[-1]), **kw) + else: + return fn(*args, **kw) + + return decorator(go)(fn) + + return decorate + +def unique_symbols(used, *bases): + used = set(used) + for base in bases: + pool = itertools.chain((base,), + itertools.imap(lambda i: base + str(i), + xrange(1000))) + for sym in pool: + if sym not in used: + used.add(sym) + yield sym + break + else: + raise NameError("exhausted namespace for symbol base %s" % base) + +def decorator(target): + """A signature-matching decorator factory.""" + + def decorate(fn): + spec = inspect.getargspec(fn) + names = tuple(spec[0]) + spec[1:3] + (fn.func_name,) + targ_name, fn_name = unique_symbols(names, 'target', 'fn') + + metadata = dict(target=targ_name, fn=fn_name) + metadata.update(format_argspec_plus(spec, grouped=False)) + + code = 'lambda %(args)s: %(target)s(%(fn)s, %(apply_kw)s)' % ( + metadata) + decorated = eval(code, {targ_name:target, fn_name:fn}) + decorated.func_defaults = getattr(fn, 'im_func', fn).func_defaults + return update_wrapper(decorated, fn) + return update_wrapper(decorate, target) + + +if sys.version_info >= (2, 5): + def decode_slice(slc): + """decode a slice object as sent to __getitem__. + + takes into account the 2.5 __index__() method, basically. + + """ + ret = [] + for x in slc.start, slc.stop, slc.step: + if hasattr(x, '__index__'): + x = x.__index__() + ret.append(x) + return tuple(ret) +else: + def decode_slice(slc): + return (slc.start, slc.stop, slc.step) + +def update_copy(d, _new=None, **kw): + """Copy the given dict and update with the given values.""" + + d = d.copy() + if _new: + d.update(_new) + d.update(**kw) + return d + +def flatten_iterator(x): + """Given an iterator of which further sub-elements may also be + iterators, flatten the sub-elements into a single iterator. + + """ + for elem in x: + if not isinstance(elem, basestring) and hasattr(elem, '__iter__'): + for y in flatten_iterator(elem): + yield y + else: + yield elem + +def get_cls_kwargs(cls): + """Return the full set of inherited kwargs for the given `cls`. + + Probes a class's __init__ method, collecting all named arguments. If the + __init__ defines a \**kwargs catch-all, then the constructor is presumed to + pass along unrecognized keywords to it's base classes, and the collection + process is repeated recursively on each of the bases. + + """ + + for c in cls.__mro__: + if '__init__' in c.__dict__: + stack = set([c]) + break + else: + return [] + + args = set() + while stack: + class_ = stack.pop() + ctr = class_.__dict__.get('__init__', False) + if not ctr or not isinstance(ctr, types.FunctionType): + stack.update(class_.__bases__) + continue + names, _, has_kw, _ = inspect.getargspec(ctr) + args.update(names) + if has_kw: + stack.update(class_.__bases__) + args.discard('self') + return args + +def get_func_kwargs(func): + """Return the full set of legal kwargs for the given `func`.""" + return inspect.getargspec(func)[0] + +def format_argspec_plus(fn, grouped=True): + """Returns a dictionary of formatted, introspected function arguments. + + A enhanced variant of inspect.formatargspec to support code generation. + + fn + An inspectable callable or tuple of inspect getargspec() results. + grouped + Defaults to True; include (parens, around, argument) lists + + Returns: + + args + Full inspect.formatargspec for fn + self_arg + The name of the first positional argument, varargs[0], or None + if the function defines no positional arguments. + apply_pos + args, re-written in calling rather than receiving syntax. Arguments are + passed positionally. + apply_kw + Like apply_pos, except keyword-ish args are passed as keywords. + + Example:: + + >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123) + {'args': '(self, a, b, c=3, **d)', + 'self_arg': 'self', + 'apply_kw': '(self, a, b, c=c, **d)', + 'apply_pos': '(self, a, b, c, **d)'} + + """ + spec = callable(fn) and inspect.getargspec(fn) or fn + args = inspect.formatargspec(*spec) + if spec[0]: + self_arg = spec[0][0] + elif spec[1]: + self_arg = '%s[0]' % spec[1] + else: + self_arg = None + apply_pos = inspect.formatargspec(spec[0], spec[1], spec[2]) + defaulted_vals = spec[3] is not None and spec[0][0-len(spec[3]):] or () + apply_kw = inspect.formatargspec(spec[0], spec[1], spec[2], defaulted_vals, + formatvalue=lambda x: '=' + x) + if grouped: + return dict(args=args, self_arg=self_arg, + apply_pos=apply_pos, apply_kw=apply_kw) + else: + return dict(args=args[1:-1], self_arg=self_arg, + apply_pos=apply_pos[1:-1], apply_kw=apply_kw[1:-1]) + +def format_argspec_init(method, grouped=True): + """format_argspec_plus with considerations for typical __init__ methods + + Wraps format_argspec_plus with error handling strategies for typical + __init__ cases:: + + object.__init__ -> (self) + other unreflectable (usually C) -> (self, *args, **kwargs) + + """ + try: + return format_argspec_plus(method, grouped=grouped) + except TypeError: + self_arg = 'self' + if method is object.__init__: + args = grouped and '(self)' or 'self' + else: + args = (grouped and '(self, *args, **kwargs)' + or 'self, *args, **kwargs') + return dict(self_arg='self', args=args, apply_pos=args, apply_kw=args) + +def getargspec_init(method): + """inspect.getargspec with considerations for typical __init__ methods + + Wraps inspect.getargspec with error handling for typical __init__ cases:: + + object.__init__ -> (self) + other unreflectable (usually C) -> (self, *args, **kwargs) + + """ + try: + return inspect.getargspec(method) + except TypeError: + if method is object.__init__: + return (['self'], None, None, None) + else: + return (['self'], 'args', 'kwargs', None) + + +def unbound_method_to_callable(func_or_cls): + """Adjust the incoming callable such that a 'self' argument is not required.""" + + if isinstance(func_or_cls, types.MethodType) and not func_or_cls.im_self: + return func_or_cls.im_func + else: + return func_or_cls + +class portable_instancemethod(object): + """Turn an instancemethod into a (parent, name) pair + to produce a serializable callable. + + """ + def __init__(self, meth): + self.target = meth.im_self + self.name = meth.__name__ + + def __call__(self, *arg, **kw): + return getattr(self.target, self.name)(*arg, **kw) + +def class_hierarchy(cls): + """Return an unordered sequence of all classes related to cls. + + Traverses diamond hierarchies. + + Fibs slightly: subclasses of builtin types are not returned. Thus + class_hierarchy(class A(object)) returns (A, object), not A plus every + class systemwide that derives from object. + + Old-style classes are discarded and hierarchies rooted on them + will not be descended. + + """ + # Py2K + if isinstance(cls, types.ClassType): + return list() + # end Py2K + hier = set([cls]) + process = list(cls.__mro__) + while process: + c = process.pop() + # Py2K + if isinstance(c, types.ClassType): + continue + for b in (_ for _ in c.__bases__ + if _ not in hier and not isinstance(_, types.ClassType)): + # end Py2K + # Py3K + #for b in (_ for _ in c.__bases__ + # if _ not in hier): + process.append(b) + hier.add(b) + # Py3K + #if c.__module__ == 'builtins' or not hasattr(c, '__subclasses__'): + # continue + # Py2K + if c.__module__ == '__builtin__' or not hasattr(c, '__subclasses__'): + continue + # end Py2K + for s in [_ for _ in c.__subclasses__() if _ not in hier]: + process.append(s) + hier.add(s) + return list(hier) + +def iterate_attributes(cls): + """iterate all the keys and attributes associated + with a class, without using getattr(). + + Does not use getattr() so that class-sensitive + descriptors (i.e. property.__get__()) are not called. + + """ + keys = dir(cls) + for key in keys: + for c in cls.__mro__: + if key in c.__dict__: + yield (key, c.__dict__[key]) + break + +# from paste.deploy.converters +def asbool(obj): + if isinstance(obj, (str, unicode)): + obj = obj.strip().lower() + if obj in ['true', 'yes', 'on', 'y', 't', '1']: + return True + elif obj in ['false', 'no', 'off', 'n', 'f', '0']: + return False + else: + raise ValueError("String is not true/false: %r" % obj) + return bool(obj) + +def coerce_kw_type(kw, key, type_, flexi_bool=True): + """If 'key' is present in dict 'kw', coerce its value to type 'type\_' if + necessary. If 'flexi_bool' is True, the string '0' is considered false + when coercing to boolean. + """ + + if key in kw and type(kw[key]) is not type_ and kw[key] is not None: + if type_ is bool and flexi_bool: + kw[key] = asbool(kw[key]) + else: + kw[key] = type_(kw[key]) + +def duck_type_collection(specimen, default=None): + """Given an instance or class, guess if it is or is acting as one of + the basic collection types: list, set and dict. If the __emulates__ + property is present, return that preferentially. + """ + + if hasattr(specimen, '__emulates__'): + # canonicalize set vs sets.Set to a standard: the builtin set + if (specimen.__emulates__ is not None and + issubclass(specimen.__emulates__, set_types)): + return set + else: + return specimen.__emulates__ + + isa = isinstance(specimen, type) and issubclass or isinstance + if isa(specimen, list): + return list + elif isa(specimen, set_types): + return set + elif isa(specimen, dict): + return dict + + if hasattr(specimen, 'append'): + return list + elif hasattr(specimen, 'add'): + return set + elif hasattr(specimen, 'set'): + return dict + else: + return default + +def dictlike_iteritems(dictlike): + """Return a (key, value) iterator for almost any dict-like object.""" + + # Py3K + #if hasattr(dictlike, 'items'): + # return dictlike.items() + # Py2K + if hasattr(dictlike, 'iteritems'): + return dictlike.iteritems() + elif hasattr(dictlike, 'items'): + return iter(dictlike.items()) + # end Py2K + + getter = getattr(dictlike, '__getitem__', getattr(dictlike, 'get', None)) + if getter is None: + raise TypeError( + "Object '%r' is not dict-like" % dictlike) + + if hasattr(dictlike, 'iterkeys'): + def iterator(): + for key in dictlike.iterkeys(): + yield key, getter(key) + return iterator() + elif hasattr(dictlike, 'keys'): + return iter((key, getter(key)) for key in dictlike.keys()) + else: + raise TypeError( + "Object '%r' is not dict-like" % dictlike) + +def assert_arg_type(arg, argtype, name): + if isinstance(arg, argtype): + return arg + else: + if isinstance(argtype, tuple): + raise exc.ArgumentError( + "Argument '%s' is expected to be one of type %s, got '%s'" % + (name, ' or '.join("'%s'" % a for a in argtype), type(arg))) + else: + raise exc.ArgumentError( + "Argument '%s' is expected to be of type '%s', got '%s'" % + (name, argtype, type(arg))) + +_creation_order = 1 +def set_creation_order(instance): + """Assign a '_creation_order' sequence to the given instance. + + This allows multiple instances to be sorted in order of creation + (typically within a single thread; the counter is not particularly + threadsafe). + + """ + global _creation_order + instance._creation_order = _creation_order + _creation_order +=1 + +def warn_exception(func, *args, **kwargs): + """executes the given function, catches all exceptions and converts to a warning.""" + try: + return func(*args, **kwargs) + except: + warn("%s('%s') ignored" % sys.exc_info()[0:2]) + +def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None, + name='self.proxy', from_instance=None): + """Automates delegation of __specials__ for a proxying type.""" + + if only: + dunders = only + else: + if skip is None: + skip = ('__slots__', '__del__', '__getattribute__', + '__metaclass__', '__getstate__', '__setstate__') + dunders = [m for m in dir(from_cls) + if (m.startswith('__') and m.endswith('__') and + not hasattr(into_cls, m) and m not in skip)] + for method in dunders: + try: + fn = getattr(from_cls, method) + if not hasattr(fn, '__call__'): + continue + fn = getattr(fn, 'im_func', fn) + except AttributeError: + continue + try: + spec = inspect.getargspec(fn) + fn_args = inspect.formatargspec(spec[0]) + d_args = inspect.formatargspec(spec[0][1:]) + except TypeError: + fn_args = '(self, *args, **kw)' + d_args = '(*args, **kw)' + + py = ("def %(method)s%(fn_args)s: " + "return %(name)s.%(method)s%(d_args)s" % locals()) + + env = from_instance is not None and {name: from_instance} or {} + exec py in env + try: + env[method].func_defaults = fn.func_defaults + except AttributeError: + pass + setattr(into_cls, method, env[method]) + +class NamedTuple(tuple): + """tuple() subclass that adds labeled names. + + Is also pickleable. + + """ + + def __new__(cls, vals, labels=None): + vals = list(vals) + t = tuple.__new__(cls, vals) + if labels: + t.__dict__ = dict(itertools.izip(labels, vals)) + t._labels = labels + return t + + def keys(self): + return self._labels + + +class OrderedProperties(object): + """An object that maintains the order in which attributes are set upon it. + + Also provides an iterator and a very basic getitem/setitem + interface to those attributes. + + (Not really a dict, since it iterates over values, not keys. Not really + a list, either, since each value must have a key associated; hence there is + no append or extend.) + """ + + def __init__(self): + self.__dict__['_data'] = OrderedDict() + + def __len__(self): + return len(self._data) + + def __iter__(self): + return self._data.itervalues() + + def __add__(self, other): + return list(self) + list(other) + + def __setitem__(self, key, object): + self._data[key] = object + + def __getitem__(self, key): + return self._data[key] + + def __delitem__(self, key): + del self._data[key] + + def __setattr__(self, key, object): + self._data[key] = object + + def __getstate__(self): + return {'_data': self.__dict__['_data']} + + def __setstate__(self, state): + self.__dict__['_data'] = state['_data'] + + def __getattr__(self, key): + try: + return self._data[key] + except KeyError: + raise AttributeError(key) + + def __contains__(self, key): + return key in self._data + + def update(self, value): + self._data.update(value) + + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def keys(self): + return self._data.keys() + + def has_key(self, key): + return key in self._data + + def clear(self): + self._data.clear() + +class OrderedDict(dict): + """A dict that returns keys/values/items in the order they were added.""" + + def __init__(self, ____sequence=None, **kwargs): + self._list = [] + if ____sequence is None: + if kwargs: + self.update(**kwargs) + else: + self.update(____sequence, **kwargs) + + def clear(self): + self._list = [] + dict.clear(self) + + def copy(self): + return self.__copy__() + + def __copy__(self): + return OrderedDict(self) + + def sort(self, *arg, **kw): + self._list.sort(*arg, **kw) + + def update(self, ____sequence=None, **kwargs): + if ____sequence is not None: + if hasattr(____sequence, 'keys'): + for key in ____sequence.keys(): + self.__setitem__(key, ____sequence[key]) + else: + for key, value in ____sequence: + self[key] = value + if kwargs: + self.update(kwargs) + + def setdefault(self, key, value): + if key not in self: + self.__setitem__(key, value) + return value + else: + return self.__getitem__(key) + + def __iter__(self): + return iter(self._list) + + def values(self): + return [self[key] for key in self._list] + + def itervalues(self): + return iter(self.values()) + + def keys(self): + return list(self._list) + + def iterkeys(self): + return iter(self.keys()) + + def items(self): + return [(key, self[key]) for key in self.keys()] + + def iteritems(self): + return iter(self.items()) + + def __setitem__(self, key, object): + if key not in self: + try: + self._list.append(key) + except AttributeError: + # work around Python pickle loads() with + # dict subclass (seems to ignore __setstate__?) + self._list = [key] + dict.__setitem__(self, key, object) + + def __delitem__(self, key): + dict.__delitem__(self, key) + self._list.remove(key) + + def pop(self, key, *default): + present = key in self + value = dict.pop(self, key, *default) + if present: + self._list.remove(key) + return value + + def popitem(self): + item = dict.popitem(self) + self._list.remove(item[0]) + return item + +class OrderedSet(set): + def __init__(self, d=None): + set.__init__(self) + self._list = [] + if d is not None: + self.update(d) + + def add(self, element): + if element not in self: + self._list.append(element) + set.add(self, element) + + def remove(self, element): + set.remove(self, element) + self._list.remove(element) + + def insert(self, pos, element): + if element not in self: + self._list.insert(pos, element) + set.add(self, element) + + def discard(self, element): + if element in self: + self._list.remove(element) + set.remove(self, element) + + def clear(self): + set.clear(self) + self._list = [] + + def __getitem__(self, key): + return self._list[key] + + def __iter__(self): + return iter(self._list) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self._list) + + __str__ = __repr__ + + def update(self, iterable): + add = self.add + for i in iterable: + add(i) + return self + + __ior__ = update + + def union(self, other): + result = self.__class__(self) + result.update(other) + return result + + __or__ = union + + def intersection(self, other): + other = set(other) + return self.__class__(a for a in self if a in other) + + __and__ = intersection + + def symmetric_difference(self, other): + other = set(other) + result = self.__class__(a for a in self if a not in other) + result.update(a for a in other if a not in self) + return result + + __xor__ = symmetric_difference + + def difference(self, other): + other = set(other) + return self.__class__(a for a in self if a not in other) + + __sub__ = difference + + def intersection_update(self, other): + other = set(other) + set.intersection_update(self, other) + self._list = [ a for a in self._list if a in other] + return self + + __iand__ = intersection_update + + def symmetric_difference_update(self, other): + set.symmetric_difference_update(self, other) + self._list = [ a for a in self._list if a in self] + self._list += [ a for a in other._list if a in self] + return self + + __ixor__ = symmetric_difference_update + + def difference_update(self, other): + set.difference_update(self, other) + self._list = [ a for a in self._list if a in self] + return self + + __isub__ = difference_update + + +class IdentitySet(object): + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + + """ + + _working_set = set + + def __init__(self, iterable=None): + self._members = dict() + if iterable: + for o in iterable: + self.add(o) + + def add(self, value): + self._members[id(value)] = value + + def __contains__(self, value): + return id(value) in self._members + + def remove(self, value): + del self._members[id(value)] + + def discard(self, value): + try: + self.remove(value) + except KeyError: + pass + + def pop(self): + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError('pop from an empty set') + + def clear(self): + self._members.clear() + + def __cmp__(self, other): + raise TypeError('cannot compare sets using cmp()') + + def __eq__(self, other): + if isinstance(other, IdentitySet): + return self._members == other._members + else: + return False + + def __ne__(self, other): + if isinstance(other, IdentitySet): + return self._members != other._members + else: + return True + + def issubset(self, iterable): + other = type(self)(iterable) + + if len(self) > len(other): + return False + for m in itertools.ifilterfalse(other._members.__contains__, + self._members.iterkeys()): + return False + return True + + def __le__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + def issuperset(self, iterable): + other = type(self)(iterable) + + if len(self) < len(other): + return False + + for m in itertools.ifilterfalse(self._members.__contains__, + other._members.iterkeys()): + return False + return True + + def __ge__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) + + def union(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._member_id_tuples()).union(_iter_id(iterable))) + return result + + def __or__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.union(other) + + def update(self, iterable): + self._members = self.union(iterable)._members + + def __ior__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.update(other) + return self + + def difference(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._member_id_tuples()).difference(_iter_id(iterable))) + return result + + def __sub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.difference(other) + + def difference_update(self, iterable): + self._members = self.difference(iterable)._members + + def __isub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.difference_update(other) + return self + + def intersection(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._member_id_tuples()).intersection(_iter_id(iterable))) + return result + + def __and__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.intersection(other) + + def intersection_update(self, iterable): + self._members = self.intersection(iterable)._members + + def __iand__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.intersection_update(other) + return self + + def symmetric_difference(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._member_id_tuples()).symmetric_difference(_iter_id(iterable))) + return result + + def _member_id_tuples(self): + return ((id(v), v) for v in self._members.itervalues()) + + def __xor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.symmetric_difference(other) + + def symmetric_difference_update(self, iterable): + self._members = self.symmetric_difference(iterable)._members + + def __ixor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.symmetric_difference(other) + return self + + def copy(self): + return type(self)(self._members.itervalues()) + + __copy__ = copy + + def __len__(self): + return len(self._members) + + def __iter__(self): + return self._members.itervalues() + + def __hash__(self): + raise TypeError('set objects are unhashable') + + def __repr__(self): + return '%s(%r)' % (type(self).__name__, self._members.values()) + + +class OrderedIdentitySet(IdentitySet): + class _working_set(OrderedSet): + # a testing pragma: exempt the OIDS working set from the test suite's + # "never call the user's __hash__" assertions. this is a big hammer, + # but it's safe here: IDS operates on (id, instance) tuples in the + # working set. + __sa_hash_exempt__ = True + + def __init__(self, iterable=None): + IdentitySet.__init__(self) + self._members = OrderedDict() + if iterable: + for o in iterable: + self.add(o) + +def _iter_id(iterable): + """Generator: ((id(o), o) for o in iterable).""" + + for item in iterable: + yield id(item), item + +# define collections that are capable of storing +# ColumnElement objects as hashable keys/elements. +column_set = set +column_dict = dict +ordered_column_set = OrderedSet +populate_column_dict = PopulateDict + +def unique_list(seq, compare_with=set): + seen = compare_with() + return [x for x in seq if x not in seen and not seen.add(x)] + +class UniqueAppender(object): + """Appends items to a collection ensuring uniqueness. + + Additional appends() of the same object are ignored. Membership is + determined by identity (``is a``) not equality (``==``). + """ + + def __init__(self, data, via=None): + self.data = data + self._unique = IdentitySet() + if via: + self._data_appender = getattr(data, via) + elif hasattr(data, 'append'): + self._data_appender = data.append + elif hasattr(data, 'add'): + # TODO: we think its a set here. bypass unneeded uniquing logic ? + self._data_appender = data.add + + def append(self, item): + if item not in self._unique: + self._data_appender(item) + self._unique.add(item) + + def __iter__(self): + return iter(self.data) + + +class ScopedRegistry(object): + """A Registry that can store one or multiple instances of a single + class on a per-thread scoped basis, or on a customized scope. + + createfunc + a callable that returns a new object to be placed in the registry + + scopefunc + a callable that will return a key to store/retrieve an object. + """ + + def __init__(self, createfunc, scopefunc): + self.createfunc = createfunc + self.scopefunc = scopefunc + self.registry = {} + + def __call__(self): + key = self.scopefunc() + try: + return self.registry[key] + except KeyError: + return self.registry.setdefault(key, self.createfunc()) + + def has(self): + return self.scopefunc() in self.registry + + def set(self, obj): + self.registry[self.scopefunc()] = obj + + def clear(self): + try: + del self.registry[self.scopefunc()] + except KeyError: + pass + +class ThreadLocalRegistry(ScopedRegistry): + def __init__(self, createfunc): + self.createfunc = createfunc + self.registry = threading.local() + + def __call__(self): + try: + return self.registry.value + except AttributeError: + val = self.registry.value = self.createfunc() + return val + + def has(self): + return hasattr(self.registry, "value") + + def set(self, obj): + self.registry.value = obj + + def clear(self): + try: + del self.registry.value + except AttributeError: + pass + +class _symbol(object): + def __init__(self, name): + """Construct a new named symbol.""" + assert isinstance(name, str) + self.name = name + def __reduce__(self): + return symbol, (self.name,) + def __repr__(self): + return "" % self.name +_symbol.__name__ = 'symbol' + + +class symbol(object): + """A constant symbol. + + >>> symbol('foo') is symbol('foo') + True + >>> symbol('foo') + + + A slight refinement of the MAGICCOOKIE=object() pattern. The primary + advantage of symbol() is its repr(). They are also singletons. + + Repeated calls of symbol('name') will all return the same instance. + + """ + symbols = {} + _lock = threading.Lock() + + def __new__(cls, name): + cls._lock.acquire() + try: + sym = cls.symbols.get(name) + if sym is None: + cls.symbols[name] = sym = _symbol(name) + return sym + finally: + symbol._lock.release() + + +def as_interface(obj, cls=None, methods=None, required=None): + """Ensure basic interface compliance for an instance or dict of callables. + + Checks that ``obj`` implements public methods of ``cls`` or has members + listed in ``methods``. If ``required`` is not supplied, implementing at + least one interface method is sufficient. Methods present on ``obj`` that + are not in the interface are ignored. + + If ``obj`` is a dict and ``dict`` does not meet the interface + requirements, the keys of the dictionary are inspected. Keys present in + ``obj`` that are not in the interface will raise TypeErrors. + + Raises TypeError if ``obj`` does not meet the interface criteria. + + In all passing cases, an object with callable members is returned. In the + simple case, ``obj`` is returned as-is; if dict processing kicks in then + an anonymous class is returned. + + obj + A type, instance, or dictionary of callables. + cls + Optional, a type. All public methods of cls are considered the + interface. An ``obj`` instance of cls will always pass, ignoring + ``required``.. + methods + Optional, a sequence of method names to consider as the interface. + required + Optional, a sequence of mandatory implementations. If omitted, an + ``obj`` that provides at least one interface method is considered + sufficient. As a convenience, required may be a type, in which case + all public methods of the type are required. + + """ + if not cls and not methods: + raise TypeError('a class or collection of method names are required') + + if isinstance(cls, type) and isinstance(obj, cls): + return obj + + interface = set(methods or [m for m in dir(cls) if not m.startswith('_')]) + implemented = set(dir(obj)) + + complies = operator.ge + if isinstance(required, type): + required = interface + elif not required: + required = set() + complies = operator.gt + else: + required = set(required) + + if complies(implemented.intersection(interface), required): + return obj + + # No dict duck typing here. + if not type(obj) is dict: + qualifier = complies is operator.gt and 'any of' or 'all of' + raise TypeError("%r does not implement %s: %s" % ( + obj, qualifier, ', '.join(interface))) + + class AnonymousInterface(object): + """A callable-holding shell.""" + + if cls: + AnonymousInterface.__name__ = 'Anonymous' + cls.__name__ + found = set() + + for method, impl in dictlike_iteritems(obj): + if method not in interface: + raise TypeError("%r: unknown in this interface" % method) + if not callable(impl): + raise TypeError("%r=%r is not callable" % (method, impl)) + setattr(AnonymousInterface, method, staticmethod(impl)) + found.add(method) + + if complies(found, required): + return AnonymousInterface + + raise TypeError("dictionary does not contain required keys %s" % + ', '.join(required - found)) + +def function_named(fn, name): + """Return a function with a given __name__. + + Will assign to __name__ and return the original function if possible on + the Python implementation, otherwise a new function will be constructed. + + """ + try: + fn.__name__ = name + except TypeError: + fn = types.FunctionType(fn.func_code, fn.func_globals, name, + fn.func_defaults, fn.func_closure) + return fn + +class memoized_property(object): + """A read-only @property that is only evaluated once.""" + def __init__(self, fget, doc=None): + self.fget = fget + self.__doc__ = doc or fget.__doc__ + self.__name__ = fget.__name__ + + def __get__(self, obj, cls): + if obj is None: + return None + obj.__dict__[self.__name__] = result = self.fget(obj) + return result + + +class memoized_instancemethod(object): + """Decorate a method memoize its return value. + + Best applied to no-arg methods: memoization is not sensitive to + argument values, and will always return the same value even when + called with different arguments. + + """ + def __init__(self, fget, doc=None): + self.fget = fget + self.__doc__ = doc or fget.__doc__ + self.__name__ = fget.__name__ + + def __get__(self, obj, cls): + if obj is None: + return None + def oneshot(*args, **kw): + result = self.fget(obj, *args, **kw) + memo = lambda *a, **kw: result + memo.__name__ = self.__name__ + memo.__doc__ = self.__doc__ + obj.__dict__[self.__name__] = memo + return result + oneshot.__name__ = self.__name__ + oneshot.__doc__ = self.__doc__ + return oneshot + +def reset_memoized(instance, name): + instance.__dict__.pop(name, None) + +class WeakIdentityMapping(weakref.WeakKeyDictionary): + """A WeakKeyDictionary with an object identity index. + + Adds a .by_id dictionary to a regular WeakKeyDictionary. Trades + performance during mutation operations for accelerated lookups by id(). + + The usual cautions about weak dictionaries and iteration also apply to + this subclass. + + """ + _none = symbol('none') + + def __init__(self): + weakref.WeakKeyDictionary.__init__(self) + self.by_id = {} + self._weakrefs = {} + + def __setitem__(self, object, value): + oid = id(object) + self.by_id[oid] = value + if oid not in self._weakrefs: + self._weakrefs[oid] = self._ref(object) + weakref.WeakKeyDictionary.__setitem__(self, object, value) + + def __delitem__(self, object): + del self._weakrefs[id(object)] + del self.by_id[id(object)] + weakref.WeakKeyDictionary.__delitem__(self, object) + + def setdefault(self, object, default=None): + value = weakref.WeakKeyDictionary.setdefault(self, object, default) + oid = id(object) + if value is default: + self.by_id[oid] = default + if oid not in self._weakrefs: + self._weakrefs[oid] = self._ref(object) + return value + + def pop(self, object, default=_none): + if default is self._none: + value = weakref.WeakKeyDictionary.pop(self, object) + else: + value = weakref.WeakKeyDictionary.pop(self, object, default) + if id(object) in self.by_id: + del self._weakrefs[id(object)] + del self.by_id[id(object)] + return value + + def popitem(self): + item = weakref.WeakKeyDictionary.popitem(self) + oid = id(item[0]) + del self._weakrefs[oid] + del self.by_id[oid] + return item + + def clear(self): + # Py2K + # in 3k, MutableMapping calls popitem() + self._weakrefs.clear() + self.by_id.clear() + # end Py2K + weakref.WeakKeyDictionary.clear(self) + + def update(self, *a, **kw): + raise NotImplementedError + + def _cleanup(self, wr, key=None): + if key is None: + key = wr.key + try: + del self._weakrefs[key] + except (KeyError, AttributeError): # pragma: no cover + pass # pragma: no cover + try: + del self.by_id[key] + except (KeyError, AttributeError): # pragma: no cover + pass # pragma: no cover + + class _keyed_weakref(weakref.ref): + def __init__(self, object, callback): + weakref.ref.__init__(self, object, callback) + self.key = id(object) + + def _ref(self, object): + return self._keyed_weakref(object, self._cleanup) + + +def warn(msg, stacklevel=3): + if isinstance(msg, basestring): + warnings.warn(msg, exc.SAWarning, stacklevel=stacklevel) + else: + warnings.warn(msg, stacklevel=stacklevel) + +def warn_deprecated(msg, stacklevel=3): + warnings.warn(msg, exc.SADeprecationWarning, stacklevel=stacklevel) + +def warn_pending_deprecation(msg, stacklevel=3): + warnings.warn(msg, exc.SAPendingDeprecationWarning, stacklevel=stacklevel) + +def deprecated(message=None, add_deprecation_to_docstring=True): + """Decorates a function and issues a deprecation warning on use. + + message + If provided, issue message in the warning. A sensible default + is used if not provided. + + add_deprecation_to_docstring + Default True. If False, the wrapped function's __doc__ is left + as-is. If True, the 'message' is prepended to the docs if + provided, or sensible default if message is omitted. + """ + + if add_deprecation_to_docstring: + header = message is not None and message or 'Deprecated.' + else: + header = None + + if message is None: + message = "Call to deprecated function %(func)s" + + def decorate(fn): + return _decorate_with_warning( + fn, exc.SADeprecationWarning, + message % dict(func=fn.__name__), header) + return decorate + +def pending_deprecation(version, message=None, + add_deprecation_to_docstring=True): + """Decorates a function and issues a pending deprecation warning on use. + + version + An approximate future version at which point the pending deprecation + will become deprecated. Not used in messaging. + + message + If provided, issue message in the warning. A sensible default + is used if not provided. + + add_deprecation_to_docstring + Default True. If False, the wrapped function's __doc__ is left + as-is. If True, the 'message' is prepended to the docs if + provided, or sensible default if message is omitted. + """ + + if add_deprecation_to_docstring: + header = message is not None and message or 'Deprecated.' + else: + header = None + + if message is None: + message = "Call to deprecated function %(func)s" + + def decorate(fn): + return _decorate_with_warning( + fn, exc.SAPendingDeprecationWarning, + message % dict(func=fn.__name__), header) + return decorate + +def _decorate_with_warning(func, wtype, message, docstring_header=None): + """Wrap a function with a warnings.warn and augmented docstring.""" + + @decorator + def warned(fn, *args, **kwargs): + warnings.warn(wtype(message), stacklevel=3) + return fn(*args, **kwargs) + + doc = func.__doc__ is not None and func.__doc__ or '' + if docstring_header is not None: + docstring_header %= dict(func=func.__name__) + docs = doc and doc.expandtabs().split('\n') or [] + indent = '' + for line in docs[1:]: + text = line.lstrip() + if text: + indent = line[0:len(line) - len(text)] + break + point = min(len(docs), 1) + docs.insert(point, '\n' + indent + docstring_header.rstrip()) + doc = '\n'.join(docs) + + decorated = warned(func) + decorated.__doc__ = doc + return decorated + +class classproperty(property): + """A decorator that behaves like @property except that operates + on classes rather than instances. + + This is helpful when you need to compute __table_args__ and/or + __mapper_args__ when using declarative.""" + def __get__(desc, self, cls): + return desc.fget(cls) +