This commit is contained in:
Øyvind Almelid 2010-05-07 17:33:49 +00:00
parent c7c7498f19
commit cae7e001e3
127 changed files with 57530 additions and 0 deletions

View File

@ -0,0 +1,45 @@
Metadata-Version: 1.0
Name: SQLAlchemy
Version: 0.6beta3
Summary: Database Abstraction Library
Home-page: http://www.sqlalchemy.org
Author: Mike Bayer
Author-email: mike_mp@zzzcomputing.com
License: MIT License
Description: SQLAlchemy is:
* The Python SQL toolkit and Object Relational Mapper that gives application developers the full power and flexibility of SQL. SQLAlchemy provides a full suite of well known enterprise-level persistence patterns, designed for efficient and high-performing database access, adapted into a simple and Pythonic domain language.
* extremely easy to use for all the basic tasks, such as: accessing pooled connections, constructing SQL from Python expressions, finding object instances, and commiting object modifications back to the database.
* powerful enough for complicated tasks, such as: eager load a graph of objects and their dependencies via joins; map recursive adjacency structures automatically; map objects to not just tables but to any arbitrary join or select statement; combine multiple tables together to load whole sets of otherwise unrelated objects from a single result set; commit entire graphs of object changes in one step.
* built to conform to what DBAs demand, including the ability to swap out generated SQL with hand-optimized statements, full usage of bind parameters for all literal values, fully transactionalized and consistent updates using Unit of Work.
* modular. Different parts of SQLAlchemy can be used independently of the rest, including the connection pool, SQL construction, and ORM. SQLAlchemy is constructed in an open style that allows plenty of customization, with an architecture that supports custom datatypes, custom SQL extensions, and ORM plugins which can augment or extend mapping functionality.
SQLAlchemy's Philosophy:
* SQL databases behave less and less like object collections the more size and performance start to matter; object collections behave less and less like tables and rows the more abstraction starts to matter. SQLAlchemy aims to accomodate both of these principles.
* Your classes aren't tables, and your objects aren't rows. Databases aren't just collections of tables; they're relational algebra engines. You don't have to select from just tables, you can select from joins, subqueries, and unions. Database and domain concepts should be visibly decoupled from the beginning, allowing both sides to develop to their full potential.
* For example, table metadata (objects that describe tables) are declared distinctly from the classes theyre designed to store. That way database relationship concepts don't interfere with your object design concepts, and vice-versa; the transition from table-mapping to selectable-mapping is seamless; a class can be mapped against the database in more than one way. SQLAlchemy provides a powerful mapping layer that can work as automatically or as manually as you choose, determining relationships based on foreign keys or letting you define the join conditions explicitly, to bridge the gap between database and domain.
SQLAlchemy's Advantages:
* The Unit Of Work system organizes pending CRUD operations into queues and commits them all in one batch. It then performs a topological "dependency sort" of all items to be committed and deleted and groups redundant statements together. This produces the maxiumum efficiency and transaction safety, and minimizes chances of deadlocks. Modeled after Fowler's "Unit of Work" pattern as well as Java Hibernate.
* Function-based query construction allows boolean expressions, operators, functions, table aliases, selectable subqueries, create/update/insert/delete queries, correlated updates, correlated EXISTS clauses, UNION clauses, inner and outer joins, bind parameters, free mixing of literal text within expressions, as little or as much as desired. Query-compilation is vendor-specific; the same query object can be compiled into any number of resulting SQL strings depending on its compilation algorithm.
* Database mapping and class design are totally separate. Persisted objects have no subclassing requirement (other than 'object') and are POPO's : plain old Python objects. They retain serializability (pickling) for usage in various caching systems and session objects. SQLAlchemy "decorates" classes with non-intrusive property accessors to automatically log object creates and modifications with the UnitOfWork engine, to lazyload related data, as well as to track attribute change histories.
* Custom list classes can be used with eagerly or lazily loaded child object lists, allowing rich relationships to be created on the fly as SQLAlchemy appends child objects to an object attribute.
* Composite (multiple-column) primary keys are supported, as are "association" objects that represent the middle of a "many-to-many" relationship.
* Self-referential tables and mappers are supported. Adjacency list structures can be created, saved, and deleted with proper cascading, with no extra programming.
* Data mapping can be used in a row-based manner. Any bizarre hyper-optimized query that you or your DBA can cook up, you can run in SQLAlchemy, and as long as it returns the expected columns within a rowset, you can get your objects from it. For a rowset that contains more than one kind of object per row, multiple mappers can be chained together to return multiple object instance lists from a single database round trip.
* The type system allows pre- and post- processing of data, both at the bind parameter and the result set level. User-defined types can be freely mixed with built-in types. Generic types as well as SQL-specific types are available.
SVN version:
<http://svn.sqlalchemy.org/sqlalchemy/trunk#egg=SQLAlchemy-dev>
Platform: UNKNOWN
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Database :: Front-Ends
Classifier: Operating System :: OS Independent

View File

@ -0,0 +1,499 @@
.hgignore
.hgtags
CHANGES
CHANGES_PRE_05
LICENSE
MANIFEST.in
README
README.py3k
README.unittests
distribute_setup.py
ez_setup.py
sa2to3.py
setup.cfg
setup.py
sqla_nose.py
doc/copyright.html
doc/dbengine.html
doc/examples.html
doc/genindex.html
doc/index.html
doc/intro.html
doc/mappers.html
doc/metadata.html
doc/ormtutorial.html
doc/search.html
doc/searchindex.js
doc/session.html
doc/sqlexpression.html
doc/_images/sqla_arch_small.jpg
doc/_sources/copyright.txt
doc/_sources/dbengine.txt
doc/_sources/examples.txt
doc/_sources/index.txt
doc/_sources/intro.txt
doc/_sources/mappers.txt
doc/_sources/metadata.txt
doc/_sources/ormtutorial.txt
doc/_sources/session.txt
doc/_sources/sqlexpression.txt
doc/_sources/reference/index.txt
doc/_sources/reference/dialects/access.txt
doc/_sources/reference/dialects/firebird.txt
doc/_sources/reference/dialects/index.txt
doc/_sources/reference/dialects/informix.txt
doc/_sources/reference/dialects/maxdb.txt
doc/_sources/reference/dialects/mssql.txt
doc/_sources/reference/dialects/mysql.txt
doc/_sources/reference/dialects/oracle.txt
doc/_sources/reference/dialects/postgresql.txt
doc/_sources/reference/dialects/sqlite.txt
doc/_sources/reference/dialects/sybase.txt
doc/_sources/reference/ext/associationproxy.txt
doc/_sources/reference/ext/compiler.txt
doc/_sources/reference/ext/declarative.txt
doc/_sources/reference/ext/horizontal_shard.txt
doc/_sources/reference/ext/index.txt
doc/_sources/reference/ext/orderinglist.txt
doc/_sources/reference/ext/serializer.txt
doc/_sources/reference/ext/sqlsoup.txt
doc/_sources/reference/orm/collections.txt
doc/_sources/reference/orm/index.txt
doc/_sources/reference/orm/interfaces.txt
doc/_sources/reference/orm/mapping.txt
doc/_sources/reference/orm/query.txt
doc/_sources/reference/orm/sessions.txt
doc/_sources/reference/orm/utilities.txt
doc/_sources/reference/sqlalchemy/connections.txt
doc/_sources/reference/sqlalchemy/expressions.txt
doc/_sources/reference/sqlalchemy/index.txt
doc/_sources/reference/sqlalchemy/interfaces.txt
doc/_sources/reference/sqlalchemy/pooling.txt
doc/_sources/reference/sqlalchemy/schema.txt
doc/_sources/reference/sqlalchemy/types.txt
doc/_static/basic.css
doc/_static/default.css
doc/_static/docs.css
doc/_static/doctools.js
doc/_static/file.png
doc/_static/init.js
doc/_static/jquery.js
doc/_static/minus.png
doc/_static/plus.png
doc/_static/pygments.css
doc/_static/searchtools.js
doc/build/Makefile
doc/build/conf.py
doc/build/copyright.rst
doc/build/dbengine.rst
doc/build/examples.rst
doc/build/index.rst
doc/build/intro.rst
doc/build/mappers.rst
doc/build/metadata.rst
doc/build/ormtutorial.rst
doc/build/session.rst
doc/build/sqla_arch_small.jpg
doc/build/sqlexpression.rst
doc/build/testdocs.py
doc/build/builder/__init__.py
doc/build/builder/builders.py
doc/build/builder/util.py
doc/build/reference/index.rst
doc/build/reference/dialects/access.rst
doc/build/reference/dialects/firebird.rst
doc/build/reference/dialects/index.rst
doc/build/reference/dialects/informix.rst
doc/build/reference/dialects/maxdb.rst
doc/build/reference/dialects/mssql.rst
doc/build/reference/dialects/mysql.rst
doc/build/reference/dialects/oracle.rst
doc/build/reference/dialects/postgresql.rst
doc/build/reference/dialects/sqlite.rst
doc/build/reference/dialects/sybase.rst
doc/build/reference/ext/associationproxy.rst
doc/build/reference/ext/compiler.rst
doc/build/reference/ext/declarative.rst
doc/build/reference/ext/horizontal_shard.rst
doc/build/reference/ext/index.rst
doc/build/reference/ext/orderinglist.rst
doc/build/reference/ext/serializer.rst
doc/build/reference/ext/sqlsoup.rst
doc/build/reference/orm/collections.rst
doc/build/reference/orm/index.rst
doc/build/reference/orm/interfaces.rst
doc/build/reference/orm/mapping.rst
doc/build/reference/orm/query.rst
doc/build/reference/orm/sessions.rst
doc/build/reference/orm/utilities.rst
doc/build/reference/sqlalchemy/connections.rst
doc/build/reference/sqlalchemy/expressions.rst
doc/build/reference/sqlalchemy/index.rst
doc/build/reference/sqlalchemy/interfaces.rst
doc/build/reference/sqlalchemy/pooling.rst
doc/build/reference/sqlalchemy/schema.rst
doc/build/reference/sqlalchemy/types.rst
doc/build/static/docs.css
doc/build/static/init.js
doc/build/templates/genindex.mako
doc/build/templates/layout.mako
doc/build/templates/page.mako
doc/build/templates/search.mako
doc/build/templates/site_base.mako
doc/build/templates/static_base.mako
doc/build/texinputs/sphinx.sty
doc/reference/index.html
doc/reference/dialects/access.html
doc/reference/dialects/firebird.html
doc/reference/dialects/index.html
doc/reference/dialects/informix.html
doc/reference/dialects/maxdb.html
doc/reference/dialects/mssql.html
doc/reference/dialects/mysql.html
doc/reference/dialects/oracle.html
doc/reference/dialects/postgresql.html
doc/reference/dialects/sqlite.html
doc/reference/dialects/sybase.html
doc/reference/ext/associationproxy.html
doc/reference/ext/compiler.html
doc/reference/ext/declarative.html
doc/reference/ext/horizontal_shard.html
doc/reference/ext/index.html
doc/reference/ext/orderinglist.html
doc/reference/ext/serializer.html
doc/reference/ext/sqlsoup.html
doc/reference/orm/collections.html
doc/reference/orm/index.html
doc/reference/orm/interfaces.html
doc/reference/orm/mapping.html
doc/reference/orm/query.html
doc/reference/orm/sessions.html
doc/reference/orm/utilities.html
doc/reference/sqlalchemy/connections.html
doc/reference/sqlalchemy/expressions.html
doc/reference/sqlalchemy/index.html
doc/reference/sqlalchemy/interfaces.html
doc/reference/sqlalchemy/pooling.html
doc/reference/sqlalchemy/schema.html
doc/reference/sqlalchemy/types.html
examples/__init__.py
examples/adjacency_list/__init__.py
examples/adjacency_list/adjacency_list.py
examples/association/__init__.py
examples/association/basic_association.py
examples/association/proxied_association.py
examples/beaker_caching/__init__.py
examples/beaker_caching/advanced.py
examples/beaker_caching/environment.py
examples/beaker_caching/fixture_data.py
examples/beaker_caching/helloworld.py
examples/beaker_caching/local_session_caching.py
examples/beaker_caching/meta.py
examples/beaker_caching/model.py
examples/beaker_caching/relation_caching.py
examples/custom_attributes/__init__.py
examples/custom_attributes/custom_management.py
examples/custom_attributes/listen_for_events.py
examples/derived_attributes/__init__.py
examples/derived_attributes/attributes.py
examples/dynamic_dict/__init__.py
examples/dynamic_dict/dynamic_dict.py
examples/elementtree/__init__.py
examples/elementtree/adjacency_list.py
examples/elementtree/optimized_al.py
examples/elementtree/pickle.py
examples/elementtree/test.xml
examples/elementtree/test2.xml
examples/elementtree/test3.xml
examples/graphs/__init__.py
examples/graphs/directed_graph.py
examples/inheritance/__init__.py
examples/inheritance/concrete.py
examples/inheritance/polymorph.py
examples/inheritance/single.py
examples/large_collection/__init__.py
examples/large_collection/large_collection.py
examples/nested_sets/__init__.py
examples/nested_sets/nested_sets.py
examples/poly_assoc/__init__.py
examples/poly_assoc/poly_assoc.py
examples/poly_assoc/poly_assoc_fk.py
examples/poly_assoc/poly_assoc_generic.py
examples/postgis/__init__.py
examples/postgis/postgis.py
examples/sharding/__init__.py
examples/sharding/attribute_shard.py
examples/versioning/__init__.py
examples/versioning/history_meta.py
examples/versioning/test_versioning.py
examples/vertical/__init__.py
examples/vertical/dictlike-polymorphic.py
examples/vertical/dictlike.py
lib/SQLAlchemy.egg-info/PKG-INFO
lib/SQLAlchemy.egg-info/SOURCES.txt
lib/SQLAlchemy.egg-info/dependency_links.txt
lib/SQLAlchemy.egg-info/entry_points.txt
lib/SQLAlchemy.egg-info/top_level.txt
lib/sqlalchemy/__init__.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/interfaces.py
lib/sqlalchemy/log.py
lib/sqlalchemy/pool.py
lib/sqlalchemy/processors.py
lib/sqlalchemy/queue.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/topological.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util.py
lib/sqlalchemy/cextension/processors.c
lib/sqlalchemy/cextension/resultproxy.c
lib/sqlalchemy/connectors/__init__.py
lib/sqlalchemy/connectors/mxodbc.py
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/connectors/zxJDBC.py
lib/sqlalchemy/databases/__init__.py
lib/sqlalchemy/dialects/__init__.py
lib/sqlalchemy/dialects/postgres.py
lib/sqlalchemy/dialects/type_migration_guidelines.txt
lib/sqlalchemy/dialects/access/__init__.py
lib/sqlalchemy/dialects/access/base.py
lib/sqlalchemy/dialects/firebird/__init__.py
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/firebird/kinterbasdb.py
lib/sqlalchemy/dialects/informix/__init__.py
lib/sqlalchemy/dialects/informix/base.py
lib/sqlalchemy/dialects/informix/informixdb.py
lib/sqlalchemy/dialects/maxdb/__init__.py
lib/sqlalchemy/dialects/maxdb/base.py
lib/sqlalchemy/dialects/maxdb/sapdb.py
lib/sqlalchemy/dialects/mssql/__init__.py
lib/sqlalchemy/dialects/mssql/adodbapi.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mssql/information_schema.py
lib/sqlalchemy/dialects/mssql/mxodbc.py
lib/sqlalchemy/dialects/mssql/pymssql.py
lib/sqlalchemy/dialects/mssql/pyodbc.py
lib/sqlalchemy/dialects/mssql/zxjdbc.py
lib/sqlalchemy/dialects/mysql/__init__.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqlconnector.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/oursql.py
lib/sqlalchemy/dialects/mysql/pyodbc.py
lib/sqlalchemy/dialects/mysql/zxjdbc.py
lib/sqlalchemy/dialects/oracle/__init__.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/oracle/zxjdbc.py
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/pg8000.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/dialects/postgresql/pypostgresql.py
lib/sqlalchemy/dialects/postgresql/zxjdbc.py
lib/sqlalchemy/dialects/sqlite/__init__.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/dialects/sqlite/pysqlite.py
lib/sqlalchemy/dialects/sybase/__init__.py
lib/sqlalchemy/dialects/sybase/base.py
lib/sqlalchemy/dialects/sybase/mxodbc.py
lib/sqlalchemy/dialects/sybase/pyodbc.py
lib/sqlalchemy/dialects/sybase/pysybase.py
lib/sqlalchemy/engine/__init__.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/ddl.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/engine/url.py
lib/sqlalchemy/ext/__init__.py
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/ext/compiler.py
lib/sqlalchemy/ext/declarative.py
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/ext/orderinglist.py
lib/sqlalchemy/ext/serializer.py
lib/sqlalchemy/ext/sqlsoup.py
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/evaluator.py
lib/sqlalchemy/orm/exc.py
lib/sqlalchemy/orm/identity.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/shard.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/orm/uowdumper.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/test/__init__.py
lib/sqlalchemy/test/assertsql.py
lib/sqlalchemy/test/config.py
lib/sqlalchemy/test/engines.py
lib/sqlalchemy/test/entities.py
lib/sqlalchemy/test/noseplugin.py
lib/sqlalchemy/test/orm.py
lib/sqlalchemy/test/pickleable.py
lib/sqlalchemy/test/profiling.py
lib/sqlalchemy/test/requires.py
lib/sqlalchemy/test/schema.py
lib/sqlalchemy/test/testing.py
lib/sqlalchemy/test/util.py
test/__init__.py
test/binary_data_one.dat
test/binary_data_two.dat
test/aaa_profiling/__init__.py
test/aaa_profiling/test_compiler.py
test/aaa_profiling/test_memusage.py
test/aaa_profiling/test_orm.py
test/aaa_profiling/test_pool.py
test/aaa_profiling/test_resultset.py
test/aaa_profiling/test_zoomark.py
test/aaa_profiling/test_zoomark_orm.py
test/base/__init__.py
test/base/test_dependency.py
test/base/test_except.py
test/base/test_utils.py
test/dialect/__init__.py
test/dialect/test_access.py
test/dialect/test_firebird.py
test/dialect/test_informix.py
test/dialect/test_maxdb.py
test/dialect/test_mssql.py
test/dialect/test_mxodbc.py
test/dialect/test_mysql.py
test/dialect/test_oracle.py
test/dialect/test_postgresql.py
test/dialect/test_sqlite.py
test/dialect/test_sybase.py
test/engine/__init__.py
test/engine/_base.py
test/engine/test_bind.py
test/engine/test_ddlevents.py
test/engine/test_execute.py
test/engine/test_metadata.py
test/engine/test_parseconnect.py
test/engine/test_pool.py
test/engine/test_reconnect.py
test/engine/test_reflection.py
test/engine/test_transaction.py
test/ex/__init__.py
test/ex/test_examples.py
test/ext/__init__.py
test/ext/test_associationproxy.py
test/ext/test_compiler.py
test/ext/test_declarative.py
test/ext/test_horizontal_shard.py
test/ext/test_orderinglist.py
test/ext/test_serializer.py
test/ext/test_sqlsoup.py
test/orm/__init__.py
test/orm/_base.py
test/orm/_fixtures.py
test/orm/test_association.py
test/orm/test_assorted_eager.py
test/orm/test_attributes.py
test/orm/test_backref_mutations.py
test/orm/test_bind.py
test/orm/test_cascade.py
test/orm/test_collection.py
test/orm/test_compile.py
test/orm/test_cycles.py
test/orm/test_defaults.py
test/orm/test_deprecations.py
test/orm/test_dynamic.py
test/orm/test_eager_relations.py
test/orm/test_evaluator.py
test/orm/test_expire.py
test/orm/test_extendedattr.py
test/orm/test_generative.py
test/orm/test_instrumentation.py
test/orm/test_lazy_relations.py
test/orm/test_lazytest1.py
test/orm/test_manytomany.py
test/orm/test_mapper.py
test/orm/test_merge.py
test/orm/test_naturalpks.py
test/orm/test_onetoone.py
test/orm/test_pickled.py
test/orm/test_query.py
test/orm/test_relationships.py
test/orm/test_scoping.py
test/orm/test_selectable.py
test/orm/test_session.py
test/orm/test_subquery_relations.py
test/orm/test_transaction.py
test/orm/test_unitofwork.py
test/orm/test_utils.py
test/orm/test_versioning.py
test/orm/inheritance/__init__.py
test/orm/inheritance/test_abc_inheritance.py
test/orm/inheritance/test_abc_polymorphic.py
test/orm/inheritance/test_basic.py
test/orm/inheritance/test_concrete.py
test/orm/inheritance/test_magazine.py
test/orm/inheritance/test_manytomany.py
test/orm/inheritance/test_poly_linked_list.py
test/orm/inheritance/test_polymorph.py
test/orm/inheritance/test_polymorph2.py
test/orm/inheritance/test_productspec.py
test/orm/inheritance/test_query.py
test/orm/inheritance/test_selects.py
test/orm/inheritance/test_single.py
test/perf/cascade_speed.py
test/perf/insertspeed.py
test/perf/masscreate.py
test/perf/masscreate2.py
test/perf/masseagerload.py
test/perf/massload.py
test/perf/massload2.py
test/perf/masssave.py
test/perf/objselectspeed.py
test/perf/objupdatespeed.py
test/perf/ormsession.py
test/perf/poolload.py
test/perf/sessions.py
test/perf/stress_all.py
test/perf/stresstest.py
test/perf/threaded_compile.py
test/perf/wsgi.py
test/sql/__init__.py
test/sql/_base.py
test/sql/test_case_statement.py
test/sql/test_columns.py
test/sql/test_compiler.py
test/sql/test_constraints.py
test/sql/test_defaults.py
test/sql/test_functions.py
test/sql/test_generative.py
test/sql/test_labels.py
test/sql/test_query.py
test/sql/test_quote.py
test/sql/test_returning.py
test/sql/test_rowcount.py
test/sql/test_selectable.py
test/sql/test_types.py
test/sql/test_unicode.py
test/zblog/__init__.py
test/zblog/blog.py
test/zblog/mappers.py
test/zblog/tables.py
test/zblog/test_zblog.py
test/zblog/user.py

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,3 @@
[nose.plugins.0.10]
sqlalchemy = sqlalchemy.test.noseplugin:NoseSQLAlchemy

View File

@ -0,0 +1 @@
sqlalchemy

119
sqlalchemy/__init__.py Normal file
View File

@ -0,0 +1,119 @@
# __init__.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import inspect
import sys
import sqlalchemy.exc as exceptions
sys.modules['sqlalchemy.exceptions'] = exceptions
from sqlalchemy.sql import (
alias,
and_,
asc,
between,
bindparam,
case,
cast,
collate,
delete,
desc,
distinct,
except_,
except_all,
exists,
extract,
func,
insert,
intersect,
intersect_all,
join,
literal,
literal_column,
modifier,
not_,
null,
or_,
outerjoin,
outparam,
select,
subquery,
text,
tuple_,
union,
union_all,
update,
)
from sqlalchemy.types import (
BLOB,
BOOLEAN,
BigInteger,
Binary,
Boolean,
CHAR,
CLOB,
DATE,
DATETIME,
DECIMAL,
Date,
DateTime,
Enum,
FLOAT,
Float,
INT,
INTEGER,
Integer,
Interval,
LargeBinary,
NCHAR,
NVARCHAR,
NUMERIC,
Numeric,
PickleType,
SMALLINT,
SmallInteger,
String,
TEXT,
TIME,
TIMESTAMP,
Text,
Time,
Unicode,
UnicodeText,
VARCHAR,
)
from sqlalchemy.schema import (
CheckConstraint,
Column,
ColumnDefault,
Constraint,
DDL,
DefaultClause,
FetchedValue,
ForeignKey,
ForeignKeyConstraint,
Index,
MetaData,
PassiveDefault,
PrimaryKeyConstraint,
Sequence,
Table,
ThreadLocalMetaData,
UniqueConstraint,
)
from sqlalchemy.engine import create_engine, engine_from_config
__all__ = sorted(name for name, obj in locals().items()
if not (name.startswith('_') or inspect.ismodule(obj)))
__version__ = '0.6beta3'
del inspect, sys

View File

@ -0,0 +1,393 @@
/*
processors.c
Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
This module is part of SQLAlchemy and is released under
the MIT License: http://www.opensource.org/licenses/mit-license.php
*/
#include <Python.h>
#include <datetime.h>
static PyObject *
int_to_boolean(PyObject *self, PyObject *arg)
{
long l = 0;
PyObject *res;
if (arg == Py_None)
Py_RETURN_NONE;
l = PyInt_AsLong(arg);
if (l == 0) {
res = Py_False;
} else if (l == 1) {
res = Py_True;
} else if ((l == -1) && PyErr_Occurred()) {
/* -1 can be either the actual value, or an error flag. */
return NULL;
} else {
PyErr_SetString(PyExc_ValueError,
"int_to_boolean only accepts None, 0 or 1");
return NULL;
}
Py_INCREF(res);
return res;
}
static PyObject *
to_str(PyObject *self, PyObject *arg)
{
if (arg == Py_None)
Py_RETURN_NONE;
return PyObject_Str(arg);
}
static PyObject *
to_float(PyObject *self, PyObject *arg)
{
if (arg == Py_None)
Py_RETURN_NONE;
return PyNumber_Float(arg);
}
static PyObject *
str_to_datetime(PyObject *self, PyObject *arg)
{
const char *str;
unsigned int year, month, day, hour, minute, second, microsecond = 0;
if (arg == Py_None)
Py_RETURN_NONE;
str = PyString_AsString(arg);
if (str == NULL)
return NULL;
/* microseconds are optional */
/*
TODO: this is slightly less picky than the Python version which would
not accept "2000-01-01 00:00:00.". I don't know which is better, but they
should be coherent.
*/
if (sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day,
&hour, &minute, &second, &microsecond) < 6) {
PyErr_SetString(PyExc_ValueError, "Couldn't parse datetime string.");
return NULL;
}
return PyDateTime_FromDateAndTime(year, month, day,
hour, minute, second, microsecond);
}
static PyObject *
str_to_time(PyObject *self, PyObject *arg)
{
const char *str;
unsigned int hour, minute, second, microsecond = 0;
if (arg == Py_None)
Py_RETURN_NONE;
str = PyString_AsString(arg);
if (str == NULL)
return NULL;
/* microseconds are optional */
/*
TODO: this is slightly less picky than the Python version which would
not accept "00:00:00.". I don't know which is better, but they should be
coherent.
*/
if (sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second,
&microsecond) < 3) {
PyErr_SetString(PyExc_ValueError, "Couldn't parse time string.");
return NULL;
}
return PyTime_FromTime(hour, minute, second, microsecond);
}
static PyObject *
str_to_date(PyObject *self, PyObject *arg)
{
const char *str;
unsigned int year, month, day;
if (arg == Py_None)
Py_RETURN_NONE;
str = PyString_AsString(arg);
if (str == NULL)
return NULL;
if (sscanf(str, "%4u-%2u-%2u", &year, &month, &day) != 3) {
PyErr_SetString(PyExc_ValueError, "Couldn't parse date string.");
return NULL;
}
return PyDate_FromDate(year, month, day);
}
/***********
* Structs *
***********/
typedef struct {
PyObject_HEAD
PyObject *encoding;
PyObject *errors;
} UnicodeResultProcessor;
typedef struct {
PyObject_HEAD
PyObject *type;
PyObject *format;
} DecimalResultProcessor;
/**************************
* UnicodeResultProcessor *
**************************/
static int
UnicodeResultProcessor_init(UnicodeResultProcessor *self, PyObject *args,
PyObject *kwds)
{
PyObject *encoding, *errors = NULL;
static char *kwlist[] = {"encoding", "errors", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:__init__", kwlist,
&encoding, &errors))
return -1;
Py_INCREF(encoding);
self->encoding = encoding;
if (errors) {
Py_INCREF(errors);
} else {
errors = PyString_FromString("strict");
if (errors == NULL)
return -1;
}
self->errors = errors;
return 0;
}
static PyObject *
UnicodeResultProcessor_process(UnicodeResultProcessor *self, PyObject *value)
{
const char *encoding, *errors;
char *str;
Py_ssize_t len;
if (value == Py_None)
Py_RETURN_NONE;
if (PyString_AsStringAndSize(value, &str, &len))
return NULL;
encoding = PyString_AS_STRING(self->encoding);
errors = PyString_AS_STRING(self->errors);
return PyUnicode_Decode(str, len, encoding, errors);
}
static PyMethodDef UnicodeResultProcessor_methods[] = {
{"process", (PyCFunction)UnicodeResultProcessor_process, METH_O,
"The value processor itself."},
{NULL} /* Sentinel */
};
static PyTypeObject UnicodeResultProcessorType = {
PyObject_HEAD_INIT(NULL)
0, /* ob_size */
"sqlalchemy.cprocessors.UnicodeResultProcessor", /* tp_name */
sizeof(UnicodeResultProcessor), /* tp_basicsize */
0, /* tp_itemsize */
0, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
"UnicodeResultProcessor objects", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
UnicodeResultProcessor_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)UnicodeResultProcessor_init, /* tp_init */
0, /* tp_alloc */
0, /* tp_new */
};
/**************************
* DecimalResultProcessor *
**************************/
static int
DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args,
PyObject *kwds)
{
PyObject *type, *format;
if (!PyArg_ParseTuple(args, "OS", &type, &format))
return -1;
Py_INCREF(type);
self->type = type;
Py_INCREF(format);
self->format = format;
return 0;
}
static PyObject *
DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value)
{
PyObject *str, *result, *args;
if (value == Py_None)
Py_RETURN_NONE;
if (PyFloat_CheckExact(value)) {
/* Decimal does not accept float values directly */
args = PyTuple_Pack(1, value);
if (args == NULL)
return NULL;
str = PyString_Format(self->format, args);
if (str == NULL)
return NULL;
result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
Py_DECREF(str);
return result;
} else {
return PyObject_CallFunctionObjArgs(self->type, value, NULL);
}
}
static PyMethodDef DecimalResultProcessor_methods[] = {
{"process", (PyCFunction)DecimalResultProcessor_process, METH_O,
"The value processor itself."},
{NULL} /* Sentinel */
};
static PyTypeObject DecimalResultProcessorType = {
PyObject_HEAD_INIT(NULL)
0, /* ob_size */
"sqlalchemy.DecimalResultProcessor", /* tp_name */
sizeof(DecimalResultProcessor), /* tp_basicsize */
0, /* tp_itemsize */
0, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
"DecimalResultProcessor objects", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
DecimalResultProcessor_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)DecimalResultProcessor_init, /* tp_init */
0, /* tp_alloc */
0, /* tp_new */
};
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
static PyMethodDef module_methods[] = {
{"int_to_boolean", int_to_boolean, METH_O,
"Convert an integer to a boolean."},
{"to_str", to_str, METH_O,
"Convert any value to its string representation."},
{"to_float", to_float, METH_O,
"Convert any value to its floating point representation."},
{"str_to_datetime", str_to_datetime, METH_O,
"Convert an ISO string to a datetime.datetime object."},
{"str_to_time", str_to_time, METH_O,
"Convert an ISO string to a datetime.time object."},
{"str_to_date", str_to_date, METH_O,
"Convert an ISO string to a datetime.date object."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
PyMODINIT_FUNC
initcprocessors(void)
{
PyObject *m;
UnicodeResultProcessorType.tp_new = PyType_GenericNew;
if (PyType_Ready(&UnicodeResultProcessorType) < 0)
return;
DecimalResultProcessorType.tp_new = PyType_GenericNew;
if (PyType_Ready(&DecimalResultProcessorType) < 0)
return;
m = Py_InitModule3("cprocessors", module_methods,
"Module containing C versions of data processing functions.");
if (m == NULL)
return;
PyDateTime_IMPORT;
Py_INCREF(&UnicodeResultProcessorType);
PyModule_AddObject(m, "UnicodeResultProcessor",
(PyObject *)&UnicodeResultProcessorType);
Py_INCREF(&DecimalResultProcessorType);
PyModule_AddObject(m, "DecimalResultProcessor",
(PyObject *)&DecimalResultProcessorType);
}

View File

@ -0,0 +1,586 @@
/*
resultproxy.c
Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
This module is part of SQLAlchemy and is released under
the MIT License: http://www.opensource.org/licenses/mit-license.php
*/
#include <Python.h>
/***********
* Structs *
***********/
typedef struct {
PyObject_HEAD
PyObject *parent;
PyObject *row;
PyObject *processors;
PyObject *keymap;
} BaseRowProxy;
/****************
* BaseRowProxy *
****************/
static PyObject *
safe_rowproxy_reconstructor(PyObject *self, PyObject *args)
{
PyObject *cls, *state, *tmp;
BaseRowProxy *obj;
if (!PyArg_ParseTuple(args, "OO", &cls, &state))
return NULL;
obj = (BaseRowProxy *)PyObject_CallMethod(cls, "__new__", "O", cls);
if (obj == NULL)
return NULL;
tmp = PyObject_CallMethod((PyObject *)obj, "__setstate__", "O", state);
if (tmp == NULL) {
Py_DECREF(obj);
return NULL;
}
Py_DECREF(tmp);
if (obj->parent == NULL || obj->row == NULL ||
obj->processors == NULL || obj->keymap == NULL) {
PyErr_SetString(PyExc_RuntimeError,
"__setstate__ for BaseRowProxy subclasses must set values "
"for parent, row, processors and keymap");
Py_DECREF(obj);
return NULL;
}
return (PyObject *)obj;
}
static int
BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds)
{
PyObject *parent, *row, *processors, *keymap;
if (!PyArg_UnpackTuple(args, "BaseRowProxy", 4, 4,
&parent, &row, &processors, &keymap))
return -1;
Py_INCREF(parent);
self->parent = parent;
if (!PyTuple_CheckExact(row)) {
PyErr_SetString(PyExc_TypeError, "row must be a tuple");
return -1;
}
Py_INCREF(row);
self->row = row;
if (!PyList_CheckExact(processors)) {
PyErr_SetString(PyExc_TypeError, "processors must be a list");
return -1;
}
Py_INCREF(processors);
self->processors = processors;
if (!PyDict_CheckExact(keymap)) {
PyErr_SetString(PyExc_TypeError, "keymap must be a dict");
return -1;
}
Py_INCREF(keymap);
self->keymap = keymap;
return 0;
}
/* We need the reduce method because otherwise the default implementation
* does very weird stuff for pickle protocol 0 and 1. It calls
* BaseRowProxy.__new__(RowProxy_instance) upon *pickling*.
*/
static PyObject *
BaseRowProxy_reduce(PyObject *self)
{
PyObject *method, *state;
PyObject *module, *reconstructor, *cls;
method = PyObject_GetAttrString(self, "__getstate__");
if (method == NULL)
return NULL;
state = PyObject_CallObject(method, NULL);
Py_DECREF(method);
if (state == NULL)
return NULL;
module = PyImport_ImportModule("sqlalchemy.engine.base");
if (module == NULL)
return NULL;
reconstructor = PyObject_GetAttrString(module, "rowproxy_reconstructor");
Py_DECREF(module);
if (reconstructor == NULL) {
Py_DECREF(state);
return NULL;
}
cls = PyObject_GetAttrString(self, "__class__");
if (cls == NULL) {
Py_DECREF(reconstructor);
Py_DECREF(state);
return NULL;
}
return Py_BuildValue("(N(NN))", reconstructor, cls, state);
}
static void
BaseRowProxy_dealloc(BaseRowProxy *self)
{
Py_XDECREF(self->parent);
Py_XDECREF(self->row);
Py_XDECREF(self->processors);
Py_XDECREF(self->keymap);
self->ob_type->tp_free((PyObject *)self);
}
static PyObject *
BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
{
Py_ssize_t num_values, num_processors;
PyObject **valueptr, **funcptr, **resultptr;
PyObject *func, *result, *processed_value;
num_values = Py_SIZE(values);
num_processors = Py_SIZE(processors);
if (num_values != num_processors) {
PyErr_SetString(PyExc_RuntimeError,
"number of values in row differ from number of column processors");
return NULL;
}
if (astuple) {
result = PyTuple_New(num_values);
} else {
result = PyList_New(num_values);
}
if (result == NULL)
return NULL;
/* we don't need to use PySequence_Fast as long as values, processors and
* result are simple tuple or lists. */
valueptr = PySequence_Fast_ITEMS(values);
funcptr = PySequence_Fast_ITEMS(processors);
resultptr = PySequence_Fast_ITEMS(result);
while (--num_values >= 0) {
func = *funcptr;
if (func != Py_None) {
processed_value = PyObject_CallFunctionObjArgs(func, *valueptr,
NULL);
if (processed_value == NULL) {
Py_DECREF(result);
return NULL;
}
*resultptr = processed_value;
} else {
Py_INCREF(*valueptr);
*resultptr = *valueptr;
}
valueptr++;
funcptr++;
resultptr++;
}
return result;
}
static PyListObject *
BaseRowProxy_values(BaseRowProxy *self)
{
return (PyListObject *)BaseRowProxy_processvalues(self->row,
self->processors, 0);
}
static PyTupleObject *
BaseRowProxy_tuplevalues(BaseRowProxy *self)
{
return (PyTupleObject *)BaseRowProxy_processvalues(self->row,
self->processors, 1);
}
static PyObject *
BaseRowProxy_iter(BaseRowProxy *self)
{
PyObject *values, *result;
values = (PyObject *)BaseRowProxy_tuplevalues(self);
if (values == NULL)
return NULL;
result = PyObject_GetIter(values);
Py_DECREF(values);
if (result == NULL)
return NULL;
return result;
}
static Py_ssize_t
BaseRowProxy_length(BaseRowProxy *self)
{
return Py_SIZE(self->row);
}
static PyObject *
BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
{
PyObject *processors, *values;
PyObject *processor, *value;
PyObject *record, *result, *indexobject;
PyObject *exc_module, *exception;
char *cstr_key;
long index;
if (PyInt_CheckExact(key)) {
index = PyInt_AS_LONG(key);
} else if (PyLong_CheckExact(key)) {
index = PyLong_AsLong(key);
if ((index == -1) && PyErr_Occurred())
/* -1 can be either the actual value, or an error flag. */
return NULL;
} else if (PySlice_Check(key)) {
values = PyObject_GetItem(self->row, key);
if (values == NULL)
return NULL;
processors = PyObject_GetItem(self->processors, key);
if (processors == NULL) {
Py_DECREF(values);
return NULL;
}
result = BaseRowProxy_processvalues(values, processors, 1);
Py_DECREF(values);
Py_DECREF(processors);
return result;
} else {
record = PyDict_GetItem((PyObject *)self->keymap, key);
if (record == NULL) {
record = PyObject_CallMethod(self->parent, "_key_fallback",
"O", key);
if (record == NULL)
return NULL;
}
indexobject = PyTuple_GetItem(record, 1);
if (indexobject == NULL)
return NULL;
if (indexobject == Py_None) {
exc_module = PyImport_ImportModule("sqlalchemy.exc");
if (exc_module == NULL)
return NULL;
exception = PyObject_GetAttrString(exc_module,
"InvalidRequestError");
Py_DECREF(exc_module);
if (exception == NULL)
return NULL;
cstr_key = PyString_AsString(key);
if (cstr_key == NULL)
return NULL;
PyErr_Format(exception,
"Ambiguous column name '%s' in result set! "
"try 'use_labels' option on select statement.", cstr_key);
return NULL;
}
index = PyInt_AsLong(indexobject);
if ((index == -1) && PyErr_Occurred())
/* -1 can be either the actual value, or an error flag. */
return NULL;
}
processor = PyList_GetItem(self->processors, index);
if (processor == NULL)
return NULL;
value = PyTuple_GetItem(self->row, index);
if (value == NULL)
return NULL;
if (processor != Py_None) {
return PyObject_CallFunctionObjArgs(processor, value, NULL);
} else {
Py_INCREF(value);
return value;
}
}
static PyObject *
BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
{
PyObject *tmp;
if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
else
return tmp;
return BaseRowProxy_subscript(self, name);
}
/***********************
* getters and setters *
***********************/
static PyObject *
BaseRowProxy_getparent(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->parent);
return self->parent;
}
static int
BaseRowProxy_setparent(BaseRowProxy *self, PyObject *value, void *closure)
{
PyObject *module, *cls;
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'parent' attribute");
return -1;
}
module = PyImport_ImportModule("sqlalchemy.engine.base");
if (module == NULL)
return -1;
cls = PyObject_GetAttrString(module, "ResultMetaData");
Py_DECREF(module);
if (cls == NULL)
return -1;
if (PyObject_IsInstance(value, cls) != 1) {
PyErr_SetString(PyExc_TypeError,
"The 'parent' attribute value must be an instance of "
"ResultMetaData");
return -1;
}
Py_DECREF(cls);
Py_XDECREF(self->parent);
Py_INCREF(value);
self->parent = value;
return 0;
}
static PyObject *
BaseRowProxy_getrow(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->row);
return self->row;
}
static int
BaseRowProxy_setrow(BaseRowProxy *self, PyObject *value, void *closure)
{
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'row' attribute");
return -1;
}
if (!PyTuple_CheckExact(value)) {
PyErr_SetString(PyExc_TypeError,
"The 'row' attribute value must be a tuple");
return -1;
}
Py_XDECREF(self->row);
Py_INCREF(value);
self->row = value;
return 0;
}
static PyObject *
BaseRowProxy_getprocessors(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->processors);
return self->processors;
}
static int
BaseRowProxy_setprocessors(BaseRowProxy *self, PyObject *value, void *closure)
{
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'processors' attribute");
return -1;
}
if (!PyList_CheckExact(value)) {
PyErr_SetString(PyExc_TypeError,
"The 'processors' attribute value must be a list");
return -1;
}
Py_XDECREF(self->processors);
Py_INCREF(value);
self->processors = value;
return 0;
}
static PyObject *
BaseRowProxy_getkeymap(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->keymap);
return self->keymap;
}
static int
BaseRowProxy_setkeymap(BaseRowProxy *self, PyObject *value, void *closure)
{
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'keymap' attribute");
return -1;
}
if (!PyDict_CheckExact(value)) {
PyErr_SetString(PyExc_TypeError,
"The 'keymap' attribute value must be a dict");
return -1;
}
Py_XDECREF(self->keymap);
Py_INCREF(value);
self->keymap = value;
return 0;
}
static PyGetSetDef BaseRowProxy_getseters[] = {
{"_parent",
(getter)BaseRowProxy_getparent, (setter)BaseRowProxy_setparent,
"ResultMetaData",
NULL},
{"_row",
(getter)BaseRowProxy_getrow, (setter)BaseRowProxy_setrow,
"Original row tuple",
NULL},
{"_processors",
(getter)BaseRowProxy_getprocessors, (setter)BaseRowProxy_setprocessors,
"list of type processors",
NULL},
{"_keymap",
(getter)BaseRowProxy_getkeymap, (setter)BaseRowProxy_setkeymap,
"Key to (processor, index) dict",
NULL},
{NULL}
};
static PyMethodDef BaseRowProxy_methods[] = {
{"values", (PyCFunction)BaseRowProxy_values, METH_NOARGS,
"Return the values represented by this BaseRowProxy as a list."},
{"__reduce__", (PyCFunction)BaseRowProxy_reduce, METH_NOARGS,
"Pickle support method."},
{NULL} /* Sentinel */
};
static PySequenceMethods BaseRowProxy_as_sequence = {
(lenfunc)BaseRowProxy_length, /* sq_length */
0, /* sq_concat */
0, /* sq_repeat */
0, /* sq_item */
0, /* sq_slice */
0, /* sq_ass_item */
0, /* sq_ass_slice */
0, /* sq_contains */
0, /* sq_inplace_concat */
0, /* sq_inplace_repeat */
};
static PyMappingMethods BaseRowProxy_as_mapping = {
(lenfunc)BaseRowProxy_length, /* mp_length */
(binaryfunc)BaseRowProxy_subscript, /* mp_subscript */
0 /* mp_ass_subscript */
};
static PyTypeObject BaseRowProxyType = {
PyObject_HEAD_INIT(NULL)
0, /* ob_size */
"sqlalchemy.cresultproxy.BaseRowProxy", /* tp_name */
sizeof(BaseRowProxy), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)BaseRowProxy_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
&BaseRowProxy_as_sequence, /* tp_as_sequence */
&BaseRowProxy_as_mapping, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
(getattrofunc)BaseRowProxy_getattro,/* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
"BaseRowProxy is a abstract base class for RowProxy", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
(getiterfunc)BaseRowProxy_iter, /* tp_iter */
0, /* tp_iternext */
BaseRowProxy_methods, /* tp_methods */
0, /* tp_members */
BaseRowProxy_getseters, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)BaseRowProxy_init, /* tp_init */
0, /* tp_alloc */
0 /* tp_new */
};
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
static PyMethodDef module_methods[] = {
{"safe_rowproxy_reconstructor", safe_rowproxy_reconstructor, METH_VARARGS,
"reconstruct a RowProxy instance from its pickled form."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
PyMODINIT_FUNC
initcresultproxy(void)
{
PyObject *m;
BaseRowProxyType.tp_new = PyType_GenericNew;
if (PyType_Ready(&BaseRowProxyType) < 0)
return;
m = Py_InitModule3("cresultproxy", module_methods,
"Module containing C versions of core ResultProxy classes.");
if (m == NULL)
return;
Py_INCREF(&BaseRowProxyType);
PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType);
}

View File

@ -0,0 +1,6 @@
class Connector(object):
pass

View File

@ -0,0 +1,146 @@
"""
Provide an SQLALchemy connector for the eGenix mxODBC commercial
Python adapter for ODBC. This is not a free product, but eGenix
provides SQLAlchemy with a license for use in continuous integration
testing.
This has been tested for use with mxODBC 3.1.2 on SQL Server 2005
and 2008, using the SQL Server Native driver. However, it is
possible for this to be used on other database platforms.
For more info on mxODBC, see http://www.egenix.com/
"""
import sys
import re
import warnings
from decimal import Decimal
from sqlalchemy.connectors import Connector
from sqlalchemy import types as sqltypes
import sqlalchemy.processors as processors
class MxODBCConnector(Connector):
driver='mxodbc'
supports_sane_multi_rowcount = False
supports_unicode_statements = False
supports_unicode_binds = False
supports_native_decimal = True
@classmethod
def dbapi(cls):
# this classmethod will normally be replaced by an instance
# attribute of the same name, so this is normally only called once.
cls._load_mx_exceptions()
platform = sys.platform
if platform == 'win32':
from mx.ODBC import Windows as module
# this can be the string "linux2", and possibly others
elif 'linux' in platform:
from mx.ODBC import unixODBC as module
elif platform == 'darwin':
from mx.ODBC import iODBC as module
else:
raise ImportError, "Unrecognized platform for mxODBC import"
return module
@classmethod
def _load_mx_exceptions(cls):
""" Import mxODBC exception classes into the module namespace,
as if they had been imported normally. This is done here
to avoid requiring all SQLAlchemy users to install mxODBC.
"""
global InterfaceError, ProgrammingError
from mx.ODBC import InterfaceError
from mx.ODBC import ProgrammingError
def on_connect(self):
def connect(conn):
conn.stringformat = self.dbapi.MIXED_STRINGFORMAT
conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT
conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT
conn.errorhandler = self._error_handler()
return connect
def _error_handler(self):
""" Return a handler that adjusts mxODBC's raised Warnings to
emit Python standard warnings.
"""
from mx.ODBC.Error import Warning as MxOdbcWarning
def error_handler(connection, cursor, errorclass, errorvalue):
if issubclass(errorclass, MxOdbcWarning):
errorclass.__bases__ = (Warning,)
warnings.warn(message=str(errorvalue),
category=errorclass,
stacklevel=2)
else:
raise errorclass, errorvalue
return error_handler
def create_connect_args(self, url):
""" Return a tuple of *args,**kwargs for creating a connection.
The mxODBC 3.x connection constructor looks like this:
connect(dsn, user='', password='',
clear_auto_commit=1, errorhandler=None)
This method translates the values in the provided uri
into args and kwargs needed to instantiate an mxODBC Connection.
The arg 'errorhandler' is not used by SQLAlchemy and will
not be populated.
"""
opts = url.translate_connect_args(username='user')
opts.update(url.query)
args = opts.pop('host')
opts.pop('port', None)
opts.pop('database', None)
return (args,), opts
def is_disconnect(self, e):
# eGenix recommends checking connection.closed here,
# but how can we get a handle on the current connection?
if isinstance(e, self.dbapi.ProgrammingError):
return "connection already closed" in str(e)
elif isinstance(e, self.dbapi.Error):
return '[08S01]' in str(e)
else:
return False
def _get_server_version_info(self, connection):
# eGenix suggests using conn.dbms_version instead of what we're doing here
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
# 18 == pyodbc.SQL_DBMS_VER
for n in r.split(dbapi_con.getinfo(18)[1]):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
def do_execute(self, cursor, statement, parameters, context=None):
if context:
native_odbc_execute = context.execution_options.\
get('native_odbc_execute', 'auto')
if native_odbc_execute is True:
# user specified native_odbc_execute=True
cursor.execute(statement, parameters)
elif native_odbc_execute is False:
# user specified native_odbc_execute=False
cursor.executedirect(statement, parameters)
elif context.is_crud:
# statement is UPDATE, DELETE, INSERT
cursor.execute(statement, parameters)
else:
# all other statements
cursor.executedirect(statement, parameters)
else:
cursor.executedirect(statement, parameters)

View File

@ -0,0 +1,113 @@
from sqlalchemy.connectors import Connector
from sqlalchemy.util import asbool
import sys
import re
import urllib
import decimal
class PyODBCConnector(Connector):
driver='pyodbc'
supports_sane_multi_rowcount = False
# PyODBC unicode is broken on UCS-4 builds
supports_unicode = sys.maxunicode == 65535
supports_unicode_statements = supports_unicode
supports_native_decimal = True
default_paramstyle = 'named'
# for non-DSN connections, this should
# hold the desired driver name
pyodbc_driver_name = None
# will be set to True after initialize()
# if the freetds.so is detected
freetds = False
@classmethod
def dbapi(cls):
return __import__('pyodbc')
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
keys = opts
query = url.query
connect_args = {}
for param in ('ansi', 'unicode_results', 'autocommit'):
if param in keys:
connect_args[param] = asbool(keys.pop(param))
if 'odbc_connect' in keys:
connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
else:
dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
if dsn_connection:
connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
else:
port = ''
if 'port' in keys and not 'port' in query:
port = ',%d' % int(keys.pop('port'))
connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name),
'Server=%s%s' % (keys.pop('host', ''), port),
'Database=%s' % keys.pop('database', '') ]
user = keys.pop("user", None)
if user:
connectors.append("UID=%s" % user)
connectors.append("PWD=%s" % keys.pop('password', ''))
else:
connectors.append("Trusted_Connection=Yes")
# if set to 'Yes', the ODBC layer will try to automagically convert
# textual data from your database encoding to your client encoding
# This should obviously be set to 'No' if you query a cp1253 encoded
# database from a latin1 client...
if 'odbc_autotranslate' in keys:
connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
return [[";".join (connectors)], connect_args]
def is_disconnect(self, e):
if isinstance(e, self.dbapi.ProgrammingError):
return "The cursor's connection has been closed." in str(e) or \
'Attempt to use a closed connection.' in str(e)
elif isinstance(e, self.dbapi.Error):
return '[08S01]' in str(e)
else:
return False
def initialize(self, connection):
# determine FreeTDS first. can't issue SQL easily
# without getting unicode_statements/binds set up.
pyodbc = self.dbapi
dbapi_con = connection.connection
self.freetds = bool(re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)))
# the "Py2K only" part here is theoretical.
# have not tried pyodbc + python3.1 yet.
# Py2K
self.supports_unicode_statements = not self.freetds
self.supports_unicode_binds = not self.freetds
# end Py2K
# run other initialization which asks for user name, etc.
super(PyODBCConnector, self).initialize(connection)
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)

View File

@ -0,0 +1,48 @@
import sys
from sqlalchemy.connectors import Connector
class ZxJDBCConnector(Connector):
driver = 'zxjdbc'
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
supports_unicode_binds = True
supports_unicode_statements = sys.version > '2.5.0+'
description_encoding = None
default_paramstyle = 'qmark'
jdbc_db_name = None
jdbc_driver_name = None
@classmethod
def dbapi(cls):
from com.ziclix.python.sql import zxJDBC
return zxJDBC
def _driver_kwargs(self):
"""Return kw arg dict to be sent to connect()."""
return {}
def _create_jdbc_url(self, url):
"""Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`"""
return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host,
url.port is not None and ':%s' % url.port or '',
url.database)
def create_connect_args(self, url):
opts = self._driver_kwargs()
opts.update(url.query)
return [[self._create_jdbc_url(url), url.username, url.password, self.jdbc_driver_name],
opts]
def is_disconnect(self, e):
if not isinstance(e, self.dbapi.ProgrammingError):
return False
e = str(e)
return 'connection is closed' in e or 'cursor is closed' in e
def _get_server_version_info(self, connection):
# use connection.connection.dbversion, and parse appropriately
# to get a tuple
raise NotImplementedError()

View File

@ -0,0 +1,31 @@
# __init__.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.sqlite import base as sqlite
from sqlalchemy.dialects.postgresql import base as postgresql
postgres = postgresql
from sqlalchemy.dialects.mysql import base as mysql
from sqlalchemy.dialects.oracle import base as oracle
from sqlalchemy.dialects.firebird import base as firebird
from sqlalchemy.dialects.maxdb import base as maxdb
from sqlalchemy.dialects.informix import base as informix
from sqlalchemy.dialects.mssql import base as mssql
from sqlalchemy.dialects.access import base as access
from sqlalchemy.dialects.sybase import base as sybase
__all__ = (
'access',
'firebird',
'informix',
'maxdb',
'mssql',
'mysql',
'postgresql',
'sqlite',
'oracle',
'sybase',
)

View File

@ -0,0 +1,12 @@
__all__ = (
# 'access',
# 'firebird',
# 'informix',
# 'maxdb',
# 'mssql',
'mysql',
'oracle',
'postgresql',
'sqlite',
# 'sybase',
)

View File

View File

@ -0,0 +1,418 @@
# access.py
# Copyright (C) 2007 Paul Johnston, paj@pajhome.org.uk
# Portions derived from jet2sql.py by Matt Keranen, mksql@yahoo.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
Support for the Microsoft Access database.
This dialect is *not* ported to SQLAlchemy 0.6.
This dialect is *not* tested on SQLAlchemy 0.6.
"""
from sqlalchemy import sql, schema, types, exc, pool
from sqlalchemy.sql import compiler, expression
from sqlalchemy.engine import default, base, reflection
from sqlalchemy import processors
class AcNumeric(types.Numeric):
def get_col_spec(self):
return "NUMERIC"
def bind_processor(self, dialect):
return processors.to_str
def result_processor(self, dialect, coltype):
return None
class AcFloat(types.Float):
def get_col_spec(self):
return "FLOAT"
def bind_processor(self, dialect):
"""By converting to string, we can use Decimal types round-trip."""
return processors.to_str
class AcInteger(types.Integer):
def get_col_spec(self):
return "INTEGER"
class AcTinyInteger(types.Integer):
def get_col_spec(self):
return "TINYINT"
class AcSmallInteger(types.SmallInteger):
def get_col_spec(self):
return "SMALLINT"
class AcDateTime(types.DateTime):
def __init__(self, *a, **kw):
super(AcDateTime, self).__init__(False)
def get_col_spec(self):
return "DATETIME"
class AcDate(types.Date):
def __init__(self, *a, **kw):
super(AcDate, self).__init__(False)
def get_col_spec(self):
return "DATETIME"
class AcText(types.Text):
def get_col_spec(self):
return "MEMO"
class AcString(types.String):
def get_col_spec(self):
return "TEXT" + (self.length and ("(%d)" % self.length) or "")
class AcUnicode(types.Unicode):
def get_col_spec(self):
return "TEXT" + (self.length and ("(%d)" % self.length) or "")
def bind_processor(self, dialect):
return None
def result_processor(self, dialect, coltype):
return None
class AcChar(types.CHAR):
def get_col_spec(self):
return "TEXT" + (self.length and ("(%d)" % self.length) or "")
class AcBinary(types.LargeBinary):
def get_col_spec(self):
return "BINARY"
class AcBoolean(types.Boolean):
def get_col_spec(self):
return "YESNO"
class AcTimeStamp(types.TIMESTAMP):
def get_col_spec(self):
return "TIMESTAMP"
class AccessExecutionContext(default.DefaultExecutionContext):
def _has_implicit_sequence(self, column):
if column.primary_key and column.autoincrement:
if isinstance(column.type, types.Integer) and not column.foreign_keys:
if column.default is None or (isinstance(column.default, schema.Sequence) and \
column.default.optional):
return True
return False
def post_exec(self):
"""If we inserted into a row with a COUNTER column, fetch the ID"""
if self.compiled.isinsert:
tbl = self.compiled.statement.table
if not hasattr(tbl, 'has_sequence'):
tbl.has_sequence = None
for column in tbl.c:
if getattr(column, 'sequence', False) or self._has_implicit_sequence(column):
tbl.has_sequence = column
break
if bool(tbl.has_sequence):
# TBD: for some reason _last_inserted_ids doesn't exist here
# (but it does at corresponding point in mssql???)
#if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
self.cursor.execute("SELECT @@identity AS lastrowid")
row = self.cursor.fetchone()
self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:]
# print "LAST ROW ID", self._last_inserted_ids
super(AccessExecutionContext, self).post_exec()
const, daoEngine = None, None
class AccessDialect(default.DefaultDialect):
colspecs = {
types.Unicode : AcUnicode,
types.Integer : AcInteger,
types.SmallInteger: AcSmallInteger,
types.Numeric : AcNumeric,
types.Float : AcFloat,
types.DateTime : AcDateTime,
types.Date : AcDate,
types.String : AcString,
types.LargeBinary : AcBinary,
types.Boolean : AcBoolean,
types.Text : AcText,
types.CHAR: AcChar,
types.TIMESTAMP: AcTimeStamp,
}
name = 'access'
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
ported_sqla_06 = False
def type_descriptor(self, typeobj):
newobj = types.adapt_type(typeobj, self.colspecs)
return newobj
def __init__(self, **params):
super(AccessDialect, self).__init__(**params)
self.text_as_varchar = False
self._dtbs = None
def dbapi(cls):
import win32com.client, pythoncom
global const, daoEngine
if const is None:
const = win32com.client.constants
for suffix in (".36", ".35", ".30"):
try:
daoEngine = win32com.client.gencache.EnsureDispatch("DAO.DBEngine" + suffix)
break
except pythoncom.com_error:
pass
else:
raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.")
import pyodbc as module
return module
dbapi = classmethod(dbapi)
def create_connect_args(self, url):
opts = url.translate_connect_args()
connectors = ["Driver={Microsoft Access Driver (*.mdb)}"]
connectors.append("Dbq=%s" % opts["database"])
user = opts.get("username", None)
if user:
connectors.append("UID=%s" % user)
connectors.append("PWD=%s" % opts.get("password", ""))
return [[";".join(connectors)], {}]
def last_inserted_ids(self):
return self.context.last_inserted_ids
def do_execute(self, cursor, statement, params, **kwargs):
if params == {}:
params = ()
super(AccessDialect, self).do_execute(cursor, statement, params, **kwargs)
def _execute(self, c, statement, parameters):
try:
if parameters == {}:
parameters = ()
c.execute(statement, parameters)
self.context.rowcount = c.rowcount
except Exception, e:
raise exc.DBAPIError.instance(statement, parameters, e)
def has_table(self, connection, tablename, schema=None):
# This approach seems to be more reliable that using DAO
try:
connection.execute('select top 1 * from [%s]' % tablename)
return True
except Exception, e:
return False
def reflecttable(self, connection, table, include_columns):
# This is defined in the function, as it relies on win32com constants,
# that aren't imported until dbapi method is called
if not hasattr(self, 'ischema_names'):
self.ischema_names = {
const.dbByte: AcBinary,
const.dbInteger: AcInteger,
const.dbLong: AcInteger,
const.dbSingle: AcFloat,
const.dbDouble: AcFloat,
const.dbDate: AcDateTime,
const.dbLongBinary: AcBinary,
const.dbMemo: AcText,
const.dbBoolean: AcBoolean,
const.dbText: AcUnicode, # All Access strings are unicode
const.dbCurrency: AcNumeric,
}
# A fresh DAO connection is opened for each reflection
# This is necessary, so we get the latest updates
dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
try:
for tbl in dtbs.TableDefs:
if tbl.Name.lower() == table.name.lower():
break
else:
raise exc.NoSuchTableError(table.name)
for col in tbl.Fields:
coltype = self.ischema_names[col.Type]
if col.Type == const.dbText:
coltype = coltype(col.Size)
colargs = \
{
'nullable': not(col.Required or col.Attributes & const.dbAutoIncrField),
}
default = col.DefaultValue
if col.Attributes & const.dbAutoIncrField:
colargs['default'] = schema.Sequence(col.Name + '_seq')
elif default:
if col.Type == const.dbBoolean:
default = default == 'Yes' and '1' or '0'
colargs['server_default'] = schema.DefaultClause(sql.text(default))
table.append_column(schema.Column(col.Name, coltype, **colargs))
# TBD: check constraints
# Find primary key columns first
for idx in tbl.Indexes:
if idx.Primary:
for col in idx.Fields:
thecol = table.c[col.Name]
table.primary_key.add(thecol)
if isinstance(thecol.type, AcInteger) and \
not (thecol.default and isinstance(thecol.default.arg, schema.Sequence)):
thecol.autoincrement = False
# Then add other indexes
for idx in tbl.Indexes:
if not idx.Primary:
if len(idx.Fields) == 1:
col = table.c[idx.Fields[0].Name]
if not col.primary_key:
col.index = True
col.unique = idx.Unique
else:
pass # TBD: multi-column indexes
for fk in dtbs.Relations:
if fk.ForeignTable != table.name:
continue
scols = [c.ForeignName for c in fk.Fields]
rcols = ['%s.%s' % (fk.Table, c.Name) for c in fk.Fields]
table.append_constraint(schema.ForeignKeyConstraint(scols, rcols, link_to_name=True))
finally:
dtbs.Close()
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
# A fresh DAO connection is opened for each reflection
# This is necessary, so we get the latest updates
dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"]
dtbs.Close()
return names
class AccessCompiler(compiler.SQLCompiler):
extract_map = compiler.SQLCompiler.extract_map.copy()
extract_map.update ({
'month': 'm',
'day': 'd',
'year': 'yyyy',
'second': 's',
'hour': 'h',
'doy': 'y',
'minute': 'n',
'quarter': 'q',
'dow': 'w',
'week': 'ww'
})
def visit_select_precolumns(self, select):
"""Access puts TOP, it's version of LIMIT here """
s = select.distinct and "DISTINCT " or ""
if select.limit:
s += "TOP %s " % (select.limit)
if select.offset:
raise exc.InvalidRequestError('Access does not support LIMIT with an offset')
return s
def limit_clause(self, select):
"""Limit in access is after the select keyword"""
return ""
def binary_operator_string(self, binary):
"""Access uses "mod" instead of "%" """
return binary.operator == '%' and 'mod' or binary.operator
def label_select_column(self, select, column, asfrom):
if isinstance(column, expression.Function):
return column.label()
else:
return super(AccessCompiler, self).label_select_column(select, column, asfrom)
function_rewrites = {'current_date': 'now',
'current_timestamp': 'now',
'length': 'len',
}
def visit_function(self, func):
"""Access function names differ from the ANSI SQL names; rewrite common ones"""
func.name = self.function_rewrites.get(func.name, func.name)
return super(AccessCompiler, self).visit_function(func)
def for_update_clause(self, select):
"""FOR UPDATE is not supported by Access; silently ignore"""
return ''
# Strip schema
def visit_table(self, table, asfrom=False, **kwargs):
if asfrom:
return self.preparer.quote(table.name, table.quote)
else:
return ""
def visit_join(self, join, asfrom=False, **kwargs):
return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \
self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
def visit_extract(self, extract, **kw):
field = self.extract_map.get(extract.field, extract.field)
return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
class AccessDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
# install a sequence if we have an implicit IDENTITY column
if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
column.autoincrement and isinstance(column.type, types.Integer) and not column.foreign_keys:
if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
column.sequence = schema.Sequence(column.name + '_seq')
if not column.nullable:
colspec += " NOT NULL"
if hasattr(column, 'sequence'):
column.table.has_sequence = column
colspec = self.preparer.format_column(column) + " counter"
else:
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
return colspec
def visit_drop_index(self, drop):
index = drop.element
self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False)))
class AccessIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = compiler.RESERVED_WORDS.copy()
reserved_words.update(['value', 'text'])
def __init__(self, dialect):
super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
dialect = AccessDialect
dialect.poolclass = pool.SingletonThreadPool
dialect.statement_compiler = AccessCompiler
dialect.ddlcompiler = AccessDDLCompiler
dialect.preparer = AccessIdentifierPreparer
dialect.execution_ctx_cls = AccessExecutionContext

View File

@ -0,0 +1,16 @@
from sqlalchemy.dialects.firebird import base, kinterbasdb
base.dialect = kinterbasdb.dialect
from sqlalchemy.dialects.firebird.base import \
SMALLINT, BIGINT, FLOAT, FLOAT, DATE, TIME, \
TEXT, NUMERIC, FLOAT, TIMESTAMP, VARCHAR, CHAR, BLOB,\
dialect
__all__ = (
'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME',
'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB',
'dialect'
)

View File

@ -0,0 +1,619 @@
# firebird.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
Support for the Firebird database.
Connectivity is usually supplied via the kinterbasdb_ DBAPI module.
Dialects
~~~~~~~~
Firebird offers two distinct dialects_ (not to be confused with a
SQLAlchemy ``Dialect``):
dialect 1
This is the old syntax and behaviour, inherited from Interbase pre-6.0.
dialect 3
This is the newer and supported syntax, introduced in Interbase 6.0.
The SQLAlchemy Firebird dialect detects these versions and
adjusts its representation of SQL accordingly. However,
support for dialect 1 is not well tested and probably has
incompatibilities.
Locking Behavior
~~~~~~~~~~~~~~~~
Firebird locks tables aggressively. For this reason, a DROP TABLE may
hang until other transactions are released. SQLAlchemy does its best
to release transactions as quickly as possible. The most common cause
of hanging transactions is a non-fully consumed result set, i.e.::
result = engine.execute("select * from table")
row = result.fetchone()
return
Where above, the ``ResultProxy`` has not been fully consumed. The
connection will be returned to the pool and the transactional state
rolled back once the Python garbage collector reclaims the objects
which hold onto the connection, which often occurs asynchronously.
The above use case can be alleviated by calling ``first()`` on the
``ResultProxy`` which will fetch the first row and immediately close
all remaining cursor/connection resources.
RETURNING support
~~~~~~~~~~~~~~~~~
Firebird 2.0 supports returning a result set from inserts, and 2.1
extends that to deletes and updates. This is generically exposed by
the SQLAlchemy ``returning()`` method, such as::
# INSERT..RETURNING
result = table.insert().returning(table.c.col1, table.c.col2).\\
values(name='foo')
print result.fetchall()
# UPDATE..RETURNING
raises = empl.update().returning(empl.c.id, empl.c.salary).\\
where(empl.c.sales>100).\\
values(dict(salary=empl.c.salary * 1.1))
print raises.fetchall()
.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html
"""
import datetime, re
from sqlalchemy import schema as sa_schema
from sqlalchemy import exc, types as sqltypes, sql, util
from sqlalchemy.sql import expression
from sqlalchemy.engine import base, default, reflection
from sqlalchemy.sql import compiler
from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE,
FLOAT, INTEGER, NUMERIC, SMALLINT,
TEXT, TIME, TIMESTAMP, VARCHAR)
RESERVED_WORDS = set([
"active", "add", "admin", "after", "all", "alter", "and", "any", "as",
"asc", "ascending", "at", "auto", "avg", "before", "begin", "between",
"bigint", "bit_length", "blob", "both", "by", "case", "cast", "char",
"character", "character_length", "char_length", "check", "close",
"collate", "column", "commit", "committed", "computed", "conditional",
"connect", "constraint", "containing", "count", "create", "cross",
"cstring", "current", "current_connection", "current_date",
"current_role", "current_time", "current_timestamp",
"current_transaction", "current_user", "cursor", "database", "date",
"day", "dec", "decimal", "declare", "default", "delete", "desc",
"descending", "disconnect", "distinct", "do", "domain", "double",
"drop", "else", "end", "entry_point", "escape", "exception",
"execute", "exists", "exit", "external", "extract", "fetch", "file",
"filter", "float", "for", "foreign", "from", "full", "function",
"gdscode", "generator", "gen_id", "global", "grant", "group",
"having", "hour", "if", "in", "inactive", "index", "inner",
"input_type", "insensitive", "insert", "int", "integer", "into", "is",
"isolation", "join", "key", "leading", "left", "length", "level",
"like", "long", "lower", "manual", "max", "maximum_segment", "merge",
"min", "minute", "module_name", "month", "names", "national",
"natural", "nchar", "no", "not", "null", "numeric", "octet_length",
"of", "on", "only", "open", "option", "or", "order", "outer",
"output_type", "overflow", "page", "pages", "page_size", "parameter",
"password", "plan", "position", "post_event", "precision", "primary",
"privileges", "procedure", "protected", "rdb$db_key", "read", "real",
"record_version", "recreate", "recursive", "references", "release",
"reserv", "reserving", "retain", "returning_values", "returns",
"revoke", "right", "rollback", "rows", "row_count", "savepoint",
"schema", "second", "segment", "select", "sensitive", "set", "shadow",
"shared", "singular", "size", "smallint", "snapshot", "some", "sort",
"sqlcode", "stability", "start", "starting", "starts", "statistics",
"sub_type", "sum", "suspend", "table", "then", "time", "timestamp",
"to", "trailing", "transaction", "trigger", "trim", "uncommitted",
"union", "unique", "update", "upper", "user", "using", "value",
"values", "varchar", "variable", "varying", "view", "wait", "when",
"where", "while", "with", "work", "write", "year",
])
colspecs = {
}
ischema_names = {
'SHORT': SMALLINT,
'LONG': BIGINT,
'QUAD': FLOAT,
'FLOAT': FLOAT,
'DATE': DATE,
'TIME': TIME,
'TEXT': TEXT,
'INT64': NUMERIC,
'DOUBLE': FLOAT,
'TIMESTAMP': TIMESTAMP,
'VARYING': VARCHAR,
'CSTRING': CHAR,
'BLOB': BLOB,
}
# TODO: date conversion types (should be implemented as _FBDateTime, _FBDate, etc.
# as bind/result functionality is required)
class FBTypeCompiler(compiler.GenericTypeCompiler):
def visit_boolean(self, type_):
return self.visit_SMALLINT(type_)
def visit_datetime(self, type_):
return self.visit_TIMESTAMP(type_)
def visit_TEXT(self, type_):
return "BLOB SUB_TYPE 1"
def visit_BLOB(self, type_):
return "BLOB SUB_TYPE 0"
class FBCompiler(sql.compiler.SQLCompiler):
"""Firebird specific idiosincrasies"""
def visit_mod(self, binary, **kw):
# Firebird lacks a builtin modulo operator, but there is
# an equivalent function in the ib_udf library.
return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
def visit_alias(self, alias, asfrom=False, **kwargs):
if self.dialect._version_two:
return super(FBCompiler, self).visit_alias(alias, asfrom=asfrom, **kwargs)
else:
# Override to not use the AS keyword which FB 1.5 does not like
if asfrom:
alias_name = isinstance(alias.name, expression._generated_label) and \
self._truncated_identifier("alias", alias.name) or alias.name
return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + \
self.preparer.format_alias(alias, alias_name)
else:
return self.process(alias.original, **kwargs)
def visit_substring_func(self, func, **kw):
s = self.process(func.clauses.clauses[0])
start = self.process(func.clauses.clauses[1])
if len(func.clauses.clauses) > 2:
length = self.process(func.clauses.clauses[2])
return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
else:
return "SUBSTRING(%s FROM %s)" % (s, start)
def visit_length_func(self, function, **kw):
if self.dialect._version_two:
return "char_length" + self.function_argspec(function)
else:
return "strlen" + self.function_argspec(function)
visit_char_length_func = visit_length_func
def function_argspec(self, func, **kw):
if func.clauses is not None and len(func.clauses):
return self.process(func.clause_expr)
else:
return ""
def default_from(self):
return " FROM rdb$database"
def visit_sequence(self, seq):
return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just
before column list Firebird puts the limit and offset right
after the ``SELECT``...
"""
result = ""
if select._limit:
result += "FIRST %d " % select._limit
if select._offset:
result +="SKIP %d " % select._offset
if select._distinct:
result += "DISTINCT "
return result
def limit_clause(self, select):
"""Already taken care of in the `get_select_precolumns` method."""
return ""
def returning_clause(self, stmt, returning_cols):
columns = [
self.process(
self.label_select_column(None, c, asfrom=False),
within_columns_clause=True,
result_map=self.result_map
)
for c in expression._select_iterables(returning_cols)
]
return 'RETURNING ' + ', '.join(columns)
class FBDDLCompiler(sql.compiler.DDLCompiler):
"""Firebird syntactic idiosincrasies"""
def visit_create_sequence(self, create):
"""Generate a ``CREATE GENERATOR`` statement for the sequence."""
# no syntax for these
# http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html
if create.element.start is not None:
raise NotImplemented("Firebird SEQUENCE doesn't support START WITH")
if create.element.increment is not None:
raise NotImplemented("Firebird SEQUENCE doesn't support INCREMENT BY")
if self.dialect._version_two:
return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
else:
return "CREATE GENERATOR %s" % self.preparer.format_sequence(create.element)
def visit_drop_sequence(self, drop):
"""Generate a ``DROP GENERATOR`` statement for the sequence."""
if self.dialect._version_two:
return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
else:
return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element)
class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
"""Install Firebird specific reserved words."""
reserved_words = RESERVED_WORDS
def __init__(self, dialect):
super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
class FBExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq):
"""Get the next value from the sequence using ``gen_id()``."""
return self._execute_scalar("SELECT gen_id(%s, 1) FROM rdb$database" % \
self.dialect.identifier_preparer.format_sequence(seq))
class FBDialect(default.DefaultDialect):
"""Firebird dialect"""
name = 'firebird'
max_identifier_length = 31
supports_sequences = True
sequences_optional = False
supports_default_values = True
postfetch_lastrowid = False
supports_native_boolean = False
requires_name_normalize = True
supports_empty_insert = False
statement_compiler = FBCompiler
ddl_compiler = FBDDLCompiler
preparer = FBIdentifierPreparer
type_compiler = FBTypeCompiler
execution_ctx_cls = FBExecutionContext
colspecs = colspecs
ischema_names = ischema_names
# defaults to dialect ver. 3,
# will be autodetected off upon
# first connect
_version_two = True
def initialize(self, connection):
super(FBDialect, self).initialize(connection)
self._version_two = self.server_version_info > (2, )
if not self._version_two:
# TODO: whatever other pre < 2.0 stuff goes here
self.ischema_names = ischema_names.copy()
self.ischema_names['TIMESTAMP'] = sqltypes.DATE
self.colspecs = {
sqltypes.DateTime: sqltypes.DATE
}
else:
self.implicit_returning = True
def normalize_name(self, name):
# Remove trailing spaces: FB uses a CHAR() type,
# that is padded with spaces
name = name and name.rstrip()
if name is None:
return None
elif name.upper() == name and \
not self.identifier_preparer._requires_quotes(name.lower()):
return name.lower()
else:
return name
def denormalize_name(self, name):
if name is None:
return None
elif name.lower() == name and \
not self.identifier_preparer._requires_quotes(name.lower()):
return name.upper()
else:
return name
def has_table(self, connection, table_name, schema=None):
"""Return ``True`` if the given table exists, ignoring the `schema`."""
tblqry = """
SELECT 1 FROM rdb$database
WHERE EXISTS (SELECT rdb$relation_name
FROM rdb$relations
WHERE rdb$relation_name=?)
"""
c = connection.execute(tblqry, [self.denormalize_name(table_name)])
return c.first() is not None
def has_sequence(self, connection, sequence_name, schema=None):
"""Return ``True`` if the given sequence (generator) exists."""
genqry = """
SELECT 1 FROM rdb$database
WHERE EXISTS (SELECT rdb$generator_name
FROM rdb$generators
WHERE rdb$generator_name=?)
"""
c = connection.execute(genqry, [self.denormalize_name(sequence_name)])
return c.first() is not None
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
s = """
SELECT DISTINCT rdb$relation_name
FROM rdb$relation_fields
WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
"""
return [self.normalize_name(row[0]) for row in connection.execute(s)]
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
s = """
SELECT distinct rdb$view_name
FROM rdb$view_relations
"""
return [self.normalize_name(row[0]) for row in connection.execute(s)]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
qry = """
SELECT rdb$view_source AS view_source
FROM rdb$relations
WHERE rdb$relation_name=?
"""
rp = connection.execute(qry, [self.denormalize_name(view_name)])
row = rp.first()
if row:
return row['view_source']
else:
return None
@reflection.cache
def get_primary_keys(self, connection, table_name, schema=None, **kw):
# Query to extract the PK/FK constrained fields of the given table
keyqry = """
SELECT se.rdb$field_name AS fname
FROM rdb$relation_constraints rc
JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
"""
tablename = self.denormalize_name(table_name)
# get primary key fields
c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()]
return pkfields
@reflection.cache
def get_column_sequence(self, connection, table_name, column_name, schema=None, **kw):
tablename = self.denormalize_name(table_name)
colname = self.denormalize_name(column_name)
# Heuristic-query to determine the generator associated to a PK field
genqry = """
SELECT trigdep.rdb$depended_on_name AS fgenerator
FROM rdb$dependencies tabdep
JOIN rdb$dependencies trigdep
ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
AND trigdep.rdb$depended_on_type=14
AND trigdep.rdb$dependent_type=2
JOIN rdb$triggers trig ON trig.rdb$trigger_name=tabdep.rdb$dependent_name
WHERE tabdep.rdb$depended_on_name=?
AND tabdep.rdb$depended_on_type=0
AND trig.rdb$trigger_type=1
AND tabdep.rdb$field_name=?
AND (SELECT count(*)
FROM rdb$dependencies trigdep2
WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
"""
genr = connection.execute(genqry, [tablename, colname]).first()
if genr is not None:
return dict(name=self.normalize_name(genr['fgenerator']))
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
# Query to extract the details of all the fields of the given table
tblqry = """
SELECT DISTINCT r.rdb$field_name AS fname,
r.rdb$null_flag AS null_flag,
t.rdb$type_name AS ftype,
f.rdb$field_sub_type AS stype,
f.rdb$field_length/COALESCE(cs.rdb$bytes_per_character,1) AS flen,
f.rdb$field_precision AS fprec,
f.rdb$field_scale AS fscale,
COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault
FROM rdb$relation_fields r
JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
JOIN rdb$types t
ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE'
LEFT JOIN rdb$character_sets cs ON f.rdb$character_set_id=cs.rdb$character_set_id
WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
ORDER BY r.rdb$field_position
"""
# get the PK, used to determine the eventual associated sequence
pkey_cols = self.get_primary_keys(connection, table_name)
tablename = self.denormalize_name(table_name)
# get all of the fields for this table
c = connection.execute(tblqry, [tablename])
cols = []
while True:
row = c.fetchone()
if row is None:
break
name = self.normalize_name(row['fname'])
orig_colname = row['fname']
# get the data type
colspec = row['ftype'].rstrip()
coltype = self.ischema_names.get(colspec)
if coltype is None:
util.warn("Did not recognize type '%s' of column '%s'" %
(colspec, name))
coltype = sqltypes.NULLTYPE
elif colspec == 'INT64':
coltype = coltype(precision=row['fprec'], scale=row['fscale'] * -1)
elif colspec in ('VARYING', 'CSTRING'):
coltype = coltype(row['flen'])
elif colspec == 'TEXT':
coltype = TEXT(row['flen'])
elif colspec == 'BLOB':
if row['stype'] == 1:
coltype = TEXT()
else:
coltype = BLOB()
else:
coltype = coltype(row)
# does it have a default value?
defvalue = None
if row['fdefault'] is not None:
# the value comes down as "DEFAULT 'value'": there may be
# more than one whitespace around the "DEFAULT" keyword
# (see also http://tracker.firebirdsql.org/browse/CORE-356)
defexpr = row['fdefault'].lstrip()
assert defexpr[:8].rstrip()=='DEFAULT', "Unrecognized default value: %s" % defexpr
defvalue = defexpr[8:].strip()
if defvalue == 'NULL':
# Redundant
defvalue = None
col_d = {
'name' : name,
'type' : coltype,
'nullable' : not bool(row['null_flag']),
'default' : defvalue
}
if orig_colname.lower() == orig_colname:
col_d['quote'] = True
# if the PK is a single field, try to see if its linked to
# a sequence thru a trigger
if len(pkey_cols)==1 and name==pkey_cols[0]:
seq_d = self.get_column_sequence(connection, tablename, name)
if seq_d is not None:
col_d['sequence'] = seq_d
cols.append(col_d)
return cols
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Query to extract the details of each UK/FK of the given table
fkqry = """
SELECT rc.rdb$constraint_name AS cname,
cse.rdb$field_name AS fname,
ix2.rdb$relation_name AS targetrname,
se.rdb$field_name AS targetfname
FROM rdb$relation_constraints rc
JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name
JOIN rdb$index_segments se
ON se.rdb$index_name=ix2.rdb$index_name
AND se.rdb$field_position=cse.rdb$field_position
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
ORDER BY se.rdb$index_name, se.rdb$field_position
"""
tablename = self.denormalize_name(table_name)
c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
fks = util.defaultdict(lambda:{
'name' : None,
'constrained_columns' : [],
'referred_schema' : None,
'referred_table' : None,
'referred_columns' : []
})
for row in c:
cname = self.normalize_name(row['cname'])
fk = fks[cname]
if not fk['name']:
fk['name'] = cname
fk['referred_table'] = self.normalize_name(row['targetrname'])
fk['constrained_columns'].append(self.normalize_name(row['fname']))
fk['referred_columns'].append(
self.normalize_name(row['targetfname']))
return fks.values()
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
qry = """
SELECT ix.rdb$index_name AS index_name,
ix.rdb$unique_flag AS unique_flag,
ic.rdb$field_name AS field_name
FROM rdb$indices ix
JOIN rdb$index_segments ic
ON ix.rdb$index_name=ic.rdb$index_name
LEFT OUTER JOIN rdb$relation_constraints
ON rdb$relation_constraints.rdb$index_name = ic.rdb$index_name
WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
AND rdb$relation_constraints.rdb$constraint_type IS NULL
ORDER BY index_name, field_name
"""
c = connection.execute(qry, [self.denormalize_name(table_name)])
indexes = util.defaultdict(dict)
for row in c:
indexrec = indexes[row['index_name']]
if 'name' not in indexrec:
indexrec['name'] = self.normalize_name(row['index_name'])
indexrec['column_names'] = []
indexrec['unique'] = bool(row['unique_flag'])
indexrec['column_names'].append(self.normalize_name(row['field_name']))
return indexes.values()
def do_execute(self, cursor, statement, parameters, **kwargs):
# kinterbase does not accept a None, but wants an empty list
# when there are no arguments.
cursor.execute(statement, parameters or [])
def do_rollback(self, connection):
# Use the retaining feature, that keeps the transaction going
connection.rollback(True)
def do_commit(self, connection):
# Use the retaining feature, that keeps the transaction going
connection.commit(True)

View File

@ -0,0 +1,120 @@
# kinterbasdb.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
The most common way to connect to a Firebird engine is implemented by
kinterbasdb__, currently maintained__ directly by the Firebird people.
The connection URL is of the form
``firebird[+kinterbasdb]://user:password@host:port/path/to/db[?key=value&key=value...]``.
Kinterbasedb backend specific keyword arguments are:
type_conv
select the kind of mapping done on the types: by default SQLAlchemy
uses 200 with Unicode, datetime and decimal support (see details__).
concurrency_level
set the backend policy with regards to threading issues: by default
SQLAlchemy uses policy 1 (see details__).
__ http://sourceforge.net/projects/kinterbasdb
__ http://firebirdsql.org/index.php?op=devel&sub=python
__ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
__ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
"""
from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler
from sqlalchemy import util, types as sqltypes
class _FBNumeric_kinterbasdb(sqltypes.Numeric):
def bind_processor(self, dialect):
def process(value):
if value is not None:
return str(value)
else:
return value
return process
class FBDialect_kinterbasdb(FBDialect):
driver = 'kinterbasdb'
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
supports_native_decimal = True
colspecs = util.update_copy(
FBDialect.colspecs,
{
sqltypes.Numeric:_FBNumeric_kinterbasdb
}
)
def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
super(FBDialect_kinterbasdb, self).__init__(**kwargs)
self.type_conv = type_conv
self.concurrency_level = concurrency_level
@classmethod
def dbapi(cls):
k = __import__('kinterbasdb')
return k
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if opts.get('port'):
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
del opts['port']
opts.update(url.query)
type_conv = opts.pop('type_conv', self.type_conv)
concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
if self.dbapi is not None:
initialized = getattr(self.dbapi, 'initialized', None)
if initialized is None:
# CVS rev 1.96 changed the name of the attribute:
# http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
initialized = getattr(self.dbapi, '_initialized', False)
if not initialized:
self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
return ([], opts)
def _get_server_version_info(self, connection):
"""Get the version of the Firebird server used by a connection.
Returns a tuple of (`major`, `minor`, `build`), three integers
representing the version of the attached server.
"""
# This is the simpler approach (the other uses the services api),
# that for backward compatibility reasons returns a string like
# LI-V6.3.3.12981 Firebird 2.0
# where the first version is a fake one resembling the old
# Interbase signature. This is more than enough for our purposes,
# as this is mainly (only?) used by the testsuite.
from re import match
fbconn = connection.connection
version = fbconn.server_version
m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
if not m:
raise AssertionError("Could not determine version from string '%s'" % version)
return tuple([int(x) for x in m.group(5, 6, 4)])
def is_disconnect(self, e):
if isinstance(e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)):
msg = str(e)
return ('Unable to complete network request to host' in msg or
'Invalid connection state' in msg or
'Invalid cursor state' in msg)
else:
return False
dialect = FBDialect_kinterbasdb

View File

@ -0,0 +1,3 @@
from sqlalchemy.dialects.informix import base, informixdb
base.dialect = informixdb.dialect

View File

@ -0,0 +1,306 @@
# informix.py
# Copyright (C) 2005,2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# coding: gbk
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Support for the Informix database.
This dialect is *not* tested on SQLAlchemy 0.6.
"""
import datetime
from sqlalchemy import sql, schema, exc, pool, util
from sqlalchemy.sql import compiler
from sqlalchemy.engine import default, reflection
from sqlalchemy import types as sqltypes
class InfoDateTime(sqltypes.DateTime):
def bind_processor(self, dialect):
def process(value):
if value is not None:
if value.microsecond:
value = value.replace(microsecond=0)
return value
return process
class InfoTime(sqltypes.Time):
def bind_processor(self, dialect):
def process(value):
if value is not None:
if value.microsecond:
value = value.replace(microsecond=0)
return value
return process
def result_processor(self, dialect, coltype):
def process(value):
if isinstance(value, datetime.datetime):
return value.time()
else:
return value
return process
colspecs = {
sqltypes.DateTime : InfoDateTime,
sqltypes.Time: InfoTime,
}
ischema_names = {
0 : sqltypes.CHAR, # CHAR
1 : sqltypes.SMALLINT, # SMALLINT
2 : sqltypes.INTEGER, # INT
3 : sqltypes.FLOAT, # Float
3 : sqltypes.Float, # SmallFloat
5 : sqltypes.DECIMAL, # DECIMAL
6 : sqltypes.Integer, # Serial
7 : sqltypes.DATE, # DATE
8 : sqltypes.Numeric, # MONEY
10 : sqltypes.DATETIME, # DATETIME
11 : sqltypes.LargeBinary, # BYTE
12 : sqltypes.TEXT, # TEXT
13 : sqltypes.VARCHAR, # VARCHAR
15 : sqltypes.NCHAR, # NCHAR
16 : sqltypes.NVARCHAR, # NVARCHAR
17 : sqltypes.Integer, # INT8
18 : sqltypes.Integer, # Serial8
43 : sqltypes.String, # LVARCHAR
-1 : sqltypes.BLOB, # BLOB
-1 : sqltypes.CLOB, # CLOB
}
class InfoTypeCompiler(compiler.GenericTypeCompiler):
def visit_DATETIME(self, type_):
return "DATETIME YEAR TO SECOND"
def visit_TIME(self, type_):
return "DATETIME HOUR TO SECOND"
def visit_large_binary(self, type_):
return "BYTE"
def visit_boolean(self, type_):
return "SMALLINT"
class InfoSQLCompiler(compiler.SQLCompiler):
def default_from(self):
return " from systables where tabname = 'systables' "
def get_select_precolumns(self, select):
s = select._distinct and "DISTINCT " or ""
# only has limit
if select._limit:
s += " FIRST %s " % select._limit
else:
s += ""
return s
def visit_select(self, select):
# the column in order by clause must in select too
def __label(c):
try:
return c._label.lower()
except:
return ''
# TODO: dont modify the original select, generate a new one
a = [__label(c) for c in select._raw_columns]
for c in select._order_by_clause.clauses:
if __label(c) not in a:
select.append_column(c)
return compiler.SQLCompiler.visit_select(self, select)
def limit_clause(self, select):
if select._offset is not None and select._offset > 0:
raise NotImplementedError("Informix does not support OFFSET")
return ""
def visit_function(self, func):
if func.name.lower() == 'current_date':
return "today"
elif func.name.lower() == 'current_time':
return "CURRENT HOUR TO SECOND"
elif func.name.lower() in ('current_timestamp', 'now'):
return "CURRENT YEAR TO SECOND"
else:
return compiler.SQLCompiler.visit_function(self, func)
class InfoDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, first_pk=False):
colspec = self.preparer.format_column(column)
if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
isinstance(column.type, sqltypes.Integer) and first_pk:
colspec += " SERIAL"
else:
colspec += " " + self.dialect.type_compiler.process(column.type)
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
if not column.nullable:
colspec += " NOT NULL"
return colspec
class InfoIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
def format_constraint(self, constraint):
# informix doesnt support names for constraints
return ''
def _requires_quotes(self, value):
return False
class InformixDialect(default.DefaultDialect):
name = 'informix'
max_identifier_length = 128 # adjusts at runtime based on server version
type_compiler = InfoTypeCompiler
statement_compiler = InfoSQLCompiler
ddl_compiler = InfoDDLCompiler
preparer = InfoIdentifierPreparer
colspecs = colspecs
ischema_names = ischema_names
def initialize(self, connection):
super(InformixDialect, self).initialize(connection)
# http://www.querix.com/support/knowledge-base/error_number_message/error_200
if self.server_version_info < (9, 2):
self.max_identifier_length = 18
else:
self.max_identifier_length = 128
def do_begin(self, connect):
cu = connect.cursor()
cu.execute('SET LOCK MODE TO WAIT')
#cu.execute('SET ISOLATION TO REPEATABLE READ')
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
s = "select tabname from systables"
return [row[0] for row in connection.execute(s)]
def has_table(self, connection, table_name, schema=None):
cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower())
return cursor.first() is not None
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from
syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3
where t1.tabid = t2.tabid and t2.tabname=?
and t3.tabid = t2.tabid and t3.colno = t1.colno
order by t1.colno""", table.name.lower())
columns = []
for name, colattr, collength, default, colno in rows:
name = name.lower()
if include_columns and name not in include_columns:
continue
# in 7.31, coltype = 0x000
# ^^-- column type
# ^-- 1 not null, 0 null
nullable, coltype = divmod(colattr, 256)
if coltype not in (0, 13) and default:
default = default.split()[-1]
if coltype == 0 or coltype == 13: # char, varchar
coltype = ischema_names[coltype](collength)
if default:
default = "'%s'" % default
elif coltype == 5: # decimal
precision, scale = (collength & 0xFF00) >> 8, collength & 0xFF
if scale == 255:
scale = 0
coltype = sqltypes.Numeric(precision, scale)
else:
try:
coltype = ischema_names[coltype]
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" %
(coltype, name))
coltype = sqltypes.NULLTYPE
# TODO: nullability ??
nullable = True
column_info = dict(name=name, type=coltype, nullable=nullable,
default=default)
columns.append(column_info)
return columns
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# FK
c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type ,
t4.colname as local_column , t7.tabname as remote_table ,
t6.colname as remote_column
from sysconstraints as t1 , systables as t2 ,
sysindexes as t3 , syscolumns as t4 ,
sysreferences as t5 , syscolumns as t6 , systables as t7 ,
sysconstraints as t8 , sysindexes as t9
where t1.tabid = t2.tabid and t2.tabname=? and t1.constrtype = 'R'
and t3.tabid = t2.tabid and t3.idxname = t1.idxname
and t4.tabid = t2.tabid and t4.colno = t3.part1
and t5.constrid = t1.constrid and t8.constrid = t5.primary
and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname
and t7.tabid = t5.ptabid""", table.name.lower())
def fkey_rec():
return {
'name' : None,
'constrained_columns' : [],
'referred_schema' : None,
'referred_table' : None,
'referred_columns' : []
}
fkeys = util.defaultdict(fkey_rec)
for cons_name, cons_type, local_column, remote_table, remote_column in rows:
rec = fkeys[cons_name]
rec['name'] = cons_name
local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns']
if not rec['referred_table']:
rec['referred_table'] = remote_table
local_cols.append(local_column)
remote_cols.append(remote_column)
return fkeys.values()
@reflection.cache
def get_primary_keys(self, connection, table_name, schema=None, **kw):
c = connection.execute("""select t4.colname as local_column
from sysconstraints as t1 , systables as t2 ,
sysindexes as t3 , syscolumns as t4
where t1.tabid = t2.tabid and t2.tabname=? and t1.constrtype = 'P'
and t3.tabid = t2.tabid and t3.idxname = t1.idxname
and t4.tabid = t2.tabid and t4.colno = t3.part1""", table.name.lower())
return [r[0] for r in c.fetchall()]
@reflection.cache
def get_indexes(self, connection, table_name, schema, **kw):
# TODO
return []

View File

@ -0,0 +1,46 @@
from sqlalchemy.dialects.informix.base import InformixDialect
from sqlalchemy.engine import default
class InformixExecutionContext_informixdb(default.DefaultExecutionContext):
def post_exec(self):
if self.isinsert:
self._lastrowid = [self.cursor.sqlerrd[1]]
class InformixDialect_informixdb(InformixDialect):
driver = 'informixdb'
default_paramstyle = 'qmark'
execution_context_cls = InformixExecutionContext_informixdb
@classmethod
def dbapi(cls):
return __import__('informixdb')
def create_connect_args(self, url):
if url.host:
dsn = '%s@%s' % (url.database, url.host)
else:
dsn = url.database
if url.username:
opt = {'user': url.username, 'password': url.password}
else:
opt = {}
return ([dsn], opt)
def _get_server_version_info(self, connection):
# http://informixdb.sourceforge.net/manual.html#inspecting-version-numbers
vers = connection.dbms_version
# TODO: not tested
return tuple([int(x) for x in vers.split('.')])
def is_disconnect(self, e):
if isinstance(e, self.dbapi.OperationalError):
return 'closed the connection' in str(e) or 'connection not open' in str(e)
else:
return False
dialect = InformixDialect_informixdb

View File

@ -0,0 +1,3 @@
from sqlalchemy.dialects.maxdb import base, sapdb
base.dialect = sapdb.dialect

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,17 @@
from sqlalchemy.dialects.maxdb.base import MaxDBDialect
class MaxDBDialect_sapdb(MaxDBDialect):
driver = 'sapdb'
@classmethod
def dbapi(cls):
from sapdb import dbapi as _dbapi
return _dbapi
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
return [], opts
dialect = MaxDBDialect_sapdb

View File

@ -0,0 +1,19 @@
from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, pymssql, zxjdbc, mxodbc
base.dialect = pyodbc.dialect
from sqlalchemy.dialects.mssql.base import \
INTEGER, BIGINT, SMALLINT, TINYINT, VARCHAR, NVARCHAR, CHAR, \
NCHAR, TEXT, NTEXT, DECIMAL, NUMERIC, FLOAT, DATETIME,\
DATETIME2, DATETIMEOFFSET, DATE, TIME, SMALLDATETIME, \
BINARY, VARBINARY, BIT, REAL, IMAGE, TIMESTAMP,\
MONEY, SMALLMONEY, UNIQUEIDENTIFIER, SQL_VARIANT, dialect
__all__ = (
'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR',
'NCHAR', 'TEXT', 'NTEXT', 'DECIMAL', 'NUMERIC', 'FLOAT', 'DATETIME',
'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME',
'BINARY', 'VARBINARY', 'BIT', 'REAL', 'IMAGE', 'TIMESTAMP',
'MONEY', 'SMALLMONEY', 'UNIQUEIDENTIFIER', 'SQL_VARIANT', 'dialect'
)

View File

@ -0,0 +1,59 @@
"""
The adodbapi dialect is not implemented for 0.6 at this time.
"""
from sqlalchemy import types as sqltypes, util
from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
import sys
class MSDateTime_adodbapi(MSDateTime):
def result_processor(self, dialect, coltype):
def process(value):
# adodbapi will return datetimes with empty time values as datetime.date() objects.
# Promote them back to full datetime.datetime()
if type(value) is datetime.date:
return datetime.datetime(value.year, value.month, value.day)
return value
return process
class MSDialect_adodbapi(MSDialect):
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_unicode = sys.maxunicode == 65535
supports_unicode_statements = True
driver = 'adodbapi'
@classmethod
def import_dbapi(cls):
import adodbapi as module
return module
colspecs = util.update_copy(
MSDialect.colspecs,
{
sqltypes.DateTime:MSDateTime_adodbapi
}
)
def create_connect_args(self, url):
keys = url.query
connectors = ["Provider=SQLOLEDB"]
if 'port' in keys:
connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
else:
connectors.append ("Data Source=%s" % keys.get("host"))
connectors.append ("Initial Catalog=%s" % keys.get("database"))
user = keys.get("user")
if user:
connectors.append("User Id=%s" % user)
connectors.append("Password=%s" % keys.get("password", ""))
else:
connectors.append("Integrated Security=SSPI")
return [[";".join (connectors)], {}]
def is_disconnect(self, e):
return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
dialect = MSDialect_adodbapi

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,83 @@
from sqlalchemy import Table, MetaData, Column, ForeignKey
from sqlalchemy.types import String, Unicode, Integer, TypeDecorator
ischema = MetaData()
class CoerceUnicode(TypeDecorator):
impl = Unicode
def process_bind_param(self, value, dialect):
if isinstance(value, str):
value = value.decode(dialect.encoding)
return value
schemata = Table("SCHEMATA", ischema,
Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
schema="INFORMATION_SCHEMA")
tables = Table("TABLES", ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"),
schema="INFORMATION_SCHEMA")
columns = Table("COLUMNS", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("IS_NULLABLE", Integer, key="is_nullable"),
Column("DATA_TYPE", String, key="data_type"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"),
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
Column("COLUMN_DEFAULT", Integer, key="column_default"),
Column("COLLATION_NAME", String, key="collation_name"),
schema="INFORMATION_SCHEMA")
constraints = Table("TABLE_CONSTRAINTS", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
Column("CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"),
schema="INFORMATION_SCHEMA")
column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
schema="INFORMATION_SCHEMA")
key_constraints = Table("KEY_COLUMN_USAGE", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
schema="INFORMATION_SCHEMA")
ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, key="unique_constraint_catalog"), # TODO: is CATLOG misspelled ?
Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, key="unique_constraint_schema"),
Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"),
Column("MATCH_OPTION", String, key="match_option"),
Column("UPDATE_RULE", String, key="update_rule"),
Column("DELETE_RULE", String, key="delete_rule"),
schema="INFORMATION_SCHEMA")
views = Table("VIEWS", ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
Column("CHECK_OPTION", String, key="check_option"),
Column("IS_UPDATABLE", String, key="is_updatable"),
schema="INFORMATION_SCHEMA")

View File

@ -0,0 +1,83 @@
"""
Support for MS-SQL via mxODBC.
mxODBC is available at:
http://www.egenix.com/
This was tested with mxODBC 3.1.2 and the SQL Server Native
Client connected to MSSQL 2005 and 2008 Express Editions.
Connecting
~~~~~~~~~~
Connection is via DSN::
mssql+mxodbc://<username>:<password>@<dsnname>
Execution Modes
~~~~~~~~~~~~~~~
mxODBC features two styles of statement execution, using the ``cursor.execute()``
and ``cursor.executedirect()`` methods (the second being an extension to the
DBAPI specification). The former makes use of the native
parameter binding services of the ODBC driver, while the latter uses string escaping.
The primary advantage to native parameter binding is that the same statement, when
executed many times, is only prepared once. Whereas the primary advantage to the
latter is that the rules for bind parameter placement are relaxed. MS-SQL has very
strict rules for native binds, including that they cannot be placed within the argument
lists of function calls, anywhere outside the FROM, or even within subqueries within the
FROM clause - making the usage of bind parameters within SELECT statements impossible for
all but the most simplistic statements. For this reason, the mxODBC dialect uses the
"native" mode by default only for INSERT, UPDATE, and DELETE statements, and uses the
escaped string mode for all other statements. This behavior can be controlled completely
via :meth:`~sqlalchemy.sql.expression.Executable.execution_options`
using the ``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a value of
``True`` will unconditionally use native bind parameters and a value of ``False`` will
uncondtionally use string-escaped parameters.
"""
import re
import sys
from sqlalchemy import types as sqltypes
from sqlalchemy import util
from sqlalchemy.connectors.mxodbc import MxODBCConnector
from sqlalchemy.dialects.mssql.pyodbc import MSExecutionContext_pyodbc
from sqlalchemy.dialects.mssql.base import (MSExecutionContext, MSDialect,
MSSQLCompiler, MSSQLStrictCompiler,
_MSDateTime, _MSDate, TIME)
class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
"""
The pyodbc execution context is useful for enabling
SELECT SCOPE_IDENTITY in cases where OUTPUT clause
does not work (tables with insert triggers).
"""
#todo - investigate whether the pyodbc execution context
# is really only being used in cases where OUTPUT
# won't work.
class MSDialect_mxodbc(MxODBCConnector, MSDialect):
# TODO: may want to use this only if FreeTDS is not in use,
# since FreeTDS doesn't seem to use native binds.
statement_compiler = MSSQLStrictCompiler
execution_ctx_cls = MSExecutionContext_mxodbc
colspecs = {
#sqltypes.Numeric : _MSNumeric,
sqltypes.DateTime : _MSDateTime,
sqltypes.Date : _MSDate,
sqltypes.Time : TIME,
}
def __init__(self, description_encoding='latin-1', **params):
super(MSDialect_mxodbc, self).__init__(**params)
self.description_encoding = description_encoding
dialect = MSDialect_mxodbc

View File

@ -0,0 +1,101 @@
"""
Support for the pymssql dialect.
This dialect supports pymssql 1.0 and greater.
pymssql is available at:
http://pymssql.sourceforge.net/
Connecting
^^^^^^^^^^
Sample connect string::
mssql+pymssql://<username>:<password>@<freetds_name>
Adding "?charset=utf8" or similar will cause pymssql to return
strings as Python unicode objects. This can potentially improve
performance in some scenarios as decoding of strings is
handled natively.
Limitations
^^^^^^^^^^^
pymssql inherits a lot of limitations from FreeTDS, including:
* no support for multibyte schema identifiers
* poor support for large decimals
* poor support for binary fields
* poor support for VARCHAR/CHAR fields over 255 characters
Please consult the pymssql documentation for further information.
"""
from sqlalchemy.dialects.mssql.base import MSDialect
from sqlalchemy import types as sqltypes, util, processors
import re
import decimal
class _MSNumeric_pymssql(sqltypes.Numeric):
def result_processor(self, dialect, type_):
if not self.asdecimal:
return processors.to_float
else:
return sqltypes.Numeric.result_processor(self, dialect, type_)
class MSDialect_pymssql(MSDialect):
supports_sane_rowcount = False
max_identifier_length = 30
driver = 'pymssql'
colspecs = util.update_copy(
MSDialect.colspecs,
{
sqltypes.Numeric:_MSNumeric_pymssql,
sqltypes.Float:sqltypes.Float,
}
)
@classmethod
def dbapi(cls):
module = __import__('pymssql')
# pymmsql doesn't have a Binary method. we use string
# TODO: monkeypatching here is less than ideal
module.Binary = str
client_ver = tuple(int(x) for x in module.__version__.split("."))
if client_ver < (1, ):
util.warn("The pymssql dialect expects at least "
"the 1.0 series of the pymssql DBAPI.")
return module
def __init__(self, **params):
super(MSDialect_pymssql, self).__init__(**params)
self.use_scope_identity = True
def _get_server_version_info(self, connection):
vers = connection.scalar("select @@version")
m = re.match(r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers)
if m:
return tuple(int(x) for x in m.group(1, 2, 3, 4))
else:
return None
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
opts.pop('port', None)
return [[], opts]
def is_disconnect(self, e):
for msg in (
"Error 10054",
"Not connected to any MS SQL server",
"Connection is closed"
):
if msg in str(e):
return True
else:
return False
dialect = MSDialect_pymssql

View File

@ -0,0 +1,197 @@
"""
Support for MS-SQL via pyodbc.
pyodbc is available at:
http://pypi.python.org/pypi/pyodbc/
Connecting
^^^^^^^^^^
Examples of pyodbc connection string URLs:
* ``mssql+pyodbc://mydsn`` - connects using the specified DSN named ``mydsn``.
The connection string that is created will appear like::
dsn=mydsn;Trusted_Connection=Yes
* ``mssql+pyodbc://user:pass@mydsn`` - connects using the DSN named
``mydsn`` passing in the ``UID`` and ``PWD`` information. The
connection string that is created will appear like::
dsn=mydsn;UID=user;PWD=pass
* ``mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english`` - connects
using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
information, plus the additional connection configuration option
``LANGUAGE``. The connection string that is created will appear
like::
dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
* ``mssql+pyodbc://user:pass@host/db`` - connects using a connection string
dynamically created that would appear like::
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
* ``mssql+pyodbc://user:pass@host:123/db`` - connects using a connection
string that is dynamically created, which also includes the port
information using the comma syntax. If your connection string
requires the port information to be passed as a ``port`` keyword
see the next example. This will create the following connection
string::
DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
* ``mssql+pyodbc://user:pass@host/db?port=123`` - connects using a connection
string that is dynamically created that includes the port
information as a separate ``port`` keyword. This will create the
following connection string::
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
If you require a connection string that is outside the options
presented above, use the ``odbc_connect`` keyword to pass in a
urlencoded connection string. What gets passed in will be urldecoded
and passed directly.
For example::
mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
would create the following connection string::
dsn=mydsn;Database=db
Encoding your connection string can be easily accomplished through
the python shell. For example::
>>> import urllib
>>> urllib.quote_plus('dsn=mydsn;Database=db')
'dsn%3Dmydsn%3BDatabase%3Ddb'
"""
from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect
from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy import types as sqltypes, util
import decimal
class _MSNumeric_pyodbc(sqltypes.Numeric):
"""Turns Decimals with adjusted() < 0 or > 7 into strings.
This is the only method that is proven to work with Pyodbc+MSSQL
without crashing (floats can be used but seem to cause sporadic
crashes).
"""
def bind_processor(self, dialect):
super_process = super(_MSNumeric_pyodbc, self).bind_processor(dialect)
def process(value):
if self.asdecimal and \
isinstance(value, decimal.Decimal):
adjusted = value.adjusted()
if adjusted < 0:
return self._small_dec_to_string(value)
elif adjusted > 7:
return self._large_dec_to_string(value)
if super_process:
return super_process(value)
else:
return value
return process
def _small_dec_to_string(self, value):
return "%s0.%s%s" % (
(value < 0 and '-' or ''),
'0' * (abs(value.adjusted()) - 1),
"".join([str(nint) for nint in value._int]))
def _large_dec_to_string(self, value):
if 'E' in str(value):
result = "%s%s%s" % (
(value < 0 and '-' or ''),
"".join([str(s) for s in value._int]),
"0" * (value.adjusted() - (len(value._int)-1)))
else:
if (len(value._int) - 1) > value.adjusted():
result = "%s%s.%s" % (
(value < 0 and '-' or ''),
"".join([str(s) for s in value._int][0:value.adjusted() + 1]),
"".join([str(s) for s in value._int][value.adjusted() + 1:]))
else:
result = "%s%s" % (
(value < 0 and '-' or ''),
"".join([str(s) for s in value._int][0:value.adjusted() + 1]))
return result
class MSExecutionContext_pyodbc(MSExecutionContext):
_embedded_scope_identity = False
def pre_exec(self):
"""where appropriate, issue "select scope_identity()" in the same statement.
Background on why "scope_identity()" is preferable to "@@identity":
http://msdn.microsoft.com/en-us/library/ms190315.aspx
Background on why we attempt to embed "scope_identity()" into the same
statement as the INSERT:
http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
"""
super(MSExecutionContext_pyodbc, self).pre_exec()
# don't embed the scope_identity select into an "INSERT .. DEFAULT VALUES"
if self._select_lastrowid and \
self.dialect.use_scope_identity and \
len(self.parameters[0]):
self._embedded_scope_identity = True
self.statement += "; select scope_identity()"
def post_exec(self):
if self._embedded_scope_identity:
# Fetch the last inserted id from the manipulated statement
# We may have to skip over a number of result sets with no data (due to triggers, etc.)
while True:
try:
# fetchall() ensures the cursor is consumed
# without closing it (FreeTDS particularly)
row = self.cursor.fetchall()[0]
break
except self.dialect.dbapi.Error, e:
# no way around this - nextset() consumes the previous set
# so we need to just keep flipping
self.cursor.nextset()
self._lastrowid = int(row[0])
else:
super(MSExecutionContext_pyodbc, self).post_exec()
class MSDialect_pyodbc(PyODBCConnector, MSDialect):
execution_ctx_cls = MSExecutionContext_pyodbc
pyodbc_driver_name = 'SQL Server'
colspecs = util.update_copy(
MSDialect.colspecs,
{
sqltypes.Numeric:_MSNumeric_pyodbc
}
)
def __init__(self, description_encoding='latin-1', **params):
super(MSDialect_pyodbc, self).__init__(**params)
self.description_encoding = description_encoding
self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset')
dialect = MSDialect_pyodbc

View File

@ -0,0 +1,64 @@
"""Support for the Microsoft SQL Server database via the zxjdbc JDBC
connector.
JDBC Driver
-----------
Requires the jTDS driver, available from: http://jtds.sourceforge.net/
Connecting
----------
URLs are of the standard form of
``mssql+zxjdbc://user:pass@host:port/dbname[?key=value&key=value...]``.
Additional arguments which may be specified either as query string
arguments on the URL, or as keyword arguments to
:func:`~sqlalchemy.create_engine()` will be passed as Connection
properties to the underlying JDBC driver.
"""
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
from sqlalchemy.dialects.mssql.base import MSDialect, MSExecutionContext
from sqlalchemy.engine import base
class MSExecutionContext_zxjdbc(MSExecutionContext):
_embedded_scope_identity = False
def pre_exec(self):
super(MSExecutionContext_zxjdbc, self).pre_exec()
# scope_identity after the fact returns null in jTDS so we must
# embed it
if self._select_lastrowid and self.dialect.use_scope_identity:
self._embedded_scope_identity = True
self.statement += "; SELECT scope_identity()"
def post_exec(self):
if self._embedded_scope_identity:
while True:
try:
row = self.cursor.fetchall()[0]
break
except self.dialect.dbapi.Error, e:
self.cursor.nextset()
self._lastrowid = int(row[0])
if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning:
self._result_proxy = base.FullyBufferedResultProxy(self)
if self._enable_identity_insert:
table = self.dialect.identifier_preparer.format_table(self.compiled.statement.table)
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table)
class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect):
jdbc_db_name = 'jtds:sqlserver'
jdbc_driver_name = 'net.sourceforge.jtds.jdbc.Driver'
execution_ctx_cls = MSExecutionContext_zxjdbc
def _get_server_version_info(self, connection):
return tuple(int(x) for x in connection.connection.dbversion.split('.'))
dialect = MSDialect_zxjdbc

View File

@ -0,0 +1,17 @@
from sqlalchemy.dialects.mysql import base, mysqldb, oursql, pyodbc, zxjdbc, mysqlconnector
# default dialect
base.dialect = mysqldb.dialect
from sqlalchemy.dialects.mysql.base import \
BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, DOUBLE, ENUM, DECIMAL,\
FLOAT, INTEGER, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, MEDIUMTEXT, NCHAR, \
NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT,\
VARBINARY, VARCHAR, YEAR, dialect
__all__ = (
'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE',
'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT',
'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP',
'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect'
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,132 @@
"""Support for the MySQL database via the MySQL Connector/Python adapter.
MySQL Connector/Python is available at:
https://launchpad.net/myconnpy
Connecting
-----------
Connect string format::
mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
"""
import re
from sqlalchemy.dialects.mysql.base import (MySQLDialect,
MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer,
BIT)
from sqlalchemy.engine import base as engine_base, default
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
from sqlalchemy import processors
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
def get_lastrowid(self):
return self.cursor.lastrowid
class MySQLCompiler_mysqlconnector(MySQLCompiler):
def visit_mod(self, binary, **kw):
return self.process(binary.left) + " %% " + self.process(binary.right)
def post_process_text(self, text):
return text.replace('%', '%%')
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace("%", "%%")
class _myconnpyBIT(BIT):
def result_processor(self, dialect, coltype):
"""MySQL-connector already converts mysql bits, so."""
return None
class MySQLDialect_mysqlconnector(MySQLDialect):
driver = 'mysqlconnector'
supports_unicode_statements = True
supports_unicode_binds = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = 'format'
execution_ctx_cls = MySQLExecutionContext_mysqlconnector
statement_compiler = MySQLCompiler_mysqlconnector
preparer = MySQLIdentifierPreparer_mysqlconnector
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
BIT: _myconnpyBIT,
}
)
@classmethod
def dbapi(cls):
from mysql import connector
return connector
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
util.coerce_kw_type(opts, 'buffered', bool)
util.coerce_kw_type(opts, 'raise_on_warnings', bool)
opts['buffered'] = True
opts['raise_on_warnings'] = True
# FOUND_ROWS must be set in ClientFlag to enable
# supports_sane_rowcount.
if self.dbapi is not None:
try:
from mysql.connector.constants import ClientFlag
client_flags = opts.get('client_flags', ClientFlag.get_default())
client_flags |= ClientFlag.FOUND_ROWS
opts['client_flags'] = client_flags
except:
pass
return [[], opts]
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
from mysql.connector.constants import ClientFlag
dbapi_con.set_client_flag(ClientFlag.FOUND_ROWS)
version = dbapi_con.get_server_version()
return tuple(version)
def _detect_charset(self, connection):
return connection.connection.get_characterset_info()
def _extract_error_code(self, exception):
try:
return exception.orig.errno
except AttributeError:
return None
def is_disconnect(self, e):
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
exceptions = (self.dbapi.OperationalError,self.dbapi.InterfaceError)
if isinstance(e, exceptions):
return e.errno in errnos
else:
return False
def _compat_fetchall(self, rp, charset=None):
return rp.fetchall()
def _compat_fetchone(self, rp, charset=None):
return rp.fetchone()
dialect = MySQLDialect_mysqlconnector

View File

@ -0,0 +1,202 @@
"""Support for the MySQL database via the MySQL-python adapter.
MySQL-Python is available at:
http://sourceforge.net/projects/mysql-python
At least version 1.2.1 or 1.2.2 should be used.
Connecting
-----------
Connect string format::
mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
Character Sets
--------------
Many MySQL server installations default to a ``latin1`` encoding for client
connections. All data sent through the connection will be converted into
``latin1``, even if you have ``utf8`` or another character set on your tables
and columns. With versions 4.1 and higher, you can change the connection
character set either through server configuration or by including the
``charset`` parameter in the URL used for ``create_engine``. The ``charset``
option is passed through to MySQL-Python and has the side-effect of also
enabling ``use_unicode`` in the driver by default. For regular encoded
strings, also pass ``use_unicode=0`` in the connection arguments::
# set client encoding to utf8; all strings come back as unicode
create_engine('mysql+mysqldb:///mydb?charset=utf8')
# set client encoding to utf8; all strings come back as utf8 str
create_engine('mysql+mysqldb:///mydb?charset=utf8&use_unicode=0')
Known Issues
-------------
MySQL-python at least as of version 1.2.2 has a serious memory leak related
to unicode conversion, a feature which is disabled via ``use_unicode=0``.
The recommended connection form with SQLAlchemy is::
engine = create_engine('mysql://scott:tiger@localhost/test?charset=utf8&use_unicode=0', pool_recycle=3600)
"""
import re
from sqlalchemy.dialects.mysql.base import (MySQLDialect, MySQLExecutionContext,
MySQLCompiler, MySQLIdentifierPreparer)
from sqlalchemy.engine import base as engine_base, default
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
from sqlalchemy import processors
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
@property
def rowcount(self):
if hasattr(self, '_rowcount'):
return self._rowcount
else:
return self.cursor.rowcount
class MySQLCompiler_mysqldb(MySQLCompiler):
def visit_mod(self, binary, **kw):
return self.process(binary.left) + " %% " + self.process(binary.right)
def post_process_text(self, text):
return text.replace('%', '%%')
class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace("%", "%%")
class MySQLDialect_mysqldb(MySQLDialect):
driver = 'mysqldb'
supports_unicode_statements = False
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = 'format'
execution_ctx_cls = MySQLExecutionContext_mysqldb
statement_compiler = MySQLCompiler_mysqldb
preparer = MySQLIdentifierPreparer_mysqldb
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
}
)
@classmethod
def dbapi(cls):
return __import__('MySQLdb')
def do_executemany(self, cursor, statement, parameters, context=None):
rowcount = cursor.executemany(statement, parameters)
if context is not None:
context._rowcount = rowcount
def create_connect_args(self, url):
opts = url.translate_connect_args(database='db', username='user',
password='passwd')
opts.update(url.query)
util.coerce_kw_type(opts, 'compress', bool)
util.coerce_kw_type(opts, 'connect_timeout', int)
util.coerce_kw_type(opts, 'client_flag', int)
util.coerce_kw_type(opts, 'local_infile', int)
# Note: using either of the below will cause all strings to be returned
# as Unicode, both in raw SQL operations and with column types like
# String and MSString.
util.coerce_kw_type(opts, 'use_unicode', bool)
util.coerce_kw_type(opts, 'charset', str)
# Rich values 'cursorclass' and 'conv' are not supported via
# query string.
ssl = {}
for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']:
if key in opts:
ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], str)
del opts[key]
if ssl:
opts['ssl'] = ssl
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
client_flag = opts.get('client_flag', 0)
if self.dbapi is not None:
try:
from MySQLdb.constants import CLIENT as CLIENT_FLAGS
client_flag |= CLIENT_FLAGS.FOUND_ROWS
except:
pass
opts['client_flag'] = client_flag
return [[], opts]
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.get_server_info()):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
def _extract_error_code(self, exception):
try:
return exception.orig.args[0]
except AttributeError:
return None
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Note: MySQL-python 1.2.1c7 seems to ignore changes made
# on a connection via set_character_set()
if self.server_version_info < (4, 1, 0):
try:
return connection.connection.character_set_name()
except AttributeError:
# < 1.2.1 final MySQL-python drivers have no charset support.
# a query is needed.
pass
# Prefer 'character_set_results' for the current connection over the
# value in the driver. SET NAMES or individual variable SETs will
# change the charset without updating the driver's view of the world.
#
# If it's decided that issuing that sort of SQL leaves you SOL, then
# this can prefer the driver value.
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
if 'character_set_results' in opts:
return opts['character_set_results']
try:
return connection.connection.character_set_name()
except AttributeError:
# Still no charset on < 1.2.1 final...
if 'character_set' in opts:
return opts['character_set']
else:
util.warn(
"Could not detect the connection character set with this "
"combination of MySQL server and MySQL-python. "
"MySQL-python >= 1.2.2 is recommended. Assuming latin1.")
return 'latin1'
dialect = MySQLDialect_mysqldb

View File

@ -0,0 +1,255 @@
"""Support for the MySQL database via the oursql adapter.
OurSQL is available at:
http://packages.python.org/oursql/
Connecting
-----------
Connect string format::
mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
Character Sets
--------------
oursql defaults to using ``utf8`` as the connection charset, but other
encodings may be used instead. Like the MySQL-Python driver, unicode support
can be completely disabled::
# oursql sets the connection charset to utf8 automatically; all strings come
# back as utf8 str
create_engine('mysql+oursql:///mydb?use_unicode=0')
To not automatically use ``utf8`` and instead use whatever the connection
defaults to, there is a separate parameter::
# use the default connection charset; all strings come back as unicode
create_engine('mysql+oursql:///mydb?default_charset=1')
# use latin1 as the connection charset; all strings come back as unicode
create_engine('mysql+oursql:///mydb?charset=latin1')
"""
import re
from sqlalchemy.dialects.mysql.base import (BIT, MySQLDialect, MySQLExecutionContext,
MySQLCompiler, MySQLIdentifierPreparer)
from sqlalchemy.engine import base as engine_base, default
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
from sqlalchemy import processors
class _oursqlBIT(BIT):
def result_processor(self, dialect, coltype):
"""oursql already converts mysql bits, so."""
return None
class MySQLExecutionContext_oursql(MySQLExecutionContext):
@property
def plain_query(self):
return self.execution_options.get('_oursql_plain_query', False)
class MySQLDialect_oursql(MySQLDialect):
driver = 'oursql'
# Py3K
# description_encoding = None
# Py2K
supports_unicode_binds = True
supports_unicode_statements = True
# end Py2K
supports_native_decimal = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
execution_ctx_cls = MySQLExecutionContext_oursql
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
sqltypes.Time: sqltypes.Time,
BIT: _oursqlBIT,
}
)
@classmethod
def dbapi(cls):
return __import__('oursql')
def do_execute(self, cursor, statement, parameters, context=None):
"""Provide an implementation of *cursor.execute(statement, parameters)*."""
if context and context.plain_query:
cursor.execute(statement, plain_query=True)
else:
cursor.execute(statement, parameters)
def do_begin(self, connection):
connection.cursor().execute('BEGIN', plain_query=True)
def _xa_query(self, connection, query, xid):
# Py2K
arg = connection.connection._escape_string(xid)
# end Py2K
# Py3K
# charset = self._connection_charset
# arg = connection.connection._escape_string(xid.encode(charset)).decode(charset)
connection.execution_options(_oursql_plain_query=True).execute(query % arg)
# Because mysql is bad, these methods have to be
# reimplemented to use _PlainQuery. Basically, some queries
# refuse to return any data if they're run through
# the parameterized query API, or refuse to be parameterized
# in the first place.
def do_begin_twophase(self, connection, xid):
self._xa_query(connection, 'XA BEGIN "%s"', xid)
def do_prepare_twophase(self, connection, xid):
self._xa_query(connection, 'XA END "%s"', xid)
self._xa_query(connection, 'XA PREPARE "%s"', xid)
def do_rollback_twophase(self, connection, xid, is_prepared=True,
recover=False):
if not is_prepared:
self._xa_query(connection, 'XA END "%s"', xid)
self._xa_query(connection, 'XA ROLLBACK "%s"', xid)
def do_commit_twophase(self, connection, xid, is_prepared=True,
recover=False):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
self._xa_query(connection, 'XA COMMIT "%s"', xid)
# Q: why didn't we need all these "plain_query" overrides earlier ?
# am i on a newer/older version of OurSQL ?
def has_table(self, connection, table_name, schema=None):
return MySQLDialect.has_table(self,
connection.connect().\
execution_options(_oursql_plain_query=True),
table_name, schema)
def get_table_options(self, connection, table_name, schema=None, **kw):
return MySQLDialect.get_table_options(self,
connection.connect().\
execution_options(_oursql_plain_query=True),
table_name,
schema = schema,
**kw
)
def get_columns(self, connection, table_name, schema=None, **kw):
return MySQLDialect.get_columns(self,
connection.connect().\
execution_options(_oursql_plain_query=True),
table_name,
schema=schema,
**kw
)
def get_view_names(self, connection, schema=None, **kw):
return MySQLDialect.get_view_names(self,
connection.connect().\
execution_options(_oursql_plain_query=True),
schema=schema,
**kw
)
def get_table_names(self, connection, schema=None, **kw):
return MySQLDialect.get_table_names(self,
connection.connect().\
execution_options(_oursql_plain_query=True),
schema
)
def get_schema_names(self, connection, **kw):
return MySQLDialect.get_schema_names(self,
connection.connect().\
execution_options(_oursql_plain_query=True),
**kw
)
def initialize(self, connection):
return MySQLDialect.initialize(
self,
connection.execution_options(_oursql_plain_query=True)
)
def _show_create_table(self, connection, table, charset=None,
full_name=None):
return MySQLDialect._show_create_table(self,
connection.contextual_connect(close_with_result=True).
execution_options(_oursql_plain_query=True),
table, charset, full_name)
def is_disconnect(self, e):
if isinstance(e, self.dbapi.ProgrammingError):
return e.errno is None and 'cursor' not in e.args[1] and e.args[1].endswith('closed')
else:
return e.errno in (2006, 2013, 2014, 2045, 2055)
def create_connect_args(self, url):
opts = url.translate_connect_args(database='db', username='user',
password='passwd')
opts.update(url.query)
util.coerce_kw_type(opts, 'port', int)
util.coerce_kw_type(opts, 'compress', bool)
util.coerce_kw_type(opts, 'autoping', bool)
util.coerce_kw_type(opts, 'default_charset', bool)
if opts.pop('default_charset', False):
opts['charset'] = None
else:
util.coerce_kw_type(opts, 'charset', str)
opts['use_unicode'] = opts.get('use_unicode', True)
util.coerce_kw_type(opts, 'use_unicode', bool)
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
opts.setdefault('found_rows', True)
return [[], opts]
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.server_info):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
def _extract_error_code(self, exception):
try:
return exception.orig.errno
except AttributeError:
return None
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
return connection.connection.charset
def _compat_fetchall(self, rp, charset=None):
"""oursql isn't super-broken like MySQLdb, yaaay."""
return rp.fetchall()
def _compat_fetchone(self, rp, charset=None):
"""oursql isn't super-broken like MySQLdb, yaaay."""
return rp.fetchone()
def _compat_first(self, rp, charset=None):
return rp.first()
dialect = MySQLDialect_oursql

View File

@ -0,0 +1,76 @@
"""Support for the MySQL database via the pyodbc adapter.
pyodbc is available at:
http://pypi.python.org/pypi/pyodbc/
Connecting
----------
Connect string::
mysql+pyodbc://<username>:<password>@<dsnname>
Limitations
-----------
The mysql-pyodbc dialect is subject to unresolved character encoding issues
which exist within the current ODBC drivers available.
(see http://code.google.com/p/pyodbc/issues/detail?id=25). Consider usage
of OurSQL, MySQLdb, or MySQL-connector/Python.
"""
from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy.engine import base as engine_base
from sqlalchemy import util
import re
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
lastrowid = cursor.fetchone()[0]
cursor.close()
return lastrowid
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
supports_unicode_statements = False
execution_ctx_cls = MySQLExecutionContext_pyodbc
pyodbc_driver_name = "MySQL"
def __init__(self, **kw):
# deal with http://code.google.com/p/pyodbc/issues/detail?id=25
kw.setdefault('convert_unicode', True)
super(MySQLDialect_pyodbc, self).__init__(**kw)
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Prefer 'character_set_results' for the current connection over the
# value in the driver. SET NAMES or individual variable SETs will
# change the charset without updating the driver's view of the world.
#
# If it's decided that issuing that sort of SQL leaves you SOL, then
# this can prefer the driver value.
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
for key in ('character_set_connection', 'character_set'):
if opts.get(key, None):
return opts[key]
util.warn("Could not detect the connection character set. Assuming latin1.")
return 'latin1'
def _extract_error_code(self, exception):
m = re.compile(r"\((\d+)\)").search(str(exception.orig.args))
c = m.group(1)
if c:
return int(c)
else:
return None
dialect = MySQLDialect_pyodbc

View File

@ -0,0 +1,111 @@
"""Support for the MySQL database via Jython's zxjdbc JDBC connector.
JDBC Driver
-----------
The official MySQL JDBC driver is at
http://dev.mysql.com/downloads/connector/j/.
Connecting
----------
Connect string format:
mysql+zxjdbc://<user>:<password>@<hostname>[:<port>]/<database>
Character Sets
--------------
SQLAlchemy zxjdbc dialects pass unicode straight through to the
zxjdbc/JDBC layer. To allow multiple character sets to be sent from the
MySQL Connector/J JDBC driver, by default SQLAlchemy sets its
``characterEncoding`` connection property to ``UTF-8``. It may be
overriden via a ``create_engine`` URL parameter.
"""
import re
from sqlalchemy import types as sqltypes, util
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext
class _ZxJDBCBit(BIT):
def result_processor(self, dialect, coltype):
"""Converts boolean or byte arrays from MySQL Connector/J to longs."""
def process(value):
if value is None:
return value
if isinstance(value, bool):
return int(value)
v = 0L
for i in value:
v = v << 8 | (i & 0xff)
value = v
return value
return process
class MySQLExecutionContext_zxjdbc(MySQLExecutionContext):
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
lastrowid = cursor.fetchone()[0]
cursor.close()
return lastrowid
class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
jdbc_db_name = 'mysql'
jdbc_driver_name = 'com.mysql.jdbc.Driver'
execution_ctx_cls = MySQLExecutionContext_zxjdbc
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
sqltypes.Time: sqltypes.Time,
BIT: _ZxJDBCBit
}
)
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Prefer 'character_set_results' for the current connection over the
# value in the driver. SET NAMES or individual variable SETs will
# change the charset without updating the driver's view of the world.
#
# If it's decided that issuing that sort of SQL leaves you SOL, then
# this can prefer the driver value.
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs))
for key in ('character_set_connection', 'character_set'):
if opts.get(key, None):
return opts[key]
util.warn("Could not detect the connection character set. Assuming latin1.")
return 'latin1'
def _driver_kwargs(self):
"""return kw arg dict to be sent to connect()."""
return dict(characterEncoding='UTF-8', yearIsDateType='false')
def _extract_error_code(self, exception):
# e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist
# [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' ()
m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.orig.args))
c = m.group(1)
if c:
return int(c)
def _get_server_version_info(self,connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.dbversion):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
dialect = MySQLDialect_zxjdbc

View File

@ -0,0 +1,17 @@
from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc
base.dialect = cx_oracle.dialect
from sqlalchemy.dialects.oracle.base import \
VARCHAR, NVARCHAR, CHAR, DATE, DATETIME, NUMBER,\
BLOB, BFILE, CLOB, NCLOB, TIMESTAMP, RAW,\
FLOAT, DOUBLE_PRECISION, LONG, dialect, INTERVAL,\
VARCHAR2, NVARCHAR2
__all__ = (
'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'DATETIME', 'NUMBER',
'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW',
'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL',
'VARCHAR2', 'NVARCHAR2'
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,529 @@
"""Support for the Oracle database via the cx_oracle driver.
Driver
------
The Oracle dialect uses the cx_oracle driver, available at
http://cx-oracle.sourceforge.net/ . The dialect has several behaviors
which are specifically tailored towards compatibility with this module.
Connecting
----------
Connecting with create_engine() uses the standard URL approach of
``oracle://user:pass@host:port/dbname[?key=value&key=value...]``. If dbname is present, the
host, port, and dbname tokens are converted to a TNS name using the cx_oracle
:func:`makedsn()` function. Otherwise, the host token is taken directly as a TNS name.
Additional arguments which may be specified either as query string arguments on the
URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are:
* *allow_twophase* - enable two-phase transactions. Defaults to ``True``.
* *arraysize* - set the cx_oracle.arraysize value on cursors, in SQLAlchemy
it defaults to 50. See the section on "LOB Objects" below.
* *auto_convert_lobs* - defaults to True, see the section on LOB objects.
* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters.
This is required for LOB datatypes but can be disabled to reduce overhead. Defaults
to ``True``.
* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an
integer value. This value is only available as a URL query string argument.
* *threaded* - enable multithreaded access to cx_oracle connections. Defaults
to ``True``. Note that this is the opposite default of cx_oracle itself.
Unicode
-------
As of cx_oracle 5, Python unicode objects can be bound directly to statements,
and it appears that cx_oracle can handle these even without NLS_LANG being set.
SQLAlchemy tests for version 5 and will pass unicode objects straight to cx_oracle
if this is the case. For older versions of cx_oracle, SQLAlchemy will encode bind
parameters normally using dialect.encoding as the encoding.
LOB Objects
-----------
cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set
is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default,
SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First,
the LOB object requires an active cursor association, meaning if you were to fetch many rows
at once such that cx_oracle had to go back to the database and fetch a new batch of rows,
the LOB objects in the already-fetched rows are now unreadable and will raise an error.
SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read.
The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy
defaults to 50 (cx_oracle normally defaults this to one).
Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to
"normalize" the results to look more like that of other DBAPIs.
The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place
for all statement executions, even plain string-based statements for which SQLA has no awareness
of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases
without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch"
of LOB objects, can be disabled using auto_convert_lobs=False.
Two Phase Transaction Support
-----------------------------
Two Phase transactions are implemented using XA transactions. Success has been reported
with this feature but it should be regarded as experimental.
"""
from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, \
RESERVED_WORDS, OracleExecutionContext
from sqlalchemy.dialects.oracle import base as oracle
from sqlalchemy.engine import base
from sqlalchemy import types as sqltypes, util, exc
from datetime import datetime
import random
class _OracleNumeric(sqltypes.Numeric):
# cx_oracle accepts Decimal objects, but returns
# floats
def bind_processor(self, dialect):
return None
class _OracleDate(sqltypes.Date):
def bind_processor(self, dialect):
return None
def result_processor(self, dialect, coltype):
def process(value):
if value is not None:
return value.date()
else:
return value
return process
class _LOBMixin(object):
def result_processor(self, dialect, coltype):
if not dialect.auto_convert_lobs:
# return the cx_oracle.LOB directly.
return None
def process(value):
if value is not None:
return value.read()
else:
return value
return process
class _NativeUnicodeMixin(object):
# Py2K
def bind_processor(self, dialect):
if dialect._cx_oracle_with_unicode:
def process(value):
if value is None:
return value
else:
return unicode(value)
return process
else:
return super(_NativeUnicodeMixin, self).bind_processor(dialect)
# end Py2K
def result_processor(self, dialect, coltype):
# if we know cx_Oracle will return unicode,
# don't process results
if dialect._cx_oracle_with_unicode:
return None
elif self.convert_unicode != 'force' and \
dialect._cx_oracle_native_nvarchar and \
coltype in dialect._cx_oracle_unicode_types:
return None
else:
return super(_NativeUnicodeMixin, self).result_processor(dialect, coltype)
class _OracleChar(_NativeUnicodeMixin, sqltypes.CHAR):
def get_dbapi_type(self, dbapi):
return dbapi.FIXED_CHAR
class _OracleNVarChar(_NativeUnicodeMixin, sqltypes.NVARCHAR):
def get_dbapi_type(self, dbapi):
return getattr(dbapi, 'UNICODE', dbapi.STRING)
class _OracleText(_LOBMixin, sqltypes.Text):
def get_dbapi_type(self, dbapi):
return dbapi.CLOB
class _OracleString(_NativeUnicodeMixin, sqltypes.String):
pass
class _OracleUnicodeText(_LOBMixin, _NativeUnicodeMixin, sqltypes.UnicodeText):
def get_dbapi_type(self, dbapi):
return dbapi.NCLOB
def result_processor(self, dialect, coltype):
lob_processor = _LOBMixin.result_processor(self, dialect, coltype)
if lob_processor is None:
return None
string_processor = _NativeUnicodeMixin.result_processor(self, dialect, coltype)
if string_processor is None:
return lob_processor
else:
def process(value):
return string_processor(lob_processor(value))
return process
class _OracleInteger(sqltypes.Integer):
def result_processor(self, dialect, coltype):
def to_int(val):
if val is not None:
val = int(val)
return val
return to_int
class _OracleBinary(_LOBMixin, sqltypes.LargeBinary):
def get_dbapi_type(self, dbapi):
return dbapi.BLOB
def bind_processor(self, dialect):
return None
class _OracleInterval(oracle.INTERVAL):
def get_dbapi_type(self, dbapi):
return dbapi.INTERVAL
class _OracleRaw(oracle.RAW):
pass
class OracleCompiler_cx_oracle(OracleCompiler):
def bindparam_string(self, name):
if self.preparer._bindparam_requires_quotes(name):
quoted_name = '"%s"' % name
self._quoted_bind_names[name] = quoted_name
return OracleCompiler.bindparam_string(self, quoted_name)
else:
return OracleCompiler.bindparam_string(self, name)
class OracleExecutionContext_cx_oracle(OracleExecutionContext):
def pre_exec(self):
quoted_bind_names = \
getattr(self.compiled, '_quoted_bind_names', None)
if quoted_bind_names:
if not self.dialect.supports_unicode_binds:
quoted_bind_names = \
dict(
(fromname, toname.encode(self.dialect.encoding))
for fromname, toname in
quoted_bind_names.items()
)
for param in self.parameters:
for fromname, toname in quoted_bind_names.items():
param[toname] = param[fromname]
del param[fromname]
if self.dialect.auto_setinputsizes:
# cx_oracle really has issues when you setinputsizes
# on String, including that outparams/RETURNING
# breaks for varchars
self.set_input_sizes(quoted_bind_names,
exclude_types=self.dialect._cx_oracle_string_types
)
# if a single execute, check for outparams
if len(self.compiled_parameters) == 1:
for bindparam in self.compiled.binds.values():
if bindparam.isoutparam:
dbtype = bindparam.type.dialect_impl(self.dialect).\
get_dbapi_type(self.dialect.dbapi)
if not hasattr(self, 'out_parameters'):
self.out_parameters = {}
if dbtype is None:
raise exc.InvalidRequestError("Cannot create out parameter for parameter "
"%r - it's type %r is not supported by"
" cx_oracle" %
(name, bindparam.type)
)
name = self.compiled.bind_names[bindparam]
self.out_parameters[name] = self.cursor.var(dbtype)
self.parameters[0][quoted_bind_names.get(name, name)] = \
self.out_parameters[name]
def create_cursor(self):
c = self._connection.connection.cursor()
if self.dialect.arraysize:
c.arraysize = self.dialect.arraysize
return c
def get_result_proxy(self):
if hasattr(self, 'out_parameters') and self.compiled.returning:
returning_params = dict(
(k, v.getvalue())
for k, v in self.out_parameters.items()
)
return ReturningResultProxy(self, returning_params)
result = None
if self.cursor.description is not None:
for column in self.cursor.description:
type_code = column[1]
if type_code in self.dialect._cx_oracle_binary_types:
result = base.BufferedColumnResultProxy(self)
if result is None:
result = base.ResultProxy(self)
if hasattr(self, 'out_parameters'):
if self.compiled_parameters is not None and \
len(self.compiled_parameters) == 1:
result.out_parameters = out_parameters = {}
for bind, name in self.compiled.bind_names.items():
if name in self.out_parameters:
type = bind.type
impl_type = type.dialect_impl(self.dialect)
dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi)
result_processor = impl_type.\
result_processor(self.dialect,
dbapi_type)
if result_processor is not None:
out_parameters[name] = \
result_processor(self.out_parameters[name].getvalue())
else:
out_parameters[name] = self.out_parameters[name].getvalue()
else:
result.out_parameters = dict(
(k, v.getvalue())
for k, v in self.out_parameters.items()
)
return result
class OracleExecutionContext_cx_oracle_with_unicode(OracleExecutionContext_cx_oracle):
"""Support WITH_UNICODE in Python 2.xx.
WITH_UNICODE allows cx_Oracle's Python 3 unicode handling
behavior under Python 2.x. This mode in some cases disallows
and in other cases silently passes corrupted data when
non-Python-unicode strings (a.k.a. plain old Python strings)
are passed as arguments to connect(), the statement sent to execute(),
or any of the bind parameter keys or values sent to execute().
This optional context therefore ensures that all statements are
passed as Python unicode objects.
"""
def __init__(self, *arg, **kw):
OracleExecutionContext_cx_oracle.__init__(self, *arg, **kw)
self.statement = unicode(self.statement)
def _execute_scalar(self, stmt):
return super(OracleExecutionContext_cx_oracle_with_unicode, self).\
_execute_scalar(unicode(stmt))
class ReturningResultProxy(base.FullyBufferedResultProxy):
"""Result proxy which stuffs the _returning clause + outparams into the fetch."""
def __init__(self, context, returning_params):
self._returning_params = returning_params
super(ReturningResultProxy, self).__init__(context)
def _cursor_description(self):
returning = self.context.compiled.returning
ret = []
for c in returning:
if hasattr(c, 'name'):
ret.append((c.name, c.type))
else:
ret.append((c.anon_label, c.type))
return ret
def _buffer_rows(self):
return [tuple(self._returning_params["ret_%d" % i]
for i, c in enumerate(self._returning_params))]
class OracleDialect_cx_oracle(OracleDialect):
execution_ctx_cls = OracleExecutionContext_cx_oracle
statement_compiler = OracleCompiler_cx_oracle
driver = "cx_oracle"
colspecs = colspecs = {
sqltypes.Numeric: _OracleNumeric,
sqltypes.Date : _OracleDate, # generic type, assume datetime.date is desired
oracle.DATE: oracle.DATE, # non generic type - passthru
sqltypes.LargeBinary : _OracleBinary,
sqltypes.Boolean : oracle._OracleBoolean,
sqltypes.Interval : _OracleInterval,
oracle.INTERVAL : _OracleInterval,
sqltypes.Text : _OracleText,
sqltypes.String : _OracleString,
sqltypes.UnicodeText : _OracleUnicodeText,
sqltypes.CHAR : _OracleChar,
sqltypes.Integer : _OracleInteger, # this is only needed for OUT parameters.
# it would be nice if we could not use it otherwise.
oracle.NUMBER : oracle.NUMBER, # don't let this get converted
oracle.RAW: _OracleRaw,
sqltypes.Unicode: _OracleNVarChar,
sqltypes.NVARCHAR : _OracleNVarChar,
}
execute_sequence_format = list
def __init__(self,
auto_setinputsizes=True,
auto_convert_lobs=True,
threaded=True,
allow_twophase=True,
arraysize=50, **kwargs):
OracleDialect.__init__(self, **kwargs)
self.threaded = threaded
self.arraysize = arraysize
self.allow_twophase = allow_twophase
self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
self.auto_setinputsizes = auto_setinputsizes
self.auto_convert_lobs = auto_convert_lobs
if hasattr(self.dbapi, 'version'):
cx_oracle_ver = tuple([int(x) for x in self.dbapi.version.split('.')])
else:
cx_oracle_ver = (0, 0, 0)
def types(*names):
return set([
getattr(self.dbapi, name, None) for name in names
]).difference([None])
self._cx_oracle_string_types = types("STRING", "UNICODE", "NCLOB", "CLOB")
self._cx_oracle_unicode_types = types("UNICODE", "NCLOB")
self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB")
self.supports_unicode_binds = cx_oracle_ver >= (5, 0)
self._cx_oracle_native_nvarchar = cx_oracle_ver >= (5, 0)
if cx_oracle_ver is None:
# this occurs in tests with mock DBAPIs
self._cx_oracle_string_types = set()
self._cx_oracle_with_unicode = False
elif cx_oracle_ver >= (5,) and not hasattr(self.dbapi, 'UNICODE'):
# cx_Oracle WITH_UNICODE mode. *only* python
# unicode objects accepted for anything
self.supports_unicode_statements = True
self.supports_unicode_binds = True
self._cx_oracle_with_unicode = True
# Py2K
# There's really no reason to run with WITH_UNICODE under Python 2.x.
# Give the user a hint.
util.warn("cx_Oracle is compiled under Python 2.xx using the "
"WITH_UNICODE flag. Consider recompiling cx_Oracle without "
"this flag, which is in no way necessary for full support of Unicode. "
"Otherwise, all string-holding bind parameters must "
"be explicitly typed using SQLAlchemy's String type or one of its subtypes,"
"or otherwise be passed as Python unicode. Plain Python strings "
"passed as bind parameters will be silently corrupted by cx_Oracle."
)
self.execution_ctx_cls = OracleExecutionContext_cx_oracle_with_unicode
# end Py2K
else:
self._cx_oracle_with_unicode = False
if cx_oracle_ver is None or \
not self.auto_convert_lobs or \
not hasattr(self.dbapi, 'CLOB'):
self.dbapi_type_map = {}
else:
# only use this for LOB objects. using it for strings, dates
# etc. leads to a little too much magic, reflection doesn't know if it should
# expect encoded strings or unicodes, etc.
self.dbapi_type_map = {
self.dbapi.CLOB: oracle.CLOB(),
self.dbapi.NCLOB:oracle.NCLOB(),
self.dbapi.BLOB: oracle.BLOB(),
self.dbapi.BINARY: oracle.RAW(),
}
@classmethod
def dbapi(cls):
import cx_Oracle
return cx_Oracle
def create_connect_args(self, url):
dialect_opts = dict(url.query)
for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
'threaded', 'allow_twophase'):
if opt in dialect_opts:
util.coerce_kw_type(dialect_opts, opt, bool)
setattr(self, opt, dialect_opts[opt])
if url.database:
# if we have a database, then we have a remote host
port = url.port
if port:
port = int(port)
else:
port = 1521
dsn = self.dbapi.makedsn(url.host, port, url.database)
else:
# we have a local tnsname
dsn = url.host
opts = dict(
user=url.username,
password=url.password,
dsn=dsn,
threaded=self.threaded,
twophase=self.allow_twophase,
)
# Py2K
if self._cx_oracle_with_unicode:
for k, v in opts.items():
if isinstance(v, str):
opts[k] = unicode(v)
# end Py2K
if 'mode' in url.query:
opts['mode'] = url.query['mode']
if isinstance(opts['mode'], basestring):
mode = opts['mode'].upper()
if mode == 'SYSDBA':
opts['mode'] = self.dbapi.SYSDBA
elif mode == 'SYSOPER':
opts['mode'] = self.dbapi.SYSOPER
else:
util.coerce_kw_type(opts, 'mode', int)
return ([], opts)
def _get_server_version_info(self, connection):
return tuple(int(x) for x in connection.connection.version.split('.'))
def is_disconnect(self, e):
if isinstance(e, self.dbapi.InterfaceError):
return "not connected" in str(e)
else:
return "ORA-03114" in str(e) or "ORA-03113" in str(e)
def create_xid(self):
"""create a two-phase transaction ID.
this id will be passed to do_begin_twophase(), do_rollback_twophase(),
do_commit_twophase(). its format is unspecified."""
id = random.randint(0, 2 ** 128)
return (0x1234, "%032x" % id, "%032x" % 9)
def do_begin_twophase(self, connection, xid):
connection.connection.begin(*xid)
def do_prepare_twophase(self, connection, xid):
connection.connection.prepare()
def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
self.do_rollback(connection.connection)
def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
self.do_commit(connection.connection)
def do_recover_twophase(self, connection):
pass
dialect = OracleDialect_cx_oracle

View File

@ -0,0 +1,209 @@
"""Support for the Oracle database via the zxjdbc JDBC connector.
JDBC Driver
-----------
The official Oracle JDBC driver is at
http://www.oracle.com/technology/software/tech/java/sqlj_jdbc/index.html.
"""
import decimal
import re
from sqlalchemy import sql, types as sqltypes, util
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, OracleExecutionContext
from sqlalchemy.engine import base, default
from sqlalchemy.sql import expression
SQLException = zxJDBC = None
class _ZxJDBCDate(sqltypes.Date):
def result_processor(self, dialect, coltype):
def process(value):
if value is None:
return None
else:
return value.date()
return process
class _ZxJDBCNumeric(sqltypes.Numeric):
def result_processor(self, dialect, coltype):
#XXX: does the dialect return Decimal or not???
# if it does (in all cases), we could use a None processor as well as
# the to_float generic processor
if self.asdecimal:
def process(value):
if isinstance(value, decimal.Decimal):
return value
else:
return decimal.Decimal(str(value))
else:
def process(value):
if isinstance(value, decimal.Decimal):
return float(value)
else:
return value
return process
class OracleCompiler_zxjdbc(OracleCompiler):
def returning_clause(self, stmt, returning_cols):
self.returning_cols = list(expression._select_iterables(returning_cols))
# within_columns_clause=False so that labels (foo AS bar) don't render
columns = [self.process(c, within_columns_clause=False, result_map=self.result_map)
for c in self.returning_cols]
if not hasattr(self, 'returning_parameters'):
self.returning_parameters = []
binds = []
for i, col in enumerate(self.returning_cols):
dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
self.returning_parameters.append((i + 1, dbtype))
bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype))
self.binds[bindparam.key] = bindparam
binds.append(self.bindparam_string(self._truncate_bindparam(bindparam)))
return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
class OracleExecutionContext_zxjdbc(OracleExecutionContext):
def pre_exec(self):
if hasattr(self.compiled, 'returning_parameters'):
# prepare a zxJDBC statement so we can grab its underlying
# OraclePreparedStatement's getReturnResultSet later
self.statement = self.cursor.prepare(self.statement)
def get_result_proxy(self):
if hasattr(self.compiled, 'returning_parameters'):
rrs = None
try:
try:
rrs = self.statement.__statement__.getReturnResultSet()
rrs.next()
except SQLException, sqle:
msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode())
if sqle.getSQLState() is not None:
msg += ' [SQLState: %s]' % sqle.getSQLState()
raise zxJDBC.Error(msg)
else:
row = tuple(self.cursor.datahandler.getPyObject(rrs, index, dbtype)
for index, dbtype in self.compiled.returning_parameters)
return ReturningResultProxy(self, row)
finally:
if rrs is not None:
try:
rrs.close()
except SQLException:
pass
self.statement.close()
return base.ResultProxy(self)
def create_cursor(self):
cursor = self._connection.connection.cursor()
cursor.datahandler = self.dialect.DataHandler(cursor.datahandler)
return cursor
class ReturningResultProxy(base.FullyBufferedResultProxy):
"""ResultProxy backed by the RETURNING ResultSet results."""
def __init__(self, context, returning_row):
self._returning_row = returning_row
super(ReturningResultProxy, self).__init__(context)
def _cursor_description(self):
ret = []
for c in self.context.compiled.returning_cols:
if hasattr(c, 'name'):
ret.append((c.name, c.type))
else:
ret.append((c.anon_label, c.type))
return ret
def _buffer_rows(self):
return [self._returning_row]
class ReturningParam(object):
"""A bindparam value representing a RETURNING parameter.
Specially handled by OracleReturningDataHandler.
"""
def __init__(self, type):
self.type = type
def __eq__(self, other):
if isinstance(other, ReturningParam):
return self.type == other.type
return NotImplemented
def __ne__(self, other):
if isinstance(other, ReturningParam):
return self.type != other.type
return NotImplemented
def __repr__(self):
kls = self.__class__
return '<%s.%s object at 0x%x type=%s>' % (kls.__module__, kls.__name__, id(self),
self.type)
class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
jdbc_db_name = 'oracle'
jdbc_driver_name = 'oracle.jdbc.OracleDriver'
statement_compiler = OracleCompiler_zxjdbc
execution_ctx_cls = OracleExecutionContext_zxjdbc
colspecs = util.update_copy(
OracleDialect.colspecs,
{
sqltypes.Date : _ZxJDBCDate,
sqltypes.Numeric: _ZxJDBCNumeric
}
)
def __init__(self, *args, **kwargs):
super(OracleDialect_zxjdbc, self).__init__(*args, **kwargs)
global SQLException, zxJDBC
from java.sql import SQLException
from com.ziclix.python.sql import zxJDBC
from com.ziclix.python.sql.handler import OracleDataHandler
class OracleReturningDataHandler(OracleDataHandler):
"""zxJDBC DataHandler that specially handles ReturningParam."""
def setJDBCObject(self, statement, index, object, dbtype=None):
if type(object) is ReturningParam:
statement.registerReturnParameter(index, object.type)
elif dbtype is None:
OracleDataHandler.setJDBCObject(self, statement, index, object)
else:
OracleDataHandler.setJDBCObject(self, statement, index, object, dbtype)
self.DataHandler = OracleReturningDataHandler
def initialize(self, connection):
super(OracleDialect_zxjdbc, self).initialize(connection)
self.implicit_returning = connection.connection.driverversion >= '10.2'
def _create_jdbc_url(self, url):
return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database)
def _get_server_version_info(self, connection):
version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
return tuple(int(x) for x in version.split('.'))
dialect = OracleDialect_zxjdbc

View File

@ -0,0 +1,10 @@
# backwards compat with the old name
from sqlalchemy.util import warn_deprecated
warn_deprecated(
"The SQLAlchemy PostgreSQL dialect has been renamed from 'postgres' to 'postgresql'. "
"The new URL format is postgresql[+driver]://<user>:<pass>@<host>/<dbname>"
)
from sqlalchemy.dialects.postgresql import *
from sqlalchemy.dialects.postgresql import base

View File

@ -0,0 +1,14 @@
from sqlalchemy.dialects.postgresql import base, psycopg2, pg8000, pypostgresql, zxjdbc
base.dialect = psycopg2.dialect
from sqlalchemy.dialects.postgresql.base import \
INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, INET, \
CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME,\
DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect
__all__ = (
'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', 'FLOAT', 'REAL', 'INET',
'CIDR', 'UUID', 'BIT', 'MACADDR', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME',
'DATE', 'BYTEA', 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect'
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,105 @@
"""Support for the PostgreSQL database via the pg8000 driver.
Connecting
----------
URLs are of the form
`postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]`.
Unicode
-------
pg8000 requires that the postgresql client encoding be configured in the postgresql.conf file
in order to use encodings other than ascii. Set this value to the same value as
the "encoding" parameter on create_engine(), usually "utf-8".
Interval
--------
Passing data from/to the Interval type is not supported as of yet.
"""
import decimal
from sqlalchemy.engine import default
from sqlalchemy import util, exc
from sqlalchemy import processors
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql.base import PGDialect, \
PGCompiler, PGIdentifierPreparer, PGExecutionContext
class _PGNumeric(sqltypes.Numeric):
def result_processor(self, dialect, coltype):
if self.asdecimal:
if coltype in (700, 701):
return processors.to_decimal_processor_factory(decimal.Decimal)
elif coltype == 1700:
# pg8000 returns Decimal natively for 1700
return None
else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
else:
if coltype in (700, 701):
# pg8000 returns float natively for 701
return None
elif coltype == 1700:
return processors.to_float
else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
class PGExecutionContext_pg8000(PGExecutionContext):
pass
class PGCompiler_pg8000(PGCompiler):
def visit_mod(self, binary, **kw):
return self.process(binary.left) + " %% " + self.process(binary.right)
def post_process_text(self, text):
if '%%' in text:
util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() "
"expressions to '%%'.")
return text.replace('%', '%%')
class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace('%', '%%')
class PGDialect_pg8000(PGDialect):
driver = 'pg8000'
supports_unicode_statements = True
supports_unicode_binds = True
default_paramstyle = 'format'
supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_pg8000
statement_compiler = PGCompiler_pg8000
preparer = PGIdentifierPreparer_pg8000
colspecs = util.update_copy(
PGDialect.colspecs,
{
sqltypes.Numeric : _PGNumeric,
}
)
@classmethod
def dbapi(cls):
return __import__('pg8000').dbapi
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if 'port' in opts:
opts['port'] = int(opts['port'])
opts.update(url.query)
return ([], opts)
def is_disconnect(self, e):
return "connection is closed" in str(e)
dialect = PGDialect_pg8000

View File

@ -0,0 +1,239 @@
"""Support for the PostgreSQL database via the psycopg2 driver.
Driver
------
The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
The dialect has several behaviors which are specifically tailored towards compatibility
with this module.
Note that psycopg1 is **not** supported.
Connecting
----------
URLs are of the form `postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]`.
psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
this feature. What this essentially means from a psycopg2 point of view is that the cursor is
created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
are not immediately pre-fetched and buffered after statement execution, but are instead left
on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows
at a time are fetched over the wire to reduce conversational overhead.
* *use_native_unicode* - Enable the usage of Psycopg2 "native unicode" mode per connection. True
by default.
* *isolation_level* - Sets the transaction isolation level for each transaction
within the engine. Valid isolation levels are `READ_COMMITTED`,
`READ_UNCOMMITTED`, `REPEATABLE_READ`, and `SERIALIZABLE`.
Transactions
------------
The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
NOTICE logging
---------------
The psycopg2 dialect will log Postgresql NOTICE messages via the
``sqlalchemy.dialects.postgresql`` logger::
import logging
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
Per-Statement Execution Options
-------------------------------
The following per-statement execution options are respected:
* *stream_results* - Enable or disable usage of server side cursors for the SELECT-statement.
If *None* or not set, the *server_side_cursors* option of the connection is used. If
auto-commit is enabled, the option is ignored.
"""
import random
import re
import decimal
import logging
from sqlalchemy import util
from sqlalchemy import processors
from sqlalchemy.engine import base, default
from sqlalchemy.sql import expression
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, \
PGIdentifierPreparer, PGExecutionContext, \
ENUM, ARRAY
logger = logging.getLogger('sqlalchemy.dialects.postgresql')
class _PGNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
return None
def result_processor(self, dialect, coltype):
if self.asdecimal:
if coltype in (700, 701):
return processors.to_decimal_processor_factory(decimal.Decimal)
elif coltype == 1700:
# pg8000 returns Decimal natively for 1700
return None
else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
else:
if coltype in (700, 701):
# pg8000 returns float natively for 701
return None
elif coltype == 1700:
return processors.to_float
else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
class _PGEnum(ENUM):
def __init__(self, *arg, **kw):
super(_PGEnum, self).__init__(*arg, **kw)
if self.convert_unicode:
self.convert_unicode = "force"
class _PGArray(ARRAY):
def __init__(self, *arg, **kw):
super(_PGArray, self).__init__(*arg, **kw)
# FIXME: this check won't work for setups that
# have convert_unicode only on their create_engine().
if isinstance(self.item_type, sqltypes.String) and \
self.item_type.convert_unicode:
self.item_type.convert_unicode = "force"
# When we're handed literal SQL, ensure it's a SELECT-query. Since
# 8.3, combining cursors and "FOR UPDATE" has been fine.
SERVER_SIDE_CURSOR_RE = re.compile(
r'\s*SELECT',
re.I | re.UNICODE)
class PGExecutionContext_psycopg2(PGExecutionContext):
def create_cursor(self):
# TODO: coverage for server side cursors + select.for_update()
if self.dialect.server_side_cursors:
is_server_side = \
self.execution_options.get('stream_results', True) and (
(self.compiled and isinstance(self.compiled.statement, expression.Selectable) \
or \
(
(not self.compiled or
isinstance(self.compiled.statement, expression._TextClause))
and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
)
)
else:
is_server_side = self.execution_options.get('stream_results', False)
self.__is_server_side = is_server_side
if is_server_side:
# use server-side cursors:
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
return self._connection.connection.cursor(ident)
else:
return self._connection.connection.cursor()
def get_result_proxy(self):
if logger.isEnabledFor(logging.INFO):
self._log_notices(self.cursor)
if self.__is_server_side:
return base.BufferedRowResultProxy(self)
else:
return base.ResultProxy(self)
def _log_notices(self, cursor):
for notice in cursor.connection.notices:
# NOTICE messages have a
# newline character at the end
logger.info(notice.rstrip())
cursor.connection.notices[:] = []
class PGCompiler_psycopg2(PGCompiler):
def visit_mod(self, binary, **kw):
return self.process(binary.left) + " %% " + self.process(binary.right)
def post_process_text(self, text):
return text.replace('%', '%%')
class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace('%', '%%')
class PGDialect_psycopg2(PGDialect):
driver = 'psycopg2'
supports_unicode_statements = False
default_paramstyle = 'pyformat'
supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_psycopg2
statement_compiler = PGCompiler_psycopg2
preparer = PGIdentifierPreparer_psycopg2
colspecs = util.update_copy(
PGDialect.colspecs,
{
sqltypes.Numeric : _PGNumeric,
ENUM : _PGEnum, # needs force_unicode
sqltypes.Enum : _PGEnum, # needs force_unicode
ARRAY : _PGArray, # needs force_unicode
}
)
def __init__(self, server_side_cursors=False, use_native_unicode=True, **kwargs):
PGDialect.__init__(self, **kwargs)
self.server_side_cursors = server_side_cursors
self.use_native_unicode = use_native_unicode
self.supports_unicode_binds = use_native_unicode
@classmethod
def dbapi(cls):
psycopg = __import__('psycopg2')
return psycopg
def on_connect(self):
base_on_connect = super(PGDialect_psycopg2, self).on_connect()
if self.dbapi and self.use_native_unicode:
extensions = __import__('psycopg2.extensions').extensions
def connect(conn):
extensions.register_type(extensions.UNICODE, conn)
if base_on_connect:
base_on_connect(conn)
return connect
else:
return base_on_connect
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if 'port' in opts:
opts['port'] = int(opts['port'])
opts.update(url.query)
return ([], opts)
def is_disconnect(self, e):
if isinstance(e, self.dbapi.OperationalError):
return 'closed the connection' in str(e) or 'connection not open' in str(e)
elif isinstance(e, self.dbapi.InterfaceError):
return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
elif isinstance(e, self.dbapi.ProgrammingError):
# yes, it really says "losed", not "closed"
return "losed the connection unexpectedly" in str(e)
else:
return False
dialect = PGDialect_psycopg2

View File

@ -0,0 +1,69 @@
"""Support for the PostgreSQL database via py-postgresql.
Connecting
----------
URLs are of the form `postgresql+pypostgresql://user@password@host:port/dbname[?key=value&key=value...]`.
"""
from sqlalchemy.engine import default
import decimal
from sqlalchemy import util
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext
from sqlalchemy import processors
class PGNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
return processors.to_str
def result_processor(self, dialect, coltype):
if self.asdecimal:
return None
else:
return processors.to_float
class PGExecutionContext_pypostgresql(PGExecutionContext):
pass
class PGDialect_pypostgresql(PGDialect):
driver = 'pypostgresql'
supports_unicode_statements = True
supports_unicode_binds = True
description_encoding = None
default_paramstyle = 'pyformat'
# requires trunk version to support sane rowcounts
# TODO: use dbapi version information to set this flag appropariately
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_pypostgresql
colspecs = util.update_copy(
PGDialect.colspecs,
{
sqltypes.Numeric : PGNumeric,
sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used
}
)
@classmethod
def dbapi(cls):
from postgresql.driver import dbapi20
return dbapi20
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if 'port' in opts:
opts['port'] = int(opts['port'])
else:
opts['port'] = 5432
opts.update(url.query)
return ([], opts)
def is_disconnect(self, e):
return "connection is closed" in str(e)
dialect = PGDialect_pypostgresql

View File

@ -0,0 +1,19 @@
"""Support for the PostgreSQL database via the zxjdbc JDBC connector.
JDBC Driver
-----------
The official Postgresql JDBC driver is at http://jdbc.postgresql.org/.
"""
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
from sqlalchemy.dialects.postgresql.base import PGDialect
class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect):
jdbc_db_name = 'postgresql'
jdbc_driver_name = 'org.postgresql.Driver'
def _get_server_version_info(self, connection):
return tuple(int(x) for x in connection.connection.dbversion.split('.'))
dialect = PGDialect_zxjdbc

View File

@ -0,0 +1,14 @@
from sqlalchemy.dialects.sqlite import base, pysqlite
# default dialect
base.dialect = pysqlite.dialect
from sqlalchemy.dialects.sqlite.base import \
BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, FLOAT, INTEGER,\
NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, dialect
__all__ = (
'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'FLOAT', 'INTEGER',
'NUMERIC', 'SMALLINT', 'TEXT', 'TIME', 'TIMESTAMP', 'VARCHAR', 'dialect'
)

View File

@ -0,0 +1,596 @@
# sqlite.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Support for the SQLite database.
For information on connecting using a specific driver, see the documentation
section regarding that driver.
Date and Time Types
-------------------
SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide
out of the box functionality for translating values between Python `datetime` objects
and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime`
and related types provide date formatting and parsing functionality when SQlite is used.
The implementation classes are :class:`DATETIME`, :class:`DATE` and :class:`TIME`.
These types represent dates and times as ISO formatted strings, which also nicely
support ordering. There's no reliance on typical "libc" internals for these functions
so historical dates are fully supported.
Auto Incrementing Behavior
--------------------------
Background on SQLite's autoincrement is at: http://sqlite.org/autoinc.html
Two things to note:
* The AUTOINCREMENT keyword is **not** required for SQLite tables to
generate primary key values automatically. AUTOINCREMENT only means that
the algorithm used to generate ROWID values should be slightly different.
* SQLite does **not** generate primary key (i.e. ROWID) values, even for
one column, if the table has a composite (i.e. multi-column) primary key.
This is regardless of the AUTOINCREMENT keyword being present or not.
To specifically render the AUTOINCREMENT keyword on the primary key
column when rendering DDL, add the flag ``sqlite_autoincrement=True``
to the Table construct::
Table('sometable', metadata,
Column('id', Integer, primary_key=True),
sqlite_autoincrement=True)
"""
import datetime, re, time
from sqlalchemy import schema as sa_schema
from sqlalchemy import sql, exc, pool, DefaultClause
from sqlalchemy.engine import default
from sqlalchemy.engine import reflection
from sqlalchemy import types as sqltypes
from sqlalchemy import util
from sqlalchemy.sql import compiler, functions as sql_functions
from sqlalchemy.util import NoneType
from sqlalchemy import processors
from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\
FLOAT, INTEGER, NUMERIC, SMALLINT, TEXT, TIME,\
TIMESTAMP, VARCHAR
class _DateTimeMixin(object):
_reg = None
_storage_format = None
def __init__(self, storage_format=None, regexp=None, **kwargs):
if regexp is not None:
self._reg = re.compile(regexp)
if storage_format is not None:
self._storage_format = storage_format
class DATETIME(_DateTimeMixin, sqltypes.DateTime):
_storage_format = "%04d-%02d-%02d %02d:%02d:%02d.%06d"
def bind_processor(self, dialect):
datetime_datetime = datetime.datetime
datetime_date = datetime.date
format = self._storage_format
def process(value):
if value is None:
return None
elif isinstance(value, datetime_datetime):
return format % (value.year, value.month, value.day,
value.hour, value.minute, value.second,
value.microsecond)
elif isinstance(value, datetime_date):
return format % (value.year, value.month, value.day,
0, 0, 0, 0)
else:
raise TypeError("SQLite DateTime type only accepts Python "
"datetime and date objects as input.")
return process
def result_processor(self, dialect, coltype):
if self._reg:
return processors.str_to_datetime_processor_factory(
self._reg, datetime.datetime)
else:
return processors.str_to_datetime
class DATE(_DateTimeMixin, sqltypes.Date):
_storage_format = "%04d-%02d-%02d"
def bind_processor(self, dialect):
datetime_date = datetime.date
format = self._storage_format
def process(value):
if value is None:
return None
elif isinstance(value, datetime_date):
return format % (value.year, value.month, value.day)
else:
raise TypeError("SQLite Date type only accepts Python "
"date objects as input.")
return process
def result_processor(self, dialect, coltype):
if self._reg:
return processors.str_to_datetime_processor_factory(
self._reg, datetime.date)
else:
return processors.str_to_date
class TIME(_DateTimeMixin, sqltypes.Time):
_storage_format = "%02d:%02d:%02d.%06d"
def bind_processor(self, dialect):
datetime_time = datetime.time
format = self._storage_format
def process(value):
if value is None:
return None
elif isinstance(value, datetime_time):
return format % (value.hour, value.minute, value.second,
value.microsecond)
else:
raise TypeError("SQLite Time type only accepts Python "
"time objects as input.")
return process
def result_processor(self, dialect, coltype):
if self._reg:
return processors.str_to_datetime_processor_factory(
self._reg, datetime.time)
else:
return processors.str_to_time
colspecs = {
sqltypes.Date: DATE,
sqltypes.DateTime: DATETIME,
sqltypes.Time: TIME,
}
ischema_names = {
'BLOB': sqltypes.BLOB,
'BOOL': sqltypes.BOOLEAN,
'BOOLEAN': sqltypes.BOOLEAN,
'CHAR': sqltypes.CHAR,
'DATE': sqltypes.DATE,
'DATETIME': sqltypes.DATETIME,
'DECIMAL': sqltypes.DECIMAL,
'FLOAT': sqltypes.FLOAT,
'INT': sqltypes.INTEGER,
'INTEGER': sqltypes.INTEGER,
'NUMERIC': sqltypes.NUMERIC,
'REAL': sqltypes.Numeric,
'SMALLINT': sqltypes.SMALLINT,
'TEXT': sqltypes.TEXT,
'TIME': sqltypes.TIME,
'TIMESTAMP': sqltypes.TIMESTAMP,
'VARCHAR': sqltypes.VARCHAR,
}
class SQLiteCompiler(compiler.SQLCompiler):
extract_map = util.update_copy(
compiler.SQLCompiler.extract_map,
{
'month': '%m',
'day': '%d',
'year': '%Y',
'second': '%S',
'hour': '%H',
'doy': '%j',
'minute': '%M',
'epoch': '%s',
'dow': '%w',
'week': '%W'
})
def visit_now_func(self, fn, **kw):
return "CURRENT_TIMESTAMP"
def visit_char_length_func(self, fn, **kw):
return "length%s" % self.function_argspec(fn)
def visit_cast(self, cast, **kwargs):
if self.dialect.supports_cast:
return super(SQLiteCompiler, self).visit_cast(cast)
else:
return self.process(cast.clause)
def visit_extract(self, extract, **kw):
try:
return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
self.extract_map[extract.field], self.process(extract.expr, **kw))
except KeyError:
raise exc.ArgumentError(
"%s is not a valid extract argument." % extract.field)
def limit_clause(self, select):
text = ""
if select._limit is not None:
text += " \n LIMIT " + str(select._limit)
if select._offset is not None:
if select._limit is None:
text += " \n LIMIT -1"
text += " OFFSET " + str(select._offset)
else:
text += " OFFSET 0"
return text
def for_update_clause(self, select):
# sqlite has no "FOR UPDATE" AFAICT
return ''
class SQLiteDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
if not column.nullable:
colspec += " NOT NULL"
if column.primary_key and \
column.table.kwargs.get('sqlite_autoincrement', False) and \
len(column.table.primary_key.columns) == 1 and \
isinstance(column.type, sqltypes.Integer) and \
not column.foreign_keys:
colspec += " PRIMARY KEY AUTOINCREMENT"
return colspec
def visit_primary_key_constraint(self, constraint):
# for columns with sqlite_autoincrement=True,
# the PRIMARY KEY constraint can only be inline
# with the column itself.
if len(constraint.columns) == 1:
c = list(constraint)[0]
if c.primary_key and \
c.table.kwargs.get('sqlite_autoincrement', False) and \
isinstance(c.type, sqltypes.Integer) and \
not c.foreign_keys:
return ''
return super(SQLiteDDLCompiler, self).\
visit_primary_key_constraint(constraint)
def visit_create_index(self, create):
index = create.element
preparer = self.preparer
text = "CREATE "
if index.unique:
text += "UNIQUE "
text += "INDEX %s ON %s (%s)" \
% (preparer.format_index(index,
name=self._validate_identifier(index.name, True)),
preparer.format_table(index.table, use_schema=False),
', '.join(preparer.quote(c.name, c.quote)
for c in index.columns))
return text
class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
def visit_large_binary(self, type_):
return self.visit_BLOB(type_)
class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = set([
'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
'attach', 'autoincrement', 'before', 'begin', 'between', 'by',
'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit',
'conflict', 'constraint', 'create', 'cross', 'current_date',
'current_time', 'current_timestamp', 'database', 'default',
'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct',
'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive',
'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob',
'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index',
'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is',
'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural',
'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer',
'plan', 'pragma', 'primary', 'query', 'raise', 'references',
'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback',
'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to',
'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using',
'vacuum', 'values', 'view', 'virtual', 'when', 'where',
])
def format_index(self, index, use_schema=True, name=None):
"""Prepare a quoted index and schema name."""
if name is None:
name = index.name
result = self.quote(name, index.quote)
if not self.omit_schema and use_schema and getattr(index.table, "schema", None):
result = self.quote_schema(index.table.schema, index.table.quote_schema) + "." + result
return result
class SQLiteDialect(default.DefaultDialect):
name = 'sqlite'
supports_alter = False
supports_unicode_statements = True
supports_unicode_binds = True
supports_default_values = True
supports_empty_insert = False
supports_cast = True
default_paramstyle = 'qmark'
statement_compiler = SQLiteCompiler
ddl_compiler = SQLiteDDLCompiler
type_compiler = SQLiteTypeCompiler
preparer = SQLiteIdentifierPreparer
ischema_names = ischema_names
colspecs = colspecs
isolation_level = None
supports_cast = True
supports_default_values = True
def __init__(self, isolation_level=None, native_datetime=False, **kwargs):
default.DefaultDialect.__init__(self, **kwargs)
if isolation_level and isolation_level not in ('SERIALIZABLE',
'READ UNCOMMITTED'):
raise exc.ArgumentError("Invalid value for isolation_level. "
"Valid isolation levels for sqlite are 'SERIALIZABLE' and "
"'READ UNCOMMITTED'.")
self.isolation_level = isolation_level
# this flag used by pysqlite dialect, and perhaps others in the
# future, to indicate the driver is handling date/timestamp
# conversions (and perhaps datetime/time as well on some
# hypothetical driver ?)
self.native_datetime = native_datetime
if self.dbapi is not None:
self.supports_default_values = \
self.dbapi.sqlite_version_info >= (3, 3, 8)
self.supports_cast = \
self.dbapi.sqlite_version_info >= (3, 2, 3)
def on_connect(self):
if self.isolation_level is not None:
if self.isolation_level == 'READ UNCOMMITTED':
isolation_level = 1
else:
isolation_level = 0
def connect(conn):
cursor = conn.cursor()
cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level)
cursor.close()
return connect
else:
return None
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
master = '%s.sqlite_master' % qschema
s = ("SELECT name FROM %s "
"WHERE type='table' ORDER BY name") % (master,)
rs = connection.execute(s)
else:
try:
s = ("SELECT name FROM "
" (SELECT * FROM sqlite_master UNION ALL "
" SELECT * FROM sqlite_temp_master) "
"WHERE type='table' ORDER BY name")
rs = connection.execute(s)
except exc.DBAPIError:
raise
s = ("SELECT name FROM sqlite_master "
"WHERE type='table' ORDER BY name")
rs = connection.execute(s)
return [row[0] for row in rs]
def has_table(self, connection, table_name, schema=None):
quote = self.identifier_preparer.quote_identifier
if schema is not None:
pragma = "PRAGMA %s." % quote(schema)
else:
pragma = "PRAGMA "
qtable = quote(table_name)
cursor = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
row = cursor.fetchone()
# consume remaining rows, to work around
# http://www.sqlite.org/cvstrac/tktview?tn=1884
while cursor.fetchone() is not None:
pass
return (row is not None)
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
master = '%s.sqlite_master' % qschema
s = ("SELECT name FROM %s "
"WHERE type='view' ORDER BY name") % (master,)
rs = connection.execute(s)
else:
try:
s = ("SELECT name FROM "
" (SELECT * FROM sqlite_master UNION ALL "
" SELECT * FROM sqlite_temp_master) "
"WHERE type='view' ORDER BY name")
rs = connection.execute(s)
except exc.DBAPIError:
raise
s = ("SELECT name FROM sqlite_master "
"WHERE type='view' ORDER BY name")
rs = connection.execute(s)
return [row[0] for row in rs]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
quote = self.identifier_preparer.quote_identifier
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
master = '%s.sqlite_master' % qschema
s = ("SELECT sql FROM %s WHERE name = '%s'"
"AND type='view'") % (master, view_name)
rs = connection.execute(s)
else:
try:
s = ("SELECT sql FROM "
" (SELECT * FROM sqlite_master UNION ALL "
" SELECT * FROM sqlite_temp_master) "
"WHERE name = '%s' "
"AND type='view'") % view_name
rs = connection.execute(s)
except exc.DBAPIError:
raise
s = ("SELECT sql FROM sqlite_master WHERE name = '%s' "
"AND type='view'") % view_name
rs = connection.execute(s)
result = rs.fetchall()
if result:
return result[0].sql
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
quote = self.identifier_preparer.quote_identifier
if schema is not None:
pragma = "PRAGMA %s." % quote(schema)
else:
pragma = "PRAGMA "
qtable = quote(table_name)
c = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
found_table = False
columns = []
while True:
row = c.fetchone()
if row is None:
break
(name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5])
name = re.sub(r'^\"|\"$', '', name)
if default:
default = re.sub(r"^\'|\'$", '', default)
match = re.match(r'(\w+)(\(.*?\))?', type_)
if match:
coltype = match.group(1)
args = match.group(2)
else:
coltype = "VARCHAR"
args = ''
try:
coltype = self.ischema_names[coltype]
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" %
(coltype, name))
coltype = sqltypes.NullType
if args is not None:
args = re.findall(r'(\d+)', args)
coltype = coltype(*[int(a) for a in args])
columns.append({
'name' : name,
'type' : coltype,
'nullable' : nullable,
'default' : default,
'primary_key': primary_key
})
return columns
@reflection.cache
def get_primary_keys(self, connection, table_name, schema=None, **kw):
cols = self.get_columns(connection, table_name, schema, **kw)
pkeys = []
for col in cols:
if col['primary_key']:
pkeys.append(col['name'])
return pkeys
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
quote = self.identifier_preparer.quote_identifier
if schema is not None:
pragma = "PRAGMA %s." % quote(schema)
else:
pragma = "PRAGMA "
qtable = quote(table_name)
c = _pragma_cursor(connection.execute("%sforeign_key_list(%s)" % (pragma, qtable)))
fkeys = []
fks = {}
while True:
row = c.fetchone()
if row is None:
break
(constraint_name, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4])
rtbl = re.sub(r'^\"|\"$', '', rtbl)
lcol = re.sub(r'^\"|\"$', '', lcol)
rcol = re.sub(r'^\"|\"$', '', rcol)
try:
fk = fks[constraint_name]
except KeyError:
fk = {
'name' : constraint_name,
'constrained_columns' : [],
'referred_schema' : None,
'referred_table' : rtbl,
'referred_columns' : []
}
fkeys.append(fk)
fks[constraint_name] = fk
# look up the table based on the given table's engine, not 'self',
# since it could be a ProxyEngine
if lcol not in fk['constrained_columns']:
fk['constrained_columns'].append(lcol)
if rcol not in fk['referred_columns']:
fk['referred_columns'].append(rcol)
return fkeys
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
quote = self.identifier_preparer.quote_identifier
if schema is not None:
pragma = "PRAGMA %s." % quote(schema)
else:
pragma = "PRAGMA "
include_auto_indexes = kw.pop('include_auto_indexes', False)
qtable = quote(table_name)
c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable)))
indexes = []
while True:
row = c.fetchone()
if row is None:
break
# ignore implicit primary key index.
# http://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html
elif not include_auto_indexes and row[1].startswith('sqlite_autoindex'):
continue
indexes.append(dict(name=row[1], column_names=[], unique=row[2]))
# loop thru unique indexes to get the column names.
for idx in indexes:
c = connection.execute("%sindex_info(%s)" % (pragma, quote(idx['name'])))
cols = idx['column_names']
while True:
row = c.fetchone()
if row is None:
break
cols.append(row[2])
return indexes
def _pragma_cursor(cursor):
"""work around SQLite issue whereby cursor.description is blank when PRAGMA returns no rows."""
if cursor.closed:
cursor._fetchone_impl = lambda: None
return cursor

View File

@ -0,0 +1,236 @@
"""Support for the SQLite database via pysqlite.
Note that pysqlite is the same driver as the ``sqlite3``
module included with the Python distribution.
Driver
------
When using Python 2.5 and above, the built in ``sqlite3`` driver is
already installed and no additional installation is needed. Otherwise,
the ``pysqlite2`` driver needs to be present. This is the same driver as
``sqlite3``, just with a different name.
The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3``
is loaded. This allows an explicitly installed pysqlite driver to take
precedence over the built in one. As with all dialects, a specific
DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control
this explicitly::
from sqlite3 import dbapi2 as sqlite
e = create_engine('sqlite+pysqlite:///file.db', module=sqlite)
Full documentation on pysqlite is available at:
`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
Connect Strings
---------------
The file specification for the SQLite database is taken as the "database" portion of
the URL. Note that the format of a url is::
driver://user:pass@host/database
This means that the actual filename to be used starts with the characters to the
**right** of the third slash. So connecting to a relative filepath looks like::
# relative path
e = create_engine('sqlite:///path/to/database.db')
An absolute path, which is denoted by starting with a slash, means you need **four**
slashes::
# absolute path
e = create_engine('sqlite:////path/to/database.db')
To use a Windows path, regular drive specifications and backslashes can be used.
Double backslashes are probably needed::
# absolute path on Windows
e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db')
The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify
``sqlite://`` and nothing else::
# in-memory database
e = create_engine('sqlite://')
Compatibility with sqlite3 "native" date and datetime types
-----------------------------------------------------------
The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and
sqlite3.PARSE_COLNAMES options, which have the effect of any column
or expression explicitly cast as "date" or "timestamp" will be converted
to a Python date or datetime object. The date and datetime types provided
with the pysqlite dialect are not currently compatible with these options,
since they render the ISO date/datetime including microseconds, which
pysqlite's driver does not. Additionally, SQLAlchemy does not at
this time automatically render the "cast" syntax required for the
freestanding functions "current_timestamp" and "current_date" to return
datetime/date types natively. Unfortunately, pysqlite
does not provide the standard DBAPI types in `cursor.description`,
leaving SQLAlchemy with no way to detect these types on the fly
without expensive per-row type checks.
Usage of PARSE_DECLTYPES can be forced if one configures
"native_datetime=True" on create_engine()::
engine = create_engine('sqlite://',
connect_args={'detect_types': sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
native_datetime=True
)
With this flag enabled, the DATE and TIMESTAMP types (but note - not the DATETIME
or TIME types...confused yet ?) will not perform any bind parameter or result
processing. Execution of "func.current_date()" will return a string.
"func.current_timestamp()" is registered as returning a DATETIME type in
SQLAlchemy, so this function still receives SQLAlchemy-level result processing.
Threading Behavior
------------------
Pysqlite connections do not support being moved between threads, unless
the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition,
when using an in-memory SQLite database, the full database exists only within
the scope of a single connection. It is reported that an in-memory
database does not support being shared between threads regardless of the
``check_same_thread`` flag - which means that a multithreaded
application **cannot** share data from a ``:memory:`` database across threads
unless access to the connection is limited to a single worker thread which communicates
through a queueing mechanism to concurrent threads.
To provide a default which accomodates SQLite's default threading capabilities
somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool`
be used by default. This pool maintains a single SQLite connection per thread
that is held open up to a count of five concurrent threads. When more than five threads
are used, a cleanup mechanism will dispose of excess unused connections.
Two optional pool implementations that may be appropriate for particular SQLite usage scenarios:
* the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded
application using an in-memory database, assuming the threading issues inherent in
pysqlite are somehow accomodated for. This pool holds persistently onto a single connection
which is never closed, and is returned for all requests.
* the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that
makes use of a file-based sqlite database. This pool disables any actual "pooling"
behavior, and simply opens and closes real connections corresonding to the :func:`connect()`
and :func:`close()` methods. SQLite can "connect" to a particular file with very high
efficiency, so this option may actually perform better without the extra overhead
of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection
useless since the database would be lost as soon as the connection is "returned" to the pool.
Unicode
-------
In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's
default behavior regarding Unicode is that all strings are returned as Python unicode objects
in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is
*not* used, you will still always receive unicode data back from a result set. It is
**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type
to represent strings, since it will raise a warning if a non-unicode Python string is
passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can
quickly create confusion, particularly when using the ORM as internal data is not
always represented by an actual database result string.
"""
from sqlalchemy.dialects.sqlite.base import SQLiteDialect, DATETIME, DATE
from sqlalchemy import schema, exc, pool
from sqlalchemy.engine import default
from sqlalchemy import types as sqltypes
from sqlalchemy import util
class _SQLite_pysqliteTimeStamp(DATETIME):
def bind_processor(self, dialect):
if dialect.native_datetime:
return None
else:
return DATETIME.bind_processor(self, dialect)
def result_processor(self, dialect, coltype):
if dialect.native_datetime:
return None
else:
return DATETIME.result_processor(self, dialect, coltype)
class _SQLite_pysqliteDate(DATE):
def bind_processor(self, dialect):
if dialect.native_datetime:
return None
else:
return DATE.bind_processor(self, dialect)
def result_processor(self, dialect, coltype):
if dialect.native_datetime:
return None
else:
return DATE.result_processor(self, dialect, coltype)
class SQLiteDialect_pysqlite(SQLiteDialect):
default_paramstyle = 'qmark'
poolclass = pool.SingletonThreadPool
colspecs = util.update_copy(
SQLiteDialect.colspecs,
{
sqltypes.Date:_SQLite_pysqliteDate,
sqltypes.TIMESTAMP:_SQLite_pysqliteTimeStamp,
}
)
# Py3K
#description_encoding = None
driver = 'pysqlite'
def __init__(self, **kwargs):
SQLiteDialect.__init__(self, **kwargs)
if self.dbapi is not None:
sqlite_ver = self.dbapi.version_info
if sqlite_ver < (2, 1, 3):
util.warn(
("The installed version of pysqlite2 (%s) is out-dated "
"and will cause errors in some cases. Version 2.1.3 "
"or greater is recommended.") %
'.'.join([str(subver) for subver in sqlite_ver]))
@classmethod
def dbapi(cls):
try:
from pysqlite2 import dbapi2 as sqlite
except ImportError, e:
try:
from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
except ImportError:
raise e
return sqlite
def _get_server_version_info(self, connection):
return self.dbapi.sqlite_version_info
def create_connect_args(self, url):
if url.username or url.password or url.host or url.port:
raise exc.ArgumentError(
"Invalid SQLite URL: %s\n"
"Valid SQLite URL forms are:\n"
" sqlite:///:memory: (or, sqlite://)\n"
" sqlite:///relative/path/to/file.db\n"
" sqlite:////absolute/path/to/file.db" % (url,))
filename = url.database or ':memory:'
opts = url.query.copy()
util.coerce_kw_type(opts, 'timeout', float)
util.coerce_kw_type(opts, 'isolation_level', str)
util.coerce_kw_type(opts, 'detect_types', int)
util.coerce_kw_type(opts, 'check_same_thread', bool)
util.coerce_kw_type(opts, 'cached_statements', int)
return ([filename], opts)
def is_disconnect(self, e):
return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
dialect = SQLiteDialect_pysqlite

View File

@ -0,0 +1,20 @@
from sqlalchemy.dialects.sybase import base, pysybase, pyodbc
from base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
TEXT,DATE,DATETIME, FLOAT, NUMERIC,\
BIGINT,INT, INTEGER, SMALLINT, BINARY,\
VARBINARY,UNITEXT,UNICHAR,UNIVARCHAR,\
IMAGE,BIT,MONEY,SMALLMONEY,TINYINT
# default dialect
base.dialect = pyodbc.dialect
__all__ = (
'CHAR', 'VARCHAR', 'TIME', 'NCHAR', 'NVARCHAR',
'TEXT','DATE','DATETIME', 'FLOAT', 'NUMERIC',
'BIGINT','INT', 'INTEGER', 'SMALLINT', 'BINARY',
'VARBINARY','UNITEXT','UNICHAR','UNIVARCHAR',
'IMAGE','BIT','MONEY','SMALLMONEY','TINYINT',
'dialect'
)

View File

@ -0,0 +1,420 @@
# sybase/base.py
# Copyright (C) 2010 Michael Bayer mike_mp@zzzcomputing.com
# get_select_precolumns(), limit_clause() implementation
# copyright (C) 2007 Fisch Asset Management
# AG http://www.fam.ch, with coding by Alexander Houben
# alexander.houben@thor-solutions.ch
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Support for Sybase Adaptive Server Enterprise (ASE).
Note that this dialect is no longer specific to Sybase iAnywhere.
ASE is the primary support platform.
"""
import operator
from sqlalchemy.sql import compiler, expression, text, bindparam
from sqlalchemy.engine import default, base, reflection
from sqlalchemy import types as sqltypes
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import schema as sa_schema
from sqlalchemy import util, sql, exc
from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
TEXT,DATE,DATETIME, FLOAT, NUMERIC,\
BIGINT,INT, INTEGER, SMALLINT, BINARY,\
VARBINARY, DECIMAL, TIMESTAMP, Unicode,\
UnicodeText
RESERVED_WORDS = set([
"add", "all", "alter", "and",
"any", "as", "asc", "backup",
"begin", "between", "bigint", "binary",
"bit", "bottom", "break", "by",
"call", "capability", "cascade", "case",
"cast", "char", "char_convert", "character",
"check", "checkpoint", "close", "comment",
"commit", "connect", "constraint", "contains",
"continue", "convert", "create", "cross",
"cube", "current", "current_timestamp", "current_user",
"cursor", "date", "dbspace", "deallocate",
"dec", "decimal", "declare", "default",
"delete", "deleting", "desc", "distinct",
"do", "double", "drop", "dynamic",
"else", "elseif", "encrypted", "end",
"endif", "escape", "except", "exception",
"exec", "execute", "existing", "exists",
"externlogin", "fetch", "first", "float",
"for", "force", "foreign", "forward",
"from", "full", "goto", "grant",
"group", "having", "holdlock", "identified",
"if", "in", "index", "index_lparen",
"inner", "inout", "insensitive", "insert",
"inserting", "install", "instead", "int",
"integer", "integrated", "intersect", "into",
"iq", "is", "isolation", "join",
"key", "lateral", "left", "like",
"lock", "login", "long", "match",
"membership", "message", "mode", "modify",
"natural", "new", "no", "noholdlock",
"not", "notify", "null", "numeric",
"of", "off", "on", "open",
"option", "options", "or", "order",
"others", "out", "outer", "over",
"passthrough", "precision", "prepare", "primary",
"print", "privileges", "proc", "procedure",
"publication", "raiserror", "readtext", "real",
"reference", "references", "release", "remote",
"remove", "rename", "reorganize", "resource",
"restore", "restrict", "return", "revoke",
"right", "rollback", "rollup", "save",
"savepoint", "scroll", "select", "sensitive",
"session", "set", "setuser", "share",
"smallint", "some", "sqlcode", "sqlstate",
"start", "stop", "subtrans", "subtransaction",
"synchronize", "syntax_error", "table", "temporary",
"then", "time", "timestamp", "tinyint",
"to", "top", "tran", "trigger",
"truncate", "tsequal", "unbounded", "union",
"unique", "unknown", "unsigned", "update",
"updating", "user", "using", "validate",
"values", "varbinary", "varchar", "variable",
"varying", "view", "wait", "waitfor",
"when", "where", "while", "window",
"with", "with_cube", "with_lparen", "with_rollup",
"within", "work", "writetext",
])
class _SybaseUnitypeMixin(object):
"""these types appear to return a buffer object."""
def result_processor(self, dialect, coltype):
def process(value):
if value is not None:
return str(value) #.decode("ucs-2")
else:
return None
return process
class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
__visit_name__ = 'UNICHAR'
class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
__visit_name__ = 'UNIVARCHAR'
class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText):
__visit_name__ = 'UNITEXT'
class TINYINT(sqltypes.Integer):
__visit_name__ = 'TINYINT'
class BIT(sqltypes.TypeEngine):
__visit_name__ = 'BIT'
class MONEY(sqltypes.TypeEngine):
__visit_name__ = "MONEY"
class SMALLMONEY(sqltypes.TypeEngine):
__visit_name__ = "SMALLMONEY"
class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
__visit_name__ = "UNIQUEIDENTIFIER"
class IMAGE(sqltypes.LargeBinary):
__visit_name__ = 'IMAGE'
class SybaseTypeCompiler(compiler.GenericTypeCompiler):
def visit_large_binary(self, type_):
return self.visit_IMAGE(type_)
def visit_boolean(self, type_):
return self.visit_BIT(type_)
def visit_unicode(self, type_):
return self.visit_NVARCHAR(type_)
def visit_UNICHAR(self, type_):
return "UNICHAR(%d)" % type_.length
def visit_UNIVARCHAR(self, type_):
return "UNIVARCHAR(%d)" % type_.length
def visit_UNITEXT(self, type_):
return "UNITEXT"
def visit_TINYINT(self, type_):
return "TINYINT"
def visit_IMAGE(self, type_):
return "IMAGE"
def visit_BIT(self, type_):
return "BIT"
def visit_MONEY(self, type_):
return "MONEY"
def visit_SMALLMONEY(self, type_):
return "SMALLMONEY"
def visit_UNIQUEIDENTIFIER(self, type_):
return "UNIQUEIDENTIFIER"
ischema_names = {
'integer' : INTEGER,
'unsigned int' : INTEGER, # TODO: unsigned flags
'unsigned smallint' : SMALLINT, # TODO: unsigned flags
'unsigned bigint' : BIGINT, # TODO: unsigned flags
'bigint': BIGINT,
'smallint' : SMALLINT,
'tinyint' : TINYINT,
'varchar' : VARCHAR,
'long varchar' : TEXT, # TODO
'char' : CHAR,
'decimal' : DECIMAL,
'numeric' : NUMERIC,
'float' : FLOAT,
'double' : NUMERIC, # TODO
'binary' : BINARY,
'varbinary' : VARBINARY,
'bit': BIT,
'image' : IMAGE,
'timestamp': TIMESTAMP,
'money': MONEY,
'smallmoney': MONEY,
'uniqueidentifier': UNIQUEIDENTIFIER,
}
class SybaseExecutionContext(default.DefaultExecutionContext):
_enable_identity_insert = False
def set_ddl_autocommit(self, connection, value):
"""Must be implemented by subclasses to accommodate DDL executions.
"connection" is the raw unwrapped DBAPI connection. "value"
is True or False. when True, the connection should be configured
such that a DDL can take place subsequently. when False,
a DDL has taken place and the connection should be resumed
into non-autocommit mode.
"""
raise NotImplementedError()
def pre_exec(self):
if self.isinsert:
tbl = self.compiled.statement.table
seq_column = tbl._autoincrement_column
insert_has_sequence = seq_column is not None
if insert_has_sequence:
self._enable_identity_insert = seq_column.key in self.compiled_parameters[0]
else:
self._enable_identity_insert = False
if self._enable_identity_insert:
self.cursor.execute("SET IDENTITY_INSERT %s ON" %
self.dialect.identifier_preparer.format_table(tbl))
if self.isddl:
# TODO: to enhance this, we can detect "ddl in tran" on the
# database settings. this error message should be improved to
# include a note about that.
if not self.should_autocommit:
raise exc.InvalidRequestError("The Sybase dialect only supports "
"DDL in 'autocommit' mode at this time.")
self.root_connection.engine.logger.info("AUTOCOMMIT (Assuming no Sybase 'ddl in tran')")
self.set_ddl_autocommit(self.root_connection.connection.connection, True)
def post_exec(self):
if self.isddl:
self.set_ddl_autocommit(self.root_connection, False)
if self._enable_identity_insert:
self.cursor.execute(
"SET IDENTITY_INSERT %s OFF" %
self.dialect.identifier_preparer.
format_table(self.compiled.statement.table)
)
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT @@identity AS lastrowid")
lastrowid = cursor.fetchone()[0]
cursor.close()
return lastrowid
class SybaseSQLCompiler(compiler.SQLCompiler):
ansi_bind_rules = True
extract_map = util.update_copy(
compiler.SQLCompiler.extract_map,
{
'doy': 'dayofyear',
'dow': 'weekday',
'milliseconds': 'millisecond'
})
def get_select_precolumns(self, select):
s = select._distinct and "DISTINCT " or ""
if select._limit:
#if select._limit == 1:
#s += "FIRST "
#else:
#s += "TOP %s " % (select._limit,)
s += "TOP %s " % (select._limit,)
if select._offset:
if not select._limit:
# FIXME: sybase doesn't allow an offset without a limit
# so use a huge value for TOP here
s += "TOP 1000000 "
s += "START AT %s " % (select._offset+1,)
return s
def get_from_hint_text(self, table, text):
return text
def limit_clause(self, select):
# Limit in sybase is after the select keyword
return ""
def visit_extract(self, extract, **kw):
field = self.extract_map.get(extract.field, extract.field)
return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
def for_update_clause(self, select):
# "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
return ''
def order_by_clause(self, select, **kw):
kw['literal_binds'] = True
order_by = self.process(select._order_by_clause, **kw)
# SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
if order_by and (not self.is_subquery() or select._limit):
return " ORDER BY " + order_by
else:
return ""
class SybaseDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + \
self.dialect.type_compiler.process(column.type)
if column.table is None:
raise exc.InvalidRequestError("The Sybase dialect requires Table-bound "\
"columns in order to generate DDL")
seq_col = column.table._autoincrement_column
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if seq_col is column:
sequence = isinstance(column.default, sa_schema.Sequence) and column.default
if sequence:
start, increment = sequence.start or 1, sequence.increment or 1
else:
start, increment = 1, 1
if (start, increment) == (1, 1):
colspec += " IDENTITY"
else:
# TODO: need correct syntax for this
colspec += " IDENTITY(%s,%s)" % (start, increment)
else:
if column.nullable is not None:
if not column.nullable or column.primary_key:
colspec += " NOT NULL"
else:
colspec += " NULL"
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
return colspec
def visit_drop_index(self, drop):
index = drop.element
return "\nDROP INDEX %s.%s" % (
self.preparer.quote_identifier(index.table.name),
self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
)
class SybaseIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
class SybaseDialect(default.DefaultDialect):
name = 'sybase'
supports_unicode_statements = False
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
supports_native_boolean = False
supports_unicode_binds = False
postfetch_lastrowid = True
colspecs = {}
ischema_names = ischema_names
type_compiler = SybaseTypeCompiler
statement_compiler = SybaseSQLCompiler
ddl_compiler = SybaseDDLCompiler
preparer = SybaseIdentifierPreparer
def _get_default_schema_name(self, connection):
return connection.scalar(
text("SELECT user_name() as user_name", typemap={'user_name':Unicode})
)
def initialize(self, connection):
super(SybaseDialect, self).initialize(connection)
if self.server_version_info is not None and\
self.server_version_info < (15, ):
self.max_identifier_length = 30
else:
self.max_identifier_length = 255
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
if schema is None:
schema = self.default_schema_name
result = connection.execute(
text("select sysobjects.name from sysobjects, sysusers "
"where sysobjects.uid=sysusers.uid and "
"sysusers.name=:schemaname and "
"sysobjects.type='U'",
bindparams=[
bindparam('schemaname', schema)
])
)
return [r[0] for r in result]
def has_table(self, connection, tablename, schema=None):
if schema is None:
schema = self.default_schema_name
result = connection.execute(
text("select sysobjects.name from sysobjects, sysusers "
"where sysobjects.uid=sysusers.uid and "
"sysobjects.name=:tablename and "
"sysusers.name=:schemaname and "
"sysobjects.type='U'",
bindparams=[
bindparam('tablename', tablename),
bindparam('schemaname', schema)
])
)
return result.scalar() is not None
def reflecttable(self, connection, table, include_columns):
raise NotImplementedError()

View File

@ -0,0 +1,17 @@
"""
Support for Sybase via mxodbc.
This dialect is a stub only and is likely non functional at this time.
"""
from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
from sqlalchemy.connectors.mxodbc import MxODBCConnector
class SybaseExecutionContext_mxodbc(SybaseExecutionContext):
pass
class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect):
execution_ctx_cls = SybaseExecutionContext_mxodbc
dialect = SybaseDialect_mxodbc

View File

@ -0,0 +1,75 @@
"""
Support for Sybase via pyodbc.
http://pypi.python.org/pypi/pyodbc/
Connect strings are of the form::
sybase+pyodbc://<username>:<password>@<dsn>/
sybase+pyodbc://<username>:<password>@<host>/<database>
Unicode Support
---------------
The pyodbc driver currently supports usage of these Sybase types with
Unicode or multibyte strings::
CHAR
NCHAR
NVARCHAR
TEXT
VARCHAR
Currently *not* supported are::
UNICHAR
UNITEXT
UNIVARCHAR
"""
from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
from sqlalchemy.connectors.pyodbc import PyODBCConnector
import decimal
from sqlalchemy import types as sqltypes, util, processors
class _SybNumeric_pyodbc(sqltypes.Numeric):
"""Turns Decimals with adjusted() < -6 into floats.
It's not yet known how to get decimals with many
significant digits or very large adjusted() into Sybase
via pyodbc.
"""
def bind_processor(self, dialect):
super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect)
def process(value):
if self.asdecimal and \
isinstance(value, decimal.Decimal):
if value.adjusted() < -6:
return processors.to_float(value)
if super_process:
return super_process(value)
else:
return value
return process
class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
def set_ddl_autocommit(self, connection, value):
if value:
connection.autocommit = True
else:
connection.autocommit = False
class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect):
execution_ctx_cls = SybaseExecutionContext_pyodbc
colspecs = {
sqltypes.Numeric:_SybNumeric_pyodbc,
}
dialect = SybaseDialect_pyodbc

View File

@ -0,0 +1,98 @@
# pysybase.py
# Copyright (C) 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
Support for Sybase via the python-sybase driver.
http://python-sybase.sourceforge.net/
Connect strings are of the form::
sybase+pysybase://<username>:<password>@<dsn>/[database name]
Unicode Support
---------------
The python-sybase driver does not appear to support non-ASCII strings of any
kind at this time.
"""
from sqlalchemy import types as sqltypes, processors
from sqlalchemy.dialects.sybase.base import SybaseDialect, \
SybaseExecutionContext, SybaseSQLCompiler
class _SybNumeric(sqltypes.Numeric):
def result_processor(self, dialect, type_):
if not self.asdecimal:
return processors.to_float
else:
return sqltypes.Numeric.result_processor(self, dialect, type_)
class SybaseExecutionContext_pysybase(SybaseExecutionContext):
def set_ddl_autocommit(self, dbapi_connection, value):
if value:
# call commit() on the Sybase connection directly,
# to avoid any side effects of calling a Connection
# transactional method inside of pre_exec()
dbapi_connection.commit()
def pre_exec(self):
SybaseExecutionContext.pre_exec(self)
for param in self.parameters:
for key in list(param):
param["@" + key] = param[key]
del param[key]
class SybaseSQLCompiler_pysybase(SybaseSQLCompiler):
def bindparam_string(self, name):
return "@" + name
class SybaseDialect_pysybase(SybaseDialect):
driver = 'pysybase'
execution_ctx_cls = SybaseExecutionContext_pysybase
statement_compiler = SybaseSQLCompiler_pysybase
colspecs={
sqltypes.Numeric:_SybNumeric,
sqltypes.Float:sqltypes.Float
}
@classmethod
def dbapi(cls):
import Sybase
return Sybase
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user', password='passwd')
return ([opts.pop('host')], opts)
def do_executemany(self, cursor, statement, parameters, context=None):
# calling python-sybase executemany yields:
# TypeError: string too long for buffer
for param in parameters:
cursor.execute(statement, param)
def _get_server_version_info(self, connection):
vers = connection.scalar("select @@version_number")
# i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0), (12, 5, 0, 0)
return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10)
def is_disconnect(self, e):
if isinstance(e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)):
msg = str(e)
return ('Unable to complete network request to host' in msg or
'Invalid connection state' in msg or
'Invalid cursor state' in msg)
else:
return False
dialect = SybaseDialect_pysybase

View File

@ -0,0 +1,145 @@
Rules for Migrating TypeEngine classes to 0.6
---------------------------------------------
1. the TypeEngine classes are used for:
a. Specifying behavior which needs to occur for bind parameters
or result row columns.
b. Specifying types that are entirely specific to the database
in use and have no analogue in the sqlalchemy.types package.
c. Specifying types where there is an analogue in sqlalchemy.types,
but the database in use takes vendor-specific flags for those
types.
d. If a TypeEngine class doesn't provide any of this, it should be
*removed* from the dialect.
2. the TypeEngine classes are *no longer* used for generating DDL. Dialects
now have a TypeCompiler subclass which uses the same visit_XXX model as
other compilers.
3. the "ischema_names" and "colspecs" dictionaries are now required members on
the Dialect class.
4. The names of types within dialects are now important. If a dialect-specific type
is a subclass of an existing generic type and is only provided for bind/result behavior,
the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case,
end users would never need to use _PGNumeric directly. However, if a dialect-specific
type is specifying a type *or* arguments that are not present generically, it should
match the real name of the type on that backend, in uppercase. E.g. postgresql.INET,
mysql.ENUM, postgresql.ARRAY.
Or follow this handy flowchart:
is the type meant to provide bind/result is the type the same name as an
behavior to a generic type (i.e. MixedCase) ---- no ---> UPPERCASE type in types.py ?
type in types.py ? | |
| no yes
yes | |
| | does your type need special
| +<--- yes --- behavior or arguments ?
| | |
| | no
name the type using | |
_MixedCase, i.e. v V
_OracleBoolean. it name the type don't make a
stays private to the dialect identically as that type, make sure the dialect's
and is invoked *only* via within the DB, base.py imports the types.py
the colspecs dict. using UPPERCASE UPPERCASE name into its namespace
| (i.e. BIT, NCHAR, INTERVAL).
| Users can import it.
| |
v v
subclass the closest is the name of this type
MixedCase type types.py, identical to an UPPERCASE
i.e. <--- no ------- name in types.py ?
class _DateTime(types.DateTime),
class DATETIME2(types.DateTime), |
class BIT(types.TypeEngine). yes
|
v
the type should
subclass the
UPPERCASE
type in types.py
(i.e. class BLOB(types.BLOB))
Example 1. pysqlite needs bind/result processing for the DateTime type in types.py,
which applies to all DateTimes and subclasses. It's named _SLDateTime and
subclasses types.DateTime.
Example 2. MS-SQL has a TIME type which takes a non-standard "precision" argument
that is rendered within DDL. So it's named TIME in the MS-SQL dialect's base.py,
and subclasses types.TIME. Users can then say mssql.TIME(precision=10).
Example 3. MS-SQL dialects also need special bind/result processing for date
But its DATE type doesn't render DDL differently than that of a plain
DATE, i.e. it takes no special arguments. Therefore we are just adding behavior
to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses
types.Date.
Example 4. MySQL has a SET type, there's no analogue for this in types.py. So
MySQL names it SET in the dialect's base.py, and it subclasses types.String, since
it ultimately deals with strings.
Example 5. Postgresql has a DATETIME type. The DBAPIs handle dates correctly,
and no special arguments are used in PG's DDL beyond what types.py provides.
Postgresql dialect therefore imports types.DATETIME into its base.py.
Ideally one should be able to specify a schema using names imported completely from a
dialect, all matching the real name on that backend:
from sqlalchemy.dialects.postgresql import base as pg
t = Table('mytable', metadata,
Column('id', pg.INTEGER, primary_key=True),
Column('name', pg.VARCHAR(300)),
Column('inetaddr', pg.INET)
)
where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types,
but the PG dialect makes them available in its own namespace.
5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types
linked to types specified in the dialect. Again, if a type in the dialect does not
specify any special behavior for bind_processor() or result_processor() and does not
indicate a special type only available in this database, it must be *removed* from the
module and from this dictionary.
6. "ischema_names" indicates string descriptions of types as returned from the database
linked to TypeEngine classes.
a. The string name should be matched to the most specific type possible within
sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which
case it points to a dialect type. *It doesn't matter* if the dialect has it's
own subclass of that type with special bind/result behavior - reflect to the types.py
UPPERCASE type as much as possible. With very few exceptions, all types
should reflect to an UPPERCASE type.
b. If the dialect contains a matching dialect-specific type that takes extra arguments
which the generic one does not, then point to the dialect-specific type. E.g.
mssql.VARCHAR takes a "collation" parameter which should be preserved.
5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by
a subclass of compiler.GenericTypeCompiler.
a. your TypeCompiler class will receive generic and uppercase types from
sqlalchemy.types. Do not assume the presence of dialect-specific attributes on
these types.
b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with
methods that produce a different DDL name. Uppercase types don't do any kind of
"guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in
all cases, regardless of whether or not that type is legal on the backend database.
c. the visit_UPPERCASE methods *should* be overridden with methods that add additional
arguments and flags to those types.
d. the visit_lowercase methods are overridden to provide an interpretation of a generic
type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)".
e. visit_lowercase methods should *never* render strings directly - it should always
be via calling a visit_UPPERCASE() method.

View File

@ -0,0 +1,274 @@
# engine/__init__.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""SQL connections, SQL execution and high-level DB-API interface.
The engine package defines the basic components used to interface
DB-API modules with higher-level statement construction,
connection-management, execution and result contexts. The primary
"entry point" class into this package is the Engine and it's public
constructor ``create_engine()``.
This package includes:
base.py
Defines interface classes and some implementation classes which
comprise the basic components used to interface between a DB-API,
constructed and plain-text statements, connections, transactions,
and results.
default.py
Contains default implementations of some of the components defined
in base.py. All current database dialects use the classes in
default.py as base classes for their own database-specific
implementations.
strategies.py
The mechanics of constructing ``Engine`` objects are represented
here. Defines the ``EngineStrategy`` class which represents how
to go from arguments specified to the ``create_engine()``
function, to a fully constructed ``Engine``, including
initialization of connection pooling, dialects, and specific
subclasses of ``Engine``.
threadlocal.py
The ``TLEngine`` class is defined here, which is a subclass of
the generic ``Engine`` and tracks ``Connection`` and
``Transaction`` objects against the identity of the current
thread. This allows certain programming patterns based around
the concept of a "thread-local connection" to be possible.
The ``TLEngine`` is created by using the "threadlocal" engine
strategy in conjunction with the ``create_engine()`` function.
url.py
Defines the ``URL`` class which represents the individual
components of a string URL passed to ``create_engine()``. Also
defines a basic module-loading strategy for the dialect specifier
within a URL.
"""
# not sure what this was used for
#import sqlalchemy.databases
from sqlalchemy.engine.base import (
BufferedColumnResultProxy,
BufferedColumnRow,
BufferedRowResultProxy,
Compiled,
Connectable,
Connection,
Dialect,
Engine,
ExecutionContext,
NestedTransaction,
ResultProxy,
RootTransaction,
RowProxy,
Transaction,
TwoPhaseTransaction,
TypeCompiler
)
from sqlalchemy.engine import strategies
from sqlalchemy import util
__all__ = (
'BufferedColumnResultProxy',
'BufferedColumnRow',
'BufferedRowResultProxy',
'Compiled',
'Connectable',
'Connection',
'Dialect',
'Engine',
'ExecutionContext',
'NestedTransaction',
'ResultProxy',
'RootTransaction',
'RowProxy',
'Transaction',
'TwoPhaseTransaction',
'TypeCompiler',
'create_engine',
'engine_from_config',
)
default_strategy = 'plain'
def create_engine(*args, **kwargs):
"""Create a new Engine instance.
The standard method of specifying the engine is via URL as the
first positional argument, to indicate the appropriate database
dialect and connection arguments, with additional keyword
arguments sent as options to the dialect and resulting Engine.
The URL is a string in the form
``dialect+driver://user:password@host/dbname[?key=value..]``, where
``dialect`` is a database name such as ``mysql``, ``oracle``,
``postgresql``, etc., and ``driver`` the name of a DBAPI, such as
``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively,
the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`.
`**kwargs` takes a wide variety of options which are routed
towards their appropriate components. Arguments may be
specific to the Engine, the underlying Dialect, as well as the
Pool. Specific dialects also accept keyword arguments that
are unique to that dialect. Here, we describe the parameters
that are common to most ``create_engine()`` usage.
:param assert_unicode: Deprecated. A warning is raised in all cases when a non-Unicode
object is passed when SQLAlchemy would coerce into an encoding
(note: but **not** when the DBAPI handles unicode objects natively).
To suppress or raise this warning to an
error, use the Python warnings filter documented at:
http://docs.python.org/library/warnings.html
:param connect_args: a dictionary of options which will be
passed directly to the DBAPI's ``connect()`` method as
additional keyword arguments.
:param convert_unicode=False: if set to True, all
String/character based types will convert Unicode values to raw
byte values going into the database, and all raw byte values to
Python Unicode coming out in result sets. This is an
engine-wide method to provide unicode conversion across the
board. For unicode conversion on a column-by-column level, use
the ``Unicode`` column type instead, described in `types`.
:param creator: a callable which returns a DBAPI connection.
This creation function will be passed to the underlying
connection pool and will be used to create all new database
connections. Usage of this function causes connection
parameters specified in the URL argument to be bypassed.
:param echo=False: if True, the Engine will log all statements
as well as a repr() of their parameter lists to the engines
logger, which defaults to sys.stdout. The ``echo`` attribute of
``Engine`` can be modified at any time to turn logging on and
off. If set to the string ``"debug"``, result rows will be
printed to the standard output as well. This flag ultimately
controls a Python logger; see :ref:`dbengine_logging` for
information on how to configure logging directly.
:param echo_pool=False: if True, the connection pool will log
all checkouts/checkins to the logging stream, which defaults to
sys.stdout. This flag ultimately controls a Python logger; see
:ref:`dbengine_logging` for information on how to configure logging
directly.
:param encoding='utf-8': the encoding to use for all Unicode
translations, both by engine-wide unicode conversion as well as
the ``Unicode`` type object.
:param label_length=None: optional integer value which limits
the size of dynamically generated column labels to that many
characters. If less than 6, labels are generated as
"_(counter)". If ``None``, the value of
``dialect.max_identifier_length`` is used instead.
:param listeners: A list of one or more
:class:`~sqlalchemy.interfaces.PoolListener` objects which will
receive connection pool events.
:param logging_name: String identifier which will be used within
the "name" field of logging records generated within the
"sqlalchemy.engine" logger. Defaults to a hexstring of the
object's id.
:param max_overflow=10: the number of connections to allow in
connection pool "overflow", that is connections that can be
opened above and beyond the pool_size setting, which defaults
to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`.
:param module=None: used by database implementations which
support multiple DBAPI modules, this is a reference to a DBAPI2
module to be used instead of the engine's default module. For
PostgreSQL, the default is psycopg2. For Oracle, it's cx_Oracle.
:param pool=None: an already-constructed instance of
:class:`~sqlalchemy.pool.Pool`, such as a
:class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this
pool will be used directly as the underlying connection pool
for the engine, bypassing whatever connection parameters are
present in the URL argument. For information on constructing
connection pools manually, see `pooling`.
:param poolclass=None: a :class:`~sqlalchemy.pool.Pool`
subclass, which will be used to create a connection pool
instance using the connection parameters given in the URL. Note
this differs from ``pool`` in that you don't actually
instantiate the pool in this case, you just indicate what type
of pool to be used.
:param pool_logging_name: String identifier which will be used within
the "name" field of logging records generated within the
"sqlalchemy.pool" logger. Defaults to a hexstring of the object's
id.
:param pool_size=5: the number of connections to keep open
inside the connection pool. This used with :class:`~sqlalchemy.pool.QueuePool` as
well as :class:`~sqlalchemy.pool.SingletonThreadPool`.
:param pool_recycle=-1: this setting causes the pool to recycle
connections after the given number of seconds has passed. It
defaults to -1, or no timeout. For example, setting to 3600
means connections will be recycled after one hour. Note that
MySQL in particular will ``disconnect automatically`` if no
activity is detected on a connection for eight hours (although
this is configurable with the MySQLDB connection itself and the
server configuration as well).
:param pool_timeout=30: number of seconds to wait before giving
up on getting a connection from the pool. This is only used
with :class:`~sqlalchemy.pool.QueuePool`.
:param strategy='plain': used to invoke alternate :class:`~sqlalchemy.engine.base.Engine.`
implementations. Currently available is the ``threadlocal``
strategy, which is described in :ref:`threadlocal_strategy`.
"""
strategy = kwargs.pop('strategy', default_strategy)
strategy = strategies.strategies[strategy]
return strategy.create(*args, **kwargs)
def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
"""Create a new Engine instance using a configuration dictionary.
The dictionary is typically produced from a config file where keys
are prefixed, such as sqlalchemy.url, sqlalchemy.echo, etc. The
'prefix' argument indicates the prefix to be searched for.
A select set of keyword arguments will be "coerced" to their
expected type based on string values. In a future release, this
functionality will be expanded and include dialect-specific
arguments.
"""
opts = _coerce_config(configuration, prefix)
opts.update(kwargs)
url = opts.pop('url')
return create_engine(url, **opts)
def _coerce_config(configuration, prefix):
"""Convert configuration values to expected types."""
options = dict((key[len(prefix):], configuration[key])
for key in configuration
if key.startswith(prefix))
for option, type_ in (
('convert_unicode', bool),
('pool_timeout', int),
('echo', bool),
('echo_pool', bool),
('pool_recycle', int),
('pool_size', int),
('max_overflow', int),
('pool_threadlocal', bool),
):
util.coerce_kw_type(options, option, type_)
return options

2422
sqlalchemy/engine/base.py Normal file

File diff suppressed because it is too large Load Diff

128
sqlalchemy/engine/ddl.py Normal file
View File

@ -0,0 +1,128 @@
# engine/ddl.py
# Copyright (C) 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Routines to handle CREATE/DROP workflow."""
from sqlalchemy import engine, schema
from sqlalchemy.sql import util as sql_util
class DDLBase(schema.SchemaVisitor):
def __init__(self, connection):
self.connection = connection
class SchemaGenerator(DDLBase):
def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
super(SchemaGenerator, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables and set(tables) or None
self.preparer = dialect.identifier_preparer
self.dialect = dialect
def _can_create(self, table):
self.dialect.validate_identifier(table.name)
if table.schema:
self.dialect.validate_identifier(table.schema)
return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
def visit_metadata(self, metadata):
if self.tables:
tables = self.tables
else:
tables = metadata.tables.values()
collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
for listener in metadata.ddl_listeners['before-create']:
listener('before-create', metadata, self.connection, tables=collection)
for table in collection:
self.traverse_single(table)
for listener in metadata.ddl_listeners['after-create']:
listener('after-create', metadata, self.connection, tables=collection)
def visit_table(self, table):
for listener in table.ddl_listeners['before-create']:
listener('before-create', table, self.connection)
for column in table.columns:
if column.default is not None:
self.traverse_single(column.default)
self.connection.execute(schema.CreateTable(table))
if hasattr(table, 'indexes'):
for index in table.indexes:
self.traverse_single(index)
for listener in table.ddl_listeners['after-create']:
listener('after-create', table, self.connection)
def visit_sequence(self, sequence):
if self.dialect.supports_sequences:
if ((not self.dialect.sequences_optional or
not sequence.optional) and
(not self.checkfirst or
not self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema))):
self.connection.execute(schema.CreateSequence(sequence))
def visit_index(self, index):
self.connection.execute(schema.CreateIndex(index))
class SchemaDropper(DDLBase):
def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
super(SchemaDropper, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
self.preparer = dialect.identifier_preparer
self.dialect = dialect
def visit_metadata(self, metadata):
if self.tables:
tables = self.tables
else:
tables = metadata.tables.values()
collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
for listener in metadata.ddl_listeners['before-drop']:
listener('before-drop', metadata, self.connection, tables=collection)
for table in collection:
self.traverse_single(table)
for listener in metadata.ddl_listeners['after-drop']:
listener('after-drop', metadata, self.connection, tables=collection)
def _can_drop(self, table):
self.dialect.validate_identifier(table.name)
if table.schema:
self.dialect.validate_identifier(table.schema)
return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
def visit_index(self, index):
self.connection.execute(schema.DropIndex(index))
def visit_table(self, table):
for listener in table.ddl_listeners['before-drop']:
listener('before-drop', table, self.connection)
for column in table.columns:
if column.default is not None:
self.traverse_single(column.default)
self.connection.execute(schema.DropTable(table))
for listener in table.ddl_listeners['after-drop']:
listener('after-drop', table, self.connection)
def visit_sequence(self, sequence):
if self.dialect.supports_sequences:
if ((not self.dialect.sequences_optional or
not sequence.optional) and
(not self.checkfirst or
self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema))):
self.connection.execute(schema.DropSequence(sequence))

View File

@ -0,0 +1,700 @@
# engine/default.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Default implementations of per-dialect sqlalchemy.engine classes.
These are semi-private implementation classes which are only of importance
to database dialect authors; dialects will usually use the classes here
as the base class for their own corresponding classes.
"""
import re, random
from sqlalchemy.engine import base, reflection
from sqlalchemy.sql import compiler, expression
from sqlalchemy import exc, types as sqltypes, util
AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
re.I | re.UNICODE)
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
statement_compiler = compiler.SQLCompiler
ddl_compiler = compiler.DDLCompiler
type_compiler = compiler.GenericTypeCompiler
preparer = compiler.IdentifierPreparer
supports_alter = True
# most DBAPIs happy with this for execute().
# not cx_oracle.
execute_sequence_format = tuple
supports_sequences = False
sequences_optional = False
preexecute_autoincrement_sequences = False
postfetch_lastrowid = True
implicit_returning = False
supports_native_enum = False
supports_native_boolean = False
# if the NUMERIC type
# returns decimal.Decimal.
# *not* the FLOAT type however.
supports_native_decimal = False
# Py3K
#supports_unicode_statements = True
#supports_unicode_binds = True
# Py2K
supports_unicode_statements = False
supports_unicode_binds = False
returns_unicode_strings = False
# end Py2K
name = 'default'
max_identifier_length = 9999
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
dbapi_type_map = {}
colspecs = {}
default_paramstyle = 'named'
supports_default_values = False
supports_empty_insert = True
server_version_info = None
# indicates symbol names are
# UPPERCASEd if they are case insensitive
# within the database.
# if this is True, the methods normalize_name()
# and denormalize_name() must be provided.
requires_name_normalize = False
reflection_options = ()
def __init__(self, convert_unicode=False, assert_unicode=False,
encoding='utf-8', paramstyle=None, dbapi=None,
implicit_returning=None,
label_length=None, **kwargs):
if not getattr(self, 'ported_sqla_06', True):
util.warn(
"The %s dialect is not yet ported to SQLAlchemy 0.6" % self.name)
self.convert_unicode = convert_unicode
if assert_unicode:
util.warn_deprecated("assert_unicode is deprecated. "
"SQLAlchemy emits a warning in all cases where it "
"would otherwise like to encode a Python unicode object "
"into a specific encoding but a plain bytestring is received. "
"This does *not* apply to DBAPIs that coerce Unicode natively."
)
self.encoding = encoding
self.positional = False
self._ischema = None
self.dbapi = dbapi
if paramstyle is not None:
self.paramstyle = paramstyle
elif self.dbapi is not None:
self.paramstyle = self.dbapi.paramstyle
else:
self.paramstyle = self.default_paramstyle
if implicit_returning is not None:
self.implicit_returning = implicit_returning
self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
self.identifier_preparer = self.preparer(self)
self.type_compiler = self.type_compiler(self)
if label_length and label_length > self.max_identifier_length:
raise exc.ArgumentError("Label length of %d is greater than this dialect's"
" maximum identifier length of %d" %
(label_length, self.max_identifier_length))
self.label_length = label_length
if not hasattr(self, 'description_encoding'):
self.description_encoding = getattr(self, 'description_encoding', encoding)
@property
def dialect_description(self):
return self.name + "+" + self.driver
def initialize(self, connection):
try:
self.server_version_info = self._get_server_version_info(connection)
except NotImplementedError:
self.server_version_info = None
try:
self.default_schema_name = self._get_default_schema_name(connection)
except NotImplementedError:
self.default_schema_name = None
self.returns_unicode_strings = self._check_unicode_returns(connection)
self.do_rollback(connection.connection)
def on_connect(self):
"""return a callable which sets up a newly created DBAPI connection.
This is used to set dialect-wide per-connection options such as isolation
modes, unicode modes, etc.
If a callable is returned, it will be assembled into a pool listener
that receives the direct DBAPI connection, with all wrappers removed.
If None is returned, no listener will be generated.
"""
return None
def _check_unicode_returns(self, connection):
# Py2K
if self.supports_unicode_statements:
cast_to = unicode
else:
cast_to = str
# end Py2K
# Py3K
#cast_to = str
def check_unicode(type_):
cursor = connection.connection.cursor()
try:
cursor.execute(
cast_to(
expression.select(
[expression.cast(
expression.literal_column("'test unicode returns'"), type_)
]).compile(dialect=self)
)
)
row = cursor.fetchone()
return isinstance(row[0], unicode)
finally:
cursor.close()
# detect plain VARCHAR
unicode_for_varchar = check_unicode(sqltypes.VARCHAR(60))
# detect if there's an NVARCHAR type with different behavior available
unicode_for_unicode = check_unicode(sqltypes.Unicode(60))
if unicode_for_unicode and not unicode_for_varchar:
return "conditional"
else:
return unicode_for_varchar
def type_descriptor(self, typeobj):
"""Provide a database-specific ``TypeEngine`` object, given
the generic object which comes from the types module.
This method looks for a dictionary called
``colspecs`` as a class or instance-level variable,
and passes on to ``types.adapt_type()``.
"""
return sqltypes.adapt_type(typeobj, self.colspecs)
def reflecttable(self, connection, table, include_columns):
insp = reflection.Inspector.from_engine(connection)
return insp.reflecttable(table, include_columns)
def validate_identifier(self, ident):
if len(ident) > self.max_identifier_length:
raise exc.IdentifierError(
"Identifier '%s' exceeds maximum length of %d characters" %
(ident, self.max_identifier_length)
)
def connect(self, *cargs, **cparams):
return self.dbapi.connect(*cargs, **cparams)
def create_connect_args(self, url):
opts = url.translate_connect_args()
opts.update(url.query)
return [[], opts]
def do_begin(self, connection):
"""Implementations might want to put logic here for turning
autocommit on/off, etc.
"""
pass
def do_rollback(self, connection):
"""Implementations might want to put logic here for turning
autocommit on/off, etc.
"""
connection.rollback()
def do_commit(self, connection):
"""Implementations might want to put logic here for turning
autocommit on/off, etc.
"""
connection.commit()
def create_xid(self):
"""Create a random two-phase transaction ID.
This id will be passed to do_begin_twophase(), do_rollback_twophase(),
do_commit_twophase(). Its format is unspecified.
"""
return "_sa_%032x" % random.randint(0, 2 ** 128)
def do_savepoint(self, connection, name):
connection.execute(expression.SavepointClause(name))
def do_rollback_to_savepoint(self, connection, name):
connection.execute(expression.RollbackToSavepointClause(name))
def do_release_savepoint(self, connection, name):
connection.execute(expression.ReleaseSavepointClause(name))
def do_executemany(self, cursor, statement, parameters, context=None):
cursor.executemany(statement, parameters)
def do_execute(self, cursor, statement, parameters, context=None):
cursor.execute(statement, parameters)
def is_disconnect(self, e):
return False
class DefaultExecutionContext(base.ExecutionContext):
execution_options = util.frozendict()
isinsert = False
isupdate = False
isdelete = False
isddl = False
executemany = False
result_map = None
compiled = None
statement = None
def __init__(self,
dialect,
connection,
compiled_sql=None,
compiled_ddl=None,
statement=None,
parameters=None):
self.dialect = dialect
self._connection = self.root_connection = connection
self.engine = connection.engine
if compiled_ddl is not None:
self.compiled = compiled = compiled_ddl
self.isddl = True
if compiled.statement._execution_options:
self.execution_options = compiled.statement._execution_options
if connection._execution_options:
self.execution_options = self.execution_options.union(
connection._execution_options
)
if not dialect.supports_unicode_statements:
self.unicode_statement = unicode(compiled)
self.statement = self.unicode_statement.encode(self.dialect.encoding)
else:
self.statement = self.unicode_statement = unicode(compiled)
self.cursor = self.create_cursor()
self.compiled_parameters = []
self.parameters = [self._default_params]
elif compiled_sql is not None:
self.compiled = compiled = compiled_sql
if not compiled.can_execute:
raise exc.ArgumentError("Not an executable clause: %s" % compiled)
if compiled.statement._execution_options:
self.execution_options = compiled.statement._execution_options
if connection._execution_options:
self.execution_options = self.execution_options.union(
connection._execution_options
)
# compiled clauseelement. process bind params, process table defaults,
# track collections used by ResultProxy to target and process results
self.processors = dict(
(key, value) for key, value in
( (compiled.bind_names[bindparam],
bindparam.bind_processor(self.dialect))
for bindparam in compiled.bind_names )
if value is not None)
self.result_map = compiled.result_map
if not dialect.supports_unicode_statements:
self.unicode_statement = unicode(compiled)
self.statement = self.unicode_statement.encode(self.dialect.encoding)
else:
self.statement = self.unicode_statement = unicode(compiled)
self.isinsert = compiled.isinsert
self.isupdate = compiled.isupdate
self.isdelete = compiled.isdelete
if not parameters:
self.compiled_parameters = [compiled.construct_params()]
else:
self.compiled_parameters = [compiled.construct_params(m, _group_number=grp) for
grp,m in enumerate(parameters)]
self.executemany = len(parameters) > 1
self.cursor = self.create_cursor()
if self.isinsert or self.isupdate:
self.__process_defaults()
self.parameters = self.__convert_compiled_params(self.compiled_parameters)
elif statement is not None:
# plain text statement
if connection._execution_options:
self.execution_options = self.execution_options.union(connection._execution_options)
self.parameters = self.__encode_param_keys(parameters)
self.executemany = len(parameters) > 1
if isinstance(statement, unicode) and not dialect.supports_unicode_statements:
self.unicode_statement = statement
self.statement = statement.encode(self.dialect.encoding)
else:
self.statement = self.unicode_statement = statement
self.cursor = self.create_cursor()
else:
# no statement. used for standalone ColumnDefault execution.
if connection._execution_options:
self.execution_options = self.execution_options.union(connection._execution_options)
self.cursor = self.create_cursor()
@util.memoized_property
def is_crud(self):
return self.isinsert or self.isupdate or self.isdelete
@util.memoized_property
def should_autocommit(self):
autocommit = self.execution_options.get('autocommit',
not self.compiled and
self.statement and
expression.PARSE_AUTOCOMMIT
or False)
if autocommit is expression.PARSE_AUTOCOMMIT:
return self.should_autocommit_text(self.unicode_statement)
else:
return autocommit
@util.memoized_property
def _is_explicit_returning(self):
return self.compiled and \
getattr(self.compiled.statement, '_returning', False)
@util.memoized_property
def _is_implicit_returning(self):
return self.compiled and \
bool(self.compiled.returning) and \
not self.compiled.statement._returning
@util.memoized_property
def _default_params(self):
if self.dialect.positional:
return self.dialect.execute_sequence_format()
else:
return {}
def _execute_scalar(self, stmt):
"""Execute a string statement on the current cursor, returning a scalar result.
Used to fire off sequences, default phrases, and "select lastrowid"
types of statements individually
or in the context of a parent INSERT or UPDATE statement.
"""
conn = self._connection
if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
stmt = stmt.encode(self.dialect.encoding)
conn._cursor_execute(self.cursor, stmt, self._default_params)
return self.cursor.fetchone()[0]
@property
def connection(self):
return self._connection._branch()
def __encode_param_keys(self, params):
"""Apply string encoding to the keys of dictionary-based bind parameters.
This is only used executing textual, non-compiled SQL expressions.
"""
if not params:
return [self._default_params]
elif isinstance(params[0], self.dialect.execute_sequence_format):
return params
elif isinstance(params[0], dict):
if self.dialect.supports_unicode_statements:
return params
else:
def proc(d):
return dict((k.encode(self.dialect.encoding), d[k]) for k in d)
return [proc(d) for d in params] or [{}]
else:
return [self.dialect.execute_sequence_format(p) for p in params]
def __convert_compiled_params(self, compiled_parameters):
"""Convert the dictionary of bind parameter values into a dict or list
to be sent to the DBAPI's execute() or executemany() method.
"""
processors = self.processors
parameters = []
if self.dialect.positional:
for compiled_params in compiled_parameters:
param = []
for key in self.compiled.positiontup:
if key in processors:
param.append(processors[key](compiled_params[key]))
else:
param.append(compiled_params[key])
parameters.append(self.dialect.execute_sequence_format(param))
else:
encode = not self.dialect.supports_unicode_statements
for compiled_params in compiled_parameters:
param = {}
if encode:
encoding = self.dialect.encoding
for key in compiled_params:
if key in processors:
param[key.encode(encoding)] = processors[key](compiled_params[key])
else:
param[key.encode(encoding)] = compiled_params[key]
else:
for key in compiled_params:
if key in processors:
param[key] = processors[key](compiled_params[key])
else:
param[key] = compiled_params[key]
parameters.append(param)
return self.dialect.execute_sequence_format(parameters)
def should_autocommit_text(self, statement):
return AUTOCOMMIT_REGEXP.match(statement)
def create_cursor(self):
return self._connection.connection.cursor()
def pre_exec(self):
pass
def post_exec(self):
pass
def get_lastrowid(self):
"""return self.cursor.lastrowid, or equivalent, after an INSERT.
This may involve calling special cursor functions,
issuing a new SELECT on the cursor (or a new one),
or returning a stored value that was
calculated within post_exec().
This function will only be called for dialects
which support "implicit" primary key generation,
keep preexecute_autoincrement_sequences set to False,
and when no explicit id value was bound to the
statement.
The function is called once, directly after
post_exec() and before the transaction is committed
or ResultProxy is generated. If the post_exec()
method assigns a value to `self._lastrowid`, the
value is used in place of calling get_lastrowid().
Note that this method is *not* equivalent to the
``lastrowid`` method on ``ResultProxy``, which is a
direct proxy to the DBAPI ``lastrowid`` accessor
in all cases.
"""
return self.cursor.lastrowid
def handle_dbapi_exception(self, e):
pass
def get_result_proxy(self):
return base.ResultProxy(self)
@property
def rowcount(self):
return self.cursor.rowcount
def supports_sane_rowcount(self):
return self.dialect.supports_sane_rowcount
def supports_sane_multi_rowcount(self):
return self.dialect.supports_sane_multi_rowcount
def post_insert(self):
if self.dialect.postfetch_lastrowid and \
(not len(self._inserted_primary_key) or \
None in self._inserted_primary_key):
table = self.compiled.statement.table
lastrowid = self.get_lastrowid()
self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v
for c, v in zip(table.primary_key, self._inserted_primary_key)
]
def _fetch_implicit_returning(self, resultproxy):
table = self.compiled.statement.table
row = resultproxy.fetchone()
self._inserted_primary_key = [v is not None and v or row[c]
for c, v in zip(table.primary_key, self._inserted_primary_key)
]
def last_inserted_params(self):
return self._last_inserted_params
def last_updated_params(self):
return self._last_updated_params
def lastrow_has_defaults(self):
return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
def set_input_sizes(self, translate=None, exclude_types=None):
"""Given a cursor and ClauseParameters, call the appropriate
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
"""
if not hasattr(self.compiled, 'bind_names'):
return
types = dict(
(self.compiled.bind_names[bindparam], bindparam.type)
for bindparam in self.compiled.bind_names)
if self.dialect.positional:
inputsizes = []
for key in self.compiled.positiontup:
typeengine = types[key]
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
inputsizes.append(dbtype)
try:
self.cursor.setinputsizes(*inputsizes)
except Exception, e:
self._connection._handle_dbapi_exception(e, None, None, None, self)
raise
else:
inputsizes = {}
for key in self.compiled.bind_names.values():
typeengine = types[key]
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
if translate:
key = translate.get(key, key)
inputsizes[key.encode(self.dialect.encoding)] = dbtype
try:
self.cursor.setinputsizes(**inputsizes)
except Exception, e:
self._connection._handle_dbapi_exception(e, None, None, None, self)
raise
def _exec_default(self, default):
if default.is_sequence:
return self.fire_sequence(default)
elif default.is_callable:
return default.arg(self)
elif default.is_clause_element:
# TODO: expensive branching here should be
# pulled into _exec_scalar()
conn = self.connection
c = expression.select([default.arg]).compile(bind=conn)
return conn._execute_compiled(c, (), {}).scalar()
else:
return default.arg
def get_insert_default(self, column):
if column.default is None:
return None
else:
return self._exec_default(column.default)
def get_update_default(self, column):
if column.onupdate is None:
return None
else:
return self._exec_default(column.onupdate)
def __process_defaults(self):
"""Generate default values for compiled insert/update statements,
and generate inserted_primary_key collection.
"""
if self.executemany:
if len(self.compiled.prefetch):
scalar_defaults = {}
# pre-determine scalar Python-side defaults
# to avoid many calls of get_insert_default()/get_update_default()
for c in self.compiled.prefetch:
if self.isinsert and c.default and c.default.is_scalar:
scalar_defaults[c] = c.default.arg
elif self.isupdate and c.onupdate and c.onupdate.is_scalar:
scalar_defaults[c] = c.onupdate.arg
for param in self.compiled_parameters:
self.current_parameters = param
for c in self.compiled.prefetch:
if c in scalar_defaults:
val = scalar_defaults[c]
elif self.isinsert:
val = self.get_insert_default(c)
else:
val = self.get_update_default(c)
if val is not None:
param[c.key] = val
del self.current_parameters
else:
self.current_parameters = compiled_parameters = self.compiled_parameters[0]
for c in self.compiled.prefetch:
if self.isinsert:
val = self.get_insert_default(c)
else:
val = self.get_update_default(c)
if val is not None:
compiled_parameters[c.key] = val
del self.current_parameters
if self.isinsert:
self._inserted_primary_key = [compiled_parameters.get(c.key, None)
for c in self.compiled.statement.table.primary_key]
self._last_inserted_params = compiled_parameters
else:
self._last_updated_params = compiled_parameters
self.postfetch_cols = self.compiled.postfetch
self.prefetch_cols = self.compiled.prefetch
DefaultDialect.execution_ctx_cls = DefaultExecutionContext

View File

@ -0,0 +1,370 @@
"""Provides an abstraction for obtaining database schema information.
Usage Notes:
Here are some general conventions when accessing the low level inspector
methods such as get_table_names, get_columns, etc.
1. Inspector methods return lists of dicts in most cases for the following
reasons:
* They're both standard types that can be serialized.
* Using a dict instead of a tuple allows easy expansion of attributes.
* Using a list for the outer structure maintains order and is easy to work
with (e.g. list comprehension [d['name'] for d in cols]).
2. Records that contain a name, such as the column name in a column record
use the key 'name'. So for most return values, each record will have a
'name' attribute..
"""
import sqlalchemy
from sqlalchemy import exc, sql
from sqlalchemy import util
from sqlalchemy.types import TypeEngine
from sqlalchemy import schema as sa_schema
@util.decorator
def cache(fn, self, con, *args, **kw):
info_cache = kw.get('info_cache', None)
if info_cache is None:
return fn(self, con, *args, **kw)
key = (
fn.__name__,
tuple(a for a in args if isinstance(a, basestring)),
tuple((k, v) for k, v in kw.iteritems() if isinstance(v, (basestring, int, float)))
)
ret = info_cache.get(key)
if ret is None:
ret = fn(self, con, *args, **kw)
info_cache[key] = ret
return ret
class Inspector(object):
"""Performs database schema inspection.
The Inspector acts as a proxy to the dialects' reflection methods and
provides higher level functions for accessing database schema information.
"""
def __init__(self, conn):
"""Initialize the instance.
:param conn: a :class:`~sqlalchemy.engine.base.Connectable`
"""
self.conn = conn
# set the engine
if hasattr(conn, 'engine'):
self.engine = conn.engine
else:
self.engine = conn
self.dialect = self.engine.dialect
self.info_cache = {}
@classmethod
def from_engine(cls, engine):
if hasattr(engine.dialect, 'inspector'):
return engine.dialect.inspector(engine)
return Inspector(engine)
@property
def default_schema_name(self):
return self.dialect.default_schema_name
def get_schema_names(self):
"""Return all schema names.
"""
if hasattr(self.dialect, 'get_schema_names'):
return self.dialect.get_schema_names(self.conn,
info_cache=self.info_cache)
return []
def get_table_names(self, schema=None, order_by=None):
"""Return all table names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
:param order_by: Optional, may be the string "foreign_key" to sort
the result on foreign key dependencies.
This should probably not return view names or maybe it should return
them with an indicator t or v.
"""
if hasattr(self.dialect, 'get_table_names'):
tnames = self.dialect.get_table_names(self.conn,
schema,
info_cache=self.info_cache)
else:
tnames = self.engine.table_names(schema)
if order_by == 'foreign_key':
ordered_tnames = tnames[:]
# Order based on foreign key dependencies.
for tname in tnames:
table_pos = tnames.index(tname)
fkeys = self.get_foreign_keys(tname, schema)
for fkey in fkeys:
rtable = fkey['referred_table']
if rtable in ordered_tnames:
ref_pos = ordered_tnames.index(rtable)
# Make sure it's lower in the list than anything it
# references.
if table_pos > ref_pos:
ordered_tnames.pop(table_pos) # rtable moves up 1
# insert just below rtable
ordered_tnames.index(ref_pos, tname)
tnames = ordered_tnames
return tnames
def get_table_options(self, table_name, schema=None, **kw):
if hasattr(self.dialect, 'get_table_options'):
return self.dialect.get_table_options(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
return {}
def get_view_names(self, schema=None):
"""Return all view names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
"""
return self.dialect.get_view_names(self.conn, schema,
info_cache=self.info_cache)
def get_view_definition(self, view_name, schema=None):
"""Return definition for `view_name`.
:param schema: Optional, retrieve names from a non-default schema.
"""
return self.dialect.get_view_definition(
self.conn, view_name, schema, info_cache=self.info_cache)
def get_columns(self, table_name, schema=None, **kw):
"""Return information about columns in `table_name`.
Given a string `table_name` and an optional string `schema`, return
column information as a list of dicts with these keys:
name
the column's name
type
:class:`~sqlalchemy.types.TypeEngine`
nullable
boolean
default
the column's default value
attrs
dict containing optional column attributes
"""
col_defs = self.dialect.get_columns(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
for col_def in col_defs:
# make this easy and only return instances for coltype
coltype = col_def['type']
if not isinstance(coltype, TypeEngine):
col_def['type'] = coltype()
return col_defs
def get_primary_keys(self, table_name, schema=None, **kw):
"""Return information about primary keys in `table_name`.
Given a string `table_name`, and an optional string `schema`, return
primary key information as a list of column names.
"""
pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
return pkeys
def get_foreign_keys(self, table_name, schema=None, **kw):
"""Return information about foreign_keys in `table_name`.
Given a string `table_name`, and an optional string `schema`, return
foreign key information as a list of dicts with these keys:
constrained_columns
a list of column names that make up the foreign key
referred_schema
the name of the referred schema
referred_table
the name of the referred table
referred_columns
a list of column names in the referred table that correspond to
constrained_columns
\**kw
other options passed to the dialect's get_foreign_keys() method.
"""
fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
return fk_defs
def get_indexes(self, table_name, schema=None, **kw):
"""Return information about indexes in `table_name`.
Given a string `table_name` and an optional string `schema`, return
index information as a list of dicts with these keys:
name
the index's name
column_names
list of column names in order
unique
boolean
\**kw
other options passed to the dialect's get_indexes() method.
"""
indexes = self.dialect.get_indexes(self.conn, table_name,
schema,
info_cache=self.info_cache, **kw)
return indexes
def reflecttable(self, table, include_columns):
dialect = self.conn.dialect
# MySQL dialect does this. Applicable with other dialects?
if hasattr(dialect, '_connection_charset') \
and hasattr(dialect, '_adjust_casing'):
charset = dialect._connection_charset
dialect._adjust_casing(table)
# table attributes we might need.
reflection_options = dict(
(k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs)
schema = table.schema
table_name = table.name
# apply table options
tbl_opts = self.get_table_options(table_name, schema, **table.kwargs)
if tbl_opts:
table.kwargs.update(tbl_opts)
# table.kwargs will need to be passed to each reflection method. Make
# sure keywords are strings.
tblkw = table.kwargs.copy()
for (k, v) in tblkw.items():
del tblkw[k]
tblkw[str(k)] = v
# Py2K
if isinstance(schema, str):
schema = schema.decode(dialect.encoding)
if isinstance(table_name, str):
table_name = table_name.decode(dialect.encoding)
# end Py2K
# columns
found_table = False
for col_d in self.get_columns(table_name, schema, **tblkw):
found_table = True
name = col_d['name']
if include_columns and name not in include_columns:
continue
coltype = col_d['type']
col_kw = {
'nullable':col_d['nullable'],
}
if 'autoincrement' in col_d:
col_kw['autoincrement'] = col_d['autoincrement']
if 'quote' in col_d:
col_kw['quote'] = col_d['quote']
colargs = []
if col_d.get('default') is not None:
# the "default" value is assumed to be a literal SQL expression,
# so is wrapped in text() so that no quoting occurs on re-issuance.
colargs.append(sa_schema.DefaultClause(sql.text(col_d['default'])))
if 'sequence' in col_d:
# TODO: mssql, maxdb and sybase are using this.
seq = col_d['sequence']
sequence = sa_schema.Sequence(seq['name'], 1, 1)
if 'start' in seq:
sequence.start = seq['start']
if 'increment' in seq:
sequence.increment = seq['increment']
colargs.append(sequence)
col = sa_schema.Column(name, coltype, *colargs, **col_kw)
table.append_column(col)
if not found_table:
raise exc.NoSuchTableError(table.name)
# Primary keys
primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[
table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw)
if pk in table.c
])
table.append_constraint(primary_key_constraint)
# Foreign keys
fkeys = self.get_foreign_keys(table_name, schema, **tblkw)
for fkey_d in fkeys:
conname = fkey_d['name']
constrained_columns = fkey_d['constrained_columns']
referred_schema = fkey_d['referred_schema']
referred_table = fkey_d['referred_table']
referred_columns = fkey_d['referred_columns']
refspec = []
if referred_schema is not None:
sa_schema.Table(referred_table, table.metadata,
autoload=True, schema=referred_schema,
autoload_with=self.conn,
**reflection_options
)
for column in referred_columns:
refspec.append(".".join(
[referred_schema, referred_table, column]))
else:
sa_schema.Table(referred_table, table.metadata, autoload=True,
autoload_with=self.conn,
**reflection_options
)
for column in referred_columns:
refspec.append(".".join([referred_table, column]))
table.append_constraint(
sa_schema.ForeignKeyConstraint(constrained_columns, refspec,
conname, link_to_name=True))
# Indexes
indexes = self.get_indexes(table_name, schema)
for index_d in indexes:
name = index_d['name']
columns = index_d['column_names']
unique = index_d['unique']
flavor = index_d.get('type', 'unknown type')
if include_columns and \
not set(columns).issubset(include_columns):
util.warn(
"Omitting %s KEY for (%s), key covers omitted columns." %
(flavor, ', '.join(columns)))
continue
sa_schema.Index(name, *[table.columns[c] for c in columns],
**dict(unique=unique))

View File

@ -0,0 +1,227 @@
"""Strategies for creating new instances of Engine types.
These are semi-private implementation classes which provide the
underlying behavior for the "strategy" keyword argument available on
:func:`~sqlalchemy.engine.create_engine`. Current available options are
``plain``, ``threadlocal``, and ``mock``.
New strategies can be added via new ``EngineStrategy`` classes.
"""
from operator import attrgetter
from sqlalchemy.engine import base, threadlocal, url
from sqlalchemy import util, exc
from sqlalchemy import pool as poollib
strategies = {}
class EngineStrategy(object):
"""An adaptor that processes input arguements and produces an Engine.
Provides a ``create`` method that receives input arguments and
produces an instance of base.Engine or a subclass.
"""
def __init__(self):
strategies[self.name] = self
def create(self, *args, **kwargs):
"""Given arguments, returns a new Engine instance."""
raise NotImplementedError()
class DefaultEngineStrategy(EngineStrategy):
"""Base class for built-in stratgies."""
pool_threadlocal = False
def create(self, name_or_url, **kwargs):
# create url.URL object
u = url.make_url(name_or_url)
dialect_cls = u.get_dialect()
dialect_args = {}
# consume dialect arguments from kwargs
for k in util.get_cls_kwargs(dialect_cls):
if k in kwargs:
dialect_args[k] = kwargs.pop(k)
dbapi = kwargs.pop('module', None)
if dbapi is None:
dbapi_args = {}
for k in util.get_func_kwargs(dialect_cls.dbapi):
if k in kwargs:
dbapi_args[k] = kwargs.pop(k)
dbapi = dialect_cls.dbapi(**dbapi_args)
dialect_args['dbapi'] = dbapi
# create dialect
dialect = dialect_cls(**dialect_args)
# assemble connection arguments
(cargs, cparams) = dialect.create_connect_args(u)
cparams.update(kwargs.pop('connect_args', {}))
# look for existing pool or create
pool = kwargs.pop('pool', None)
if pool is None:
def connect():
try:
return dialect.connect(*cargs, **cparams)
except Exception, e:
# Py3K
#raise exc.DBAPIError.instance(None, None, e) from e
# Py2K
import sys
raise exc.DBAPIError.instance(None, None, e), None, sys.exc_info()[2]
# end Py2K
creator = kwargs.pop('creator', connect)
poolclass = (kwargs.pop('poolclass', None) or
getattr(dialect_cls, 'poolclass', poollib.QueuePool))
pool_args = {}
# consume pool arguments from kwargs, translating a few of
# the arguments
translate = {'logging_name': 'pool_logging_name',
'echo': 'echo_pool',
'timeout': 'pool_timeout',
'recycle': 'pool_recycle',
'use_threadlocal':'pool_threadlocal'}
for k in util.get_cls_kwargs(poolclass):
tk = translate.get(k, k)
if tk in kwargs:
pool_args[k] = kwargs.pop(tk)
pool_args.setdefault('use_threadlocal', self.pool_threadlocal)
pool = poolclass(creator, **pool_args)
else:
if isinstance(pool, poollib._DBProxy):
pool = pool.get_pool(*cargs, **cparams)
else:
pool = pool
# create engine.
engineclass = self.engine_cls
engine_args = {}
for k in util.get_cls_kwargs(engineclass):
if k in kwargs:
engine_args[k] = kwargs.pop(k)
_initialize = kwargs.pop('_initialize', True)
# all kwargs should be consumed
if kwargs:
raise TypeError(
"Invalid argument(s) %s sent to create_engine(), "
"using configuration %s/%s/%s. Please check that the "
"keyword arguments are appropriate for this combination "
"of components." % (','.join("'%s'" % k for k in kwargs),
dialect.__class__.__name__,
pool.__class__.__name__,
engineclass.__name__))
engine = engineclass(pool, dialect, u, **engine_args)
if _initialize:
do_on_connect = dialect.on_connect()
if do_on_connect:
def on_connect(conn, rec):
conn = getattr(conn, '_sqla_unwrap', conn)
if conn is None:
return
do_on_connect(conn)
pool.add_listener({'first_connect': on_connect, 'connect':on_connect})
def first_connect(conn, rec):
c = base.Connection(engine, connection=conn)
dialect.initialize(c)
pool.add_listener({'first_connect':first_connect})
return engine
class PlainEngineStrategy(DefaultEngineStrategy):
"""Strategy for configuring a regular Engine."""
name = 'plain'
engine_cls = base.Engine
PlainEngineStrategy()
class ThreadLocalEngineStrategy(DefaultEngineStrategy):
"""Strategy for configuring an Engine with thredlocal behavior."""
name = 'threadlocal'
pool_threadlocal = True
engine_cls = threadlocal.TLEngine
ThreadLocalEngineStrategy()
class MockEngineStrategy(EngineStrategy):
"""Strategy for configuring an Engine-like object with mocked execution.
Produces a single mock Connectable object which dispatches
statement execution to a passed-in function.
"""
name = 'mock'
def create(self, name_or_url, executor, **kwargs):
# create url.URL object
u = url.make_url(name_or_url)
dialect_cls = u.get_dialect()
dialect_args = {}
# consume dialect arguments from kwargs
for k in util.get_cls_kwargs(dialect_cls):
if k in kwargs:
dialect_args[k] = kwargs.pop(k)
# create dialect
dialect = dialect_cls(**dialect_args)
return MockEngineStrategy.MockConnection(dialect, executor)
class MockConnection(base.Connectable):
def __init__(self, dialect, execute):
self._dialect = dialect
self.execute = execute
engine = property(lambda s: s)
dialect = property(attrgetter('_dialect'))
name = property(lambda s: s._dialect.name)
def contextual_connect(self, **kwargs):
return self
def compiler(self, statement, parameters, **kwargs):
return self._dialect.compiler(
statement, parameters, engine=self, **kwargs)
def create(self, entity, **kwargs):
kwargs['checkfirst'] = False
from sqlalchemy.engine import ddl
ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity)
def drop(self, entity, **kwargs):
kwargs['checkfirst'] = False
from sqlalchemy.engine import ddl
ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity)
def execute(self, object, *multiparams, **params):
raise NotImplementedError()
MockEngineStrategy()

View File

@ -0,0 +1,103 @@
"""Provides a thread-local transactional wrapper around the root Engine class.
The ``threadlocal`` module is invoked when using the ``strategy="threadlocal"`` flag
with :func:`~sqlalchemy.engine.create_engine`. This module is semi-private and is
invoked automatically when the threadlocal engine strategy is used.
"""
from sqlalchemy import util
from sqlalchemy.engine import base
import weakref
class TLConnection(base.Connection):
def __init__(self, *arg, **kw):
super(TLConnection, self).__init__(*arg, **kw)
self.__opencount = 0
def _increment_connect(self):
self.__opencount += 1
return self
def close(self):
if self.__opencount == 1:
base.Connection.close(self)
self.__opencount -= 1
def _force_close(self):
self.__opencount = 0
base.Connection.close(self)
class TLEngine(base.Engine):
"""An Engine that includes support for thread-local managed transactions."""
def __init__(self, *args, **kwargs):
super(TLEngine, self).__init__(*args, **kwargs)
self._connections = util.threading.local()
proxy = kwargs.get('proxy')
if proxy:
self.TLConnection = base._proxy_connection_cls(TLConnection, proxy)
else:
self.TLConnection = TLConnection
def contextual_connect(self, **kw):
if not hasattr(self._connections, 'conn'):
connection = None
else:
connection = self._connections.conn()
if connection is None or connection.closed:
# guards against pool-level reapers, if desired.
# or not connection.connection.is_valid:
connection = self.TLConnection(self, self.pool.connect(), **kw)
self._connections.conn = conn = weakref.ref(connection)
return connection._increment_connect()
def begin_twophase(self, xid=None):
if not hasattr(self._connections, 'trans'):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid))
def begin_nested(self):
if not hasattr(self._connections, 'trans'):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin_nested())
def begin(self):
if not hasattr(self._connections, 'trans'):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin())
def prepare(self):
self._connections.trans[-1].prepare()
def commit(self):
trans = self._connections.trans.pop(-1)
trans.commit()
def rollback(self):
trans = self._connections.trans.pop(-1)
trans.rollback()
def dispose(self):
self._connections = util.threading.local()
super(TLEngine, self).dispose()
@property
def closed(self):
return not hasattr(self._connections, 'conn') or \
self._connections.conn() is None or \
self._connections.conn().closed
def close(self):
if not self.closed:
self.contextual_connect().close()
connection = self._connections.conn()
connection._force_close()
del self._connections.conn
self._connections.trans = []
def __repr__(self):
return 'TLEngine(%s)' % str(self.url)

214
sqlalchemy/engine/url.py Normal file
View File

@ -0,0 +1,214 @@
"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates
information about a database connection specification.
The URL object is created automatically when :func:`~sqlalchemy.engine.create_engine` is called
with a string argument; alternatively, the URL is a public-facing construct which can
be used directly and is also accepted directly by ``create_engine()``.
"""
import re, cgi, sys, urllib
from sqlalchemy import exc
class URL(object):
"""
Represent the components of a URL used to connect to a database.
This object is suitable to be passed directly to a
``create_engine()`` call. The fields of the URL are parsed from a
string by the ``module-level make_url()`` function. the string
format of the URL is an RFC-1738-style string.
All initialization parameters are available as public attributes.
:param drivername: the name of the database backend.
This name will correspond to a module in sqlalchemy/databases
or a third party plug-in.
:param username: The user name.
:param password: database password.
:param host: The name of the host.
:param port: The port number.
:param database: The database name.
:param query: A dictionary of options to be passed to the
dialect and/or the DBAPI upon connect.
"""
def __init__(self, drivername, username=None, password=None,
host=None, port=None, database=None, query=None):
self.drivername = drivername
self.username = username
self.password = password
self.host = host
if port is not None:
self.port = int(port)
else:
self.port = None
self.database = database
self.query = query or {}
def __str__(self):
s = self.drivername + "://"
if self.username is not None:
s += self.username
if self.password is not None:
s += ':' + urllib.quote_plus(self.password)
s += "@"
if self.host is not None:
s += self.host
if self.port is not None:
s += ':' + str(self.port)
if self.database is not None:
s += '/' + self.database
if self.query:
keys = self.query.keys()
keys.sort()
s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys)
return s
def __hash__(self):
return hash(str(self))
def __eq__(self, other):
return \
isinstance(other, URL) and \
self.drivername == other.drivername and \
self.username == other.username and \
self.password == other.password and \
self.host == other.host and \
self.database == other.database and \
self.query == other.query
def get_dialect(self):
"""Return the SQLAlchemy database dialect class corresponding
to this URL's driver name.
"""
try:
if '+' in self.drivername:
dialect, driver = self.drivername.split('+')
else:
dialect, driver = self.drivername, 'base'
module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
module = getattr(module, dialect)
module = getattr(module, driver)
return module.dialect
except ImportError:
module = self._load_entry_point()
if module is not None:
return module
else:
raise
def _load_entry_point(self):
"""attempt to load this url's dialect from entry points, or return None
if pkg_resources is not installed or there is no matching entry point.
Raise ImportError if the actual load fails.
"""
try:
import pkg_resources
except ImportError:
return None
for res in pkg_resources.iter_entry_points('sqlalchemy.dialects'):
if res.name == self.drivername:
return res.load()
else:
return None
def translate_connect_args(self, names=[], **kw):
"""Translate url attributes into a dictionary of connection arguments.
Returns attributes of this url (`host`, `database`, `username`,
`password`, `port`) as a plain dictionary. The attribute names are
used as the keys by default. Unset or false attributes are omitted
from the final dictionary.
:param \**kw: Optional, alternate key names for url attributes.
:param names: Deprecated. Same purpose as the keyword-based alternate names,
but correlates the name to the original positionally.
"""
translated = {}
attribute_names = ['host', 'database', 'username', 'password', 'port']
for sname in attribute_names:
if names:
name = names.pop(0)
elif sname in kw:
name = kw[sname]
else:
name = sname
if name is not None and getattr(self, sname, False):
translated[name] = getattr(self, sname)
return translated
def make_url(name_or_url):
"""Given a string or unicode instance, produce a new URL instance.
The given string is parsed according to the RFC 1738 spec. If an
existing URL object is passed, just returns the object.
"""
if isinstance(name_or_url, basestring):
return _parse_rfc1738_args(name_or_url)
else:
return name_or_url
def _parse_rfc1738_args(name):
pattern = re.compile(r'''
(?P<name>[\w\+]+)://
(?:
(?P<username>[^:/]*)
(?::(?P<password>[^/]*))?
@)?
(?:
(?P<host>[^/:]*)
(?::(?P<port>[^/]*))?
)?
(?:/(?P<database>.*))?
'''
, re.X)
m = pattern.match(name)
if m is not None:
components = m.groupdict()
if components['database'] is not None:
tokens = components['database'].split('?', 2)
components['database'] = tokens[0]
query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None
# Py2K
if query is not None:
query = dict((k.encode('ascii'), query[k]) for k in query)
# end Py2K
else:
query = None
components['query'] = query
if components['password'] is not None:
components['password'] = urllib.unquote_plus(components['password'])
name = components.pop('name')
return URL(name, **components)
else:
raise exc.ArgumentError(
"Could not parse rfc1738 URL from string '%s'" % name)
def _parse_keyvalue_args(name):
m = re.match( r'(\w+)://(.*)', name)
if m is not None:
(name, args) = m.group(1, 2)
opts = dict( cgi.parse_qsl( args ) )
return URL(name, *opts)
else:
return None

191
sqlalchemy/exc.py Normal file
View File

@ -0,0 +1,191 @@
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Exceptions used with SQLAlchemy.
The base exception class is SQLAlchemyError. Exceptions which are raised as a
result of DBAPI exceptions are all subclasses of
:class:`~sqlalchemy.exc.DBAPIError`.
"""
class SQLAlchemyError(Exception):
"""Generic error class."""
class ArgumentError(SQLAlchemyError):
"""Raised when an invalid or conflicting function argument is supplied.
This error generally corresponds to construction time state errors.
"""
class CircularDependencyError(SQLAlchemyError):
"""Raised by topological sorts when a circular dependency is detected"""
class CompileError(SQLAlchemyError):
"""Raised when an error occurs during SQL compilation"""
class IdentifierError(SQLAlchemyError):
"""Raised when a schema name is beyond the max character limit"""
# Moved to orm.exc; compatability definition installed by orm import until 0.6
ConcurrentModificationError = None
class DisconnectionError(SQLAlchemyError):
"""A disconnect is detected on a raw DB-API connection.
This error is raised and consumed internally by a connection pool. It can
be raised by a ``PoolListener`` so that the host pool forces a disconnect.
"""
# Moved to orm.exc; compatability definition installed by orm import until 0.6
FlushError = None
class TimeoutError(SQLAlchemyError):
"""Raised when a connection pool times out on getting a connection."""
class InvalidRequestError(SQLAlchemyError):
"""SQLAlchemy was asked to do something it can't do.
This error generally corresponds to runtime state errors.
"""
class NoSuchColumnError(KeyError, InvalidRequestError):
"""A nonexistent column is requested from a ``RowProxy``."""
class NoReferenceError(InvalidRequestError):
"""Raised by ``ForeignKey`` to indicate a reference cannot be resolved."""
class NoReferencedTableError(NoReferenceError):
"""Raised by ``ForeignKey`` when the referred ``Table`` cannot be located."""
class NoReferencedColumnError(NoReferenceError):
"""Raised by ``ForeignKey`` when the referred ``Column`` cannot be located."""
class NoSuchTableError(InvalidRequestError):
"""Table does not exist or is not visible to a connection."""
class UnboundExecutionError(InvalidRequestError):
"""SQL was attempted without a database connection to execute it on."""
# Moved to orm.exc; compatability definition installed by orm import until 0.6
UnmappedColumnError = None
class DBAPIError(SQLAlchemyError):
"""Raised when the execution of a database operation fails.
``DBAPIError`` wraps exceptions raised by the DB-API underlying the
database operation. Driver-specific implementations of the standard
DB-API exception types are wrapped by matching sub-types of SQLAlchemy's
``DBAPIError`` when possible. DB-API's ``Error`` type maps to
``DBAPIError`` in SQLAlchemy, otherwise the names are identical. Note
that there is no guarantee that different DB-API implementations will
raise the same exception type for any given error condition.
If the error-raising operation occured in the execution of a SQL
statement, that statement and its parameters will be available on
the exception object in the ``statement`` and ``params`` attributes.
The wrapped exception object is available in the ``orig`` attribute.
Its type and properties are DB-API implementation specific.
"""
@classmethod
def instance(cls, statement, params, orig, connection_invalidated=False):
# Don't ever wrap these, just return them directly as if
# DBAPIError didn't exist.
if isinstance(orig, (KeyboardInterrupt, SystemExit)):
return orig
if orig is not None:
name, glob = orig.__class__.__name__, globals()
if name in glob and issubclass(glob[name], DBAPIError):
cls = glob[name]
return cls(statement, params, orig, connection_invalidated)
def __init__(self, statement, params, orig, connection_invalidated=False):
try:
text = str(orig)
except (KeyboardInterrupt, SystemExit):
raise
except Exception, e:
text = 'Error in str() of DB-API-generated exception: ' + str(e)
SQLAlchemyError.__init__(
self, '(%s) %s' % (orig.__class__.__name__, text))
self.statement = statement
self.params = params
self.orig = orig
self.connection_invalidated = connection_invalidated
def __str__(self):
if isinstance(self.params, (list, tuple)) and len(self.params) > 10 and isinstance(self.params[0], (list, dict, tuple)):
return ' '.join((SQLAlchemyError.__str__(self),
repr(self.statement),
repr(self.params[:2]),
'... and a total of %i bound parameter sets' % len(self.params)))
return ' '.join((SQLAlchemyError.__str__(self),
repr(self.statement), repr(self.params)))
# As of 0.4, SQLError is now DBAPIError.
# SQLError alias will be removed in 0.6.
SQLError = DBAPIError
class InterfaceError(DBAPIError):
"""Wraps a DB-API InterfaceError."""
class DatabaseError(DBAPIError):
"""Wraps a DB-API DatabaseError."""
class DataError(DatabaseError):
"""Wraps a DB-API DataError."""
class OperationalError(DatabaseError):
"""Wraps a DB-API OperationalError."""
class IntegrityError(DatabaseError):
"""Wraps a DB-API IntegrityError."""
class InternalError(DatabaseError):
"""Wraps a DB-API InternalError."""
class ProgrammingError(DatabaseError):
"""Wraps a DB-API ProgrammingError."""
class NotSupportedError(DatabaseError):
"""Wraps a DB-API NotSupportedError."""
# Warnings
class SADeprecationWarning(DeprecationWarning):
"""Issued once per usage of a deprecated API."""
class SAPendingDeprecationWarning(PendingDeprecationWarning):
"""Issued once per usage of a deprecated API."""
class SAWarning(RuntimeWarning):
"""Issued at runtime."""

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,878 @@
"""Contain the ``AssociationProxy`` class.
The ``AssociationProxy`` is a Python property object which provides
transparent proxied access to the endpoint of an association object.
See the example ``examples/association/proxied_association.py``.
"""
import itertools
import operator
import weakref
from sqlalchemy import exceptions
from sqlalchemy import orm
from sqlalchemy import util
from sqlalchemy.orm import collections
from sqlalchemy.sql import not_
def association_proxy(target_collection, attr, **kw):
"""Return a Python property implementing a view of *attr* over a collection.
Implements a read/write view over an instance's *target_collection*,
extracting *attr* from each member of the collection. The property acts
somewhat like this list comprehension::
[getattr(member, *attr*)
for member in getattr(instance, *target_collection*)]
Unlike the list comprehension, the collection returned by the property is
always in sync with *target_collection*, and mutations made to either
collection will be reflected in both.
Implements a Python property representing a relationship as a collection of
simpler values. The proxied property will mimic the collection type of
the target (list, dict or set), or, in the case of a one to one relationship,
a simple scalar value.
:param target_collection: Name of the relationship attribute we'll proxy to,
usually created with :func:`~sqlalchemy.orm.relationship`.
:param attr: Attribute on the associated instances we'll proxy for.
For example, given a target collection of [obj1, obj2], a list created
by this proxy property would look like [getattr(obj1, *attr*),
getattr(obj2, *attr*)]
If the relationship is one-to-one or otherwise uselist=False, then simply:
getattr(obj, *attr*)
:param creator: optional.
When new items are added to this proxied collection, new instances of
the class collected by the target collection will be created. For list
and set collections, the target class constructor will be called with
the 'value' for the new instance. For dict types, two arguments are
passed: key and value.
If you want to construct instances differently, supply a *creator*
function that takes arguments as above and returns instances.
For scalar relationships, creator() will be called if the target is None.
If the target is present, set operations are proxied to setattr() on the
associated object.
If you have an associated object with multiple attributes, you may set
up multiple association proxies mapping to different attributes. See
the unit tests for examples, and for examples of how creator() functions
can be used to construct the scalar relationship on-demand in this
situation.
:param \*\*kw: Passes along any other keyword arguments to
:class:`AssociationProxy`.
"""
return AssociationProxy(target_collection, attr, **kw)
class AssociationProxy(object):
"""A descriptor that presents a read/write view of an object attribute."""
def __init__(self, target_collection, attr, creator=None,
getset_factory=None, proxy_factory=None, proxy_bulk_set=None):
"""Arguments are:
target_collection
Name of the collection we'll proxy to, usually created with
'relationship()' in a mapper setup.
attr
Attribute on the collected instances we'll proxy for. For example,
given a target collection of [obj1, obj2], a list created by this
proxy property would look like [getattr(obj1, attr), getattr(obj2,
attr)]
creator
Optional. When new items are added to this proxied collection, new
instances of the class collected by the target collection will be
created. For list and set collections, the target class constructor
will be called with the 'value' for the new instance. For dict
types, two arguments are passed: key and value.
If you want to construct instances differently, supply a 'creator'
function that takes arguments as above and returns instances.
getset_factory
Optional. Proxied attribute access is automatically handled by
routines that get and set values based on the `attr` argument for
this proxy.
If you would like to customize this behavior, you may supply a
`getset_factory` callable that produces a tuple of `getter` and
`setter` functions. The factory is called with two arguments, the
abstract type of the underlying collection and this proxy instance.
proxy_factory
Optional. The type of collection to emulate is determined by
sniffing the target collection. If your collection type can't be
determined by duck typing or you'd like to use a different
collection implementation, you may supply a factory function to
produce those collections. Only applicable to non-scalar relationships.
proxy_bulk_set
Optional, use with proxy_factory. See the _set() method for
details.
"""
self.target_collection = target_collection
self.value_attr = attr
self.creator = creator
self.getset_factory = getset_factory
self.proxy_factory = proxy_factory
self.proxy_bulk_set = proxy_bulk_set
self.scalar = None
self.owning_class = None
self.key = '_%s_%s_%s' % (
type(self).__name__, target_collection, id(self))
self.collection_class = None
def _get_property(self):
return (orm.class_mapper(self.owning_class).
get_property(self.target_collection))
@property
def target_class(self):
"""The class the proxy is attached to."""
return self._get_property().mapper.class_
def _target_is_scalar(self):
return not self._get_property().uselist
def __get__(self, obj, class_):
if self.owning_class is None:
self.owning_class = class_ and class_ or type(obj)
if obj is None:
return self
elif self.scalar is None:
self.scalar = self._target_is_scalar()
if self.scalar:
self._initialize_scalar_accessors()
if self.scalar:
return self._scalar_get(getattr(obj, self.target_collection))
else:
try:
# If the owning instance is reborn (orm session resurrect,
# etc.), refresh the proxy cache.
creator_id, proxy = getattr(obj, self.key)
if id(obj) == creator_id:
return proxy
except AttributeError:
pass
proxy = self._new(_lazy_collection(obj, self.target_collection))
setattr(obj, self.key, (id(obj), proxy))
return proxy
def __set__(self, obj, values):
if self.owning_class is None:
self.owning_class = type(obj)
if self.scalar is None:
self.scalar = self._target_is_scalar()
if self.scalar:
self._initialize_scalar_accessors()
if self.scalar:
creator = self.creator and self.creator or self.target_class
target = getattr(obj, self.target_collection)
if target is None:
setattr(obj, self.target_collection, creator(values))
else:
self._scalar_set(target, values)
else:
proxy = self.__get__(obj, None)
if proxy is not values:
proxy.clear()
self._set(proxy, values)
def __delete__(self, obj):
if self.owning_class is None:
self.owning_class = type(obj)
delattr(obj, self.key)
def _initialize_scalar_accessors(self):
if self.getset_factory:
get, set = self.getset_factory(None, self)
else:
get, set = self._default_getset(None)
self._scalar_get, self._scalar_set = get, set
def _default_getset(self, collection_class):
attr = self.value_attr
getter = operator.attrgetter(attr)
if collection_class is dict:
setter = lambda o, k, v: setattr(o, attr, v)
else:
setter = lambda o, v: setattr(o, attr, v)
return getter, setter
def _new(self, lazy_collection):
creator = self.creator and self.creator or self.target_class
self.collection_class = util.duck_type_collection(lazy_collection())
if self.proxy_factory:
return self.proxy_factory(lazy_collection, creator, self.value_attr, self)
if self.getset_factory:
getter, setter = self.getset_factory(self.collection_class, self)
else:
getter, setter = self._default_getset(self.collection_class)
if self.collection_class is list:
return _AssociationList(lazy_collection, creator, getter, setter, self)
elif self.collection_class is dict:
return _AssociationDict(lazy_collection, creator, getter, setter, self)
elif self.collection_class is set:
return _AssociationSet(lazy_collection, creator, getter, setter, self)
else:
raise exceptions.ArgumentError(
'could not guess which interface to use for '
'collection_class "%s" backing "%s"; specify a '
'proxy_factory and proxy_bulk_set manually' %
(self.collection_class.__name__, self.target_collection))
def _inflate(self, proxy):
creator = self.creator and self.creator or self.target_class
if self.getset_factory:
getter, setter = self.getset_factory(self.collection_class, self)
else:
getter, setter = self._default_getset(self.collection_class)
proxy.creator = creator
proxy.getter = getter
proxy.setter = setter
def _set(self, proxy, values):
if self.proxy_bulk_set:
self.proxy_bulk_set(proxy, values)
elif self.collection_class is list:
proxy.extend(values)
elif self.collection_class is dict:
proxy.update(values)
elif self.collection_class is set:
proxy.update(values)
else:
raise exceptions.ArgumentError(
'no proxy_bulk_set supplied for custom '
'collection_class implementation')
@property
def _comparator(self):
return self._get_property().comparator
def any(self, criterion=None, **kwargs):
return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
def has(self, criterion=None, **kwargs):
return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
def contains(self, obj):
return self._comparator.any(**{self.value_attr: obj})
def __eq__(self, obj):
return self._comparator.has(**{self.value_attr: obj})
def __ne__(self, obj):
return not_(self.__eq__(obj))
class _lazy_collection(object):
def __init__(self, obj, target):
self.ref = weakref.ref(obj)
self.target = target
def __call__(self):
obj = self.ref()
if obj is None:
raise exceptions.InvalidRequestError(
"stale association proxy, parent object has gone out of "
"scope")
return getattr(obj, self.target)
def __getstate__(self):
return {'obj':self.ref(), 'target':self.target}
def __setstate__(self, state):
self.ref = weakref.ref(state['obj'])
self.target = state['target']
class _AssociationCollection(object):
def __init__(self, lazy_collection, creator, getter, setter, parent):
"""Constructs an _AssociationCollection.
This will always be a subclass of either _AssociationList,
_AssociationSet, or _AssociationDict.
lazy_collection
A callable returning a list-based collection of entities (usually an
object attribute managed by a SQLAlchemy relationship())
creator
A function that creates new target entities. Given one parameter:
value. This assertion is assumed::
obj = creator(somevalue)
assert getter(obj) == somevalue
getter
A function. Given an associated object, return the 'value'.
setter
A function. Given an associated object and a value, store that
value on the object.
"""
self.lazy_collection = lazy_collection
self.creator = creator
self.getter = getter
self.setter = setter
self.parent = parent
col = property(lambda self: self.lazy_collection())
def __len__(self):
return len(self.col)
def __nonzero__(self):
return bool(self.col)
def __getstate__(self):
return {'parent':self.parent, 'lazy_collection':self.lazy_collection}
def __setstate__(self, state):
self.parent = state['parent']
self.lazy_collection = state['lazy_collection']
self.parent._inflate(self)
class _AssociationList(_AssociationCollection):
"""Generic, converting, list-to-list proxy."""
def _create(self, value):
return self.creator(value)
def _get(self, object):
return self.getter(object)
def _set(self, object, value):
return self.setter(object, value)
def __getitem__(self, index):
return self._get(self.col[index])
def __setitem__(self, index, value):
if not isinstance(index, slice):
self._set(self.col[index], value)
else:
if index.stop is None:
stop = len(self)
elif index.stop < 0:
stop = len(self) + index.stop
else:
stop = index.stop
step = index.step or 1
rng = range(index.start or 0, stop, step)
if step == 1:
for i in rng:
del self[index.start]
i = index.start
for item in value:
self.insert(i, item)
i += 1
else:
if len(value) != len(rng):
raise ValueError(
"attempt to assign sequence of size %s to "
"extended slice of size %s" % (len(value),
len(rng)))
for i, item in zip(rng, value):
self._set(self.col[i], item)
def __delitem__(self, index):
del self.col[index]
def __contains__(self, value):
for member in self.col:
# testlib.pragma exempt:__eq__
if self._get(member) == value:
return True
return False
def __getslice__(self, start, end):
return [self._get(member) for member in self.col[start:end]]
def __setslice__(self, start, end, values):
members = [self._create(v) for v in values]
self.col[start:end] = members
def __delslice__(self, start, end):
del self.col[start:end]
def __iter__(self):
"""Iterate over proxied values.
For the actual domain objects, iterate over .col instead or
just use the underlying collection directly from its property
on the parent.
"""
for member in self.col:
yield self._get(member)
raise StopIteration
def append(self, value):
item = self._create(value)
self.col.append(item)
def count(self, value):
return sum([1 for _ in
itertools.ifilter(lambda v: v == value, iter(self))])
def extend(self, values):
for v in values:
self.append(v)
def insert(self, index, value):
self.col[index:index] = [self._create(value)]
def pop(self, index=-1):
return self.getter(self.col.pop(index))
def remove(self, value):
for i, val in enumerate(self):
if val == value:
del self.col[i]
return
raise ValueError("value not in list")
def reverse(self):
"""Not supported, use reversed(mylist)"""
raise NotImplementedError
def sort(self):
"""Not supported, use sorted(mylist)"""
raise NotImplementedError
def clear(self):
del self.col[0:len(self.col)]
def __eq__(self, other):
return list(self) == other
def __ne__(self, other):
return list(self) != other
def __lt__(self, other):
return list(self) < other
def __le__(self, other):
return list(self) <= other
def __gt__(self, other):
return list(self) > other
def __ge__(self, other):
return list(self) >= other
def __cmp__(self, other):
return cmp(list(self), other)
def __add__(self, iterable):
try:
other = list(iterable)
except TypeError:
return NotImplemented
return list(self) + other
def __radd__(self, iterable):
try:
other = list(iterable)
except TypeError:
return NotImplemented
return other + list(self)
def __mul__(self, n):
if not isinstance(n, int):
return NotImplemented
return list(self) * n
__rmul__ = __mul__
def __iadd__(self, iterable):
self.extend(iterable)
return self
def __imul__(self, n):
# unlike a regular list *=, proxied __imul__ will generate unique
# backing objects for each copy. *= on proxied lists is a bit of
# a stretch anyhow, and this interpretation of the __imul__ contract
# is more plausibly useful than copying the backing objects.
if not isinstance(n, int):
return NotImplemented
if n == 0:
self.clear()
elif n > 1:
self.extend(list(self) * (n - 1))
return self
def copy(self):
return list(self)
def __repr__(self):
return repr(list(self))
def __hash__(self):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items():
if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(list, func_name)):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
_NotProvided = util.symbol('_NotProvided')
class _AssociationDict(_AssociationCollection):
"""Generic, converting, dict-to-dict proxy."""
def _create(self, key, value):
return self.creator(key, value)
def _get(self, object):
return self.getter(object)
def _set(self, object, key, value):
return self.setter(object, key, value)
def __getitem__(self, key):
return self._get(self.col[key])
def __setitem__(self, key, value):
if key in self.col:
self._set(self.col[key], key, value)
else:
self.col[key] = self._create(key, value)
def __delitem__(self, key):
del self.col[key]
def __contains__(self, key):
# testlib.pragma exempt:__hash__
return key in self.col
def has_key(self, key):
# testlib.pragma exempt:__hash__
return key in self.col
def __iter__(self):
return self.col.iterkeys()
def clear(self):
self.col.clear()
def __eq__(self, other):
return dict(self) == other
def __ne__(self, other):
return dict(self) != other
def __lt__(self, other):
return dict(self) < other
def __le__(self, other):
return dict(self) <= other
def __gt__(self, other):
return dict(self) > other
def __ge__(self, other):
return dict(self) >= other
def __cmp__(self, other):
return cmp(dict(self), other)
def __repr__(self):
return repr(dict(self.items()))
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def setdefault(self, key, default=None):
if key not in self.col:
self.col[key] = self._create(key, default)
return default
else:
return self[key]
def keys(self):
return self.col.keys()
def iterkeys(self):
return self.col.iterkeys()
def values(self):
return [ self._get(member) for member in self.col.values() ]
def itervalues(self):
for key in self.col:
yield self._get(self.col[key])
raise StopIteration
def items(self):
return [(k, self._get(self.col[k])) for k in self]
def iteritems(self):
for key in self.col:
yield (key, self._get(self.col[key]))
raise StopIteration
def pop(self, key, default=_NotProvided):
if default is _NotProvided:
member = self.col.pop(key)
else:
member = self.col.pop(key, default)
return self._get(member)
def popitem(self):
item = self.col.popitem()
return (item[0], self._get(item[1]))
def update(self, *a, **kw):
if len(a) > 1:
raise TypeError('update expected at most 1 arguments, got %i' %
len(a))
elif len(a) == 1:
seq_or_map = a[0]
for item in seq_or_map:
if isinstance(item, tuple):
self[item[0]] = item[1]
else:
self[item] = seq_or_map[item]
for key, value in kw:
self[key] = value
def copy(self):
return dict(self.items())
def __hash__(self):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items():
if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(dict, func_name)):
func.__doc__ = getattr(dict, func_name).__doc__
del func_name, func
class _AssociationSet(_AssociationCollection):
"""Generic, converting, set-to-set proxy."""
def _create(self, value):
return self.creator(value)
def _get(self, object):
return self.getter(object)
def _set(self, object, value):
return self.setter(object, value)
def __len__(self):
return len(self.col)
def __nonzero__(self):
if self.col:
return True
else:
return False
def __contains__(self, value):
for member in self.col:
# testlib.pragma exempt:__eq__
if self._get(member) == value:
return True
return False
def __iter__(self):
"""Iterate over proxied values.
For the actual domain objects, iterate over .col instead or just use
the underlying collection directly from its property on the parent.
"""
for member in self.col:
yield self._get(member)
raise StopIteration
def add(self, value):
if value not in self:
self.col.add(self._create(value))
# for discard and remove, choosing a more expensive check strategy rather
# than call self.creator()
def discard(self, value):
for member in self.col:
if self._get(member) == value:
self.col.discard(member)
break
def remove(self, value):
for member in self.col:
if self._get(member) == value:
self.col.discard(member)
return
raise KeyError(value)
def pop(self):
if not self.col:
raise KeyError('pop from an empty set')
member = self.col.pop()
return self._get(member)
def update(self, other):
for value in other:
self.add(value)
def __ior__(self, other):
if not collections._set_binops_check_strict(self, other):
return NotImplemented
for value in other:
self.add(value)
return self
def _set(self):
return set(iter(self))
def union(self, other):
return set(self).union(other)
__or__ = union
def difference(self, other):
return set(self).difference(other)
__sub__ = difference
def difference_update(self, other):
for value in other:
self.discard(value)
def __isub__(self, other):
if not collections._set_binops_check_strict(self, other):
return NotImplemented
for value in other:
self.discard(value)
return self
def intersection(self, other):
return set(self).intersection(other)
__and__ = intersection
def intersection_update(self, other):
want, have = self.intersection(other), set(self)
remove, add = have - want, want - have
for value in remove:
self.remove(value)
for value in add:
self.add(value)
def __iand__(self, other):
if not collections._set_binops_check_strict(self, other):
return NotImplemented
want, have = self.intersection(other), set(self)
remove, add = have - want, want - have
for value in remove:
self.remove(value)
for value in add:
self.add(value)
return self
def symmetric_difference(self, other):
return set(self).symmetric_difference(other)
__xor__ = symmetric_difference
def symmetric_difference_update(self, other):
want, have = self.symmetric_difference(other), set(self)
remove, add = have - want, want - have
for value in remove:
self.remove(value)
for value in add:
self.add(value)
def __ixor__(self, other):
if not collections._set_binops_check_strict(self, other):
return NotImplemented
want, have = self.symmetric_difference(other), set(self)
remove, add = have - want, want - have
for value in remove:
self.remove(value)
for value in add:
self.add(value)
return self
def issubset(self, other):
return set(self).issubset(other)
def issuperset(self, other):
return set(self).issuperset(other)
def clear(self):
self.col.clear()
def copy(self):
return set(self)
def __eq__(self, other):
return set(self) == other
def __ne__(self, other):
return set(self) != other
def __lt__(self, other):
return set(self) < other
def __le__(self, other):
return set(self) <= other
def __gt__(self, other):
return set(self) > other
def __ge__(self, other):
return set(self) >= other
def __repr__(self):
return repr(set(self))
def __hash__(self):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items():
if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(set, func_name)):
func.__doc__ = getattr(set, func_name).__doc__
del func_name, func

194
sqlalchemy/ext/compiler.py Normal file
View File

@ -0,0 +1,194 @@
"""Provides an API for creation of custom ClauseElements and compilers.
Synopsis
========
Usage involves the creation of one or more :class:`~sqlalchemy.sql.expression.ClauseElement`
subclasses and one or more callables defining its compilation::
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import ColumnClause
class MyColumn(ColumnClause):
pass
@compiles(MyColumn)
def compile_mycolumn(element, compiler, **kw):
return "[%s]" % element.name
Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`,
the base expression element for named column objects. The ``compiles``
decorator registers itself with the ``MyColumn`` class so that it is invoked
when the object is compiled to a string::
from sqlalchemy import select
s = select([MyColumn('x'), MyColumn('y')])
print str(s)
Produces::
SELECT [x], [y]
Dialect-specific compilation rules
==================================
Compilers can also be made dialect-specific. The appropriate compiler will be
invoked for the dialect in use::
from sqlalchemy.schema import DDLElement
class AlterColumn(DDLElement):
def __init__(self, column, cmd):
self.column = column
self.cmd = cmd
@compiles(AlterColumn)
def visit_alter_column(element, compiler, **kw):
return "ALTER COLUMN %s ..." % element.column.name
@compiles(AlterColumn, 'postgresql')
def visit_alter_column(element, compiler, **kw):
return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, element.column.name)
The second ``visit_alter_table`` will be invoked when any ``postgresql`` dialect is used.
Compiling sub-elements of a custom expression construct
=======================================================
The ``compiler`` argument is the :class:`~sqlalchemy.engine.base.Compiled`
object in use. This object can be inspected for any information about the
in-progress compilation, including ``compiler.dialect``,
``compiler.statement`` etc. The :class:`~sqlalchemy.sql.compiler.SQLCompiler`
and :class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
method which can be used for compilation of embedded attributes::
from sqlalchemy.sql.expression import Executable, ClauseElement
class InsertFromSelect(Executable, ClauseElement):
def __init__(self, table, select):
self.table = table
self.select = select
@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
return "INSERT INTO %s (%s)" % (
compiler.process(element.table, asfrom=True),
compiler.process(element.select)
)
insert = InsertFromSelect(t1, select([t1]).where(t1.c.x>5))
print insert
Produces::
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)"
Cross Compiling between SQL and DDL compilers
---------------------------------------------
SQL and DDL constructs are each compiled using different base compilers - ``SQLCompiler``
and ``DDLCompiler``. A common need is to access the compilation rules of SQL expressions
from within a DDL expression. The ``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as below where we generate a CHECK
constraint that embeds a SQL expression::
@compiles(MyConstraint)
def compile_my_constraint(constraint, ddlcompiler, **kw):
return "CONSTRAINT %s CHECK (%s)" % (
constraint.name,
ddlcompiler.sql_compiler.process(constraint.expression)
)
Changing the default compilation of existing constructs
=======================================================
The compiler extension applies just as well to the existing constructs. When overriding
the compilation of a built in SQL construct, the @compiles decorator is invoked upon
the appropriate class (be sure to use the class, i.e. ``Insert`` or ``Select``, instead of the creation function such as ``insert()`` or ``select()``).
Within the new compilation function, to get at the "original" compilation routine,
use the appropriate visit_XXX method - this because compiler.process() will call upon the
overriding routine and cause an endless loop. Such as, to add "prefix" to all insert statements::
from sqlalchemy.sql.expression import Insert
@compiles(Insert)
def prefix_inserts(insert, compiler, **kw):
return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
The above compiler will prefix all INSERT statements with "some prefix" when compiled.
Subclassing Guidelines
======================
A big part of using the compiler extension is subclassing SQLAlchemy expression constructs. To make this easier, the expression and schema packages feature a set of "bases" intended for common tasks. A synopsis is as follows:
* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root
expression class. Any SQL expression can be derived from this base, and is
probably the best choice for longer constructs such as specialized INSERT
statements.
* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all
"column-like" elements. Anything that you'd place in the "columns" clause of
a SELECT statement (as well as order by and group by) can derive from this -
the object will automatically have Python "comparison" behavior.
:class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a
``type`` member which is expression's return type. This can be established
at the instance level in the constructor, or at the class level if its
generally constant::
class timestamp(ColumnElement):
type = TIMESTAMP()
* :class:`~sqlalchemy.sql.expression.FunctionElement` - This is a hybrid of a
``ColumnElement`` and a "from clause" like object, and represents a SQL
function or stored procedure type of call. Since most databases support
statements along the line of "SELECT FROM <some function>"
``FunctionElement`` adds in the ability to be used in the FROM clause of a
``select()`` construct.
* :class:`~sqlalchemy.schema.DDLElement` - The root of all DDL expressions,
like CREATE TABLE, ALTER TABLE, etc. Compilation of ``DDLElement``
subclasses is issued by a ``DDLCompiler`` instead of a ``SQLCompiler``.
``DDLElement`` also features ``Table`` and ``MetaData`` event hooks via the
``execute_at()`` method, allowing the construct to be invoked during CREATE
TABLE and DROP TABLE sequences.
* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which should be
used with any expression class that represents a "standalone" SQL statement that
can be passed directly to an ``execute()`` method. It is already implicit
within ``DDLElement`` and ``FunctionElement``.
"""
def compiles(class_, *specs):
def decorate(fn):
existing = getattr(class_, '_compiler_dispatcher', None)
if not existing:
existing = _dispatcher()
# TODO: why is the lambda needed ?
setattr(class_, '_compiler_dispatch', lambda *arg, **kw: existing(*arg, **kw))
setattr(class_, '_compiler_dispatcher', existing)
if specs:
for s in specs:
existing.specs[s] = fn
else:
existing.specs['default'] = fn
return fn
return decorate
class _dispatcher(object):
def __init__(self):
self.specs = {}
def __call__(self, element, compiler, **kw):
# TODO: yes, this could also switch off of DBAPI in use.
fn = self.specs.get(compiler.dialect.name, None)
if not fn:
fn = self.specs['default']
return fn(element, compiler, **kw)

View File

@ -0,0 +1,940 @@
"""
Synopsis
========
SQLAlchemy object-relational configuration involves the use of
:class:`~sqlalchemy.schema.Table`, :func:`~sqlalchemy.orm.mapper`, and
class objects to define the three areas of configuration.
:mod:`~sqlalchemy.ext.declarative` allows all three types of
configuration to be expressed declaratively on an individual
mapped class. Regular SQLAlchemy schema elements and ORM constructs
are used in most cases.
As a simple example::
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class SomeClass(Base):
__tablename__ = 'some_table'
id = Column(Integer, primary_key=True)
name = Column(String(50))
Above, the :func:`declarative_base` callable returns a new base class from which
all mapped classes should inherit. When the class definition is completed, a
new :class:`~sqlalchemy.schema.Table` and
:class:`~sqlalchemy.orm.mapper` will have been generated, accessible
via the ``__table__`` and ``__mapper__`` attributes on the ``SomeClass`` class.
Defining Attributes
===================
In the above example, the :class:`~sqlalchemy.schema.Column` objects are
automatically named with the name of the attribute to which they are
assigned.
They can also be explicitly named, and that name does not have to be
the same as name assigned on the class.
The column will be assigned to the :class:`~sqlalchemy.schema.Table` using the
given name, and mapped to the class using the attribute name::
class SomeClass(Base):
__tablename__ = 'some_table'
id = Column("some_table_id", Integer, primary_key=True)
name = Column("name", String(50))
Attributes may be added to the class after its construction, and they will be
added to the underlying :class:`~sqlalchemy.schema.Table` and
:func:`~sqlalchemy.orm.mapper()` definitions as appropriate::
SomeClass.data = Column('data', Unicode)
SomeClass.related = relationship(RelatedInfo)
Classes which are mapped explicitly using
:func:`~sqlalchemy.orm.mapper()` can interact freely with declarative
classes.
It is recommended, though not required, that all tables
share the same underlying :class:`~sqlalchemy.schema.MetaData` object,
so that string-configured :class:`~sqlalchemy.schema.ForeignKey`
references can be resolved without issue.
Association of Metadata and Engine
==================================
The :func:`declarative_base` base class contains a
:class:`~sqlalchemy.schema.MetaData` object where newly
defined :class:`~sqlalchemy.schema.Table` objects are collected. This
is accessed via the :class:`~sqlalchemy.schema.MetaData` class level
accessor, so to create tables we can say::
engine = create_engine('sqlite://')
Base.metadata.create_all(engine)
The :class:`~sqlalchemy.engine.base.Engine` created above may also be
directly associated with the declarative base class using the ``bind``
keyword argument, where it will be associated with the underlying
:class:`~sqlalchemy.schema.MetaData` object and allow SQL operations
involving that metadata and its tables to make use of that engine
automatically::
Base = declarative_base(bind=create_engine('sqlite://'))
Alternatively, by way of the normal
:class:`~sqlalchemy.schema.MetaData` behaviour, the ``bind`` attribute
of the class level accessor can be assigned at any time as follows::
Base.metadata.bind = create_engine('sqlite://')
The :func:`declarative_base` can also receive a pre-created
:class:`~sqlalchemy.schema.MetaData` object, which allows a
declarative setup to be associated with an already
existing traditional collection of :class:`~sqlalchemy.schema.Table`
objects::
mymetadata = MetaData()
Base = declarative_base(metadata=mymetadata)
Configuring Relationships
=========================
Relationships to other classes are done in the usual way, with the added
feature that the class specified to :func:`~sqlalchemy.orm.relationship`
may be a string name (note that :func:`~sqlalchemy.orm.relationship` is
only available as of SQLAlchemy 0.6beta2, and in all prior versions is known
as :func:`~sqlalchemy.orm.relation`,
including 0.5 and 0.4). The "class registry" associated with ``Base``
is used at mapper compilation time to resolve the name into the actual
class object, which is expected to have been defined once the mapper
configuration is used::
class User(Base):
__tablename__ = 'users'
id = Column(Integer, primary_key=True)
name = Column(String(50))
addresses = relationship("Address", backref="user")
class Address(Base):
__tablename__ = 'addresses'
id = Column(Integer, primary_key=True)
email = Column(String(50))
user_id = Column(Integer, ForeignKey('users.id'))
Column constructs, since they are just that, are immediately usable,
as below where we define a primary join condition on the ``Address``
class using them::
class Address(Base):
__tablename__ = 'addresses'
id = Column(Integer, primary_key=True)
email = Column(String(50))
user_id = Column(Integer, ForeignKey('users.id'))
user = relationship(User, primaryjoin=user_id == User.id)
In addition to the main argument for :func:`~sqlalchemy.orm.relationship`,
other arguments which depend upon the columns present on an as-yet
undefined class may also be specified as strings. These strings are
evaluated as Python expressions. The full namespace available within
this evaluation includes all classes mapped for this declarative base,
as well as the contents of the ``sqlalchemy`` package, including
expression functions like :func:`~sqlalchemy.sql.expression.desc` and
:attr:`~sqlalchemy.sql.expression.func`::
class User(Base):
# ....
addresses = relationship("Address",
order_by="desc(Address.email)",
primaryjoin="Address.user_id==User.id")
As an alternative to string-based attributes, attributes may also be
defined after all classes have been created. Just add them to the target
class after the fact::
User.addresses = relationship(Address,
primaryjoin=Address.user_id==User.id)
Configuring Many-to-Many Relationships
======================================
There's nothing special about many-to-many with declarative. The
``secondary`` argument to :func:`~sqlalchemy.orm.relationship` still
requires a :class:`~sqlalchemy.schema.Table` object, not a declarative
class. The :class:`~sqlalchemy.schema.Table` should share the same
:class:`~sqlalchemy.schema.MetaData` object used by the declarative
base::
keywords = Table(
'keywords', Base.metadata,
Column('author_id', Integer, ForeignKey('authors.id')),
Column('keyword_id', Integer, ForeignKey('keywords.id'))
)
class Author(Base):
__tablename__ = 'authors'
id = Column(Integer, primary_key=True)
keywords = relationship("Keyword", secondary=keywords)
You should generally **not** map a class and also specify its table in
a many-to-many relationship, since the ORM may issue duplicate INSERT and
DELETE statements.
Defining Synonyms
=================
Synonyms are introduced in :ref:`synonyms`. To define a getter/setter
which proxies to an underlying attribute, use
:func:`~sqlalchemy.orm.synonym` with the ``descriptor`` argument::
class MyClass(Base):
__tablename__ = 'sometable'
_attr = Column('attr', String)
def _get_attr(self):
return self._some_attr
def _set_attr(self, attr):
self._some_attr = attr
attr = synonym('_attr', descriptor=property(_get_attr, _set_attr))
The above synonym is then usable as an instance attribute as well as a
class-level expression construct::
x = MyClass()
x.attr = "some value"
session.query(MyClass).filter(MyClass.attr == 'some other value').all()
For simple getters, the :func:`synonym_for` decorator can be used in
conjunction with ``@property``::
class MyClass(Base):
__tablename__ = 'sometable'
_attr = Column('attr', String)
@synonym_for('_attr')
@property
def attr(self):
return self._some_attr
Similarly, :func:`comparable_using` is a front end for the
:func:`~sqlalchemy.orm.comparable_property` ORM function::
class MyClass(Base):
__tablename__ = 'sometable'
name = Column('name', String)
@comparable_using(MyUpperCaseComparator)
@property
def uc_name(self):
return self.name.upper()
Table Configuration
===================
Table arguments other than the name, metadata, and mapped Column
arguments are specified using the ``__table_args__`` class attribute.
This attribute accommodates both positional as well as keyword
arguments that are normally sent to the
:class:`~sqlalchemy.schema.Table` constructor.
The attribute can be specified in one of two forms. One is as a
dictionary::
class MyClass(Base):
__tablename__ = 'sometable'
__table_args__ = {'mysql_engine':'InnoDB'}
The other, a tuple of the form
``(arg1, arg2, ..., {kwarg1:value, ...})``, which allows positional
arguments to be specified as well (usually constraints)::
class MyClass(Base):
__tablename__ = 'sometable'
__table_args__ = (
ForeignKeyConstraint(['id'], ['remote_table.id']),
UniqueConstraint('foo'),
{'autoload':True}
)
Note that the keyword parameters dictionary is required in the tuple
form even if empty.
As an alternative to ``__tablename__``, a direct
:class:`~sqlalchemy.schema.Table` construct may be used. The
:class:`~sqlalchemy.schema.Column` objects, which in this case require
their names, will be added to the mapping just like a regular mapping
to a table::
class MyClass(Base):
__table__ = Table('my_table', Base.metadata,
Column('id', Integer, primary_key=True),
Column('name', String(50))
)
Mapper Configuration
====================
Configuration of mappers is done with the
:func:`~sqlalchemy.orm.mapper` function and all the possible mapper
configuration parameters can be found in the documentation for that
function.
:func:`~sqlalchemy.orm.mapper` is still used by declaratively mapped
classes and keyword parameters to the function can be passed by
placing them in the ``__mapper_args__`` class variable::
class Widget(Base):
__tablename__ = 'widgets'
id = Column(Integer, primary_key=True)
__mapper_args__ = {'extension': MyWidgetExtension()}
Inheritance Configuration
=========================
Declarative supports all three forms of inheritance as intuitively
as possible. The ``inherits`` mapper keyword argument is not needed
as declarative will determine this from the class itself. The various
"polymorphic" keyword arguments are specified using ``__mapper_args__``.
Joined Table Inheritance
~~~~~~~~~~~~~~~~~~~~~~~~
Joined table inheritance is defined as a subclass that defines its own
table::
class Person(Base):
__tablename__ = 'people'
id = Column(Integer, primary_key=True)
discriminator = Column('type', String(50))
__mapper_args__ = {'polymorphic_on': discriminator}
class Engineer(Person):
__tablename__ = 'engineers'
__mapper_args__ = {'polymorphic_identity': 'engineer'}
id = Column(Integer, ForeignKey('people.id'), primary_key=True)
primary_language = Column(String(50))
Note that above, the ``Engineer.id`` attribute, since it shares the
same attribute name as the ``Person.id`` attribute, will in fact
represent the ``people.id`` and ``engineers.id`` columns together, and
will render inside a query as ``"people.id"``.
To provide the ``Engineer`` class with an attribute that represents
only the ``engineers.id`` column, give it a different attribute name::
class Engineer(Person):
__tablename__ = 'engineers'
__mapper_args__ = {'polymorphic_identity': 'engineer'}
engineer_id = Column('id', Integer, ForeignKey('people.id'), primary_key=True)
primary_language = Column(String(50))
Single Table Inheritance
~~~~~~~~~~~~~~~~~~~~~~~~
Single table inheritance is defined as a subclass that does not have
its own table; you just leave out the ``__table__`` and ``__tablename__``
attributes::
class Person(Base):
__tablename__ = 'people'
id = Column(Integer, primary_key=True)
discriminator = Column('type', String(50))
__mapper_args__ = {'polymorphic_on': discriminator}
class Engineer(Person):
__mapper_args__ = {'polymorphic_identity': 'engineer'}
primary_language = Column(String(50))
When the above mappers are configured, the ``Person`` class is mapped
to the ``people`` table *before* the ``primary_language`` column is
defined, and this column will not be included in its own mapping.
When ``Engineer`` then defines the ``primary_language`` column, the
column is added to the ``people`` table so that it is included in the
mapping for ``Engineer`` and is also part of the table's full set of
columns. Columns which are not mapped to ``Person`` are also excluded
from any other single or joined inheriting classes using the
``exclude_properties`` mapper argument. Below, ``Manager`` will have
all the attributes of ``Person`` and ``Manager`` but *not* the
``primary_language`` attribute of ``Engineer``::
class Manager(Person):
__mapper_args__ = {'polymorphic_identity': 'manager'}
golf_swing = Column(String(50))
The attribute exclusion logic is provided by the
``exclude_properties`` mapper argument, and declarative's default
behavior can be disabled by passing an explicit ``exclude_properties``
collection (empty or otherwise) to the ``__mapper_args__``.
Concrete Table Inheritance
~~~~~~~~~~~~~~~~~~~~~~~~~~
Concrete is defined as a subclass which has its own table and sets the
``concrete`` keyword argument to ``True``::
class Person(Base):
__tablename__ = 'people'
id = Column(Integer, primary_key=True)
name = Column(String(50))
class Engineer(Person):
__tablename__ = 'engineers'
__mapper_args__ = {'concrete':True}
id = Column(Integer, primary_key=True)
primary_language = Column(String(50))
name = Column(String(50))
Usage of an abstract base class is a little less straightforward as it
requires usage of :func:`~sqlalchemy.orm.util.polymorphic_union`::
engineers = Table('engineers', Base.metadata,
Column('id', Integer, primary_key=True),
Column('name', String(50)),
Column('primary_language', String(50))
)
managers = Table('managers', Base.metadata,
Column('id', Integer, primary_key=True),
Column('name', String(50)),
Column('golf_swing', String(50))
)
punion = polymorphic_union({
'engineer':engineers,
'manager':managers
}, 'type', 'punion')
class Person(Base):
__table__ = punion
__mapper_args__ = {'polymorphic_on':punion.c.type}
class Engineer(Person):
__table__ = engineers
__mapper_args__ = {'polymorphic_identity':'engineer', 'concrete':True}
class Manager(Person):
__table__ = managers
__mapper_args__ = {'polymorphic_identity':'manager', 'concrete':True}
Mix-in Classes
==============
A common need when using :mod:`~sqlalchemy.ext.declarative` is to
share some functionality, often a set of columns, across many
classes. The normal python idiom would be to put this common code into
a base class and have all the other classes subclass this class.
When using :mod:`~sqlalchemy.ext.declarative`, this need is met by
using a "mix-in class". A mix-in class is one that isn't mapped to a
table and doesn't subclass the declarative :class:`Base`. For example::
class MyMixin(object):
__table_args__ = {'mysql_engine':'InnoDB'}
__mapper_args__=dict(always_refresh=True)
id = Column(Integer, primary_key=True)
def foo(self):
return 'bar'+str(self.id)
class MyModel(Base,MyMixin):
__tablename__='test'
name = Column(String(1000), nullable=False, index=True)
As the above example shows, ``__table_args__`` and ``__mapper_args__``
can both be abstracted out into a mix-in if you use common values for
these across many classes.
However, particularly in the case of ``__table_args__``, you may want
to combine some parameters from several mix-ins with those you wish to
define on the class iteself. To help with this, a
:func:`~sqlalchemy.util.classproperty` decorator is provided that lets
you implement a class property with a function. For example::
from sqlalchemy.util import classproperty
class MySQLSettings:
__table_args__ = {'mysql_engine':'InnoDB'}
class MyOtherMixin:
__table_args__ = {'info':'foo'}
class MyModel(Base,MySQLSettings,MyOtherMixin):
__tablename__='my_model'
@classproperty
def __table_args__(self):
args = dict()
args.update(MySQLSettings.__table_args__)
args.update(MyOtherMixin.__table_args__)
return args
id = Column(Integer, primary_key=True)
Class Constructor
=================
As a convenience feature, the :func:`declarative_base` sets a default
constructor on classes which takes keyword arguments, and assigns them
to the named attributes::
e = Engineer(primary_language='python')
Sessions
========
Note that ``declarative`` does nothing special with sessions, and is
only intended as an easier way to configure mappers and
:class:`~sqlalchemy.schema.Table` objects. A typical application
setup using :func:`~sqlalchemy.orm.scoped_session` might look like::
engine = create_engine('postgresql://scott:tiger@localhost/test')
Session = scoped_session(sessionmaker(autocommit=False,
autoflush=False,
bind=engine))
Base = declarative_base()
Mapped instances then make usage of
:class:`~sqlalchemy.orm.session.Session` in the usual way.
"""
from sqlalchemy.schema import Table, Column, MetaData
from sqlalchemy.orm import synonym as _orm_synonym, mapper, comparable_property, class_mapper
from sqlalchemy.orm.interfaces import MapperProperty
from sqlalchemy.orm.properties import RelationshipProperty, ColumnProperty
from sqlalchemy.orm.util import _is_mapped_class
from sqlalchemy import util, exceptions
from sqlalchemy.sql import util as sql_util
__all__ = 'declarative_base', 'synonym_for', 'comparable_using', 'instrument_declarative'
def instrument_declarative(cls, registry, metadata):
"""Given a class, configure the class declaratively,
using the given registry, which can be any dictionary, and
MetaData object.
"""
if '_decl_class_registry' in cls.__dict__:
raise exceptions.InvalidRequestError(
"Class %r already has been "
"instrumented declaratively" % cls)
cls._decl_class_registry = registry
cls.metadata = metadata
_as_declarative(cls, cls.__name__, cls.__dict__)
def _as_declarative(cls, classname, dict_):
# dict_ will be a dictproxy, which we can't write to, and we need to!
dict_ = dict(dict_)
column_copies = dict()
unmapped_mixins = False
for base in cls.__bases__:
names = dir(base)
if not _is_mapped_class(base):
unmapped_mixins = True
for name in names:
obj = getattr(base,name, None)
if isinstance(obj, Column):
if obj.foreign_keys:
raise exceptions.InvalidRequestError(
"Columns with foreign keys to other columns "
"are not allowed on declarative mixins at this time."
)
dict_[name]=column_copies[obj]=obj.copy()
elif isinstance(obj, RelationshipProperty):
raise exceptions.InvalidRequestError(
"relationships are not allowed on "
"declarative mixins at this time.")
# doing it this way enables these attributes to be descriptors
get_mapper_args = '__mapper_args__' in dict_
get_table_args = '__table_args__' in dict_
if unmapped_mixins:
get_mapper_args = get_mapper_args or getattr(cls,'__mapper_args__',None)
get_table_args = get_table_args or getattr(cls,'__table_args__',None)
tablename = getattr(cls,'__tablename__',None)
if tablename:
# subtle: if tablename is a descriptor here, we actually
# put the wrong value in, but it serves as a marker to get
# the right value value...
dict_['__tablename__']=tablename
# now that we know whether or not to get these, get them from the class
# if we should, enabling them to be decorators
mapper_args = get_mapper_args and cls.__mapper_args__ or {}
table_args = get_table_args and cls.__table_args__ or None
# make sure that column copies are used rather than the original columns
# from any mixins
for k, v in mapper_args.iteritems():
mapper_args[k] = column_copies.get(v,v)
cls._decl_class_registry[classname] = cls
our_stuff = util.OrderedDict()
for k in dict_:
value = dict_[k]
if (isinstance(value, tuple) and len(value) == 1 and
isinstance(value[0], (Column, MapperProperty))):
util.warn("Ignoring declarative-like tuple value of attribute "
"%s: possibly a copy-and-paste error with a comma "
"left at the end of the line?" % k)
continue
if not isinstance(value, (Column, MapperProperty)):
continue
prop = _deferred_relationship(cls, value)
our_stuff[k] = prop
# set up attributes in the order they were created
our_stuff.sort(key=lambda key: our_stuff[key]._creation_order)
# extract columns from the class dict
cols = []
for key, c in our_stuff.iteritems():
if isinstance(c, ColumnProperty):
for col in c.columns:
if isinstance(col, Column) and col.table is None:
_undefer_column_name(key, col)
cols.append(col)
elif isinstance(c, Column):
_undefer_column_name(key, c)
cols.append(c)
# if the column is the same name as the key,
# remove it from the explicit properties dict.
# the normal rules for assigning column-based properties
# will take over, including precedence of columns
# in multi-column ColumnProperties.
if key == c.key:
del our_stuff[key]
table = None
if '__table__' not in dict_:
if '__tablename__' in dict_:
# see above: if __tablename__ is a descriptor, this
# means we get the right value used!
tablename = cls.__tablename__
if isinstance(table_args, dict):
args, table_kw = (), table_args
elif isinstance(table_args, tuple):
args = table_args[0:-1]
table_kw = table_args[-1]
if len(table_args) < 2 or not isinstance(table_kw, dict):
raise exceptions.ArgumentError(
"Tuple form of __table_args__ is "
"(arg1, arg2, arg3, ..., {'kw1':val1, 'kw2':val2, ...})"
)
else:
args, table_kw = (), {}
autoload = dict_.get('__autoload__')
if autoload:
table_kw['autoload'] = True
cls.__table__ = table = Table(tablename, cls.metadata,
*(tuple(cols) + tuple(args)), **table_kw)
else:
table = cls.__table__
if cols:
for c in cols:
if not table.c.contains_column(c):
raise exceptions.ArgumentError(
"Can't add additional column %r when specifying __table__" % key
)
if 'inherits' not in mapper_args:
for c in cls.__bases__:
if _is_mapped_class(c):
mapper_args['inherits'] = cls._decl_class_registry.get(c.__name__, None)
break
if hasattr(cls, '__mapper_cls__'):
mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
else:
mapper_cls = mapper
if table is None and 'inherits' not in mapper_args:
raise exceptions.InvalidRequestError(
"Class %r does not have a __table__ or __tablename__ "
"specified and does not inherit from an existing table-mapped class." % cls
)
elif 'inherits' in mapper_args and not mapper_args.get('concrete', False):
inherited_mapper = class_mapper(mapper_args['inherits'], compile=False)
inherited_table = inherited_mapper.local_table
if 'inherit_condition' not in mapper_args and table is not None:
# figure out the inherit condition with relaxed rules
# about nonexistent tables, to allow for ForeignKeys to
# not-yet-defined tables (since we know for sure that our
# parent table is defined within the same MetaData)
mapper_args['inherit_condition'] = sql_util.join_condition(
mapper_args['inherits'].__table__, table,
ignore_nonexistent_tables=True)
if table is None:
# single table inheritance.
# ensure no table args
if table_args is not None:
raise exceptions.ArgumentError(
"Can't place __table_args__ on an inherited class with no table."
)
# add any columns declared here to the inherited table.
for c in cols:
if c.primary_key:
raise exceptions.ArgumentError(
"Can't place primary key columns on an inherited class with no table."
)
if c.name in inherited_table.c:
raise exceptions.ArgumentError(
"Column '%s' on class %s conflicts with existing column '%s'" %
(c, cls, inherited_table.c[c.name])
)
inherited_table.append_column(c)
# single or joined inheritance
# exclude any cols on the inherited table which are not mapped on the
# parent class, to avoid
# mapping columns specific to sibling/nephew classes
inherited_mapper = class_mapper(mapper_args['inherits'], compile=False)
inherited_table = inherited_mapper.local_table
if 'exclude_properties' not in mapper_args:
mapper_args['exclude_properties'] = exclude_properties = \
set([c.key for c in inherited_table.c
if c not in inherited_mapper._columntoproperty])
exclude_properties.difference_update([c.key for c in cols])
cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args)
class DeclarativeMeta(type):
def __init__(cls, classname, bases, dict_):
if '_decl_class_registry' in cls.__dict__:
return type.__init__(cls, classname, bases, dict_)
_as_declarative(cls, classname, cls.__dict__)
return type.__init__(cls, classname, bases, dict_)
def __setattr__(cls, key, value):
if '__mapper__' in cls.__dict__:
if isinstance(value, Column):
_undefer_column_name(key, value)
cls.__table__.append_column(value)
cls.__mapper__.add_property(key, value)
elif isinstance(value, ColumnProperty):
for col in value.columns:
if isinstance(col, Column) and col.table is None:
_undefer_column_name(key, col)
cls.__table__.append_column(col)
cls.__mapper__.add_property(key, value)
elif isinstance(value, MapperProperty):
cls.__mapper__.add_property(key, _deferred_relationship(cls, value))
else:
type.__setattr__(cls, key, value)
else:
type.__setattr__(cls, key, value)
class _GetColumns(object):
def __init__(self, cls):
self.cls = cls
def __getattr__(self, key):
mapper = class_mapper(self.cls, compile=False)
if mapper:
prop = mapper.get_property(key)
if not isinstance(prop, ColumnProperty):
raise exceptions.InvalidRequestError(
"Property %r is not an instance of"
" ColumnProperty (i.e. does not correspond"
" directly to a Column)." % key)
return getattr(self.cls, key)
def _deferred_relationship(cls, prop):
def resolve_arg(arg):
import sqlalchemy
def access_cls(key):
if key in cls._decl_class_registry:
return _GetColumns(cls._decl_class_registry[key])
elif key in cls.metadata.tables:
return cls.metadata.tables[key]
else:
return sqlalchemy.__dict__[key]
d = util.PopulateDict(access_cls)
def return_cls():
try:
x = eval(arg, globals(), d)
if isinstance(x, _GetColumns):
return x.cls
else:
return x
except NameError, n:
raise exceptions.InvalidRequestError(
"When compiling mapper %s, expression %r failed to locate a name (%r). "
"If this is a class name, consider adding this relationship() to the %r "
"class after both dependent classes have been defined." % (
prop.parent, arg, n.args[0], cls))
return return_cls
if isinstance(prop, RelationshipProperty):
for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin',
'secondary', '_foreign_keys', 'remote_side'):
v = getattr(prop, attr)
if isinstance(v, basestring):
setattr(prop, attr, resolve_arg(v))
if prop.backref and isinstance(prop.backref, tuple):
key, kwargs = prop.backref
for attr in ('primaryjoin', 'secondaryjoin', 'secondary',
'foreign_keys', 'remote_side', 'order_by'):
if attr in kwargs and isinstance(kwargs[attr], basestring):
kwargs[attr] = resolve_arg(kwargs[attr])
return prop
def synonym_for(name, map_column=False):
"""Decorator, make a Python @property a query synonym for a column.
A decorator version of :func:`~sqlalchemy.orm.synonym`. The function being
decorated is the 'descriptor', otherwise passes its arguments through
to synonym()::
@synonym_for('col')
@property
def prop(self):
return 'special sauce'
The regular ``synonym()`` is also usable directly in a declarative setting
and may be convenient for read/write properties::
prop = synonym('col', descriptor=property(_read_prop, _write_prop))
"""
def decorate(fn):
return _orm_synonym(name, map_column=map_column, descriptor=fn)
return decorate
def comparable_using(comparator_factory):
"""Decorator, allow a Python @property to be used in query criteria.
This is a decorator front end to
:func:`~sqlalchemy.orm.comparable_property` that passes
through the comparator_factory and the function being decorated::
@comparable_using(MyComparatorType)
@property
def prop(self):
return 'special sauce'
The regular ``comparable_property()`` is also usable directly in a
declarative setting and may be convenient for read/write properties::
prop = comparable_property(MyComparatorType)
"""
def decorate(fn):
return comparable_property(comparator_factory, fn)
return decorate
def _declarative_constructor(self, **kwargs):
"""A simple constructor that allows initialization from kwargs.
Sets attributes on the constructed instance using the names and
values in ``kwargs``.
Only keys that are present as
attributes of the instance's class are allowed. These could be,
for example, any mapped columns or relationships.
"""
for k in kwargs:
if not hasattr(type(self), k):
raise TypeError(
"%r is an invalid keyword argument for %s" %
(k, type(self).__name__))
setattr(self, k, kwargs[k])
_declarative_constructor.__name__ = '__init__'
def declarative_base(bind=None, metadata=None, mapper=None, cls=object,
name='Base', constructor=_declarative_constructor,
metaclass=DeclarativeMeta):
"""Construct a base class for declarative class definitions.
The new base class will be given a metaclass that produces
appropriate :class:`~sqlalchemy.schema.Table` objects and makes
the appropriate :func:`~sqlalchemy.orm.mapper` calls based on the
information provided declaratively in the class and any subclasses
of the class.
:param bind: An optional
:class:`~sqlalchemy.engine.base.Connectable`, will be assigned
the ``bind`` attribute on the :class:`~sqlalchemy.MetaData`
instance.
:param metadata:
An optional :class:`~sqlalchemy.MetaData` instance. All
:class:`~sqlalchemy.schema.Table` objects implicitly declared by
subclasses of the base will share this MetaData. A MetaData instance
will be created if none is provided. The
:class:`~sqlalchemy.MetaData` instance will be available via the
`metadata` attribute of the generated declarative base class.
:param mapper:
An optional callable, defaults to :func:`~sqlalchemy.orm.mapper`. Will be
used to map subclasses to their Tables.
:param cls:
Defaults to :class:`object`. A type to use as the base for the generated
declarative base class. May be a class or tuple of classes.
:param name:
Defaults to ``Base``. The display name for the generated
class. Customizing this is not required, but can improve clarity in
tracebacks and debugging.
:param constructor:
Defaults to
:func:`~sqlalchemy.ext.declarative._declarative_constructor`, an
__init__ implementation that assigns \**kwargs for declared
fields and relationships to an instance. If ``None`` is supplied,
no __init__ will be provided and construction will fall back to
cls.__init__ by way of the normal Python semantics.
:param metaclass:
Defaults to :class:`DeclarativeMeta`. A metaclass or __metaclass__
compatible callable to use as the meta type of the generated
declarative base class.
"""
lcl_metadata = metadata or MetaData()
if bind:
lcl_metadata.bind = bind
bases = not isinstance(cls, tuple) and (cls,) or cls
class_dict = dict(_decl_class_registry=dict(),
metadata=lcl_metadata)
if constructor:
class_dict['__init__'] = constructor
if mapper:
class_dict['__mapper_cls__'] = mapper
return metaclass(name, bases, class_dict)
def _undefer_column_name(key, column):
if column.key is None:
column.key = key
if column.name is None:
column.name = key

View File

@ -0,0 +1,125 @@
# horizontal_shard.py
# Copyright (C) the SQLAlchemy authors and contributors
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Horizontal sharding support.
Defines a rudimental 'horizontal sharding' system which allows a Session to
distribute queries and persistence operations across multiple databases.
For a usage example, see the :ref:`examples_sharding` example included in
the source distrbution.
"""
import sqlalchemy.exceptions as sa_exc
from sqlalchemy import util
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.query import Query
__all__ = ['ShardedSession', 'ShardedQuery']
class ShardedSession(Session):
def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped instance, and possibly a
SQL clause, returns a shard ID. This id may be based off of the
attributes present within the object, or on some round-robin
scheme. If the scheme is based on a selection, it should set
whatever state on the instance to mark it in the future as
participating in that shard.
:param id_chooser: A callable, passed a query and a tuple of identity values, which
should return a list of shard ids where the ID might reside. The
databases will be queried in the order of this listing.
:param query_chooser: For a given Query, returns the list of shard_ids where the query
should be issued. Results from all shards returned will be combined
together into a single listing.
:param shards: A dictionary of string shard names to :class:`~sqlalchemy.engine.base.Engine`
objects.
"""
super(ShardedSession, self).__init__(**kwargs)
self.shard_chooser = shard_chooser
self.id_chooser = id_chooser
self.query_chooser = query_chooser
self.__binds = {}
self._mapper_flush_opts = {'connection_callable':self.connection}
self._query_cls = ShardedQuery
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance)
if self.transaction is not None:
return self.transaction.connection(mapper, shard_id=shard_id)
else:
return self.get_bind(mapper,
shard_id=shard_id,
instance=instance).contextual_connect(**kwargs)
def get_bind(self, mapper, shard_id=None, instance=None, clause=None, **kw):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance, clause=clause)
return self.__binds[shard_id]
def bind_shard(self, shard_id, bind):
self.__binds[shard_id] = bind
class ShardedQuery(Query):
def __init__(self, *args, **kwargs):
super(ShardedQuery, self).__init__(*args, **kwargs)
self.id_chooser = self.session.id_chooser
self.query_chooser = self.session.query_chooser
self._shard_id = None
def set_shard(self, shard_id):
"""return a new query, limited to a single shard ID.
all subsequent operations with the returned query will
be against the single shard regardless of other state.
"""
q = self._clone()
q._shard_id = shard_id
return q
def _execute_and_instances(self, context):
if self._shard_id is not None:
result = self.session.connection(
mapper=self._mapper_zero(),
shard_id=self._shard_id).execute(context.statement, self._params)
return self.instances(result, context)
else:
partial = []
for shard_id in self.query_chooser(self):
result = self.session.connection(
mapper=self._mapper_zero(),
shard_id=shard_id).execute(context.statement, self._params)
partial = partial + list(self.instances(result, context))
# if some kind of in memory 'sorting'
# were done, this is where it would happen
return iter(partial)
def get(self, ident, **kwargs):
if self._shard_id is not None:
return super(ShardedQuery, self).get(ident)
else:
ident = util.to_list(ident)
for shard_id in self.id_chooser(self, ident):
o = self.set_shard(shard_id).get(ident, **kwargs)
if o is not None:
return o
else:
return None

View File

@ -0,0 +1,315 @@
"""A custom list that manages index/position information for its children.
:author: Jason Kirtland
``orderinglist`` is a helper for mutable ordered relationships. It will intercept
list operations performed on a relationship collection and automatically
synchronize changes in list position with an attribute on the related objects.
(See :ref:`advdatamapping_entitycollections` for more information on the general pattern.)
Example: Two tables that store slides in a presentation. Each slide
has a number of bullet points, displayed in order by the 'position'
column on the bullets table. These bullets can be inserted and re-ordered
by your end users, and you need to update the 'position' column of all
affected rows when changes are made.
.. sourcecode:: python+sql
slides_table = Table('Slides', metadata,
Column('id', Integer, primary_key=True),
Column('name', String))
bullets_table = Table('Bullets', metadata,
Column('id', Integer, primary_key=True),
Column('slide_id', Integer, ForeignKey('Slides.id')),
Column('position', Integer),
Column('text', String))
class Slide(object):
pass
class Bullet(object):
pass
mapper(Slide, slides_table, properties={
'bullets': relationship(Bullet, order_by=[bullets_table.c.position])
})
mapper(Bullet, bullets_table)
The standard relationship mapping will produce a list-like attribute on each Slide
containing all related Bullets, but coping with changes in ordering is totally
your responsibility. If you insert a Bullet into that list, there is no
magic- it won't have a position attribute unless you assign it it one, and
you'll need to manually renumber all the subsequent Bullets in the list to
accommodate the insert.
An ``orderinglist`` can automate this and manage the 'position' attribute on all
related bullets for you.
.. sourcecode:: python+sql
mapper(Slide, slides_table, properties={
'bullets': relationship(Bullet,
collection_class=ordering_list('position'),
order_by=[bullets_table.c.position])
})
mapper(Bullet, bullets_table)
s = Slide()
s.bullets.append(Bullet())
s.bullets.append(Bullet())
s.bullets[1].position
>>> 1
s.bullets.insert(1, Bullet())
s.bullets[2].position
>>> 2
Use the ``ordering_list`` function to set up the ``collection_class`` on relationships
(as in the mapper example above). This implementation depends on the list
starting in the proper order, so be SURE to put an order_by on your relationship.
.. warning:: ``ordering_list`` only provides limited functionality when a primary
key column or unique column is the target of the sort. Since changing the order of
entries often means that two rows must trade values, this is not possible when
the value is constrained by a primary key or unique constraint, since one of the rows
would temporarily have to point to a third available value so that the other row
could take its old value. ``ordering_list`` doesn't do any of this for you,
nor does SQLAlchemy itself.
``ordering_list`` takes the name of the related object's ordering attribute as
an argument. By default, the zero-based integer index of the object's
position in the ``ordering_list`` is synchronized with the ordering attribute:
index 0 will get position 0, index 1 position 1, etc. To start numbering at 1
or some other integer, provide ``count_from=1``.
Ordering values are not limited to incrementing integers. Almost any scheme
can implemented by supplying a custom ``ordering_func`` that maps a Python list
index to any value you require.
"""
from sqlalchemy.orm.collections import collection
from sqlalchemy import util
__all__ = [ 'ordering_list' ]
def ordering_list(attr, count_from=None, **kw):
"""Prepares an OrderingList factory for use in mapper definitions.
Returns an object suitable for use as an argument to a Mapper relationship's
``collection_class`` option. Arguments are:
attr
Name of the mapped attribute to use for storage and retrieval of
ordering information
count_from (optional)
Set up an integer-based ordering, starting at ``count_from``. For
example, ``ordering_list('pos', count_from=1)`` would create a 1-based
list in SQL, storing the value in the 'pos' column. Ignored if
``ordering_func`` is supplied.
Passes along any keyword arguments to ``OrderingList`` constructor.
"""
kw = _unsugar_count_from(count_from=count_from, **kw)
return lambda: OrderingList(attr, **kw)
# Ordering utility functions
def count_from_0(index, collection):
"""Numbering function: consecutive integers starting at 0."""
return index
def count_from_1(index, collection):
"""Numbering function: consecutive integers starting at 1."""
return index + 1
def count_from_n_factory(start):
"""Numbering function: consecutive integers starting at arbitrary start."""
def f(index, collection):
return index + start
try:
f.__name__ = 'count_from_%i' % start
except TypeError:
pass
return f
def _unsugar_count_from(**kw):
"""Builds counting functions from keywrod arguments.
Keyword argument filter, prepares a simple ``ordering_func`` from a
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
"""
count_from = kw.pop('count_from', None)
if kw.get('ordering_func', None) is None and count_from is not None:
if count_from == 0:
kw['ordering_func'] = count_from_0
elif count_from == 1:
kw['ordering_func'] = count_from_1
else:
kw['ordering_func'] = count_from_n_factory(count_from)
return kw
class OrderingList(list):
"""A custom list that manages position information for its children.
See the module and __init__ documentation for more details. The
``ordering_list`` factory function is used to configure ``OrderingList``
collections in ``mapper`` relationship definitions.
"""
def __init__(self, ordering_attr=None, ordering_func=None,
reorder_on_append=False):
"""A custom list that manages position information for its children.
``OrderingList`` is a ``collection_class`` list implementation that
syncs position in a Python list with a position attribute on the
mapped objects.
This implementation relies on the list starting in the proper order,
so be **sure** to put an ``order_by`` on your relationship.
ordering_attr
Name of the attribute that stores the object's order in the
relationship.
ordering_func
Optional. A function that maps the position in the Python list to a
value to store in the ``ordering_attr``. Values returned are
usually (but need not be!) integers.
An ``ordering_func`` is called with two positional parameters: the
index of the element in the list, and the list itself.
If omitted, Python list indexes are used for the attribute values.
Two basic pre-built numbering functions are provided in this module:
``count_from_0`` and ``count_from_1``. For more exotic examples
like stepped numbering, alphabetical and Fibonacci numbering, see
the unit tests.
reorder_on_append
Default False. When appending an object with an existing (non-None)
ordering value, that value will be left untouched unless
``reorder_on_append`` is true. This is an optimization to avoid a
variety of dangerous unexpected database writes.
SQLAlchemy will add instances to the list via append() when your
object loads. If for some reason the result set from the database
skips a step in the ordering (say, row '1' is missing but you get
'2', '3', and '4'), reorder_on_append=True would immediately
renumber the items to '1', '2', '3'. If you have multiple sessions
making changes, any of whom happen to load this collection even in
passing, all of the sessions would try to "clean up" the numbering
in their commits, possibly causing all but one to fail with a
concurrent modification error. Spooky action at a distance.
Recommend leaving this with the default of False, and just call
``reorder()`` if you're doing ``append()`` operations with
previously ordered instances or when doing some housekeeping after
manual sql operations.
"""
self.ordering_attr = ordering_attr
if ordering_func is None:
ordering_func = count_from_0
self.ordering_func = ordering_func
self.reorder_on_append = reorder_on_append
# More complex serialization schemes (multi column, e.g.) are possible by
# subclassing and reimplementing these two methods.
def _get_order_value(self, entity):
return getattr(entity, self.ordering_attr)
def _set_order_value(self, entity, value):
setattr(entity, self.ordering_attr, value)
def reorder(self):
"""Synchronize ordering for the entire collection.
Sweeps through the list and ensures that each object has accurate
ordering information set.
"""
for index, entity in enumerate(self):
self._order_entity(index, entity, True)
# As of 0.5, _reorder is no longer semi-private
_reorder = reorder
def _order_entity(self, index, entity, reorder=True):
have = self._get_order_value(entity)
# Don't disturb existing ordering if reorder is False
if have is not None and not reorder:
return
should_be = self.ordering_func(index, self)
if have != should_be:
self._set_order_value(entity, should_be)
def append(self, entity):
super(OrderingList, self).append(entity)
self._order_entity(len(self) - 1, entity, self.reorder_on_append)
def _raw_append(self, entity):
"""Append without any ordering behavior."""
super(OrderingList, self).append(entity)
_raw_append = collection.adds(1)(_raw_append)
def insert(self, index, entity):
super(OrderingList, self).insert(index, entity)
self._reorder()
def remove(self, entity):
super(OrderingList, self).remove(entity)
self._reorder()
def pop(self, index=-1):
entity = super(OrderingList, self).pop(index)
self._reorder()
return entity
def __setitem__(self, index, entity):
if isinstance(index, slice):
step = index.step or 1
start = index.start or 0
if start < 0:
start += len(self)
stop = index.stop or len(self)
if stop < 0:
stop += len(self)
for i in xrange(start, stop, step):
self.__setitem__(i, entity[i])
else:
self._order_entity(index, entity, True)
super(OrderingList, self).__setitem__(index, entity)
def __delitem__(self, index):
super(OrderingList, self).__delitem__(index)
self._reorder()
# Py2K
def __setslice__(self, start, end, values):
super(OrderingList, self).__setslice__(start, end, values)
self._reorder()
def __delslice__(self, start, end):
super(OrderingList, self).__delslice__(start, end)
self._reorder()
# end Py2K
for func_name, func in locals().items():
if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(list, func_name)):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func

View File

@ -0,0 +1,155 @@
"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
allowing "contextual" deserialization.
Any SQLAlchemy query structure, either based on sqlalchemy.sql.*
or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session
etc. which are referenced by the structure are not persisted in serialized
form, but are instead re-associated with the query structure
when it is deserialized.
Usage is nearly the same as that of the standard Python pickle module::
from sqlalchemy.ext.serializer import loads, dumps
metadata = MetaData(bind=some_engine)
Session = scoped_session(sessionmaker())
# ... define mappers
query = Session.query(MyClass).filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
# pickle the query
serialized = dumps(query)
# unpickle. Pass in metadata + scoped_session
query2 = loads(serialized, metadata, Session)
print query2.all()
Similar restrictions as when using raw pickle apply; mapped classes must be
themselves be pickleable, meaning they are importable from a module-level
namespace.
The serializer module is only appropriate for query structures. It is not
needed for:
* instances of user-defined classes. These contain no references to engines,
sessions or expression constructs in the typical case and can be serialized directly.
* Table metadata that is to be loaded entirely from the serialized structure (i.e. is
not already declared in the application). Regular pickle.loads()/dumps() can
be used to fully dump any ``MetaData`` object, typically one which was reflected
from an existing database at some previous point in time. The serializer module
is specifically for the opposite case, where the Table metadata is already present
in memory.
"""
from sqlalchemy.orm import class_mapper, Query
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.attributes import QueryableAttribute
from sqlalchemy import Table, Column
from sqlalchemy.engine import Engine
from sqlalchemy.util import pickle
import re
import base64
# Py3K
#from io import BytesIO as byte_buffer
# Py2K
from cStringIO import StringIO as byte_buffer
# end Py2K
# Py3K
#def b64encode(x):
# return base64.b64encode(x).decode('ascii')
#def b64decode(x):
# return base64.b64decode(x.encode('ascii'))
# Py2K
b64encode = base64.b64encode
b64decode = base64.b64decode
# end Py2K
__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads']
def Serializer(*args, **kw):
pickler = pickle.Pickler(*args, **kw)
def persistent_id(obj):
#print "serializing:", repr(obj)
if isinstance(obj, QueryableAttribute):
cls = obj.impl.class_
key = obj.impl.key
id = "attribute:" + key + ":" + b64encode(pickle.dumps(cls))
elif isinstance(obj, Mapper) and not obj.non_primary:
id = "mapper:" + b64encode(pickle.dumps(obj.class_))
elif isinstance(obj, Table):
id = "table:" + str(obj)
elif isinstance(obj, Column) and isinstance(obj.table, Table):
id = "column:" + str(obj.table) + ":" + obj.key
elif isinstance(obj, Session):
id = "session:"
elif isinstance(obj, Engine):
id = "engine:"
else:
return None
return id
pickler.persistent_id = persistent_id
return pickler
our_ids = re.compile(r'(mapper|table|column|session|attribute|engine):(.*)')
def Deserializer(file, metadata=None, scoped_session=None, engine=None):
unpickler = pickle.Unpickler(file)
def get_engine():
if engine:
return engine
elif scoped_session and scoped_session().bind:
return scoped_session().bind
elif metadata and metadata.bind:
return metadata.bind
else:
return None
def persistent_load(id):
m = our_ids.match(id)
if not m:
return None
else:
type_, args = m.group(1, 2)
if type_ == 'attribute':
key, clsarg = args.split(":")
cls = pickle.loads(b64decode(clsarg))
return getattr(cls, key)
elif type_ == "mapper":
cls = pickle.loads(b64decode(args))
return class_mapper(cls)
elif type_ == "table":
return metadata.tables[args]
elif type_ == "column":
table, colname = args.split(':')
return metadata.tables[table].c[colname]
elif type_ == "session":
return scoped_session()
elif type_ == "engine":
return get_engine()
else:
raise Exception("Unknown token: %s" % type_)
unpickler.persistent_load = persistent_load
return unpickler
def dumps(obj, protocol=0):
buf = byte_buffer()
pickler = Serializer(buf, protocol)
pickler.dump(obj)
return buf.getvalue()
def loads(data, metadata=None, scoped_session=None, engine=None):
buf = byte_buffer(data)
unpickler = Deserializer(buf, metadata, scoped_session, engine)
return unpickler.load()

551
sqlalchemy/ext/sqlsoup.py Normal file
View File

@ -0,0 +1,551 @@
"""
Introduction
============
SqlSoup provides a convenient way to access existing database tables without
having to declare table or mapper classes ahead of time. It is built on top of the SQLAlchemy ORM and provides a super-minimalistic interface to an existing database.
Suppose we have a database with users, books, and loans tables
(corresponding to the PyWebOff dataset, if you're curious).
Creating a SqlSoup gateway is just like creating an SQLAlchemy
engine::
>>> from sqlalchemy.ext.sqlsoup import SqlSoup
>>> db = SqlSoup('sqlite:///:memory:')
or, you can re-use an existing engine::
>>> db = SqlSoup(engine)
You can optionally specify a schema within the database for your
SqlSoup::
>>> db.schema = myschemaname
Loading objects
===============
Loading objects is as easy as this::
>>> users = db.users.all()
>>> users.sort()
>>> users
[MappedUsers(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0), MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)]
Of course, letting the database do the sort is better::
>>> db.users.order_by(db.users.name).all()
[MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1), MappedUsers(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0)]
Field access is intuitive::
>>> users[0].email
u'student@example.edu'
Of course, you don't want to load all users very often. Let's add a
WHERE clause. Let's also switch the order_by to DESC while we're at
it::
>>> from sqlalchemy import or_, and_, desc
>>> where = or_(db.users.name=='Bhargan Basepair', db.users.email=='student@example.edu')
>>> db.users.filter(where).order_by(desc(db.users.name)).all()
[MappedUsers(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0), MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)]
You can also use .first() (to retrieve only the first object from a query) or
.one() (like .first when you expect exactly one user -- it will raise an
exception if more were returned)::
>>> db.users.filter(db.users.name=='Bhargan Basepair').one()
MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)
Since name is the primary key, this is equivalent to
>>> db.users.get('Bhargan Basepair')
MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)
This is also equivalent to
>>> db.users.filter_by(name='Bhargan Basepair').one()
MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)
filter_by is like filter, but takes kwargs instead of full clause expressions.
This makes it more concise for simple queries like this, but you can't do
complex queries like the or\_ above or non-equality based comparisons this way.
Full query documentation
------------------------
Get, filter, filter_by, order_by, limit, and the rest of the
query methods are explained in detail in :ref:`ormtutorial_querying`.
Modifying objects
=================
Modifying objects is intuitive::
>>> user = _
>>> user.email = 'basepair+nospam@example.edu'
>>> db.commit()
(SqlSoup leverages the sophisticated SQLAlchemy unit-of-work code, so
multiple updates to a single object will be turned into a single
``UPDATE`` statement when you commit.)
To finish covering the basics, let's insert a new loan, then delete
it::
>>> book_id = db.books.filter_by(title='Regional Variation in Moss').first().id
>>> db.loans.insert(book_id=book_id, user_name=user.name)
MappedLoans(book_id=2,user_name=u'Bhargan Basepair',loan_date=None)
>>> loan = db.loans.filter_by(book_id=2, user_name='Bhargan Basepair').one()
>>> db.delete(loan)
>>> db.commit()
You can also delete rows that have not been loaded as objects. Let's
do our insert/delete cycle once more, this time using the loans
table's delete method. (For SQLAlchemy experts: note that no flush()
call is required since this delete acts at the SQL level, not at the
Mapper level.) The same where-clause construction rules apply here as
to the select methods.
::
>>> db.loans.insert(book_id=book_id, user_name=user.name)
MappedLoans(book_id=2,user_name=u'Bhargan Basepair',loan_date=None)
>>> db.loans.delete(db.loans.book_id==2)
You can similarly update multiple rows at once. This will change the
book_id to 1 in all loans whose book_id is 2::
>>> db.loans.update(db.loans.book_id==2, book_id=1)
>>> db.loans.filter_by(book_id=1).all()
[MappedLoans(book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))]
Joins
=====
Occasionally, you will want to pull out a lot of data from related
tables all at once. In this situation, it is far more efficient to
have the database perform the necessary join. (Here we do not have *a
lot of data* but hopefully the concept is still clear.) SQLAlchemy is
smart enough to recognize that loans has a foreign key to users, and
uses that as the join condition automatically.
::
>>> join1 = db.join(db.users, db.loans, isouter=True)
>>> join1.filter_by(name='Joe Student').all()
[MappedJoin(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0,book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))]
If you're unfortunate enough to be using MySQL with the default MyISAM
storage engine, you'll have to specify the join condition manually,
since MyISAM does not store foreign keys. Here's the same join again,
with the join condition explicitly specified::
>>> db.join(db.users, db.loans, db.users.name==db.loans.user_name, isouter=True)
<class 'sqlalchemy.ext.sqlsoup.MappedJoin'>
You can compose arbitrarily complex joins by combining Join objects
with tables or other joins. Here we combine our first join with the
books table::
>>> join2 = db.join(join1, db.books)
>>> join2.all()
[MappedJoin(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0,book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0),id=1,title=u'Mustards I Have Known',published_year=u'1989',authors=u'Jones')]
If you join tables that have an identical column name, wrap your join
with `with_labels`, to disambiguate columns with their table name
(.c is short for .columns)::
>>> db.with_labels(join1).c.keys()
[u'users_name', u'users_email', u'users_password', u'users_classname', u'users_admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date']
You can also join directly to a labeled object::
>>> labeled_loans = db.with_labels(db.loans)
>>> db.join(db.users, labeled_loans, isouter=True).c.keys()
[u'name', u'email', u'password', u'classname', u'admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date']
Relationships
=============
You can define relationships on SqlSoup classes:
>>> db.users.relate('loans', db.loans)
These can then be used like a normal SA property:
>>> db.users.get('Joe Student').loans
[MappedLoans(book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))]
>>> db.users.filter(~db.users.loans.any()).all()
[MappedUsers(name=u'Bhargan Basepair',email='basepair+nospam@example.edu',password=u'basepair',classname=None,admin=1)]
relate can take any options that the relationship function accepts in normal mapper definition:
>>> del db._cache['users']
>>> db.users.relate('loans', db.loans, order_by=db.loans.loan_date, cascade='all, delete-orphan')
Advanced Use
============
Sessions, Transations and Application Integration
-------------------------------------------------
**Note:** please read and understand this section thoroughly before using SqlSoup in any web application.
SqlSoup uses a ScopedSession to provide thread-local sessions. You
can get a reference to the current one like this::
>>> session = db.session
The default session is available at the module level in SQLSoup, via::
>>> from sqlalchemy.ext.sqlsoup import Session
The configuration of this session is ``autoflush=True``, ``autocommit=False``.
This means when you work with the SqlSoup object, you need to call ``db.commit()``
in order to have changes persisted. You may also call ``db.rollback()`` to
roll things back.
Since the SqlSoup object's Session automatically enters into a transaction as soon
as it's used, it is *essential* that you call ``commit()`` or ``rollback()``
on it when the work within a thread completes. This means all the guidelines
for web application integration at :ref:`session_lifespan` must be followed.
The SqlSoup object can have any session or scoped session configured onto it.
This is of key importance when integrating with existing code or frameworks
such as Pylons. If your application already has a ``Session`` configured,
pass it to your SqlSoup object::
>>> from myapplication import Session
>>> db = SqlSoup(session=Session)
If the ``Session`` is configured with ``autocommit=True``, use ``flush()``
instead of ``commit()`` to persist changes - in this case, the ``Session``
closes out its transaction immediately and no external management is needed. ``rollback()`` is also not available. Configuring a new SQLSoup object in "autocommit" mode looks like::
>>> from sqlalchemy.orm import scoped_session, sessionmaker
>>> db = SqlSoup('sqlite://', session=scoped_session(sessionmaker(autoflush=False, expire_on_commit=False, autocommit=True)))
Mapping arbitrary Selectables
-----------------------------
SqlSoup can map any SQLAlchemy ``Selectable`` with the map
method. Let's map a ``Select`` object that uses an aggregate function;
we'll use the SQLAlchemy ``Table`` that SqlSoup introspected as the
basis. (Since we're not mapping to a simple table or join, we need to
tell SQLAlchemy how to find the *primary key* which just needs to be
unique within the select, and not necessarily correspond to a *real*
PK in the database.)
::
>>> from sqlalchemy import select, func
>>> b = db.books._table
>>> s = select([b.c.published_year, func.count('*').label('n')], from_obj=[b], group_by=[b.c.published_year])
>>> s = s.alias('years_with_count')
>>> years_with_count = db.map(s, primary_key=[s.c.published_year])
>>> years_with_count.filter_by(published_year='1989').all()
[MappedBooks(published_year=u'1989',n=1)]
Obviously if we just wanted to get a list of counts associated with
book years once, raw SQL is going to be less work. The advantage of
mapping a Select is reusability, both standalone and in Joins. (And if
you go to full SQLAlchemy, you can perform mappings like this directly
to your object models.)
An easy way to save mapped selectables like this is to just hang them on
your db object::
>>> db.years_with_count = years_with_count
Python is flexible like that!
Raw SQL
-------
SqlSoup works fine with SQLAlchemy's text construct, described in :ref:`sqlexpression_text`.
You can also execute textual SQL directly using the `execute()` method,
which corresponds to the `execute()` method on the underlying `Session`.
Expressions here are expressed like ``text()`` constructs, using named parameters
with colons::
>>> rp = db.execute('select name, email from users where name like :name order by name', name='%Bhargan%')
>>> for name, email in rp.fetchall(): print name, email
Bhargan Basepair basepair+nospam@example.edu
Or you can get at the current transaction's connection using `connection()`. This is the
raw connection object which can accept any sort of SQL expression or raw SQL string passed to the database::
>>> conn = db.connection()
>>> conn.execute("'select name, email from users where name like ? order by name'", '%Bhargan%')
Dynamic table names
-------------------
You can load a table whose name is specified at runtime with the entity() method:
>>> tablename = 'loans'
>>> db.entity(tablename) == db.loans
True
entity() also takes an optional schema argument. If none is specified, the
default schema is used.
"""
from sqlalchemy import Table, MetaData, join
from sqlalchemy import schema, sql
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import scoped_session, sessionmaker, mapper, \
class_mapper, relationship, session,\
object_session
from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE
from sqlalchemy.exceptions import SQLAlchemyError, InvalidRequestError, ArgumentError
from sqlalchemy.sql import expression
__all__ = ['PKNotFoundError', 'SqlSoup']
Session = scoped_session(sessionmaker(autoflush=True, autocommit=False))
class AutoAdd(MapperExtension):
def __init__(self, scoped_session):
self.scoped_session = scoped_session
def instrument_class(self, mapper, class_):
class_.__init__ = self._default__init__(mapper)
def _default__init__(ext, mapper):
def __init__(self, **kwargs):
for key, value in kwargs.iteritems():
setattr(self, key, value)
return __init__
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
session = self.scoped_session()
session._save_without_cascade(instance)
return EXT_CONTINUE
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
sess = object_session(instance)
if sess:
sess.expunge(instance)
return EXT_CONTINUE
class PKNotFoundError(SQLAlchemyError):
pass
def _ddl_error(cls):
msg = 'SQLSoup can only modify mapped Tables (found: %s)' \
% cls._table.__class__.__name__
raise InvalidRequestError(msg)
# metaclass is necessary to expose class methods with getattr, e.g.
# we want to pass db.users.select through to users._mapper.select
class SelectableClassType(type):
def insert(cls, **kwargs):
_ddl_error(cls)
def __clause_element__(cls):
return cls._table
def __getattr__(cls, attr):
if attr == '_query':
# called during mapper init
raise AttributeError()
return getattr(cls._query, attr)
class TableClassType(SelectableClassType):
def insert(cls, **kwargs):
o = cls()
o.__dict__.update(kwargs)
return o
def relate(cls, propname, *args, **kwargs):
class_mapper(cls)._configure_property(propname, relationship(*args, **kwargs))
def _is_outer_join(selectable):
if not isinstance(selectable, sql.Join):
return False
if selectable.isouter:
return True
return _is_outer_join(selectable.left) or _is_outer_join(selectable.right)
def _selectable_name(selectable):
if isinstance(selectable, sql.Alias):
return _selectable_name(selectable.element)
elif isinstance(selectable, sql.Select):
return ''.join(_selectable_name(s) for s in selectable.froms)
elif isinstance(selectable, schema.Table):
return selectable.name.capitalize()
else:
x = selectable.__class__.__name__
if x[0] == '_':
x = x[1:]
return x
def _class_for_table(session, engine, selectable, **mapper_kwargs):
selectable = expression._clause_element_as_expr(selectable)
mapname = 'Mapped' + _selectable_name(selectable)
# Py2K
if isinstance(mapname, unicode):
engine_encoding = engine.dialect.encoding
mapname = mapname.encode(engine_encoding)
# end Py2K
if isinstance(selectable, Table):
klass = TableClassType(mapname, (object,), {})
else:
klass = SelectableClassType(mapname, (object,), {})
def _compare(self, o):
L = list(self.__class__.c.keys())
L.sort()
t1 = [getattr(self, k) for k in L]
try:
t2 = [getattr(o, k) for k in L]
except AttributeError:
raise TypeError('unable to compare with %s' % o.__class__)
return t1, t2
# python2/python3 compatible system of
# __cmp__ - __lt__ + __eq__
def __lt__(self, o):
t1, t2 = _compare(self, o)
return t1 < t2
def __eq__(self, o):
t1, t2 = _compare(self, o)
return t1 == t2
def __repr__(self):
L = ["%s=%r" % (key, getattr(self, key, ''))
for key in self.__class__.c.keys()]
return '%s(%s)' % (self.__class__.__name__, ','.join(L))
for m in ['__eq__', '__repr__', '__lt__']:
setattr(klass, m, eval(m))
klass._table = selectable
klass.c = expression.ColumnCollection()
mappr = mapper(klass,
selectable,
extension=AutoAdd(session),
**mapper_kwargs)
for k in mappr.iterate_properties:
klass.c[k.key] = k.columns[0]
klass._query = session.query_property()
return klass
class SqlSoup(object):
def __init__(self, engine_or_metadata, **kw):
"""Initialize a new ``SqlSoup``.
`args` may either be an ``SQLEngine`` or a set of arguments
suitable for passing to ``create_engine``.
"""
self.session = kw.pop('session', Session)
if isinstance(engine_or_metadata, MetaData):
self._metadata = engine_or_metadata
elif isinstance(engine_or_metadata, (basestring, Engine)):
self._metadata = MetaData(engine_or_metadata)
else:
raise ArgumentError("invalid engine or metadata argument %r" % engine_or_metadata)
self._cache = {}
self.schema = None
@property
def engine(self):
return self._metadata.bind
bind = engine
def delete(self, *args, **kwargs):
self.session.delete(*args, **kwargs)
def execute(self, stmt, **params):
return self.session.execute(sql.text(stmt, bind=self.bind), **params)
@property
def _underlying_session(self):
if isinstance(self.session, session.Session):
return self.session
else:
return self.session()
def connection(self):
return self._underlying_session._connection_for_bind(self.bind)
def flush(self):
self.session.flush()
def rollback(self):
self.session.rollback()
def commit(self):
self.session.commit()
def clear(self):
self.session.expunge_all()
def expunge(self, *args, **kw):
self.session.expunge(*args, **kw)
def expunge_all(self):
self.session.expunge_all()
def map(self, selectable, **kwargs):
try:
t = self._cache[selectable]
except KeyError:
t = _class_for_table(self.session, self.engine, selectable, **kwargs)
self._cache[selectable] = t
return t
def with_labels(self, item):
# TODO give meaningful aliases
return self.map(
expression._clause_element_as_expr(item).
select(use_labels=True).
alias('foo'))
def join(self, *args, **kwargs):
j = join(*args, **kwargs)
return self.map(j)
def entity(self, attr, schema=None):
try:
t = self._cache[attr]
except KeyError, ke:
table = Table(attr, self._metadata, autoload=True, autoload_with=self.bind, schema=schema or self.schema)
if not table.primary_key.columns:
raise PKNotFoundError('table %r does not have a primary key defined [columns: %s]' % (attr, ','.join(table.c.keys())))
if table.columns:
t = _class_for_table(self.session, self.engine, table)
else:
t = None
self._cache[attr] = t
return t
def __getattr__(self, attr):
return self.entity(attr)
def __repr__(self):
return 'SqlSoup(%r)' % self._metadata

205
sqlalchemy/interfaces.py Normal file
View File

@ -0,0 +1,205 @@
# interfaces.py
# Copyright (C) 2007 Jason Kirtland jek@discorporate.us
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Interfaces and abstract types."""
class PoolListener(object):
"""Hooks into the lifecycle of connections in a ``Pool``.
Usage::
class MyListener(PoolListener):
def connect(self, dbapi_con, con_record):
'''perform connect operations'''
# etc.
# create a new pool with a listener
p = QueuePool(..., listeners=[MyListener()])
# add a listener after the fact
p.add_listener(MyListener())
# usage with create_engine()
e = create_engine("url://", listeners=[MyListener()])
All of the standard connection :class:`~sqlalchemy.pool.Pool` types can
accept event listeners for key connection lifecycle events:
creation, pool check-out and check-in. There are no events fired
when a connection closes.
For any given DB-API connection, there will be one ``connect``
event, `n` number of ``checkout`` events, and either `n` or `n - 1`
``checkin`` events. (If a ``Connection`` is detached from its
pool via the ``detach()`` method, it won't be checked back in.)
These are low-level events for low-level objects: raw Python
DB-API connections, without the conveniences of the SQLAlchemy
``Connection`` wrapper, ``Dialect`` services or ``ClauseElement``
execution. If you execute SQL through the connection, explicitly
closing all cursors and other resources is recommended.
Events also receive a ``_ConnectionRecord``, a long-lived internal
``Pool`` object that basically represents a "slot" in the
connection pool. ``_ConnectionRecord`` objects have one public
attribute of note: ``info``, a dictionary whose contents are
scoped to the lifetime of the DB-API connection managed by the
record. You can use this shared storage area however you like.
There is no need to subclass ``PoolListener`` to handle events.
Any class that implements one or more of these methods can be used
as a pool listener. The ``Pool`` will inspect the methods
provided by a listener object and add the listener to one or more
internal event queues based on its capabilities. In terms of
efficiency and function call overhead, you're much better off only
providing implementations for the hooks you'll be using.
"""
def connect(self, dbapi_con, con_record):
"""Called once for each new DB-API connection or Pool's ``creator()``.
dbapi_con
A newly connected raw DB-API connection (not a SQLAlchemy
``Connection`` wrapper).
con_record
The ``_ConnectionRecord`` that persistently manages the connection
"""
def first_connect(self, dbapi_con, con_record):
"""Called exactly once for the first DB-API connection.
dbapi_con
A newly connected raw DB-API connection (not a SQLAlchemy
``Connection`` wrapper).
con_record
The ``_ConnectionRecord`` that persistently manages the connection
"""
def checkout(self, dbapi_con, con_record, con_proxy):
"""Called when a connection is retrieved from the Pool.
dbapi_con
A raw DB-API connection
con_record
The ``_ConnectionRecord`` that persistently manages the connection
con_proxy
The ``_ConnectionFairy`` which manages the connection for the span of
the current checkout.
If you raise an ``exc.DisconnectionError``, the current
connection will be disposed and a fresh connection retrieved.
Processing of all checkout listeners will abort and restart
using the new connection.
"""
def checkin(self, dbapi_con, con_record):
"""Called when a connection returns to the pool.
Note that the connection may be closed, and may be None if the
connection has been invalidated. ``checkin`` will not be called
for detached connections. (They do not return to the pool.)
dbapi_con
A raw DB-API connection
con_record
The ``_ConnectionRecord`` that persistently manages the connection
"""
class ConnectionProxy(object):
"""Allows interception of statement execution by Connections.
Either or both of the ``execute()`` and ``cursor_execute()``
may be implemented to intercept compiled statement and
cursor level executions, e.g.::
class MyProxy(ConnectionProxy):
def execute(self, conn, execute, clauseelement, *multiparams, **params):
print "compiled statement:", clauseelement
return execute(clauseelement, *multiparams, **params)
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
print "raw statement:", statement
return execute(cursor, statement, parameters, context)
The ``execute`` argument is a function that will fulfill the default
execution behavior for the operation. The signature illustrated
in the example should be used.
The proxy is installed into an :class:`~sqlalchemy.engine.Engine` via
the ``proxy`` argument::
e = create_engine('someurl://', proxy=MyProxy())
"""
def execute(self, conn, execute, clauseelement, *multiparams, **params):
"""Intercept high level execute() events."""
return execute(clauseelement, *multiparams, **params)
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
"""Intercept low-level cursor execute() events."""
return execute(cursor, statement, parameters, context)
def begin(self, conn, begin):
"""Intercept begin() events."""
return begin()
def rollback(self, conn, rollback):
"""Intercept rollback() events."""
return rollback()
def commit(self, conn, commit):
"""Intercept commit() events."""
return commit()
def savepoint(self, conn, savepoint, name=None):
"""Intercept savepoint() events."""
return savepoint(name=name)
def rollback_savepoint(self, conn, rollback_savepoint, name, context):
"""Intercept rollback_savepoint() events."""
return rollback_savepoint(name, context)
def release_savepoint(self, conn, release_savepoint, name, context):
"""Intercept release_savepoint() events."""
return release_savepoint(name, context)
def begin_twophase(self, conn, begin_twophase, xid):
"""Intercept begin_twophase() events."""
return begin_twophase(xid)
def prepare_twophase(self, conn, prepare_twophase, xid):
"""Intercept prepare_twophase() events."""
return prepare_twophase(xid)
def rollback_twophase(self, conn, rollback_twophase, xid, is_prepared):
"""Intercept rollback_twophase() events."""
return rollback_twophase(xid, is_prepared)
def commit_twophase(self, conn, commit_twophase, xid, is_prepared):
"""Intercept commit_twophase() events."""
return commit_twophase(xid, is_prepared)

119
sqlalchemy/log.py Normal file
View File

@ -0,0 +1,119 @@
# log.py - adapt python logging module to SQLAlchemy
# Copyright (C) 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Logging control and utilities.
Control of logging for SA can be performed from the regular python logging
module. The regular dotted module namespace is used, starting at
'sqlalchemy'. For class-level logging, the class name is appended.
The "echo" keyword parameter which is available on SQLA ``Engine``
and ``Pool`` objects corresponds to a logger specific to that
instance only.
E.g.::
engine.echo = True
is equivalent to::
import logging
logger = logging.getLogger('sqlalchemy.engine.Engine.%s' % hex(id(engine)))
logger.setLevel(logging.DEBUG)
"""
import logging
import sys
from sqlalchemy import util
rootlogger = logging.getLogger('sqlalchemy')
if rootlogger.level == logging.NOTSET:
rootlogger.setLevel(logging.WARN)
default_enabled = False
def default_logging(name):
global default_enabled
if logging.getLogger(name).getEffectiveLevel() < logging.WARN:
default_enabled = True
if not default_enabled:
default_enabled = True
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s %(name)s %(message)s'))
rootlogger.addHandler(handler)
_logged_classes = set()
def class_logger(cls, enable=False):
logger = logging.getLogger(cls.__module__ + "." + cls.__name__)
if enable == 'debug':
logger.setLevel(logging.DEBUG)
elif enable == 'info':
logger.setLevel(logging.INFO)
cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG)
cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO)
cls.logger = logger
_logged_classes.add(cls)
class Identified(object):
@util.memoized_property
def logging_name(self):
# limit the number of loggers by chopping off the hex(id).
# some novice users unfortunately create an unlimited number
# of Engines in their applications which would otherwise
# cause the app to run out of memory.
return "0x...%s" % hex(id(self))[-4:]
def instance_logger(instance, echoflag=None):
"""create a logger for an instance that implements :class:`Identified`.
Warning: this is an expensive call which also results in a permanent
increase in memory overhead for each call. Use only for
low-volume, long-time-spanning objects.
"""
name = "%s.%s.%s" % (instance.__class__.__module__,
instance.__class__.__name__, instance.logging_name)
if echoflag is not None:
l = logging.getLogger(name)
if echoflag == 'debug':
default_logging(name)
l.setLevel(logging.DEBUG)
elif echoflag is True:
default_logging(name)
l.setLevel(logging.INFO)
elif echoflag is False:
l.setLevel(logging.WARN)
else:
l = logging.getLogger(name)
instance._should_log_debug = lambda: l.isEnabledFor(logging.DEBUG)
instance._should_log_info = lambda: l.isEnabledFor(logging.INFO)
return l
class echo_property(object):
__doc__ = """\
When ``True``, enable log output for this element.
This has the effect of setting the Python logging level for the namespace
of this element's class and object reference. A value of boolean ``True``
indicates that the loglevel ``logging.INFO`` will be set for the logger,
whereas the string value ``debug`` will set the loglevel to
``logging.DEBUG``.
"""
def __get__(self, instance, owner):
if instance is None:
return self
else:
return instance._should_log_debug() and 'debug' or \
(instance._should_log_info() and True or False)
def __set__(self, instance, value):
instance_logger(instance, echoflag=value)

1176
sqlalchemy/orm/__init__.py Normal file

File diff suppressed because it is too large Load Diff

1708
sqlalchemy/orm/attributes.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,575 @@
# orm/dependency.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Relationship dependencies.
Bridges the ``PropertyLoader`` (i.e. a ``relationship()``) and the
``UOWTransaction`` together to allow processing of relationship()-based
dependencies at flush time.
"""
from sqlalchemy import sql, util
import sqlalchemy.exceptions as sa_exc
from sqlalchemy.orm import attributes, exc, sync
from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
def create_dependency_processor(prop):
types = {
ONETOMANY : OneToManyDP,
MANYTOONE: ManyToOneDP,
MANYTOMANY : ManyToManyDP,
}
return types[prop.direction](prop)
class DependencyProcessor(object):
has_dependencies = True
def __init__(self, prop):
self.prop = prop
self.cascade = prop.cascade
self.mapper = prop.mapper
self.parent = prop.parent
self.secondary = prop.secondary
self.direction = prop.direction
self.post_update = prop.post_update
self.passive_deletes = prop.passive_deletes
self.passive_updates = prop.passive_updates
self.enable_typechecks = prop.enable_typechecks
self.key = prop.key
self.dependency_marker = MapperStub(self.parent, self.mapper, self.key)
if not self.prop.synchronize_pairs:
raise sa_exc.ArgumentError("Can't build a DependencyProcessor for relationship %s. "
"No target attributes to populate between parent and child are present" % self.prop)
def _get_instrumented_attribute(self):
"""Return the ``InstrumentedAttribute`` handled by this
``DependencyProecssor``.
"""
return self.parent.class_manager.get_impl(self.key)
def hasparent(self, state):
"""return True if the given object instance has a parent,
according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``.
"""
# TODO: use correct API for this
return self._get_instrumented_attribute().hasparent(state)
def register_dependencies(self, uowcommit):
"""Tell a ``UOWTransaction`` what mappers are dependent on
which, with regards to the two or three mappers handled by
this ``DependencyProcessor``.
"""
raise NotImplementedError()
def register_processors(self, uowcommit):
"""Tell a ``UOWTransaction`` about this object as a processor,
which will be executed after that mapper's objects have been
saved or before they've been deleted. The process operation
manages attributes and dependent operations between two mappers.
"""
raise NotImplementedError()
def whose_dependent_on_who(self, state1, state2):
"""Given an object pair assuming `obj2` is a child of `obj1`,
return a tuple with the dependent object second, or None if
there is no dependency.
"""
if state1 is state2:
return None
elif self.direction == ONETOMANY:
return (state1, state2)
else:
return (state2, state1)
def process_dependencies(self, task, deplist, uowcommit, delete = False):
"""This method is called during a flush operation to
synchronize data between a parent and child object.
It is called within the context of the various mappers and
sometimes individual objects sorted according to their
insert/update/delete order (topological sort).
"""
raise NotImplementedError()
def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
"""Used before the flushes' topological sort to traverse
through related objects and ensure every instance which will
require save/update/delete is properly added to the
UOWTransaction.
"""
raise NotImplementedError()
def _verify_canload(self, state):
if state is not None and \
not self.mapper._canload(state, allow_subtypes=not self.enable_typechecks):
if self.mapper._canload(state, allow_subtypes=True):
raise exc.FlushError(
"Attempting to flush an item of type %s on collection '%s', "
"which is not the expected type %s. Configure mapper '%s' to "
"load this subtype polymorphically, or set "
"enable_typechecks=False to allow subtypes. "
"Mismatched typeloading may cause bi-directional relationships "
"(backrefs) to not function properly." %
(state.class_, self.prop, self.mapper.class_, self.mapper))
else:
raise exc.FlushError(
"Attempting to flush an item of type %s on collection '%s', "
"whose mapper does not inherit from that of %s." %
(state.class_, self.prop, self.mapper.class_))
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
"""Called during a flush to synchronize primary key identifier
values between a parent/child object, as well as to an
associationrow in the case of many-to-many.
"""
raise NotImplementedError()
def _check_reverse_action(self, uowcommit, parent, child, action):
"""Determine if an action has been performed by the 'reverse' property of this property.
this is used to ensure that only one side of a bidirectional relationship
issues a certain operation for a parent/child pair.
"""
for r in self.prop._reverse_property:
if not r.viewonly and (r._dependency_processor, action, parent, child) in uowcommit.attributes:
return True
return False
def _performed_action(self, uowcommit, parent, child, action):
"""Establish that an action has been performed for a certain parent/child pair.
Used only for actions that are sensitive to bidirectional double-action,
i.e. manytomany, post_update.
"""
uowcommit.attributes[(self, action, parent, child)] = True
def _conditional_post_update(self, state, uowcommit, related):
"""Execute a post_update call.
For relationships that contain the post_update flag, an additional
``UPDATE`` statement may be associated after an ``INSERT`` or
before a ``DELETE`` in order to resolve circular row
dependencies.
This method will check for the post_update flag being set on a
particular relationship, and given a target object and list of
one or more related objects, and execute the ``UPDATE`` if the
given related object list contains ``INSERT``s or ``DELETE``s.
"""
if state is not None and self.post_update:
for x in related:
if x is not None and not self._check_reverse_action(uowcommit, x, state, "postupdate"):
uowcommit.register_object(state, postupdate=True, post_update_cols=[r for l, r in self.prop.synchronize_pairs])
self._performed_action(uowcommit, x, state, "postupdate")
break
def _pks_changed(self, uowcommit, state):
raise NotImplementedError()
def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, self.prop)
class OneToManyDP(DependencyProcessor):
def register_dependencies(self, uowcommit):
if self.post_update:
uowcommit.register_dependency(self.mapper, self.dependency_marker)
uowcommit.register_dependency(self.parent, self.dependency_marker)
else:
uowcommit.register_dependency(self.parent, self.mapper)
def register_processors(self, uowcommit):
if self.post_update:
uowcommit.register_processor(self.dependency_marker, self, self.parent)
else:
uowcommit.register_processor(self.parent, self, self.parent)
def process_dependencies(self, task, deplist, uowcommit, delete = False):
if delete:
# head object is being deleted, and we manage its list of child objects
# the child objects have to have their foreign key to the parent set to NULL
# this phase can be called safely for any cascade but is unnecessary if delete cascade
# is on.
if self.post_update or not self.passive_deletes == 'all':
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if history:
for child in history.deleted:
if child is not None and self.hasparent(child) is False:
self._synchronize(state, child, None, True, uowcommit)
self._conditional_post_update(child, uowcommit, [state])
if self.post_update or not self.cascade.delete:
for child in history.unchanged:
if child is not None:
self._synchronize(state, child, None, True, uowcommit)
self._conditional_post_update(child, uowcommit, [state])
else:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=True)
if history:
for child in history.added:
self._synchronize(state, child, None, False, uowcommit)
if child is not None:
self._conditional_post_update(child, uowcommit, [state])
for child in history.deleted:
if not self.cascade.delete_orphan and not self.hasparent(child):
self._synchronize(state, child, None, True, uowcommit)
if self._pks_changed(uowcommit, state):
for child in history.unchanged:
self._synchronize(state, child, None, False, uowcommit)
def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
if delete:
# head object is being deleted, and we manage its list of child objects
# the child objects have to have their foreign key to the parent set to NULL
if not self.post_update:
should_null_fks = not self.cascade.delete and not self.passive_deletes == 'all'
for state in deplist:
history = uowcommit.get_attribute_history(
state, self.key, passive=self.passive_deletes)
if history:
for child in history.deleted:
if child is not None and self.hasparent(child) is False:
if self.cascade.delete_orphan:
uowcommit.register_object(child, isdelete=True)
else:
uowcommit.register_object(child)
if should_null_fks:
for child in history.unchanged:
if child is not None:
uowcommit.register_object(child)
else:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=True)
if history:
for child in history.added:
if child is not None:
uowcommit.register_object(child)
for child in history.deleted:
if not self.cascade.delete_orphan:
uowcommit.register_object(child, isdelete=False)
elif self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(
attributes.instance_state(c),
isdelete=True)
if self._pks_changed(uowcommit, state):
if not history:
history = uowcommit.get_attribute_history(
state, self.key, passive=self.passive_updates)
if history:
for child in history.unchanged:
if child is not None:
uowcommit.register_object(child)
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
source = state
dest = child
if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
return
self._verify_canload(child)
if clearkeys:
sync.clear(dest, self.mapper, self.prop.synchronize_pairs)
else:
sync.populate(source, self.parent, dest, self.mapper,
self.prop.synchronize_pairs, uowcommit,
self.passive_updates)
def _pks_changed(self, uowcommit, state):
return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
class DetectKeySwitch(DependencyProcessor):
"""a special DP that works for many-to-one relationships, fires off for
child items who have changed their referenced key."""
has_dependencies = False
def register_dependencies(self, uowcommit):
pass
def register_processors(self, uowcommit):
uowcommit.register_processor(self.parent, self, self.mapper)
def preprocess_dependencies(self, task, deplist, uowcommit, delete=False):
# for non-passive updates, register in the preprocess stage
# so that mapper save_obj() gets a hold of changes
if not delete and not self.passive_updates:
self._process_key_switches(deplist, uowcommit)
def process_dependencies(self, task, deplist, uowcommit, delete=False):
# for passive updates, register objects in the process stage
# so that we avoid ManyToOneDP's registering the object without
# the listonly flag in its own preprocess stage (results in UPDATE)
# statements being emitted
if not delete and self.passive_updates:
self._process_key_switches(deplist, uowcommit)
def _process_key_switches(self, deplist, uowcommit):
switchers = set(s for s in deplist if self._pks_changed(uowcommit, s))
if switchers:
# yes, we're doing a linear search right now through the UOW. only
# takes effect when primary key values have actually changed.
# a possible optimization might be to enhance the "hasparents" capability of
# attributes to actually store all parent references, but this introduces
# more complicated attribute accounting.
for s in [elem for elem in uowcommit.session.identity_map.all_states()
if issubclass(elem.class_, self.parent.class_) and
self.key in elem.dict and
elem.dict[self.key] is not None and
attributes.instance_state(elem.dict[self.key]) in switchers
]:
uowcommit.register_object(s)
sync.populate(
attributes.instance_state(s.dict[self.key]),
self.mapper, s, self.parent, self.prop.synchronize_pairs,
uowcommit, self.passive_updates)
def _pks_changed(self, uowcommit, state):
return sync.source_modified(uowcommit, state, self.mapper, self.prop.synchronize_pairs)
class ManyToOneDP(DependencyProcessor):
def __init__(self, prop):
DependencyProcessor.__init__(self, prop)
self.mapper._dependency_processors.append(DetectKeySwitch(prop))
def register_dependencies(self, uowcommit):
if self.post_update:
uowcommit.register_dependency(self.mapper, self.dependency_marker)
uowcommit.register_dependency(self.parent, self.dependency_marker)
else:
uowcommit.register_dependency(self.mapper, self.parent)
def register_processors(self, uowcommit):
if self.post_update:
uowcommit.register_processor(self.dependency_marker, self, self.parent)
else:
uowcommit.register_processor(self.mapper, self, self.parent)
def process_dependencies(self, task, deplist, uowcommit, delete=False):
if delete:
if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all':
# post_update means we have to update our row to not reference the child object
# before we can DELETE the row
for state in deplist:
self._synchronize(state, None, None, True, uowcommit)
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if history:
self._conditional_post_update(state, uowcommit, history.sum())
else:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=True)
if history:
for child in history.added:
self._synchronize(state, child, None, False, uowcommit)
self._conditional_post_update(state, uowcommit, history.sum())
def preprocess_dependencies(self, task, deplist, uowcommit, delete=False):
if self.post_update:
return
if delete:
if self.cascade.delete or self.cascade.delete_orphan:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if history:
if self.cascade.delete_orphan:
todelete = history.sum()
else:
todelete = history.non_deleted()
for child in todelete:
if child is None:
continue
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(
attributes.instance_state(c), isdelete=True)
else:
for state in deplist:
uowcommit.register_object(state)
if self.cascade.delete_orphan:
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if history:
for child in history.deleted:
if self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(
attributes.instance_state(c),
isdelete=True)
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
if state is None or (not self.post_update and uowcommit.is_deleted(state)):
return
if clearkeys or child is None:
sync.clear(state, self.parent, self.prop.synchronize_pairs)
else:
self._verify_canload(child)
sync.populate(child, self.mapper, state,
self.parent, self.prop.synchronize_pairs, uowcommit,
self.passive_updates
)
class ManyToManyDP(DependencyProcessor):
def register_dependencies(self, uowcommit):
# many-to-many. create a "Stub" mapper to represent the
# "middle table" in the relationship. This stub mapper doesnt save
# or delete any objects, but just marks a dependency on the two
# related mappers. its dependency processor then populates the
# association table.
uowcommit.register_dependency(self.parent, self.dependency_marker)
uowcommit.register_dependency(self.mapper, self.dependency_marker)
def register_processors(self, uowcommit):
uowcommit.register_processor(self.dependency_marker, self, self.parent)
def process_dependencies(self, task, deplist, uowcommit, delete = False):
connection = uowcommit.transaction.connection(self.mapper)
secondary_delete = []
secondary_insert = []
secondary_update = []
if delete:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if history:
for child in history.non_added():
if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"):
continue
associationrow = {}
self._synchronize(state, child, associationrow, False, uowcommit)
secondary_delete.append(associationrow)
self._performed_action(uowcommit, state, child, "manytomany")
else:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key)
if history:
for child in history.added:
if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"):
continue
associationrow = {}
self._synchronize(state, child, associationrow, False, uowcommit)
self._performed_action(uowcommit, state, child, "manytomany")
secondary_insert.append(associationrow)
for child in history.deleted:
if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"):
continue
associationrow = {}
self._synchronize(state, child, associationrow, False, uowcommit)
self._performed_action(uowcommit, state, child, "manytomany")
secondary_delete.append(associationrow)
if not self.passive_updates and self._pks_changed(uowcommit, state):
if not history:
history = uowcommit.get_attribute_history(state, self.key, passive=False)
for child in history.unchanged:
associationrow = {}
sync.update(state, self.parent, associationrow, "old_", self.prop.synchronize_pairs)
sync.update(child, self.mapper, associationrow, "old_", self.prop.secondary_synchronize_pairs)
#self.syncrules.update(associationrow, state, child, "old_")
secondary_update.append(associationrow)
if secondary_delete:
statement = self.secondary.delete(sql.and_(*[
c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow
]))
result = connection.execute(statement, secondary_delete)
if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_delete):
raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of "
"secondary table rows deleted from table '%s': %d" %
(result.rowcount, self.secondary.description, len(secondary_delete)))
if secondary_update:
statement = self.secondary.update(sql.and_(*[
c == sql.bindparam("old_" + c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow
]))
result = connection.execute(statement, secondary_update)
if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update):
raise exc.ConcurrentModificationError("Updated rowcount %d does not match number of "
"secondary table rows updated from table '%s': %d" %
(result.rowcount, self.secondary.description, len(secondary_update)))
if secondary_insert:
statement = self.secondary.insert()
connection.execute(statement, secondary_insert)
def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
if not delete:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=True)
if history:
for child in history.deleted:
if self.cascade.delete_orphan and self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(
attributes.instance_state(c), isdelete=True)
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
if associationrow is None:
return
self._verify_canload(child)
sync.populate_dict(state, self.parent, associationrow,
self.prop.synchronize_pairs)
sync.populate_dict(child, self.mapper, associationrow,
self.prop.secondary_synchronize_pairs)
def _pks_changed(self, uowcommit, state):
return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
class MapperStub(object):
"""Represent a many-to-many dependency within a flush
context.
The UOWTransaction corresponds dependencies to mappers.
MapperStub takes the place of the "association table"
so that a depedendency can be corresponded to it.
"""
def __init__(self, parent, mapper, key):
self.mapper = mapper
self.base_mapper = self
self.class_ = mapper.class_
self._inheriting_mappers = []
def polymorphic_iterator(self):
return iter((self,))
def _register_dependencies(self, uowcommit):
pass
def _register_procesors(self, uowcommit):
pass
def _save_obj(self, *args, **kwargs):
pass
def _delete_obj(self, *args, **kwargs):
pass
def primary_mapper(self):
return self

293
sqlalchemy/orm/dynamic.py Normal file
View File

@ -0,0 +1,293 @@
# dynamic.py
# Copyright (C) the SQLAlchemy authors and contributors
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Dynamic collection API.
Dynamic collections act like Query() objects for read operations and support
basic add/delete mutation.
"""
from sqlalchemy import log, util
from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import exc as sa_exc
from sqlalchemy.sql import operators
from sqlalchemy.orm import (
attributes, object_session, util as mapperutil, strategies, object_mapper
)
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.util import _state_has_identity, has_identity
from sqlalchemy.orm import attributes, collections
class DynaLoader(strategies.AbstractRelationshipLoader):
def init_class_attribute(self, mapper):
self.is_class_level = True
strategies._register_attribute(self,
mapper,
useobject=True,
impl_class=DynamicAttributeImpl,
target_mapper=self.parent_property.mapper,
order_by=self.parent_property.order_by,
query_class=self.parent_property.query_class
)
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
return (None, None)
log.class_logger(DynaLoader)
class DynamicAttributeImpl(attributes.AttributeImpl):
uses_objects = True
accepts_scalar_loader = False
def __init__(self, class_, key, typecallable,
target_mapper, order_by, query_class=None, **kwargs):
super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs)
self.target_mapper = target_mapper
self.order_by = order_by
if not query_class:
self.query_class = AppenderQuery
elif AppenderMixin in query_class.mro():
self.query_class = query_class
else:
self.query_class = mixin_user_query(query_class)
def get(self, state, dict_, passive=False):
if passive:
return self._get_collection_history(state, passive=True).added_items
else:
return self.query_class(self, state)
def get_collection(self, state, dict_, user_data=None, passive=True):
if passive:
return self._get_collection_history(state, passive=passive).added_items
else:
history = self._get_collection_history(state, passive=passive)
return history.added_items + history.unchanged_items
def fire_append_event(self, state, dict_, value, initiator):
collection_history = self._modified_event(state, dict_)
collection_history.added_items.append(value)
for ext in self.extensions:
ext.append(state, value, initiator or self)
if self.trackparent and value is not None:
self.sethasparent(attributes.instance_state(value), True)
def fire_remove_event(self, state, dict_, value, initiator):
collection_history = self._modified_event(state, dict_)
collection_history.deleted_items.append(value)
if self.trackparent and value is not None:
self.sethasparent(attributes.instance_state(value), False)
for ext in self.extensions:
ext.remove(state, value, initiator or self)
def _modified_event(self, state, dict_):
if self.key not in state.committed_state:
state.committed_state[self.key] = CollectionHistory(self, state)
state.modified_event(dict_,
self,
False,
attributes.NEVER_SET,
passive=attributes.PASSIVE_NO_INITIALIZE)
# this is a hack to allow the _base.ComparableEntity fixture
# to work
dict_[self.key] = True
return state.committed_state[self.key]
def set(self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF):
if initiator is self:
return
self._set_iterable(state, dict_, value)
def _set_iterable(self, state, dict_, iterable, adapter=None):
collection_history = self._modified_event(state, dict_)
new_values = list(iterable)
if _state_has_identity(state):
old_collection = list(self.get(state, dict_))
else:
old_collection = []
collections.bulk_replace(new_values, DynCollectionAdapter(self, state, old_collection), DynCollectionAdapter(self, state, new_values))
def delete(self, *args, **kwargs):
raise NotImplementedError()
def get_history(self, state, dict_, passive=False):
c = self._get_collection_history(state, passive)
return attributes.History(c.added_items, c.unchanged_items, c.deleted_items)
def _get_collection_history(self, state, passive=False):
if self.key in state.committed_state:
c = state.committed_state[self.key]
else:
c = CollectionHistory(self, state)
if not passive:
return CollectionHistory(self, state, apply_to=c)
else:
return c
def append(self, state, dict_, value, initiator, passive=False):
if initiator is not self:
self.fire_append_event(state, dict_, value, initiator)
def remove(self, state, dict_, value, initiator, passive=False):
if initiator is not self:
self.fire_remove_event(state, dict_, value, initiator)
class DynCollectionAdapter(object):
"""the dynamic analogue to orm.collections.CollectionAdapter"""
def __init__(self, attr, owner_state, data):
self.attr = attr
self.state = owner_state
self.data = data
def __iter__(self):
return iter(self.data)
def append_with_event(self, item, initiator=None):
self.attr.append(self.state, self.state.dict, item, initiator)
def remove_with_event(self, item, initiator=None):
self.attr.remove(self.state, self.state.dict, item, initiator)
def append_without_event(self, item):
pass
def remove_without_event(self, item):
pass
class AppenderMixin(object):
query_class = None
def __init__(self, attr, state):
Query.__init__(self, attr.target_mapper, None)
self.instance = instance = state.obj()
self.attr = attr
mapper = object_mapper(instance)
prop = mapper.get_property(self.attr.key, resolve_synonyms=True)
self._criterion = prop.compare(
operators.eq,
instance,
value_is_parent=True,
alias_secondary=False)
if self.attr.order_by:
self._order_by = self.attr.order_by
def __session(self):
sess = object_session(self.instance)
if sess is not None and self.autoflush and sess.autoflush and self.instance in sess:
sess.flush()
if not has_identity(self.instance):
return None
else:
return sess
def session(self):
return self.__session()
session = property(session, lambda s, x:None)
def __iter__(self):
sess = self.__session()
if sess is None:
return iter(self.attr._get_collection_history(
attributes.instance_state(self.instance),
passive=True).added_items)
else:
return iter(self._clone(sess))
def __getitem__(self, index):
sess = self.__session()
if sess is None:
return self.attr._get_collection_history(
attributes.instance_state(self.instance),
passive=True).added_items.__getitem__(index)
else:
return self._clone(sess).__getitem__(index)
def count(self):
sess = self.__session()
if sess is None:
return len(self.attr._get_collection_history(
attributes.instance_state(self.instance),
passive=True).added_items)
else:
return self._clone(sess).count()
def _clone(self, sess=None):
# note we're returning an entirely new Query class instance
# here without any assignment capabilities; the class of this
# query is determined by the session.
instance = self.instance
if sess is None:
sess = object_session(instance)
if sess is None:
raise orm_exc.DetachedInstanceError(
"Parent instance %s is not bound to a Session, and no "
"contextual session is established; lazy load operation "
"of attribute '%s' cannot proceed" % (
mapperutil.instance_str(instance), self.attr.key))
if self.query_class:
query = self.query_class(self.attr.target_mapper, session=sess)
else:
query = sess.query(self.attr.target_mapper)
query._criterion = self._criterion
query._order_by = self._order_by
return query
def append(self, item):
self.attr.append(
attributes.instance_state(self.instance),
attributes.instance_dict(self.instance), item, None)
def remove(self, item):
self.attr.remove(
attributes.instance_state(self.instance),
attributes.instance_dict(self.instance), item, None)
class AppenderQuery(AppenderMixin, Query):
"""A dynamic query that supports basic collection storage operations."""
def mixin_user_query(cls):
"""Return a new class with AppenderQuery functionality layered over."""
name = 'Appender' + cls.__name__
return type(name, (AppenderMixin, cls), {'query_class': cls})
class CollectionHistory(object):
"""Overrides AttributeHistory to receive append/remove events directly."""
def __init__(self, attr, state, apply_to=None):
if apply_to:
deleted = util.IdentitySet(apply_to.deleted_items)
added = apply_to.added_items
coll = AppenderQuery(attr, state).autoflush(False)
self.unchanged_items = [o for o in util.IdentitySet(coll) if o not in deleted]
self.added_items = apply_to.added_items
self.deleted_items = apply_to.deleted_items
else:
self.deleted_items = []
self.added_items = []
self.unchanged_items = []

104
sqlalchemy/orm/evaluator.py Normal file
View File

@ -0,0 +1,104 @@
import operator
from sqlalchemy.sql import operators, functions
from sqlalchemy.sql import expression as sql
class UnevaluatableError(Exception):
pass
_straight_ops = set(getattr(operators, op)
for op in ('add', 'mul', 'sub',
# Py2K
'div',
# end Py2K
'mod', 'truediv',
'lt', 'le', 'ne', 'gt', 'ge', 'eq'))
_notimplemented_ops = set(getattr(operators, op)
for op in ('like_op', 'notlike_op', 'ilike_op',
'notilike_op', 'between_op', 'in_op',
'notin_op', 'endswith_op', 'concat_op'))
class EvaluatorCompiler(object):
def process(self, clause):
meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
if not meth:
raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__)
return meth(clause)
def visit_grouping(self, clause):
return self.process(clause.element)
def visit_null(self, clause):
return lambda obj: None
def visit_column(self, clause):
if 'parentmapper' in clause._annotations:
key = clause._annotations['parentmapper']._get_col_to_prop(clause).key
else:
key = clause.key
get_corresponding_attr = operator.attrgetter(key)
return lambda obj: get_corresponding_attr(obj)
def visit_clauselist(self, clause):
evaluators = map(self.process, clause.clauses)
if clause.operator is operators.or_:
def evaluate(obj):
has_null = False
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if value:
return True
has_null = has_null or value is None
if has_null:
return None
return False
elif clause.operator is operators.and_:
def evaluate(obj):
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if not value:
if value is None:
return None
return False
return True
else:
raise UnevaluatableError("Cannot evaluate clauselist with operator %s" % clause.operator)
return evaluate
def visit_binary(self, clause):
eval_left,eval_right = map(self.process, [clause.left, clause.right])
operator = clause.operator
if operator is operators.is_:
def evaluate(obj):
return eval_left(obj) == eval_right(obj)
elif operator is operators.isnot:
def evaluate(obj):
return eval_left(obj) != eval_right(obj)
elif operator in _straight_ops:
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
if left_val is None or right_val is None:
return None
return operator(eval_left(obj), eval_right(obj))
else:
raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator))
return evaluate
def visit_unary(self, clause):
eval_inner = self.process(clause.element)
if clause.operator is operators.inv:
def evaluate(obj):
value = eval_inner(obj)
if value is None:
return None
return not value
return evaluate
raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator))
def visit_bindparam(self, clause):
val = clause.value
return lambda obj: val

98
sqlalchemy/orm/exc.py Normal file
View File

@ -0,0 +1,98 @@
# exc.py - ORM exceptions
# Copyright (C) the SQLAlchemy authors and contributors
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""SQLAlchemy ORM exceptions."""
import sqlalchemy as sa
NO_STATE = (AttributeError, KeyError)
"""Exception types that may be raised by instrumentation implementations."""
class ConcurrentModificationError(sa.exc.SQLAlchemyError):
"""Rows have been modified outside of the unit of work."""
class FlushError(sa.exc.SQLAlchemyError):
"""A invalid condition was detected during flush()."""
class UnmappedError(sa.exc.InvalidRequestError):
"""TODO"""
class DetachedInstanceError(sa.exc.SQLAlchemyError):
"""An attempt to access unloaded attributes on a mapped instance that is detached."""
class UnmappedInstanceError(UnmappedError):
"""An mapping operation was requested for an unknown instance."""
def __init__(self, obj, msg=None):
if not msg:
try:
mapper = sa.orm.class_mapper(type(obj))
name = _safe_cls_name(type(obj))
msg = ("Class %r is mapped, but this instance lacks "
"instrumentation. This occurs when the instance is created "
"before sqlalchemy.orm.mapper(%s) was called." % (name, name))
except UnmappedClassError:
msg = _default_unmapped(type(obj))
if isinstance(obj, type):
msg += (
'; was a class (%s) supplied where an instance was '
'required?' % _safe_cls_name(obj))
UnmappedError.__init__(self, msg)
class UnmappedClassError(UnmappedError):
"""An mapping operation was requested for an unknown class."""
def __init__(self, cls, msg=None):
if not msg:
msg = _default_unmapped(cls)
UnmappedError.__init__(self, msg)
class ObjectDeletedError(sa.exc.InvalidRequestError):
"""An refresh() operation failed to re-retrieve an object's row."""
class UnmappedColumnError(sa.exc.InvalidRequestError):
"""Mapping operation was requested on an unknown column."""
class NoResultFound(sa.exc.InvalidRequestError):
"""A database result was required but none was found."""
class MultipleResultsFound(sa.exc.InvalidRequestError):
"""A single database result was required but more than one were found."""
# Legacy compat until 0.6.
sa.exc.ConcurrentModificationError = ConcurrentModificationError
sa.exc.FlushError = FlushError
sa.exc.UnmappedColumnError
def _safe_cls_name(cls):
try:
cls_name = '.'.join((cls.__module__, cls.__name__))
except AttributeError:
cls_name = getattr(cls, '__name__', None)
if cls_name is None:
cls_name = repr(cls)
return cls_name
def _default_unmapped(cls):
try:
mappers = sa.orm.attributes.manager_of_class(cls).mappers
except NO_STATE:
mappers = {}
except TypeError:
mappers = {}
name = _safe_cls_name(cls)
if not mappers:
return "Class '%s' is not mapped" % name

251
sqlalchemy/orm/identity.py Normal file
View File

@ -0,0 +1,251 @@
# identity.py
# Copyright (C) the SQLAlchemy authors and contributors
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import weakref
from sqlalchemy import util as base_util
from sqlalchemy.orm import attributes
class IdentityMap(dict):
def __init__(self):
self._mutable_attrs = set()
self._modified = set()
self._wr = weakref.ref(self)
def replace(self, state):
raise NotImplementedError()
def add(self, state):
raise NotImplementedError()
def remove(self, state):
raise NotImplementedError()
def update(self, dict):
raise NotImplementedError("IdentityMap uses add() to insert data")
def clear(self):
raise NotImplementedError("IdentityMap uses remove() to remove data")
def _manage_incoming_state(self, state):
state._instance_dict = self._wr
if state.modified:
self._modified.add(state)
if state.manager.mutable_attributes:
self._mutable_attrs.add(state)
def _manage_removed_state(self, state):
del state._instance_dict
self._mutable_attrs.discard(state)
self._modified.discard(state)
def _dirty_states(self):
return self._modified.union(s for s in self._mutable_attrs.copy()
if s.modified)
def check_modified(self):
"""return True if any InstanceStates present have been marked as 'modified'."""
if self._modified:
return True
else:
for state in self._mutable_attrs.copy():
if state.modified:
return True
return False
def has_key(self, key):
return key in self
def popitem(self):
raise NotImplementedError("IdentityMap uses remove() to remove data")
def pop(self, key, *args):
raise NotImplementedError("IdentityMap uses remove() to remove data")
def setdefault(self, key, default=None):
raise NotImplementedError("IdentityMap uses add() to insert data")
def copy(self):
raise NotImplementedError()
def __setitem__(self, key, value):
raise NotImplementedError("IdentityMap uses add() to insert data")
def __delitem__(self, key):
raise NotImplementedError("IdentityMap uses remove() to remove data")
class WeakInstanceDict(IdentityMap):
def __getitem__(self, key):
state = dict.__getitem__(self, key)
o = state.obj()
if o is None:
o = state._is_really_none()
if o is None:
raise KeyError, key
return o
def __contains__(self, key):
try:
if dict.__contains__(self, key):
state = dict.__getitem__(self, key)
o = state.obj()
if o is None:
o = state._is_really_none()
else:
return False
except KeyError:
return False
else:
return o is not None
def contains_state(self, state):
return dict.get(self, state.key) is state
def replace(self, state):
if dict.__contains__(self, state.key):
existing = dict.__getitem__(self, state.key)
if existing is not state:
self._manage_removed_state(existing)
else:
return
dict.__setitem__(self, state.key, state)
self._manage_incoming_state(state)
def add(self, state):
if state.key in self:
if dict.__getitem__(self, state.key) is not state:
raise AssertionError("A conflicting state is already "
"present in the identity map for key %r"
% (state.key, ))
else:
dict.__setitem__(self, state.key, state)
self._manage_incoming_state(state)
def remove_key(self, key):
state = dict.__getitem__(self, key)
self.remove(state)
def remove(self, state):
if dict.pop(self, state.key) is not state:
raise AssertionError("State %s is not present in this identity map" % state)
self._manage_removed_state(state)
def discard(self, state):
if self.contains_state(state):
dict.__delitem__(self, state.key)
self._manage_removed_state(state)
def get(self, key, default=None):
state = dict.get(self, key, default)
if state is default:
return default
o = state.obj()
if o is None:
o = state._is_really_none()
if o is None:
return default
return o
# Py2K
def items(self):
return list(self.iteritems())
def iteritems(self):
for state in dict.itervalues(self):
# end Py2K
# Py3K
#def items(self):
# for state in dict.values(self):
value = state.obj()
if value is not None:
yield state.key, value
# Py2K
def values(self):
return list(self.itervalues())
def itervalues(self):
for state in dict.itervalues(self):
# end Py2K
# Py3K
#def values(self):
# for state in dict.values(self):
instance = state.obj()
if instance is not None:
yield instance
def all_states(self):
# Py3K
# return list(dict.values(self))
# Py2K
return dict.values(self)
# end Py2K
def prune(self):
return 0
class StrongInstanceDict(IdentityMap):
def all_states(self):
return [attributes.instance_state(o) for o in self.itervalues()]
def contains_state(self, state):
return state.key in self and attributes.instance_state(self[state.key]) is state
def replace(self, state):
if dict.__contains__(self, state.key):
existing = dict.__getitem__(self, state.key)
existing = attributes.instance_state(existing)
if existing is not state:
self._manage_removed_state(existing)
else:
return
dict.__setitem__(self, state.key, state.obj())
self._manage_incoming_state(state)
def add(self, state):
if state.key in self:
if attributes.instance_state(dict.__getitem__(self, state.key)) is not state:
raise AssertionError("A conflicting state is already present in the identity map for key %r" % (state.key, ))
else:
dict.__setitem__(self, state.key, state.obj())
self._manage_incoming_state(state)
def remove(self, state):
if attributes.instance_state(dict.pop(self, state.key)) is not state:
raise AssertionError("State %s is not present in this identity map" % state)
self._manage_removed_state(state)
def discard(self, state):
if self.contains_state(state):
dict.__delitem__(self, state.key)
self._manage_removed_state(state)
def remove_key(self, key):
state = attributes.instance_state(dict.__getitem__(self, key))
self.remove(state)
def prune(self):
"""prune unreferenced, non-dirty states."""
ref_count = len(self)
dirty = [s.obj() for s in self.all_states() if s.modified]
# work around http://bugs.python.org/issue6149
keepers = weakref.WeakValueDictionary()
keepers.update(self)
dict.clear(self)
dict.update(self, keepers)
self.modified = bool(dirty)
return ref_count - len(self)

1098
sqlalchemy/orm/interfaces.py Normal file

File diff suppressed because it is too large Load Diff

1958
sqlalchemy/orm/mapper.py Normal file

File diff suppressed because it is too large Load Diff

1205
sqlalchemy/orm/properties.py Normal file

File diff suppressed because it is too large Load Diff

2469
sqlalchemy/orm/query.py Normal file

File diff suppressed because it is too large Load Diff

205
sqlalchemy/orm/scoping.py Normal file
View File

@ -0,0 +1,205 @@
# scoping.py
# Copyright (C) the SQLAlchemy authors and contributors
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sqlalchemy.exceptions as sa_exc
from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \
to_list, get_cls_kwargs, deprecated
from sqlalchemy.orm import (
EXT_CONTINUE, MapperExtension, class_mapper, object_session
)
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm.session import Session
__all__ = ['ScopedSession']
class ScopedSession(object):
"""Provides thread-local management of Sessions.
Usage::
Session = scoped_session(sessionmaker(autoflush=True))
... use session normally.
"""
def __init__(self, session_factory, scopefunc=None):
self.session_factory = session_factory
if scopefunc:
self.registry = ScopedRegistry(session_factory, scopefunc)
else:
self.registry = ThreadLocalRegistry(session_factory)
self.extension = _ScopedExt(self)
def __call__(self, **kwargs):
if kwargs:
scope = kwargs.pop('scope', False)
if scope is not None:
if self.registry.has():
raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
else:
sess = self.session_factory(**kwargs)
self.registry.set(sess)
return sess
else:
return self.session_factory(**kwargs)
else:
return self.registry()
def remove(self):
"""Dispose of the current contextual session."""
if self.registry.has():
self.registry().close()
self.registry.clear()
@deprecated("Session.mapper is deprecated. "
"Please see http://www.sqlalchemy.org/trac/wiki/UsageRecipes/SessionAwareMapper "
"for information on how to replicate its behavior.")
def mapper(self, *args, **kwargs):
"""return a mapper() function which associates this ScopedSession with the Mapper.
DEPRECATED.
"""
from sqlalchemy.orm import mapper
extension_args = dict((arg, kwargs.pop(arg))
for arg in get_cls_kwargs(_ScopedExt)
if arg in kwargs)
kwargs['extension'] = extension = to_list(kwargs.get('extension', []))
if extension_args:
extension.append(self.extension.configure(**extension_args))
else:
extension.append(self.extension)
return mapper(*args, **kwargs)
def configure(self, **kwargs):
"""reconfigure the sessionmaker used by this ScopedSession."""
self.session_factory.configure(**kwargs)
def query_property(self, query_cls=None):
"""return a class property which produces a `Query` object against the
class when called.
e.g.::
Session = scoped_session(sessionmaker())
class MyClass(object):
query = Session.query_property()
# after mappers are defined
result = MyClass.query.filter(MyClass.name=='foo').all()
Produces instances of the session's configured query class by
default. To override and use a custom implementation, provide
a ``query_cls`` callable. The callable will be invoked with
the class's mapper as a positional argument and a session
keyword argument.
There is no limit to the number of query properties placed on
a class.
"""
class query(object):
def __get__(s, instance, owner):
try:
mapper = class_mapper(owner)
if mapper:
if query_cls:
# custom query class
return query_cls(mapper, session=self.registry())
else:
# session's configured query class
return self.registry().query(mapper)
except orm_exc.UnmappedClassError:
return None
return query()
def instrument(name):
def do(self, *args, **kwargs):
return getattr(self.registry(), name)(*args, **kwargs)
return do
for meth in Session.public_methods:
setattr(ScopedSession, meth, instrument(meth))
def makeprop(name):
def set(self, attr):
setattr(self.registry(), name, attr)
def get(self):
return getattr(self.registry(), name)
return property(get, set)
for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active', 'autoflush'):
setattr(ScopedSession, prop, makeprop(prop))
def clslevel(name):
def do(cls, *args, **kwargs):
return getattr(Session, name)(*args, **kwargs)
return classmethod(do)
for prop in ('close_all', 'object_session', 'identity_key'):
setattr(ScopedSession, prop, clslevel(prop))
class _ScopedExt(MapperExtension):
def __init__(self, context, validate=False, save_on_init=True):
self.context = context
self.validate = validate
self.save_on_init = save_on_init
self.set_kwargs_on_init = True
def validating(self):
return _ScopedExt(self.context, validate=True)
def configure(self, **kwargs):
return _ScopedExt(self.context, **kwargs)
def instrument_class(self, mapper, class_):
class query(object):
def __getattr__(s, key):
return getattr(self.context.registry().query(class_), key)
def __call__(s):
return self.context.registry().query(class_)
def __get__(self, instance, cls):
return self
if not 'query' in class_.__dict__:
class_.query = query()
if self.set_kwargs_on_init and class_.__init__ is object.__init__:
class_.__init__ = self._default__init__(mapper)
def _default__init__(ext, mapper):
def __init__(self, **kwargs):
for key, value in kwargs.iteritems():
if ext.validate:
if not mapper.get_property(key, resolve_synonyms=False,
raiseerr=False):
raise sa_exc.ArgumentError(
"Invalid __init__ argument: '%s'" % key)
setattr(self, key, value)
return __init__
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
if self.save_on_init:
session = kwargs.pop('_sa_session', None)
if session is None:
session = self.context.registry()
session._save_without_cascade(instance)
return EXT_CONTINUE
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
sess = object_session(instance)
if sess:
sess.expunge(instance)
return EXT_CONTINUE
def dispose_class(self, mapper, class_):
if hasattr(class_, 'query'):
delattr(class_, 'query')

1604
sqlalchemy/orm/session.py Normal file

File diff suppressed because it is too large Load Diff

15
sqlalchemy/orm/shard.py Normal file
View File

@ -0,0 +1,15 @@
# shard.py
# Copyright (C) the SQLAlchemy authors and contributors
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy import util
util.warn_deprecated(
"Horizontal sharding is now importable via "
"'import sqlalchemy.ext.horizontal_shard"
)
from sqlalchemy.ext.horizontal_shard import *

527
sqlalchemy/orm/state.py Normal file
View File

@ -0,0 +1,527 @@
from sqlalchemy.util import EMPTY_SET
import weakref
from sqlalchemy import util
from sqlalchemy.orm.attributes import PASSIVE_NO_RESULT, PASSIVE_OFF, \
NEVER_SET, NO_VALUE, manager_of_class, \
ATTR_WAS_SET
from sqlalchemy.orm import attributes, exc as orm_exc, interfaces
import sys
attributes.state = sys.modules['sqlalchemy.orm.state']
class InstanceState(object):
"""tracks state information at the instance level."""
session_id = None
key = None
runid = None
load_options = EMPTY_SET
load_path = ()
insert_order = None
mutable_dict = None
_strong_obj = None
modified = False
expired = False
def __init__(self, obj, manager):
self.class_ = obj.__class__
self.manager = manager
self.obj = weakref.ref(obj, self._cleanup)
@util.memoized_property
def committed_state(self):
return {}
@util.memoized_property
def parents(self):
return {}
@util.memoized_property
def pending(self):
return {}
@util.memoized_property
def callables(self):
return {}
def detach(self):
if self.session_id:
try:
del self.session_id
except AttributeError:
pass
def dispose(self):
self.detach()
del self.obj
def _cleanup(self, ref):
instance_dict = self._instance_dict()
if instance_dict:
try:
instance_dict.remove(self)
except AssertionError:
pass
# remove possible cycles
self.__dict__.pop('callables', None)
self.dispose()
def obj(self):
return None
@property
def dict(self):
o = self.obj()
if o is not None:
return attributes.instance_dict(o)
else:
return {}
@property
def sort_key(self):
return self.key and self.key[1] or (self.insert_order, )
def initialize_instance(*mixed, **kwargs):
self, instance, args = mixed[0], mixed[1], mixed[2:]
manager = self.manager
for fn in manager.events.on_init:
fn(self, instance, args, kwargs)
# LESSTHANIDEAL:
# adjust for the case where the InstanceState was created before
# mapper compilation, and this actually needs to be a MutableAttrInstanceState
if manager.mutable_attributes and self.__class__ is not MutableAttrInstanceState:
self.__class__ = MutableAttrInstanceState
self.obj = weakref.ref(self.obj(), self._cleanup)
self.mutable_dict = {}
try:
return manager.events.original_init(*mixed[1:], **kwargs)
except:
for fn in manager.events.on_init_failure:
fn(self, instance, args, kwargs)
raise
def get_history(self, key, **kwargs):
return self.manager.get_impl(key).get_history(self, self.dict, **kwargs)
def get_impl(self, key):
return self.manager.get_impl(key)
def get_pending(self, key):
if key not in self.pending:
self.pending[key] = PendingCollection()
return self.pending[key]
def value_as_iterable(self, key, passive=PASSIVE_OFF):
"""return an InstanceState attribute as a list,
regardless of it being a scalar or collection-based
attribute.
returns None if passive is not PASSIVE_OFF and the getter returns
PASSIVE_NO_RESULT.
"""
impl = self.get_impl(key)
dict_ = self.dict
x = impl.get(self, dict_, passive=passive)
if x is PASSIVE_NO_RESULT:
return None
elif hasattr(impl, 'get_collection'):
return impl.get_collection(self, dict_, x, passive=passive)
else:
return [x]
def _run_on_load(self, instance):
self.manager.events.run('on_load', instance)
def __getstate__(self):
d = {'instance':self.obj()}
d.update(
(k, self.__dict__[k]) for k in (
'committed_state', 'pending', 'parents', 'modified', 'expired',
'callables', 'key', 'load_options', 'mutable_dict'
) if k in self.__dict__
)
if self.load_path:
d['load_path'] = interfaces.serialize_path(self.load_path)
return d
def __setstate__(self, state):
self.obj = weakref.ref(state['instance'], self._cleanup)
self.class_ = state['instance'].__class__
self.manager = manager = manager_of_class(self.class_)
if manager is None:
raise orm_exc.UnmappedInstanceError(
state['instance'],
"Cannot deserialize object of type %r - no mapper() has"
" been configured for this class within the current Python process!" %
self.class_)
elif manager.mapper and not manager.mapper.compiled:
manager.mapper.compile()
self.committed_state = state.get('committed_state', {})
self.pending = state.get('pending', {})
self.parents = state.get('parents', {})
self.modified = state.get('modified', False)
self.expired = state.get('expired', False)
self.callables = state.get('callables', {})
if self.modified:
self._strong_obj = state['instance']
self.__dict__.update([
(k, state[k]) for k in (
'key', 'load_options', 'mutable_dict'
) if k in state
])
if 'load_path' in state:
self.load_path = interfaces.deserialize_path(state['load_path'])
def initialize(self, key):
"""Set this attribute to an empty value or collection,
based on the AttributeImpl in use."""
self.manager.get_impl(key).initialize(self, self.dict)
def reset(self, dict_, key):
"""Remove the given attribute and any
callables associated with it."""
dict_.pop(key, None)
self.callables.pop(key, None)
def expire_attribute_pre_commit(self, dict_, key):
"""a fast expire that can be called by column loaders during a load.
The additional bookkeeping is finished up in commit_all().
This method is actually called a lot with joined-table
loading, when the second table isn't present in the result.
"""
dict_.pop(key, None)
self.callables[key] = self
def set_callable(self, dict_, key, callable_):
"""Remove the given attribute and set the given callable
as a loader."""
dict_.pop(key, None)
self.callables[key] = callable_
def expire_attributes(self, dict_, attribute_names, instance_dict=None):
"""Expire all or a group of attributes.
If all attributes are expired, the "expired" flag is set to True.
"""
if attribute_names is None:
attribute_names = self.manager.keys()
self.expired = True
if self.modified:
if not instance_dict:
instance_dict = self._instance_dict()
if instance_dict:
instance_dict._modified.discard(self)
else:
instance_dict._modified.discard(self)
self.modified = False
filter_deferred = True
else:
filter_deferred = False
to_clear = (
self.__dict__.get('pending', None),
self.__dict__.get('committed_state', None),
self.mutable_dict
)
for key in attribute_names:
impl = self.manager[key].impl
if impl.accepts_scalar_loader and \
(not filter_deferred or impl.expire_missing or key in dict_):
self.callables[key] = self
dict_.pop(key, None)
for d in to_clear:
if d is not None:
d.pop(key, None)
def __call__(self, **kw):
"""__call__ allows the InstanceState to act as a deferred
callable for loading expired attributes, which is also
serializable (picklable).
"""
if kw.get('passive') is attributes.PASSIVE_NO_FETCH:
return attributes.PASSIVE_NO_RESULT
toload = self.expired_attributes.\
intersection(self.unmodified)
self.manager.deferred_scalar_loader(self, toload)
# if the loader failed, or this
# instance state didn't have an identity,
# the attributes still might be in the callables
# dict. ensure they are removed.
for k in toload.intersection(self.callables):
del self.callables[k]
return ATTR_WAS_SET
@property
def unmodified(self):
"""Return the set of keys which have no uncommitted changes"""
return set(self.manager).difference(self.committed_state)
@property
def unloaded(self):
"""Return the set of keys which do not have a loaded value.
This includes expired attributes and any other attribute that
was never populated or modified.
"""
return set(self.manager).\
difference(self.committed_state).\
difference(self.dict)
@property
def expired_attributes(self):
"""Return the set of keys which are 'expired' to be loaded by
the manager's deferred scalar loader, assuming no pending
changes.
see also the ``unmodified`` collection which is intersected
against this set when a refresh operation occurs.
"""
return set([k for k, v in self.callables.items() if v is self])
def _instance_dict(self):
return None
def _is_really_none(self):
return self.obj()
def modified_event(self, dict_, attr, should_copy, previous, passive=PASSIVE_OFF):
needs_committed = attr.key not in self.committed_state
if needs_committed:
if previous is NEVER_SET:
if passive:
if attr.key in dict_:
previous = dict_[attr.key]
else:
previous = attr.get(self, dict_)
if should_copy and previous not in (None, NO_VALUE, NEVER_SET):
previous = attr.copy(previous)
if needs_committed:
self.committed_state[attr.key] = previous
if not self.modified:
instance_dict = self._instance_dict()
if instance_dict:
instance_dict._modified.add(self)
self.modified = True
if self._strong_obj is None:
self._strong_obj = self.obj()
def commit(self, dict_, keys):
"""Commit attributes.
This is used by a partial-attribute load operation to mark committed
those attributes which were refreshed from the database.
Attributes marked as "expired" can potentially remain "expired" after
this step if a value was not populated in state.dict.
"""
class_manager = self.manager
for key in keys:
if key in dict_ and key in class_manager.mutable_attributes:
self.committed_state[key] = self.manager[key].impl.copy(dict_[key])
else:
self.committed_state.pop(key, None)
self.expired = False
for key in set(self.callables).\
intersection(keys).\
intersection(dict_):
del self.callables[key]
def commit_all(self, dict_, instance_dict=None):
"""commit all attributes unconditionally.
This is used after a flush() or a full load/refresh
to remove all pending state from the instance.
- all attributes are marked as "committed"
- the "strong dirty reference" is removed
- the "modified" flag is set to False
- any "expired" markers/callables for attributes loaded are removed.
Attributes marked as "expired" can potentially remain "expired" after this step
if a value was not populated in state.dict.
"""
self.__dict__.pop('committed_state', None)
self.__dict__.pop('pending', None)
if 'callables' in self.__dict__:
callables = self.callables
for key in list(callables):
if key in dict_ and callables[key] is self:
del callables[key]
for key in self.manager.mutable_attributes:
if key in dict_:
self.committed_state[key] = self.manager[key].impl.copy(dict_[key])
if instance_dict and self.modified:
instance_dict._modified.discard(self)
self.modified = self.expired = False
self._strong_obj = None
class MutableAttrInstanceState(InstanceState):
"""InstanceState implementation for objects that reference 'mutable'
attributes.
Has a more involved "cleanup" handler that checks mutable attributes
for changes upon dereference, resurrecting if needed.
"""
@util.memoized_property
def mutable_dict(self):
return {}
def _get_modified(self, dict_=None):
if self.__dict__.get('modified', False):
return True
else:
if dict_ is None:
dict_ = self.dict
for key in self.manager.mutable_attributes:
if self.manager[key].impl.check_mutable_modified(self, dict_):
return True
else:
return False
def _set_modified(self, value):
self.__dict__['modified'] = value
modified = property(_get_modified, _set_modified)
@property
def unmodified(self):
"""a set of keys which have no uncommitted changes"""
dict_ = self.dict
return set([
key for key in self.manager
if (key not in self.committed_state or
(key in self.manager.mutable_attributes and
not self.manager[key].impl.check_mutable_modified(self, dict_)))])
def _is_really_none(self):
"""do a check modified/resurrect.
This would be called in the extremely rare
race condition that the weakref returned None but
the cleanup handler had not yet established the
__resurrect callable as its replacement.
"""
if self.modified:
self.obj = self.__resurrect
return self.obj()
else:
return None
def reset(self, dict_, key):
self.mutable_dict.pop(key, None)
InstanceState.reset(self, dict_, key)
def _cleanup(self, ref):
"""weakref callback.
This method may be called by an asynchronous
gc.
If the state shows pending changes, the weakref
is replaced by the __resurrect callable which will
re-establish an object reference on next access,
else removes this InstanceState from the owning
identity map, if any.
"""
if self._get_modified(self.mutable_dict):
self.obj = self.__resurrect
else:
instance_dict = self._instance_dict()
if instance_dict:
try:
instance_dict.remove(self)
except AssertionError:
pass
self.dispose()
def __resurrect(self):
"""A substitute for the obj() weakref function which resurrects."""
# store strong ref'ed version of the object; will revert
# to weakref when changes are persisted
obj = self.manager.new_instance(state=self)
self.obj = weakref.ref(obj, self._cleanup)
self._strong_obj = obj
obj.__dict__.update(self.mutable_dict)
# re-establishes identity attributes from the key
self.manager.events.run('on_resurrect', self, obj)
# TODO: don't really think we should run this here.
# resurrect is only meant to preserve the minimal state needed to
# do an UPDATE, not to produce a fully usable object
self._run_on_load(obj)
return obj
class PendingCollection(object):
"""A writable placeholder for an unloaded collection.
Stores items appended to and removed from a collection that has not yet
been loaded. When the collection is loaded, the changes stored in
PendingCollection are applied to it to produce the final result.
"""
def __init__(self):
self.deleted_items = util.IdentitySet()
self.added_items = util.OrderedIdentitySet()
def append(self, value):
if value in self.deleted_items:
self.deleted_items.remove(value)
self.added_items.add(value)
def remove(self, value):
if value in self.added_items:
self.added_items.remove(value)
self.deleted_items.add(value)

1229
sqlalchemy/orm/strategies.py Normal file

File diff suppressed because it is too large Load Diff

98
sqlalchemy/orm/sync.py Normal file
View File

@ -0,0 +1,98 @@
# mapper/sync.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""private module containing functions used for copying data
between instances based on join conditions.
"""
from sqlalchemy.orm import exc, util as mapperutil
def populate(source, source_mapper, dest, dest_mapper,
synchronize_pairs, uowcommit, passive_updates):
for l, r in synchronize_pairs:
try:
value = source_mapper._get_state_attr_by_column(source, l)
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
try:
dest_mapper._set_state_attr_by_column(dest, r, value)
except exc.UnmappedColumnError:
_raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
# techically the "r.primary_key" check isn't
# needed here, but we check for this condition to limit
# how often this logic is invoked for memory/performance
# reasons, since we only need this info for a primary key
# destination.
if l.primary_key and r.primary_key and \
r.references(l) and passive_updates:
uowcommit.attributes[("pk_cascaded", dest, r)] = True
def clear(dest, dest_mapper, synchronize_pairs):
for l, r in synchronize_pairs:
if r.primary_key:
raise AssertionError(
"Dependency rule tried to blank-out primary key "
"column '%s' on instance '%s'" %
(r, mapperutil.state_str(dest))
)
try:
dest_mapper._set_state_attr_by_column(dest, r, None)
except exc.UnmappedColumnError:
_raise_col_to_prop(True, None, l, dest_mapper, r)
def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
for l, r in synchronize_pairs:
try:
oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l)
value = source_mapper._get_state_attr_by_column(source, l)
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
dest[r.key] = value
dest[old_prefix + r.key] = oldvalue
def populate_dict(source, source_mapper, dict_, synchronize_pairs):
for l, r in synchronize_pairs:
try:
value = source_mapper._get_state_attr_by_column(source, l)
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
dict_[r.key] = value
def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
"""return true if the source object has changes from an old to a
new value on the given synchronize pairs
"""
for l, r in synchronize_pairs:
try:
prop = source_mapper._get_col_to_prop(l)
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
history = uowcommit.get_attribute_history(source, prop.key, passive=True)
if len(history.deleted):
return True
else:
return False
def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column):
if isdest:
raise exc.UnmappedColumnError(
"Can't execute sync rule for destination column '%s'; "
"mapper '%s' does not map this column. Try using an explicit"
" `foreign_keys` collection which does not include this column "
"(or use a viewonly=True relation)." % (dest_column, source_mapper)
)
else:
raise exc.UnmappedColumnError(
"Can't execute sync rule for source column '%s'; mapper '%s' "
"does not map this column. Try using an explicit `foreign_keys`"
" collection which does not include destination column '%s' (or "
"use a viewonly=True relation)." %
(source_column, source_mapper, dest_column)
)

View File

@ -0,0 +1,781 @@
# orm/unitofwork.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""The internals for the Unit Of Work system.
Includes hooks into the attributes package enabling the routing of
change events to Unit Of Work objects, as well as the flush()
mechanism which creates a dependency structure that executes change
operations.
A Unit of Work is essentially a system of maintaining a graph of
in-memory objects and their modified state. Objects are maintained as
unique against their primary key identity using an *identity map*
pattern. The Unit of Work then maintains lists of objects that are
new, dirty, or deleted and provides the capability to flush all those
changes at once.
"""
from sqlalchemy import util, log, topological
from sqlalchemy.orm import attributes, interfaces
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm.mapper import _state_mapper
# Load lazily
object_session = None
_state_session = None
class UOWEventHandler(interfaces.AttributeExtension):
"""An event handler added to all relationship attributes which handles
session cascade operations.
"""
active_history = False
def __init__(self, key):
self.key = key
def append(self, state, item, initiator):
# process "save_update" cascade rules for when an instance is appended to the list of another instance
sess = _state_session(state)
if sess:
prop = _state_mapper(state).get_property(self.key)
if prop.cascade.save_update and item not in sess:
sess.add(item)
return item
def remove(self, state, item, initiator):
sess = _state_session(state)
if sess:
prop = _state_mapper(state).get_property(self.key)
# expunge pending orphans
if prop.cascade.delete_orphan and \
item in sess.new and \
prop.mapper._is_orphan(attributes.instance_state(item)):
sess.expunge(item)
def set(self, state, newvalue, oldvalue, initiator):
# process "save_update" cascade rules for when an instance is attached to another instance
if oldvalue is newvalue:
return newvalue
sess = _state_session(state)
if sess:
prop = _state_mapper(state).get_property(self.key)
if newvalue is not None and prop.cascade.save_update and newvalue not in sess:
sess.add(newvalue)
if prop.cascade.delete_orphan and oldvalue in sess.new and \
prop.mapper._is_orphan(attributes.instance_state(oldvalue)):
sess.expunge(oldvalue)
return newvalue
class UOWTransaction(object):
"""Handles the details of organizing and executing transaction
tasks during a UnitOfWork object's flush() operation.
The central operation is to form a graph of nodes represented by the
``UOWTask`` class, which is then traversed by a ``UOWExecutor`` object
that issues SQL and instance-synchronizing operations via the related
packages.
"""
def __init__(self, session):
self.session = session
self.mapper_flush_opts = session._mapper_flush_opts
# stores tuples of mapper/dependent mapper pairs,
# representing a partial ordering fed into topological sort
self.dependencies = set()
# dictionary of mappers to UOWTasks
self.tasks = {}
# dictionary used by external actors to store arbitrary state
# information.
self.attributes = {}
self.processors = set()
def get_attribute_history(self, state, key, passive=True):
hashkey = ("history", state, key)
# cache the objects, not the states; the strong reference here
# prevents newly loaded objects from being dereferenced during the
# flush process
if hashkey in self.attributes:
(history, cached_passive) = self.attributes[hashkey]
# if the cached lookup was "passive" and now we want non-passive, do a non-passive
# lookup and re-cache
if cached_passive and not passive:
history = attributes.get_state_history(state, key, passive=False)
self.attributes[hashkey] = (history, passive)
else:
history = attributes.get_state_history(state, key, passive=passive)
self.attributes[hashkey] = (history, passive)
if not history or not state.get_impl(key).uses_objects:
return history
else:
return history.as_state()
def register_object(self, state, isdelete=False,
listonly=False, postupdate=False, post_update_cols=None):
# if object is not in the overall session, do nothing
if not self.session._contains_state(state):
return
mapper = _state_mapper(state)
task = self.get_task_by_mapper(mapper)
if postupdate:
task.append_postupdate(state, post_update_cols)
else:
task.append(state, listonly=listonly, isdelete=isdelete)
# ensure the mapper for this object has had its
# DependencyProcessors added.
if mapper not in self.processors:
mapper._register_processors(self)
self.processors.add(mapper)
if mapper.base_mapper not in self.processors:
mapper.base_mapper._register_processors(self)
self.processors.add(mapper.base_mapper)
def set_row_switch(self, state):
"""mark a deleted object as a 'row switch'.
this indicates that an INSERT statement elsewhere corresponds to this DELETE;
the INSERT is converted to an UPDATE and the DELETE does not occur.
"""
mapper = _state_mapper(state)
task = self.get_task_by_mapper(mapper)
taskelement = task._objects[state]
taskelement.isdelete = "rowswitch"
def is_deleted(self, state):
"""return true if the given state is marked as deleted within this UOWTransaction."""
mapper = _state_mapper(state)
task = self.get_task_by_mapper(mapper)
return task.is_deleted(state)
def get_task_by_mapper(self, mapper, dontcreate=False):
"""return UOWTask element corresponding to the given mapper.
Will create a new UOWTask, including a UOWTask corresponding to the
"base" inherited mapper, if needed, unless the dontcreate flag is True.
"""
try:
return self.tasks[mapper]
except KeyError:
if dontcreate:
return None
base_mapper = mapper.base_mapper
if base_mapper in self.tasks:
base_task = self.tasks[base_mapper]
else:
self.tasks[base_mapper] = base_task = UOWTask(self, base_mapper)
base_mapper._register_dependencies(self)
if mapper not in self.tasks:
self.tasks[mapper] = task = UOWTask(self, mapper, base_task=base_task)
mapper._register_dependencies(self)
else:
task = self.tasks[mapper]
return task
def register_dependency(self, mapper, dependency):
"""register a dependency between two mappers.
Called by ``mapper.PropertyLoader`` to register the objects
handled by one mapper being dependent on the objects handled
by another.
"""
# correct for primary mapper
# also convert to the "base mapper", the parentmost task at the top of an inheritance chain
# dependency sorting is done via non-inheriting mappers only, dependencies between mappers
# in the same inheritance chain is done at the per-object level
mapper = mapper.primary_mapper().base_mapper
dependency = dependency.primary_mapper().base_mapper
self.dependencies.add((mapper, dependency))
def register_processor(self, mapper, processor, mapperfrom):
"""register a dependency processor, corresponding to
operations which occur between two mappers.
"""
# correct for primary mapper
mapper = mapper.primary_mapper()
mapperfrom = mapperfrom.primary_mapper()
task = self.get_task_by_mapper(mapper)
targettask = self.get_task_by_mapper(mapperfrom)
up = UOWDependencyProcessor(processor, targettask)
task.dependencies.add(up)
def execute(self):
"""Execute this UOWTransaction.
This will organize all collected UOWTasks into a dependency-sorted
list which is then traversed using the traversal scheme
encoded in the UOWExecutor class. Operations to mappers and dependency
processors are fired off in order to issue SQL to the database and
synchronize instance attributes with database values and related
foreign key values."""
# pre-execute dependency processors. this process may
# result in new tasks, objects and/or dependency processors being added,
# particularly with 'delete-orphan' cascade rules.
# keep running through the full list of tasks until all
# objects have been processed.
while True:
ret = False
for task in self.tasks.values():
for up in list(task.dependencies):
if up.preexecute(self):
ret = True
if not ret:
break
tasks = self._sort_dependencies()
if self._should_log_info():
self.logger.info("Task dump:\n%s", self._dump(tasks))
UOWExecutor().execute(self, tasks)
self.logger.info("Execute Complete")
def _dump(self, tasks):
from uowdumper import UOWDumper
return UOWDumper.dump(tasks)
@property
def elements(self):
"""Iterate UOWTaskElements."""
for task in self.tasks.itervalues():
for elem in task.elements:
yield elem
def finalize_flush_changes(self):
"""mark processed objects as clean / deleted after a successful flush().
this method is called within the flush() method after the
execute() method has succeeded and the transaction has been committed.
"""
for elem in self.elements:
if elem.isdelete:
self.session._remove_newly_deleted(elem.state)
elif not elem.listonly:
self.session._register_newly_persistent(elem.state)
def _sort_dependencies(self):
nodes = topological.sort_with_cycles(self.dependencies,
[t.mapper for t in self.tasks.itervalues() if t.base_task is t]
)
ret = []
for item, cycles in nodes:
task = self.get_task_by_mapper(item)
if cycles:
for t in task._sort_circular_dependencies(
self,
[self.get_task_by_mapper(i) for i in cycles]
):
ret.append(t)
else:
ret.append(task)
return ret
log.class_logger(UOWTransaction)
class UOWTask(object):
"""A collection of mapped states corresponding to a particular mapper."""
def __init__(self, uowtransaction, mapper, base_task=None):
self.uowtransaction = uowtransaction
# base_task is the UOWTask which represents the "base mapper"
# in our mapper's inheritance chain. if the mapper does not
# inherit from any other mapper, the base_task is self.
# the _inheriting_tasks dictionary is a dictionary present only
# on the "base_task"-holding UOWTask, which maps all mappers within
# an inheritance hierarchy to their corresponding UOWTask instances.
if base_task is None:
self.base_task = self
self._inheriting_tasks = {mapper:self}
else:
self.base_task = base_task
base_task._inheriting_tasks[mapper] = self
# the Mapper which this UOWTask corresponds to
self.mapper = mapper
# mapping of InstanceState -> UOWTaskElement
self._objects = {}
self.dependent_tasks = []
self.dependencies = set()
self.cyclical_dependencies = set()
@util.memoized_property
def inheriting_mappers(self):
return list(self.mapper.polymorphic_iterator())
@property
def polymorphic_tasks(self):
"""Return an iterator of UOWTask objects corresponding to the
inheritance sequence of this UOWTask's mapper.
e.g. if mapper B and mapper C inherit from mapper A, and
mapper D inherits from B:
mapperA -> mapperB -> mapperD
-> mapperC
the inheritance sequence starting at mapper A is a depth-first
traversal:
[mapperA, mapperB, mapperD, mapperC]
this method will therefore return
[UOWTask(mapperA), UOWTask(mapperB), UOWTask(mapperD),
UOWTask(mapperC)]
The concept of "polymporphic iteration" is adapted into
several property-based iterators which return object
instances, UOWTaskElements and UOWDependencyProcessors in an
order corresponding to this sequence of parent UOWTasks. This
is used to issue operations related to inheritance-chains of
mappers in the proper order based on dependencies between
those mappers.
"""
for mapper in self.inheriting_mappers:
t = self.base_task._inheriting_tasks.get(mapper, None)
if t is not None:
yield t
def is_empty(self):
"""return True if this UOWTask is 'empty', meaning it has no child items.
used only for debugging output.
"""
return not self._objects and not self.dependencies
def append(self, state, listonly=False, isdelete=False):
if state not in self._objects:
self._objects[state] = rec = UOWTaskElement(state)
else:
rec = self._objects[state]
rec.update(listonly, isdelete)
def append_postupdate(self, state, post_update_cols):
"""issue a 'post update' UPDATE statement via this object's mapper immediately.
this operation is used only with relationships that specify the `post_update=True`
flag.
"""
# postupdates are UPDATED immeditely (for now)
# convert post_update_cols list to a Set so that __hash__() is used to compare columns
# instead of __eq__()
self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=set(post_update_cols))
def __contains__(self, state):
"""return True if the given object is contained within this UOWTask or inheriting tasks."""
for task in self.polymorphic_tasks:
if state in task._objects:
return True
else:
return False
def is_deleted(self, state):
"""return True if the given object is marked as to be deleted within this UOWTask."""
try:
return self._objects[state].isdelete
except KeyError:
return False
def _polymorphic_collection(fn):
"""return a property that will adapt the collection returned by the
given callable into a polymorphic traversal."""
@property
def collection(self):
for task in self.polymorphic_tasks:
for rec in fn(task):
yield rec
return collection
def _polymorphic_collection_filtered(fn):
def collection(self, mappers):
for task in self.polymorphic_tasks:
if task.mapper in mappers:
for rec in fn(task):
yield rec
return collection
@property
def elements(self):
return self._objects.values()
@_polymorphic_collection
def polymorphic_elements(self):
return self.elements
@_polymorphic_collection_filtered
def filter_polymorphic_elements(self):
return self.elements
@property
def polymorphic_tosave_elements(self):
return [rec for rec in self.polymorphic_elements if not rec.isdelete]
@property
def polymorphic_todelete_elements(self):
return [rec for rec in self.polymorphic_elements if rec.isdelete]
@property
def polymorphic_tosave_objects(self):
return [
rec.state for rec in self.polymorphic_elements
if rec.state is not None and not rec.listonly and rec.isdelete is False
]
@property
def polymorphic_todelete_objects(self):
return [
rec.state for rec in self.polymorphic_elements
if rec.state is not None and not rec.listonly and rec.isdelete is True
]
@_polymorphic_collection
def polymorphic_dependencies(self):
return self.dependencies
@_polymorphic_collection
def polymorphic_cyclical_dependencies(self):
return self.cyclical_dependencies
def _sort_circular_dependencies(self, trans, cycles):
"""Topologically sort individual entities with row-level dependencies.
Builds a modified UOWTask structure, and is invoked when the
per-mapper topological structure is found to have cycles.
"""
dependencies = {}
def set_processor_for_state(state, depprocessor, target_state, isdelete):
if state not in dependencies:
dependencies[state] = {}
tasks = dependencies[state]
if depprocessor not in tasks:
tasks[depprocessor] = UOWDependencyProcessor(
depprocessor.processor,
UOWTask(self.uowtransaction, depprocessor.targettask.mapper)
)
tasks[depprocessor].targettask.append(target_state, isdelete=isdelete)
cycles = set(cycles)
def dependency_in_cycles(dep):
proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper, True)
targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper, True)
return targettask in cycles and (proctask is not None and proctask in cycles)
deps_by_targettask = {}
extradeplist = []
for task in cycles:
for dep in task.polymorphic_dependencies:
if not dependency_in_cycles(dep):
extradeplist.append(dep)
for t in dep.targettask.polymorphic_tasks:
l = deps_by_targettask.setdefault(t, [])
l.append(dep)
object_to_original_task = {}
tuples = []
for task in cycles:
for subtask in task.polymorphic_tasks:
for taskelement in subtask.elements:
state = taskelement.state
object_to_original_task[state] = subtask
if subtask not in deps_by_targettask:
continue
for dep in deps_by_targettask[subtask]:
if not dep.processor.has_dependencies or not dependency_in_cycles(dep):
continue
(processor, targettask) = (dep.processor, dep.targettask)
isdelete = taskelement.isdelete
# list of dependent objects from this object
(added, unchanged, deleted) = dep.get_object_dependencies(state, trans, passive=True)
if not added and not unchanged and not deleted:
continue
# the task corresponding to saving/deleting of those dependent objects
childtask = trans.get_task_by_mapper(processor.mapper)
childlist = added + unchanged + deleted
for o in childlist:
if o is None:
continue
if o not in childtask:
childtask.append(o, listonly=True)
object_to_original_task[o] = childtask
whosdep = dep.whose_dependent_on_who(state, o)
if whosdep is not None:
tuples.append(whosdep)
if whosdep[0] is state:
set_processor_for_state(whosdep[0], dep, whosdep[0], isdelete=isdelete)
else:
set_processor_for_state(whosdep[0], dep, whosdep[1], isdelete=isdelete)
else:
# TODO: no test coverage here
set_processor_for_state(state, dep, state, isdelete=isdelete)
t = UOWTask(self.uowtransaction, self.mapper)
t.dependencies.update(extradeplist)
used_tasks = set()
# rationale for "tree" sort as opposed to a straight
# dependency - keep non-dependent objects
# grouped together, so that insert ordering as determined
# by session.add() is maintained.
# An alternative might be to represent the "insert order"
# as part of the topological sort itself, which would
# eliminate the need for this step (but may make the original
# topological sort more expensive)
head = topological.sort_as_tree(tuples, object_to_original_task.iterkeys())
if head is not None:
original_to_tasks = {}
stack = [(head, t)]
while stack:
((state, cycles, children), parenttask) = stack.pop()
originating_task = object_to_original_task[state]
used_tasks.add(originating_task)
if (parenttask, originating_task) not in original_to_tasks:
task = UOWTask(self.uowtransaction, originating_task.mapper)
original_to_tasks[(parenttask, originating_task)] = task
parenttask.dependent_tasks.append(task)
else:
task = original_to_tasks[(parenttask, originating_task)]
task.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete)
if state in dependencies:
task.cyclical_dependencies.update(dependencies[state].itervalues())
stack += [(n, task) for n in children]
ret = [t]
# add tasks that were in the cycle, but didnt get assembled
# into the cyclical tree, to the start of the list
for t2 in cycles:
if t2 not in used_tasks and t2 is not self:
localtask = UOWTask(self.uowtransaction, t2.mapper)
for state in t2.elements:
localtask.append(state, t2.listonly, isdelete=t2._objects[state].isdelete)
for dep in t2.dependencies:
localtask.dependencies.add(dep)
ret.insert(0, localtask)
return ret
def __repr__(self):
return ("UOWTask(%s) Mapper: '%r'" % (hex(id(self)), self.mapper))
class UOWTaskElement(object):
"""Corresponds to a single InstanceState to be saved, deleted,
or otherwise marked as having dependencies. A collection of
UOWTaskElements are held by a UOWTask.
"""
def __init__(self, state):
self.state = state
self.listonly = True
self.isdelete = False
self.preprocessed = set()
def update(self, listonly, isdelete):
if not listonly and self.listonly:
self.listonly = False
self.preprocessed.clear()
if isdelete and not self.isdelete:
self.isdelete = True
self.preprocessed.clear()
def __repr__(self):
return "UOWTaskElement/%d: %s/%d %s" % (
id(self),
self.state.class_.__name__,
id(self.state.obj()),
(self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save'))
)
class UOWDependencyProcessor(object):
"""In between the saving and deleting of objects, process
dependent data, such as filling in a foreign key on a child item
from a new primary key, or deleting association rows before a
delete. This object acts as a proxy to a DependencyProcessor.
"""
def __init__(self, processor, targettask):
self.processor = processor
self.targettask = targettask
prop = processor.prop
# define a set of mappers which
# will filter the lists of entities
# this UOWDP processes. this allows
# MapperProperties to be overridden
# at least for concrete mappers.
self._mappers = set([
m
for m in self.processor.parent.polymorphic_iterator()
if m._props[prop.key] is prop
]).union(self.processor.mapper.polymorphic_iterator())
def __repr__(self):
return "UOWDependencyProcessor(%s, %s)" % (str(self.processor), str(self.targettask))
def __eq__(self, other):
return other.processor is self.processor and other.targettask is self.targettask
def __hash__(self):
return hash((self.processor, self.targettask))
def preexecute(self, trans):
"""preprocess all objects contained within this ``UOWDependencyProcessor``s target task.
This may locate additional objects which should be part of the
transaction, such as those affected deletes, orphans to be
deleted, etc.
Once an object is preprocessed, its ``UOWTaskElement`` is marked as processed. If subsequent
changes occur to the ``UOWTaskElement``, its processed flag is reset, and will require processing
again.
Return True if any objects were preprocessed, or False if no
objects were preprocessed. If True is returned, the parent ``UOWTransaction`` will
ultimately call ``preexecute()`` again on all processors until no new objects are processed.
"""
def getobj(elem):
elem.preprocessed.add(self)
return elem.state
ret = False
elements = [getobj(elem) for elem in
self.targettask.filter_polymorphic_elements(self._mappers)
if self not in elem.preprocessed and not elem.isdelete]
if elements:
ret = True
self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False)
elements = [getobj(elem) for elem in
self.targettask.filter_polymorphic_elements(self._mappers)
if self not in elem.preprocessed and elem.isdelete]
if elements:
ret = True
self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True)
return ret
def execute(self, trans, delete):
"""process all objects contained within this ``UOWDependencyProcessor``s target task."""
elements = [e for e in
self.targettask.filter_polymorphic_elements(self._mappers)
if bool(e.isdelete)==delete]
self.processor.process_dependencies(
self.targettask,
[elem.state for elem in elements],
trans,
delete=delete)
def get_object_dependencies(self, state, trans, passive):
return trans.get_attribute_history(state, self.processor.key, passive=passive)
def whose_dependent_on_who(self, state1, state2):
"""establish which object is operationally dependent amongst a parent/child
using the semantics stated by the dependency processor.
This method is used to establish a partial ordering (set of dependency tuples)
when toplogically sorting on a per-instance basis.
"""
return self.processor.whose_dependent_on_who(state1, state2)
class UOWExecutor(object):
"""Encapsulates the execution traversal of a UOWTransaction structure."""
def execute(self, trans, tasks, isdelete=None):
if isdelete is not True:
for task in tasks:
self.execute_save_steps(trans, task)
if isdelete is not False:
for task in reversed(tasks):
self.execute_delete_steps(trans, task)
def save_objects(self, trans, task):
task.mapper._save_obj(task.polymorphic_tosave_objects, trans)
def delete_objects(self, trans, task):
task.mapper._delete_obj(task.polymorphic_todelete_objects, trans)
def execute_dependency(self, trans, dep, isdelete):
dep.execute(trans, isdelete)
def execute_save_steps(self, trans, task):
self.save_objects(trans, task)
for dep in task.polymorphic_cyclical_dependencies:
self.execute_dependency(trans, dep, False)
for dep in task.polymorphic_cyclical_dependencies:
self.execute_dependency(trans, dep, True)
self.execute_cyclical_dependencies(trans, task, False)
self.execute_dependencies(trans, task)
def execute_delete_steps(self, trans, task):
self.execute_cyclical_dependencies(trans, task, True)
self.delete_objects(trans, task)
def execute_dependencies(self, trans, task):
polymorphic_dependencies = list(task.polymorphic_dependencies)
for dep in polymorphic_dependencies:
self.execute_dependency(trans, dep, False)
for dep in reversed(polymorphic_dependencies):
self.execute_dependency(trans, dep, True)
def execute_cyclical_dependencies(self, trans, task, isdelete):
for t in task.dependent_tasks:
self.execute(trans, [t], isdelete)

101
sqlalchemy/orm/uowdumper.py Normal file
View File

@ -0,0 +1,101 @@
# orm/uowdumper.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Dumps out a string representation of a UOWTask structure"""
from sqlalchemy.orm import unitofwork
from sqlalchemy.orm import util as mapperutil
import StringIO
class UOWDumper(unitofwork.UOWExecutor):
def __init__(self, tasks, buf):
self.indent = 0
self.tasks = tasks
self.buf = buf
self.execute(None, tasks)
@classmethod
def dump(cls, tasks):
buf = StringIO.StringIO()
UOWDumper(tasks, buf)
return buf.getvalue()
def execute(self, trans, tasks, isdelete=None):
if isdelete is not True:
for task in tasks:
self._execute(trans, task, False)
if isdelete is not False:
for task in reversed(tasks):
self._execute(trans, task, True)
def _execute(self, trans, task, isdelete):
try:
i = self._indent()
if i:
i = i[:-1] + "+-"
self.buf.write(i + " " + self._repr_task(task))
self.buf.write(" (" + (isdelete and "delete " or "save/update ") + "phase) \n")
self.indent += 1
super(UOWDumper, self).execute(trans, [task], isdelete)
finally:
self.indent -= 1
def save_objects(self, trans, task):
for rec in sorted(task.polymorphic_tosave_elements, key=lambda a: a.state.sort_key):
if rec.listonly:
continue
self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec) + "\n")
def delete_objects(self, trans, task):
for rec in task.polymorphic_todelete_elements:
if rec.listonly:
continue
self.buf.write(self._indent() + "- " + self._repr_task_element(rec) + "\n")
def execute_dependency(self, transaction, dep, isdelete):
self._dump_processor(dep, isdelete)
def _dump_processor(self, proc, deletes):
if deletes:
val = proc.targettask.polymorphic_todelete_elements
else:
val = proc.targettask.polymorphic_tosave_elements
for v in val:
self.buf.write(self._indent() + " +- " + self._repr_task_element(v, proc.processor.key, process=True) + "\n")
def _repr_task_element(self, te, attribute=None, process=False):
if getattr(te, 'state', None) is None:
objid = "(placeholder)"
else:
if attribute is not None:
objid = "%s.%s" % (mapperutil.state_str(te.state), attribute)
else:
objid = mapperutil.state_str(te.state)
if process:
return "Process %s" % (objid)
else:
return "%s %s" % ((te.isdelete and "Delete" or "Save"), objid)
def _repr_task(self, task):
if task.mapper is not None:
if task.mapper.__class__.__name__ == 'Mapper':
name = task.mapper.class_.__name__ + "/" + task.mapper.local_table.description
else:
name = repr(task.mapper)
else:
name = '(none)'
return ("UOWTask(%s, %s)" % (hex(id(task)), name))
def _repr_task_class(self, task):
if task.mapper is not None and task.mapper.__class__.__name__ == 'Mapper':
return task.mapper.class_.__name__
else:
return '(none)'
def _indent(self):
return " |" * self.indent

668
sqlalchemy/orm/util.py Normal file
View File

@ -0,0 +1,668 @@
# mapper/util.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sqlalchemy.exceptions as sa_exc
from sqlalchemy import sql, util
from sqlalchemy.sql import expression, util as sql_util, operators
from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, \
MapperProperty, AttributeExtension
from sqlalchemy.orm import attributes, exc
mapperlib = None
all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
"expunge", "save-update", "refresh-expire",
"none"))
_INSTRUMENTOR = ('mapper', 'instrumentor')
class CascadeOptions(object):
"""Keeps track of the options sent to relationship().cascade"""
def __init__(self, arg=""):
if not arg:
values = set()
else:
values = set(c.strip() for c in arg.split(','))
self.delete_orphan = "delete-orphan" in values
self.delete = "delete" in values or "all" in values
self.save_update = "save-update" in values or "all" in values
self.merge = "merge" in values or "all" in values
self.expunge = "expunge" in values or "all" in values
self.refresh_expire = "refresh-expire" in values or "all" in values
if self.delete_orphan and not self.delete:
util.warn("The 'delete-orphan' cascade option requires "
"'delete'. This will raise an error in 0.6.")
for x in values:
if x not in all_cascades:
raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x)
def __contains__(self, item):
return getattr(self, item.replace("-", "_"), False)
def __repr__(self):
return "CascadeOptions(%s)" % repr(",".join(
[x for x in ['delete', 'save_update', 'merge', 'expunge',
'delete_orphan', 'refresh-expire']
if getattr(self, x, False) is True]))
class Validator(AttributeExtension):
"""Runs a validation method on an attribute value to be set or appended.
The Validator class is used by the :func:`~sqlalchemy.orm.validates`
decorator, and direct access is usually not needed.
"""
def __init__(self, key, validator):
"""Construct a new Validator.
key - name of the attribute to be validated;
will be passed as the second argument to
the validation method (the first is the object instance itself).
validator - an function or instance method which accepts
three arguments; an instance (usually just 'self' for a method),
the key name of the attribute, and the value. The function should
return the same value given, unless it wishes to modify it.
"""
self.key = key
self.validator = validator
def append(self, state, value, initiator):
return self.validator(state.obj(), self.key, value)
def set(self, state, value, oldvalue, initiator):
return self.validator(state.obj(), self.key, value)
def polymorphic_union(table_map, typecolname, aliasname='p_union'):
"""Create a ``UNION`` statement used by a polymorphic mapper.
See :ref:`concrete_inheritance` for an example of how
this is used.
"""
colnames = set()
colnamemaps = {}
types = {}
for key in table_map.keys():
table = table_map[key]
# mysql doesnt like selecting from a select; make it an alias of the select
if isinstance(table, sql.Select):
table = table.alias()
table_map[key] = table
m = {}
for c in table.c:
colnames.add(c.key)
m[c.key] = c
types[c.key] = c.type
colnamemaps[table] = m
def col(name, table):
try:
return colnamemaps[table][name]
except KeyError:
return sql.cast(sql.null(), types[name]).label(name)
result = []
for type, table in table_map.iteritems():
if typecolname is not None:
result.append(sql.select([col(name, table) for name in colnames] +
[sql.literal_column("'%s'" % type).label(typecolname)],
from_obj=[table]))
else:
result.append(sql.select([col(name, table) for name in colnames],
from_obj=[table]))
return sql.union_all(*result).alias(aliasname)
def identity_key(*args, **kwargs):
"""Get an identity key.
Valid call signatures:
* ``identity_key(class, ident)``
class
mapped class (must be a positional argument)
ident
primary key, if the key is composite this is a tuple
* ``identity_key(instance=instance)``
instance
object instance (must be given as a keyword arg)
* ``identity_key(class, row=row)``
class
mapped class (must be a positional argument)
row
result proxy row (must be given as a keyword arg)
"""
if args:
if len(args) == 1:
class_ = args[0]
try:
row = kwargs.pop("row")
except KeyError:
ident = kwargs.pop("ident")
elif len(args) == 2:
class_, ident = args
elif len(args) == 3:
class_, ident = args
else:
raise sa_exc.ArgumentError("expected up to three "
"positional arguments, got %s" % len(args))
if kwargs:
raise sa_exc.ArgumentError("unknown keyword arguments: %s"
% ", ".join(kwargs.keys()))
mapper = class_mapper(class_)
if "ident" in locals():
return mapper.identity_key_from_primary_key(ident)
return mapper.identity_key_from_row(row)
instance = kwargs.pop("instance")
if kwargs:
raise sa_exc.ArgumentError("unknown keyword arguments: %s"
% ", ".join(kwargs.keys()))
mapper = object_mapper(instance)
return mapper.identity_key_from_instance(instance)
class ExtensionCarrier(dict):
"""Fronts an ordered collection of MapperExtension objects.
Bundles multiple MapperExtensions into a unified callable unit,
encapsulating ordering, looping and EXT_CONTINUE logic. The
ExtensionCarrier implements the MapperExtension interface, e.g.::
carrier.after_insert(...args...)
The dictionary interface provides containment for implemented
method names mapped to a callable which executes that method
for participating extensions.
"""
interface = set(method for method in dir(MapperExtension)
if not method.startswith('_'))
def __init__(self, extensions=None):
self._extensions = []
for ext in extensions or ():
self.append(ext)
def copy(self):
return ExtensionCarrier(self._extensions)
def push(self, extension):
"""Insert a MapperExtension at the beginning of the collection."""
self._register(extension)
self._extensions.insert(0, extension)
def append(self, extension):
"""Append a MapperExtension at the end of the collection."""
self._register(extension)
self._extensions.append(extension)
def __iter__(self):
"""Iterate over MapperExtensions in the collection."""
return iter(self._extensions)
def _register(self, extension):
"""Register callable fronts for overridden interface methods."""
for method in self.interface.difference(self):
impl = getattr(extension, method, None)
if impl and impl is not getattr(MapperExtension, method):
self[method] = self._create_do(method)
def _create_do(self, method):
"""Return a closure that loops over impls of the named method."""
def _do(*args, **kwargs):
for ext in self._extensions:
ret = getattr(ext, method)(*args, **kwargs)
if ret is not EXT_CONTINUE:
return ret
else:
return EXT_CONTINUE
_do.__name__ = method
return _do
@staticmethod
def _pass(*args, **kwargs):
return EXT_CONTINUE
def __getattr__(self, key):
"""Delegate MapperExtension methods to bundled fronts."""
if key not in self.interface:
raise AttributeError(key)
return self.get(key, self._pass)
class ORMAdapter(sql_util.ColumnAdapter):
"""Extends ColumnAdapter to accept ORM entities.
The selectable is extracted from the given entity,
and the AliasedClass if any is referenced.
"""
def __init__(self, entity, equivalents=None, chain_to=None, adapt_required=False):
self.mapper, selectable, is_aliased_class = _entity_info(entity)
if is_aliased_class:
self.aliased_class = entity
else:
self.aliased_class = None
sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to, adapt_required=adapt_required)
def replace(self, elem):
entity = elem._annotations.get('parentmapper', None)
if not entity or entity.isa(self.mapper):
return sql_util.ColumnAdapter.replace(self, elem)
else:
return None
class AliasedClass(object):
"""Represents an "aliased" form of a mapped class for usage with Query.
The ORM equivalent of a :func:`sqlalchemy.sql.expression.alias`
construct, this object mimics the mapped class using a
__getattr__ scheme and maintains a reference to a
real :class:`~sqlalchemy.sql.expression.Alias` object.
Usage is via the :class:`~sqlalchemy.orm.aliased()` synonym::
# find all pairs of users with the same name
user_alias = aliased(User)
session.query(User, user_alias).\\
join((user_alias, User.id > user_alias.id)).\\
filter(User.name==user_alias.name)
"""
def __init__(self, cls, alias=None, name=None):
self.__mapper = _class_to_mapper(cls)
self.__target = self.__mapper.class_
if alias is None:
alias = self.__mapper._with_polymorphic_selectable.alias()
self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
self.__alias = alias
# used to assign a name to the RowTuple object
# returned by Query.
self._sa_label_name = name
self.__name__ = 'AliasedClass_' + str(self.__target)
def __getstate__(self):
return {'mapper':self.__mapper, 'alias':self.__alias, 'name':self._sa_label_name}
def __setstate__(self, state):
self.__mapper = state['mapper']
self.__target = self.__mapper.class_
alias = state['alias']
self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
self.__alias = alias
name = state['name']
self._sa_label_name = name
self.__name__ = 'AliasedClass_' + str(self.__target)
def __adapt_element(self, elem):
return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper})
def __adapt_prop(self, prop):
existing = getattr(self.__target, prop.key)
comparator = existing.comparator.adapted(self.__adapt_element)
queryattr = attributes.QueryableAttribute(prop.key,
impl=existing.impl, parententity=self, comparator=comparator)
setattr(self, prop.key, queryattr)
return queryattr
def __getattr__(self, key):
prop = self.__mapper._get_property(key, raiseerr=False)
if prop:
return self.__adapt_prop(prop)
for base in self.__target.__mro__:
try:
attr = object.__getattribute__(base, key)
except AttributeError:
continue
else:
break
else:
raise AttributeError(key)
if hasattr(attr, 'func_code'):
is_method = getattr(self.__target, key, None)
if is_method and is_method.im_self is not None:
return util.types.MethodType(attr.im_func, self, self)
else:
return None
elif hasattr(attr, '__get__'):
return attr.__get__(None, self)
else:
return attr
def __repr__(self):
return '<AliasedClass at 0x%x; %s>' % (
id(self), self.__target.__name__)
def _orm_annotate(element, exclude=None):
"""Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag.
Elements within the exclude collection will be cloned but not annotated.
"""
return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude)
_orm_deannotate = sql_util._deep_deannotate
class _ORMJoin(expression.Join):
"""Extend Join to support ORM constructs as input."""
__visit_name__ = expression.Join.__visit_name__
def __init__(self, left, right, onclause=None, isouter=False, join_to_left=True):
adapt_from = None
if hasattr(left, '_orm_mappers'):
left_mapper = left._orm_mappers[1]
if join_to_left:
adapt_from = left.right
else:
left_mapper, left, left_is_aliased = _entity_info(left)
if join_to_left and (left_is_aliased or not left_mapper):
adapt_from = left
right_mapper, right, right_is_aliased = _entity_info(right)
if right_is_aliased:
adapt_to = right
else:
adapt_to = None
if left_mapper or right_mapper:
self._orm_mappers = (left_mapper, right_mapper)
if isinstance(onclause, basestring):
prop = left_mapper.get_property(onclause)
elif isinstance(onclause, attributes.QueryableAttribute):
if adapt_from is None:
adapt_from = onclause.__clause_element__()
prop = onclause.property
elif isinstance(onclause, MapperProperty):
prop = onclause
else:
prop = None
if prop:
pj, sj, source, dest, secondary, target_adapter = prop._create_joins(
source_selectable=adapt_from,
dest_selectable=adapt_to,
source_polymorphic=True,
dest_polymorphic=True,
of_type=right_mapper)
if sj is not None:
left = sql.join(left, secondary, pj, isouter)
onclause = sj
else:
onclause = pj
self._target_adapter = target_adapter
expression.Join.__init__(self, left, right, onclause, isouter)
def join(self, right, onclause=None, isouter=False, join_to_left=True):
return _ORMJoin(self, right, onclause, isouter, join_to_left)
def outerjoin(self, right, onclause=None, join_to_left=True):
return _ORMJoin(self, right, onclause, True, join_to_left)
def join(left, right, onclause=None, isouter=False, join_to_left=True):
"""Produce an inner join between left and right clauses.
In addition to the interface provided by
:func:`~sqlalchemy.sql.expression.join()`, left and right may be mapped
classes or AliasedClass instances. The onclause may be a
string name of a relationship(), or a class-bound descriptor
representing a relationship.
join_to_left indicates to attempt aliasing the ON clause,
in whatever form it is passed, to the selectable
passed as the left side. If False, the onclause
is used as is.
"""
return _ORMJoin(left, right, onclause, isouter, join_to_left)
def outerjoin(left, right, onclause=None, join_to_left=True):
"""Produce a left outer join between left and right clauses.
In addition to the interface provided by
:func:`~sqlalchemy.sql.expression.outerjoin()`, left and right may be mapped
classes or AliasedClass instances. The onclause may be a
string name of a relationship(), or a class-bound descriptor
representing a relationship.
"""
return _ORMJoin(left, right, onclause, True, join_to_left)
def with_parent(instance, prop):
"""Return criterion which selects instances with a given parent.
instance
a parent instance, which should be persistent or detached.
property
a class-attached descriptor, MapperProperty or string property name
attached to the parent instance.
\**kwargs
all extra keyword arguments are propagated to the constructor of
Query.
"""
if isinstance(prop, basestring):
mapper = object_mapper(instance)
prop = mapper.get_property(prop, resolve_synonyms=True)
elif isinstance(prop, attributes.QueryableAttribute):
prop = prop.property
return prop.compare(operators.eq, instance, value_is_parent=True)
def _entity_info(entity, compile=True):
"""Return mapping information given a class, mapper, or AliasedClass.
Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this
is an aliased() construct.
If the given entity is not a mapper, mapped class, or aliased construct,
returns None, the entity, False. This is typically used to allow
unmapped selectables through.
"""
if isinstance(entity, AliasedClass):
return entity._AliasedClass__mapper, entity._AliasedClass__alias, True
global mapperlib
if mapperlib is None:
from sqlalchemy.orm import mapperlib
if isinstance(entity, mapperlib.Mapper):
mapper = entity
elif isinstance(entity, type):
class_manager = attributes.manager_of_class(entity)
if class_manager is None:
return None, entity, False
mapper = class_manager.mapper
else:
return None, entity, False
if compile:
mapper = mapper.compile()
return mapper, mapper._with_polymorphic_selectable, False
def _entity_descriptor(entity, key):
"""Return attribute/property information given an entity and string name.
Returns a 2-tuple representing InstrumentedAttribute/MapperProperty.
"""
if isinstance(entity, AliasedClass):
try:
desc = getattr(entity, key)
return desc, desc.property
except AttributeError:
raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
elif isinstance(entity, type):
try:
desc = attributes.manager_of_class(entity)[key]
return desc, desc.property
except KeyError:
raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
else:
try:
desc = entity.class_manager[key]
return desc, desc.property
except KeyError:
raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
def _orm_columns(entity):
mapper, selectable, is_aliased_class = _entity_info(entity)
if isinstance(selectable, expression.Selectable):
return [c for c in selectable.c]
else:
return [selectable]
def _orm_selectable(entity):
mapper, selectable, is_aliased_class = _entity_info(entity)
return selectable
def _is_aliased_class(entity):
return isinstance(entity, AliasedClass)
def _state_mapper(state):
return state.manager.mapper
def object_mapper(instance):
"""Given an object, return the primary Mapper associated with the object instance.
Raises UnmappedInstanceError if no mapping is configured.
"""
try:
state = attributes.instance_state(instance)
if not state.manager.mapper:
raise exc.UnmappedInstanceError(instance)
return state.manager.mapper
except exc.NO_STATE:
raise exc.UnmappedInstanceError(instance)
def class_mapper(class_, compile=True):
"""Given a class, return the primary Mapper associated with the key.
Raises UnmappedClassError if no mapping is configured.
"""
try:
class_manager = attributes.manager_of_class(class_)
mapper = class_manager.mapper
# HACK until [ticket:1142] is complete
if mapper is None:
raise AttributeError
except exc.NO_STATE:
raise exc.UnmappedClassError(class_)
if compile:
mapper = mapper.compile()
return mapper
def _class_to_mapper(class_or_mapper, compile=True):
if _is_aliased_class(class_or_mapper):
return class_or_mapper._AliasedClass__mapper
elif isinstance(class_or_mapper, type):
return class_mapper(class_or_mapper, compile=compile)
elif hasattr(class_or_mapper, 'compile'):
if compile:
return class_or_mapper.compile()
else:
return class_or_mapper
else:
raise exc.UnmappedClassError(class_or_mapper)
def has_identity(object):
state = attributes.instance_state(object)
return _state_has_identity(state)
def _state_has_identity(state):
return bool(state.key)
def _is_mapped_class(cls):
global mapperlib
if mapperlib is None:
from sqlalchemy.orm import mapperlib
if isinstance(cls, (AliasedClass, mapperlib.Mapper)):
return True
if isinstance(cls, expression.ClauseElement):
return False
if isinstance(cls, type):
manager = attributes.manager_of_class(cls)
return manager and _INSTRUMENTOR in manager.info
return False
def instance_str(instance):
"""Return a string describing an instance."""
return state_str(attributes.instance_state(instance))
def state_str(state):
"""Return a string describing an instance via its InstanceState."""
if state is None:
return "None"
else:
return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj()))
def attribute_str(instance, attribute):
return instance_str(instance) + "." + attribute
def state_attribute_str(state, attribute):
return state_str(state) + "." + attribute
def identity_equal(a, b):
if a is b:
return True
if a is None or b is None:
return False
try:
state_a = attributes.instance_state(a)
state_b = attributes.instance_state(b)
except exc.NO_STATE:
return False
if state_a.key is None or state_b.key is None:
return False
return state_a.key == state_b.key
# TODO: Avoid circular import.
attributes.identity_equal = identity_equal
attributes._is_aliased_class = _is_aliased_class
attributes._entity_info = _entity_info

Some files were not shown because too many files have changed in this diff Show More