This commit is contained in:
2010-05-07 17:33:49 +00:00
parent c7c7498f19
commit cae7e001e3
127 changed files with 57530 additions and 0 deletions

View File

@@ -0,0 +1,274 @@
# engine/__init__.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""SQL connections, SQL execution and high-level DB-API interface.
The engine package defines the basic components used to interface
DB-API modules with higher-level statement construction,
connection-management, execution and result contexts. The primary
"entry point" class into this package is the Engine and it's public
constructor ``create_engine()``.
This package includes:
base.py
Defines interface classes and some implementation classes which
comprise the basic components used to interface between a DB-API,
constructed and plain-text statements, connections, transactions,
and results.
default.py
Contains default implementations of some of the components defined
in base.py. All current database dialects use the classes in
default.py as base classes for their own database-specific
implementations.
strategies.py
The mechanics of constructing ``Engine`` objects are represented
here. Defines the ``EngineStrategy`` class which represents how
to go from arguments specified to the ``create_engine()``
function, to a fully constructed ``Engine``, including
initialization of connection pooling, dialects, and specific
subclasses of ``Engine``.
threadlocal.py
The ``TLEngine`` class is defined here, which is a subclass of
the generic ``Engine`` and tracks ``Connection`` and
``Transaction`` objects against the identity of the current
thread. This allows certain programming patterns based around
the concept of a "thread-local connection" to be possible.
The ``TLEngine`` is created by using the "threadlocal" engine
strategy in conjunction with the ``create_engine()`` function.
url.py
Defines the ``URL`` class which represents the individual
components of a string URL passed to ``create_engine()``. Also
defines a basic module-loading strategy for the dialect specifier
within a URL.
"""
# not sure what this was used for
#import sqlalchemy.databases
from sqlalchemy.engine.base import (
BufferedColumnResultProxy,
BufferedColumnRow,
BufferedRowResultProxy,
Compiled,
Connectable,
Connection,
Dialect,
Engine,
ExecutionContext,
NestedTransaction,
ResultProxy,
RootTransaction,
RowProxy,
Transaction,
TwoPhaseTransaction,
TypeCompiler
)
from sqlalchemy.engine import strategies
from sqlalchemy import util
__all__ = (
'BufferedColumnResultProxy',
'BufferedColumnRow',
'BufferedRowResultProxy',
'Compiled',
'Connectable',
'Connection',
'Dialect',
'Engine',
'ExecutionContext',
'NestedTransaction',
'ResultProxy',
'RootTransaction',
'RowProxy',
'Transaction',
'TwoPhaseTransaction',
'TypeCompiler',
'create_engine',
'engine_from_config',
)
default_strategy = 'plain'
def create_engine(*args, **kwargs):
"""Create a new Engine instance.
The standard method of specifying the engine is via URL as the
first positional argument, to indicate the appropriate database
dialect and connection arguments, with additional keyword
arguments sent as options to the dialect and resulting Engine.
The URL is a string in the form
``dialect+driver://user:password@host/dbname[?key=value..]``, where
``dialect`` is a database name such as ``mysql``, ``oracle``,
``postgresql``, etc., and ``driver`` the name of a DBAPI, such as
``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively,
the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`.
`**kwargs` takes a wide variety of options which are routed
towards their appropriate components. Arguments may be
specific to the Engine, the underlying Dialect, as well as the
Pool. Specific dialects also accept keyword arguments that
are unique to that dialect. Here, we describe the parameters
that are common to most ``create_engine()`` usage.
:param assert_unicode: Deprecated. A warning is raised in all cases when a non-Unicode
object is passed when SQLAlchemy would coerce into an encoding
(note: but **not** when the DBAPI handles unicode objects natively).
To suppress or raise this warning to an
error, use the Python warnings filter documented at:
http://docs.python.org/library/warnings.html
:param connect_args: a dictionary of options which will be
passed directly to the DBAPI's ``connect()`` method as
additional keyword arguments.
:param convert_unicode=False: if set to True, all
String/character based types will convert Unicode values to raw
byte values going into the database, and all raw byte values to
Python Unicode coming out in result sets. This is an
engine-wide method to provide unicode conversion across the
board. For unicode conversion on a column-by-column level, use
the ``Unicode`` column type instead, described in `types`.
:param creator: a callable which returns a DBAPI connection.
This creation function will be passed to the underlying
connection pool and will be used to create all new database
connections. Usage of this function causes connection
parameters specified in the URL argument to be bypassed.
:param echo=False: if True, the Engine will log all statements
as well as a repr() of their parameter lists to the engines
logger, which defaults to sys.stdout. The ``echo`` attribute of
``Engine`` can be modified at any time to turn logging on and
off. If set to the string ``"debug"``, result rows will be
printed to the standard output as well. This flag ultimately
controls a Python logger; see :ref:`dbengine_logging` for
information on how to configure logging directly.
:param echo_pool=False: if True, the connection pool will log
all checkouts/checkins to the logging stream, which defaults to
sys.stdout. This flag ultimately controls a Python logger; see
:ref:`dbengine_logging` for information on how to configure logging
directly.
:param encoding='utf-8': the encoding to use for all Unicode
translations, both by engine-wide unicode conversion as well as
the ``Unicode`` type object.
:param label_length=None: optional integer value which limits
the size of dynamically generated column labels to that many
characters. If less than 6, labels are generated as
"_(counter)". If ``None``, the value of
``dialect.max_identifier_length`` is used instead.
:param listeners: A list of one or more
:class:`~sqlalchemy.interfaces.PoolListener` objects which will
receive connection pool events.
:param logging_name: String identifier which will be used within
the "name" field of logging records generated within the
"sqlalchemy.engine" logger. Defaults to a hexstring of the
object's id.
:param max_overflow=10: the number of connections to allow in
connection pool "overflow", that is connections that can be
opened above and beyond the pool_size setting, which defaults
to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`.
:param module=None: used by database implementations which
support multiple DBAPI modules, this is a reference to a DBAPI2
module to be used instead of the engine's default module. For
PostgreSQL, the default is psycopg2. For Oracle, it's cx_Oracle.
:param pool=None: an already-constructed instance of
:class:`~sqlalchemy.pool.Pool`, such as a
:class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this
pool will be used directly as the underlying connection pool
for the engine, bypassing whatever connection parameters are
present in the URL argument. For information on constructing
connection pools manually, see `pooling`.
:param poolclass=None: a :class:`~sqlalchemy.pool.Pool`
subclass, which will be used to create a connection pool
instance using the connection parameters given in the URL. Note
this differs from ``pool`` in that you don't actually
instantiate the pool in this case, you just indicate what type
of pool to be used.
:param pool_logging_name: String identifier which will be used within
the "name" field of logging records generated within the
"sqlalchemy.pool" logger. Defaults to a hexstring of the object's
id.
:param pool_size=5: the number of connections to keep open
inside the connection pool. This used with :class:`~sqlalchemy.pool.QueuePool` as
well as :class:`~sqlalchemy.pool.SingletonThreadPool`.
:param pool_recycle=-1: this setting causes the pool to recycle
connections after the given number of seconds has passed. It
defaults to -1, or no timeout. For example, setting to 3600
means connections will be recycled after one hour. Note that
MySQL in particular will ``disconnect automatically`` if no
activity is detected on a connection for eight hours (although
this is configurable with the MySQLDB connection itself and the
server configuration as well).
:param pool_timeout=30: number of seconds to wait before giving
up on getting a connection from the pool. This is only used
with :class:`~sqlalchemy.pool.QueuePool`.
:param strategy='plain': used to invoke alternate :class:`~sqlalchemy.engine.base.Engine.`
implementations. Currently available is the ``threadlocal``
strategy, which is described in :ref:`threadlocal_strategy`.
"""
strategy = kwargs.pop('strategy', default_strategy)
strategy = strategies.strategies[strategy]
return strategy.create(*args, **kwargs)
def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
"""Create a new Engine instance using a configuration dictionary.
The dictionary is typically produced from a config file where keys
are prefixed, such as sqlalchemy.url, sqlalchemy.echo, etc. The
'prefix' argument indicates the prefix to be searched for.
A select set of keyword arguments will be "coerced" to their
expected type based on string values. In a future release, this
functionality will be expanded and include dialect-specific
arguments.
"""
opts = _coerce_config(configuration, prefix)
opts.update(kwargs)
url = opts.pop('url')
return create_engine(url, **opts)
def _coerce_config(configuration, prefix):
"""Convert configuration values to expected types."""
options = dict((key[len(prefix):], configuration[key])
for key in configuration
if key.startswith(prefix))
for option, type_ in (
('convert_unicode', bool),
('pool_timeout', int),
('echo', bool),
('echo_pool', bool),
('pool_recycle', int),
('pool_size', int),
('max_overflow', int),
('pool_threadlocal', bool),
):
util.coerce_kw_type(options, option, type_)
return options

2422
sqlalchemy/engine/base.py Normal file

File diff suppressed because it is too large Load Diff

128
sqlalchemy/engine/ddl.py Normal file
View File

@@ -0,0 +1,128 @@
# engine/ddl.py
# Copyright (C) 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Routines to handle CREATE/DROP workflow."""
from sqlalchemy import engine, schema
from sqlalchemy.sql import util as sql_util
class DDLBase(schema.SchemaVisitor):
def __init__(self, connection):
self.connection = connection
class SchemaGenerator(DDLBase):
def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
super(SchemaGenerator, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables and set(tables) or None
self.preparer = dialect.identifier_preparer
self.dialect = dialect
def _can_create(self, table):
self.dialect.validate_identifier(table.name)
if table.schema:
self.dialect.validate_identifier(table.schema)
return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
def visit_metadata(self, metadata):
if self.tables:
tables = self.tables
else:
tables = metadata.tables.values()
collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
for listener in metadata.ddl_listeners['before-create']:
listener('before-create', metadata, self.connection, tables=collection)
for table in collection:
self.traverse_single(table)
for listener in metadata.ddl_listeners['after-create']:
listener('after-create', metadata, self.connection, tables=collection)
def visit_table(self, table):
for listener in table.ddl_listeners['before-create']:
listener('before-create', table, self.connection)
for column in table.columns:
if column.default is not None:
self.traverse_single(column.default)
self.connection.execute(schema.CreateTable(table))
if hasattr(table, 'indexes'):
for index in table.indexes:
self.traverse_single(index)
for listener in table.ddl_listeners['after-create']:
listener('after-create', table, self.connection)
def visit_sequence(self, sequence):
if self.dialect.supports_sequences:
if ((not self.dialect.sequences_optional or
not sequence.optional) and
(not self.checkfirst or
not self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema))):
self.connection.execute(schema.CreateSequence(sequence))
def visit_index(self, index):
self.connection.execute(schema.CreateIndex(index))
class SchemaDropper(DDLBase):
def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
super(SchemaDropper, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
self.preparer = dialect.identifier_preparer
self.dialect = dialect
def visit_metadata(self, metadata):
if self.tables:
tables = self.tables
else:
tables = metadata.tables.values()
collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
for listener in metadata.ddl_listeners['before-drop']:
listener('before-drop', metadata, self.connection, tables=collection)
for table in collection:
self.traverse_single(table)
for listener in metadata.ddl_listeners['after-drop']:
listener('after-drop', metadata, self.connection, tables=collection)
def _can_drop(self, table):
self.dialect.validate_identifier(table.name)
if table.schema:
self.dialect.validate_identifier(table.schema)
return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
def visit_index(self, index):
self.connection.execute(schema.DropIndex(index))
def visit_table(self, table):
for listener in table.ddl_listeners['before-drop']:
listener('before-drop', table, self.connection)
for column in table.columns:
if column.default is not None:
self.traverse_single(column.default)
self.connection.execute(schema.DropTable(table))
for listener in table.ddl_listeners['after-drop']:
listener('after-drop', table, self.connection)
def visit_sequence(self, sequence):
if self.dialect.supports_sequences:
if ((not self.dialect.sequences_optional or
not sequence.optional) and
(not self.checkfirst or
self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema))):
self.connection.execute(schema.DropSequence(sequence))

View File

@@ -0,0 +1,700 @@
# engine/default.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Default implementations of per-dialect sqlalchemy.engine classes.
These are semi-private implementation classes which are only of importance
to database dialect authors; dialects will usually use the classes here
as the base class for their own corresponding classes.
"""
import re, random
from sqlalchemy.engine import base, reflection
from sqlalchemy.sql import compiler, expression
from sqlalchemy import exc, types as sqltypes, util
AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
re.I | re.UNICODE)
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
statement_compiler = compiler.SQLCompiler
ddl_compiler = compiler.DDLCompiler
type_compiler = compiler.GenericTypeCompiler
preparer = compiler.IdentifierPreparer
supports_alter = True
# most DBAPIs happy with this for execute().
# not cx_oracle.
execute_sequence_format = tuple
supports_sequences = False
sequences_optional = False
preexecute_autoincrement_sequences = False
postfetch_lastrowid = True
implicit_returning = False
supports_native_enum = False
supports_native_boolean = False
# if the NUMERIC type
# returns decimal.Decimal.
# *not* the FLOAT type however.
supports_native_decimal = False
# Py3K
#supports_unicode_statements = True
#supports_unicode_binds = True
# Py2K
supports_unicode_statements = False
supports_unicode_binds = False
returns_unicode_strings = False
# end Py2K
name = 'default'
max_identifier_length = 9999
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
dbapi_type_map = {}
colspecs = {}
default_paramstyle = 'named'
supports_default_values = False
supports_empty_insert = True
server_version_info = None
# indicates symbol names are
# UPPERCASEd if they are case insensitive
# within the database.
# if this is True, the methods normalize_name()
# and denormalize_name() must be provided.
requires_name_normalize = False
reflection_options = ()
def __init__(self, convert_unicode=False, assert_unicode=False,
encoding='utf-8', paramstyle=None, dbapi=None,
implicit_returning=None,
label_length=None, **kwargs):
if not getattr(self, 'ported_sqla_06', True):
util.warn(
"The %s dialect is not yet ported to SQLAlchemy 0.6" % self.name)
self.convert_unicode = convert_unicode
if assert_unicode:
util.warn_deprecated("assert_unicode is deprecated. "
"SQLAlchemy emits a warning in all cases where it "
"would otherwise like to encode a Python unicode object "
"into a specific encoding but a plain bytestring is received. "
"This does *not* apply to DBAPIs that coerce Unicode natively."
)
self.encoding = encoding
self.positional = False
self._ischema = None
self.dbapi = dbapi
if paramstyle is not None:
self.paramstyle = paramstyle
elif self.dbapi is not None:
self.paramstyle = self.dbapi.paramstyle
else:
self.paramstyle = self.default_paramstyle
if implicit_returning is not None:
self.implicit_returning = implicit_returning
self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
self.identifier_preparer = self.preparer(self)
self.type_compiler = self.type_compiler(self)
if label_length and label_length > self.max_identifier_length:
raise exc.ArgumentError("Label length of %d is greater than this dialect's"
" maximum identifier length of %d" %
(label_length, self.max_identifier_length))
self.label_length = label_length
if not hasattr(self, 'description_encoding'):
self.description_encoding = getattr(self, 'description_encoding', encoding)
@property
def dialect_description(self):
return self.name + "+" + self.driver
def initialize(self, connection):
try:
self.server_version_info = self._get_server_version_info(connection)
except NotImplementedError:
self.server_version_info = None
try:
self.default_schema_name = self._get_default_schema_name(connection)
except NotImplementedError:
self.default_schema_name = None
self.returns_unicode_strings = self._check_unicode_returns(connection)
self.do_rollback(connection.connection)
def on_connect(self):
"""return a callable which sets up a newly created DBAPI connection.
This is used to set dialect-wide per-connection options such as isolation
modes, unicode modes, etc.
If a callable is returned, it will be assembled into a pool listener
that receives the direct DBAPI connection, with all wrappers removed.
If None is returned, no listener will be generated.
"""
return None
def _check_unicode_returns(self, connection):
# Py2K
if self.supports_unicode_statements:
cast_to = unicode
else:
cast_to = str
# end Py2K
# Py3K
#cast_to = str
def check_unicode(type_):
cursor = connection.connection.cursor()
try:
cursor.execute(
cast_to(
expression.select(
[expression.cast(
expression.literal_column("'test unicode returns'"), type_)
]).compile(dialect=self)
)
)
row = cursor.fetchone()
return isinstance(row[0], unicode)
finally:
cursor.close()
# detect plain VARCHAR
unicode_for_varchar = check_unicode(sqltypes.VARCHAR(60))
# detect if there's an NVARCHAR type with different behavior available
unicode_for_unicode = check_unicode(sqltypes.Unicode(60))
if unicode_for_unicode and not unicode_for_varchar:
return "conditional"
else:
return unicode_for_varchar
def type_descriptor(self, typeobj):
"""Provide a database-specific ``TypeEngine`` object, given
the generic object which comes from the types module.
This method looks for a dictionary called
``colspecs`` as a class or instance-level variable,
and passes on to ``types.adapt_type()``.
"""
return sqltypes.adapt_type(typeobj, self.colspecs)
def reflecttable(self, connection, table, include_columns):
insp = reflection.Inspector.from_engine(connection)
return insp.reflecttable(table, include_columns)
def validate_identifier(self, ident):
if len(ident) > self.max_identifier_length:
raise exc.IdentifierError(
"Identifier '%s' exceeds maximum length of %d characters" %
(ident, self.max_identifier_length)
)
def connect(self, *cargs, **cparams):
return self.dbapi.connect(*cargs, **cparams)
def create_connect_args(self, url):
opts = url.translate_connect_args()
opts.update(url.query)
return [[], opts]
def do_begin(self, connection):
"""Implementations might want to put logic here for turning
autocommit on/off, etc.
"""
pass
def do_rollback(self, connection):
"""Implementations might want to put logic here for turning
autocommit on/off, etc.
"""
connection.rollback()
def do_commit(self, connection):
"""Implementations might want to put logic here for turning
autocommit on/off, etc.
"""
connection.commit()
def create_xid(self):
"""Create a random two-phase transaction ID.
This id will be passed to do_begin_twophase(), do_rollback_twophase(),
do_commit_twophase(). Its format is unspecified.
"""
return "_sa_%032x" % random.randint(0, 2 ** 128)
def do_savepoint(self, connection, name):
connection.execute(expression.SavepointClause(name))
def do_rollback_to_savepoint(self, connection, name):
connection.execute(expression.RollbackToSavepointClause(name))
def do_release_savepoint(self, connection, name):
connection.execute(expression.ReleaseSavepointClause(name))
def do_executemany(self, cursor, statement, parameters, context=None):
cursor.executemany(statement, parameters)
def do_execute(self, cursor, statement, parameters, context=None):
cursor.execute(statement, parameters)
def is_disconnect(self, e):
return False
class DefaultExecutionContext(base.ExecutionContext):
execution_options = util.frozendict()
isinsert = False
isupdate = False
isdelete = False
isddl = False
executemany = False
result_map = None
compiled = None
statement = None
def __init__(self,
dialect,
connection,
compiled_sql=None,
compiled_ddl=None,
statement=None,
parameters=None):
self.dialect = dialect
self._connection = self.root_connection = connection
self.engine = connection.engine
if compiled_ddl is not None:
self.compiled = compiled = compiled_ddl
self.isddl = True
if compiled.statement._execution_options:
self.execution_options = compiled.statement._execution_options
if connection._execution_options:
self.execution_options = self.execution_options.union(
connection._execution_options
)
if not dialect.supports_unicode_statements:
self.unicode_statement = unicode(compiled)
self.statement = self.unicode_statement.encode(self.dialect.encoding)
else:
self.statement = self.unicode_statement = unicode(compiled)
self.cursor = self.create_cursor()
self.compiled_parameters = []
self.parameters = [self._default_params]
elif compiled_sql is not None:
self.compiled = compiled = compiled_sql
if not compiled.can_execute:
raise exc.ArgumentError("Not an executable clause: %s" % compiled)
if compiled.statement._execution_options:
self.execution_options = compiled.statement._execution_options
if connection._execution_options:
self.execution_options = self.execution_options.union(
connection._execution_options
)
# compiled clauseelement. process bind params, process table defaults,
# track collections used by ResultProxy to target and process results
self.processors = dict(
(key, value) for key, value in
( (compiled.bind_names[bindparam],
bindparam.bind_processor(self.dialect))
for bindparam in compiled.bind_names )
if value is not None)
self.result_map = compiled.result_map
if not dialect.supports_unicode_statements:
self.unicode_statement = unicode(compiled)
self.statement = self.unicode_statement.encode(self.dialect.encoding)
else:
self.statement = self.unicode_statement = unicode(compiled)
self.isinsert = compiled.isinsert
self.isupdate = compiled.isupdate
self.isdelete = compiled.isdelete
if not parameters:
self.compiled_parameters = [compiled.construct_params()]
else:
self.compiled_parameters = [compiled.construct_params(m, _group_number=grp) for
grp,m in enumerate(parameters)]
self.executemany = len(parameters) > 1
self.cursor = self.create_cursor()
if self.isinsert or self.isupdate:
self.__process_defaults()
self.parameters = self.__convert_compiled_params(self.compiled_parameters)
elif statement is not None:
# plain text statement
if connection._execution_options:
self.execution_options = self.execution_options.union(connection._execution_options)
self.parameters = self.__encode_param_keys(parameters)
self.executemany = len(parameters) > 1
if isinstance(statement, unicode) and not dialect.supports_unicode_statements:
self.unicode_statement = statement
self.statement = statement.encode(self.dialect.encoding)
else:
self.statement = self.unicode_statement = statement
self.cursor = self.create_cursor()
else:
# no statement. used for standalone ColumnDefault execution.
if connection._execution_options:
self.execution_options = self.execution_options.union(connection._execution_options)
self.cursor = self.create_cursor()
@util.memoized_property
def is_crud(self):
return self.isinsert or self.isupdate or self.isdelete
@util.memoized_property
def should_autocommit(self):
autocommit = self.execution_options.get('autocommit',
not self.compiled and
self.statement and
expression.PARSE_AUTOCOMMIT
or False)
if autocommit is expression.PARSE_AUTOCOMMIT:
return self.should_autocommit_text(self.unicode_statement)
else:
return autocommit
@util.memoized_property
def _is_explicit_returning(self):
return self.compiled and \
getattr(self.compiled.statement, '_returning', False)
@util.memoized_property
def _is_implicit_returning(self):
return self.compiled and \
bool(self.compiled.returning) and \
not self.compiled.statement._returning
@util.memoized_property
def _default_params(self):
if self.dialect.positional:
return self.dialect.execute_sequence_format()
else:
return {}
def _execute_scalar(self, stmt):
"""Execute a string statement on the current cursor, returning a scalar result.
Used to fire off sequences, default phrases, and "select lastrowid"
types of statements individually
or in the context of a parent INSERT or UPDATE statement.
"""
conn = self._connection
if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
stmt = stmt.encode(self.dialect.encoding)
conn._cursor_execute(self.cursor, stmt, self._default_params)
return self.cursor.fetchone()[0]
@property
def connection(self):
return self._connection._branch()
def __encode_param_keys(self, params):
"""Apply string encoding to the keys of dictionary-based bind parameters.
This is only used executing textual, non-compiled SQL expressions.
"""
if not params:
return [self._default_params]
elif isinstance(params[0], self.dialect.execute_sequence_format):
return params
elif isinstance(params[0], dict):
if self.dialect.supports_unicode_statements:
return params
else:
def proc(d):
return dict((k.encode(self.dialect.encoding), d[k]) for k in d)
return [proc(d) for d in params] or [{}]
else:
return [self.dialect.execute_sequence_format(p) for p in params]
def __convert_compiled_params(self, compiled_parameters):
"""Convert the dictionary of bind parameter values into a dict or list
to be sent to the DBAPI's execute() or executemany() method.
"""
processors = self.processors
parameters = []
if self.dialect.positional:
for compiled_params in compiled_parameters:
param = []
for key in self.compiled.positiontup:
if key in processors:
param.append(processors[key](compiled_params[key]))
else:
param.append(compiled_params[key])
parameters.append(self.dialect.execute_sequence_format(param))
else:
encode = not self.dialect.supports_unicode_statements
for compiled_params in compiled_parameters:
param = {}
if encode:
encoding = self.dialect.encoding
for key in compiled_params:
if key in processors:
param[key.encode(encoding)] = processors[key](compiled_params[key])
else:
param[key.encode(encoding)] = compiled_params[key]
else:
for key in compiled_params:
if key in processors:
param[key] = processors[key](compiled_params[key])
else:
param[key] = compiled_params[key]
parameters.append(param)
return self.dialect.execute_sequence_format(parameters)
def should_autocommit_text(self, statement):
return AUTOCOMMIT_REGEXP.match(statement)
def create_cursor(self):
return self._connection.connection.cursor()
def pre_exec(self):
pass
def post_exec(self):
pass
def get_lastrowid(self):
"""return self.cursor.lastrowid, or equivalent, after an INSERT.
This may involve calling special cursor functions,
issuing a new SELECT on the cursor (or a new one),
or returning a stored value that was
calculated within post_exec().
This function will only be called for dialects
which support "implicit" primary key generation,
keep preexecute_autoincrement_sequences set to False,
and when no explicit id value was bound to the
statement.
The function is called once, directly after
post_exec() and before the transaction is committed
or ResultProxy is generated. If the post_exec()
method assigns a value to `self._lastrowid`, the
value is used in place of calling get_lastrowid().
Note that this method is *not* equivalent to the
``lastrowid`` method on ``ResultProxy``, which is a
direct proxy to the DBAPI ``lastrowid`` accessor
in all cases.
"""
return self.cursor.lastrowid
def handle_dbapi_exception(self, e):
pass
def get_result_proxy(self):
return base.ResultProxy(self)
@property
def rowcount(self):
return self.cursor.rowcount
def supports_sane_rowcount(self):
return self.dialect.supports_sane_rowcount
def supports_sane_multi_rowcount(self):
return self.dialect.supports_sane_multi_rowcount
def post_insert(self):
if self.dialect.postfetch_lastrowid and \
(not len(self._inserted_primary_key) or \
None in self._inserted_primary_key):
table = self.compiled.statement.table
lastrowid = self.get_lastrowid()
self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v
for c, v in zip(table.primary_key, self._inserted_primary_key)
]
def _fetch_implicit_returning(self, resultproxy):
table = self.compiled.statement.table
row = resultproxy.fetchone()
self._inserted_primary_key = [v is not None and v or row[c]
for c, v in zip(table.primary_key, self._inserted_primary_key)
]
def last_inserted_params(self):
return self._last_inserted_params
def last_updated_params(self):
return self._last_updated_params
def lastrow_has_defaults(self):
return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
def set_input_sizes(self, translate=None, exclude_types=None):
"""Given a cursor and ClauseParameters, call the appropriate
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
"""
if not hasattr(self.compiled, 'bind_names'):
return
types = dict(
(self.compiled.bind_names[bindparam], bindparam.type)
for bindparam in self.compiled.bind_names)
if self.dialect.positional:
inputsizes = []
for key in self.compiled.positiontup:
typeengine = types[key]
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
inputsizes.append(dbtype)
try:
self.cursor.setinputsizes(*inputsizes)
except Exception, e:
self._connection._handle_dbapi_exception(e, None, None, None, self)
raise
else:
inputsizes = {}
for key in self.compiled.bind_names.values():
typeengine = types[key]
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
if translate:
key = translate.get(key, key)
inputsizes[key.encode(self.dialect.encoding)] = dbtype
try:
self.cursor.setinputsizes(**inputsizes)
except Exception, e:
self._connection._handle_dbapi_exception(e, None, None, None, self)
raise
def _exec_default(self, default):
if default.is_sequence:
return self.fire_sequence(default)
elif default.is_callable:
return default.arg(self)
elif default.is_clause_element:
# TODO: expensive branching here should be
# pulled into _exec_scalar()
conn = self.connection
c = expression.select([default.arg]).compile(bind=conn)
return conn._execute_compiled(c, (), {}).scalar()
else:
return default.arg
def get_insert_default(self, column):
if column.default is None:
return None
else:
return self._exec_default(column.default)
def get_update_default(self, column):
if column.onupdate is None:
return None
else:
return self._exec_default(column.onupdate)
def __process_defaults(self):
"""Generate default values for compiled insert/update statements,
and generate inserted_primary_key collection.
"""
if self.executemany:
if len(self.compiled.prefetch):
scalar_defaults = {}
# pre-determine scalar Python-side defaults
# to avoid many calls of get_insert_default()/get_update_default()
for c in self.compiled.prefetch:
if self.isinsert and c.default and c.default.is_scalar:
scalar_defaults[c] = c.default.arg
elif self.isupdate and c.onupdate and c.onupdate.is_scalar:
scalar_defaults[c] = c.onupdate.arg
for param in self.compiled_parameters:
self.current_parameters = param
for c in self.compiled.prefetch:
if c in scalar_defaults:
val = scalar_defaults[c]
elif self.isinsert:
val = self.get_insert_default(c)
else:
val = self.get_update_default(c)
if val is not None:
param[c.key] = val
del self.current_parameters
else:
self.current_parameters = compiled_parameters = self.compiled_parameters[0]
for c in self.compiled.prefetch:
if self.isinsert:
val = self.get_insert_default(c)
else:
val = self.get_update_default(c)
if val is not None:
compiled_parameters[c.key] = val
del self.current_parameters
if self.isinsert:
self._inserted_primary_key = [compiled_parameters.get(c.key, None)
for c in self.compiled.statement.table.primary_key]
self._last_inserted_params = compiled_parameters
else:
self._last_updated_params = compiled_parameters
self.postfetch_cols = self.compiled.postfetch
self.prefetch_cols = self.compiled.prefetch
DefaultDialect.execution_ctx_cls = DefaultExecutionContext

View File

@@ -0,0 +1,370 @@
"""Provides an abstraction for obtaining database schema information.
Usage Notes:
Here are some general conventions when accessing the low level inspector
methods such as get_table_names, get_columns, etc.
1. Inspector methods return lists of dicts in most cases for the following
reasons:
* They're both standard types that can be serialized.
* Using a dict instead of a tuple allows easy expansion of attributes.
* Using a list for the outer structure maintains order and is easy to work
with (e.g. list comprehension [d['name'] for d in cols]).
2. Records that contain a name, such as the column name in a column record
use the key 'name'. So for most return values, each record will have a
'name' attribute..
"""
import sqlalchemy
from sqlalchemy import exc, sql
from sqlalchemy import util
from sqlalchemy.types import TypeEngine
from sqlalchemy import schema as sa_schema
@util.decorator
def cache(fn, self, con, *args, **kw):
info_cache = kw.get('info_cache', None)
if info_cache is None:
return fn(self, con, *args, **kw)
key = (
fn.__name__,
tuple(a for a in args if isinstance(a, basestring)),
tuple((k, v) for k, v in kw.iteritems() if isinstance(v, (basestring, int, float)))
)
ret = info_cache.get(key)
if ret is None:
ret = fn(self, con, *args, **kw)
info_cache[key] = ret
return ret
class Inspector(object):
"""Performs database schema inspection.
The Inspector acts as a proxy to the dialects' reflection methods and
provides higher level functions for accessing database schema information.
"""
def __init__(self, conn):
"""Initialize the instance.
:param conn: a :class:`~sqlalchemy.engine.base.Connectable`
"""
self.conn = conn
# set the engine
if hasattr(conn, 'engine'):
self.engine = conn.engine
else:
self.engine = conn
self.dialect = self.engine.dialect
self.info_cache = {}
@classmethod
def from_engine(cls, engine):
if hasattr(engine.dialect, 'inspector'):
return engine.dialect.inspector(engine)
return Inspector(engine)
@property
def default_schema_name(self):
return self.dialect.default_schema_name
def get_schema_names(self):
"""Return all schema names.
"""
if hasattr(self.dialect, 'get_schema_names'):
return self.dialect.get_schema_names(self.conn,
info_cache=self.info_cache)
return []
def get_table_names(self, schema=None, order_by=None):
"""Return all table names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
:param order_by: Optional, may be the string "foreign_key" to sort
the result on foreign key dependencies.
This should probably not return view names or maybe it should return
them with an indicator t or v.
"""
if hasattr(self.dialect, 'get_table_names'):
tnames = self.dialect.get_table_names(self.conn,
schema,
info_cache=self.info_cache)
else:
tnames = self.engine.table_names(schema)
if order_by == 'foreign_key':
ordered_tnames = tnames[:]
# Order based on foreign key dependencies.
for tname in tnames:
table_pos = tnames.index(tname)
fkeys = self.get_foreign_keys(tname, schema)
for fkey in fkeys:
rtable = fkey['referred_table']
if rtable in ordered_tnames:
ref_pos = ordered_tnames.index(rtable)
# Make sure it's lower in the list than anything it
# references.
if table_pos > ref_pos:
ordered_tnames.pop(table_pos) # rtable moves up 1
# insert just below rtable
ordered_tnames.index(ref_pos, tname)
tnames = ordered_tnames
return tnames
def get_table_options(self, table_name, schema=None, **kw):
if hasattr(self.dialect, 'get_table_options'):
return self.dialect.get_table_options(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
return {}
def get_view_names(self, schema=None):
"""Return all view names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
"""
return self.dialect.get_view_names(self.conn, schema,
info_cache=self.info_cache)
def get_view_definition(self, view_name, schema=None):
"""Return definition for `view_name`.
:param schema: Optional, retrieve names from a non-default schema.
"""
return self.dialect.get_view_definition(
self.conn, view_name, schema, info_cache=self.info_cache)
def get_columns(self, table_name, schema=None, **kw):
"""Return information about columns in `table_name`.
Given a string `table_name` and an optional string `schema`, return
column information as a list of dicts with these keys:
name
the column's name
type
:class:`~sqlalchemy.types.TypeEngine`
nullable
boolean
default
the column's default value
attrs
dict containing optional column attributes
"""
col_defs = self.dialect.get_columns(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
for col_def in col_defs:
# make this easy and only return instances for coltype
coltype = col_def['type']
if not isinstance(coltype, TypeEngine):
col_def['type'] = coltype()
return col_defs
def get_primary_keys(self, table_name, schema=None, **kw):
"""Return information about primary keys in `table_name`.
Given a string `table_name`, and an optional string `schema`, return
primary key information as a list of column names.
"""
pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
return pkeys
def get_foreign_keys(self, table_name, schema=None, **kw):
"""Return information about foreign_keys in `table_name`.
Given a string `table_name`, and an optional string `schema`, return
foreign key information as a list of dicts with these keys:
constrained_columns
a list of column names that make up the foreign key
referred_schema
the name of the referred schema
referred_table
the name of the referred table
referred_columns
a list of column names in the referred table that correspond to
constrained_columns
\**kw
other options passed to the dialect's get_foreign_keys() method.
"""
fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
return fk_defs
def get_indexes(self, table_name, schema=None, **kw):
"""Return information about indexes in `table_name`.
Given a string `table_name` and an optional string `schema`, return
index information as a list of dicts with these keys:
name
the index's name
column_names
list of column names in order
unique
boolean
\**kw
other options passed to the dialect's get_indexes() method.
"""
indexes = self.dialect.get_indexes(self.conn, table_name,
schema,
info_cache=self.info_cache, **kw)
return indexes
def reflecttable(self, table, include_columns):
dialect = self.conn.dialect
# MySQL dialect does this. Applicable with other dialects?
if hasattr(dialect, '_connection_charset') \
and hasattr(dialect, '_adjust_casing'):
charset = dialect._connection_charset
dialect._adjust_casing(table)
# table attributes we might need.
reflection_options = dict(
(k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs)
schema = table.schema
table_name = table.name
# apply table options
tbl_opts = self.get_table_options(table_name, schema, **table.kwargs)
if tbl_opts:
table.kwargs.update(tbl_opts)
# table.kwargs will need to be passed to each reflection method. Make
# sure keywords are strings.
tblkw = table.kwargs.copy()
for (k, v) in tblkw.items():
del tblkw[k]
tblkw[str(k)] = v
# Py2K
if isinstance(schema, str):
schema = schema.decode(dialect.encoding)
if isinstance(table_name, str):
table_name = table_name.decode(dialect.encoding)
# end Py2K
# columns
found_table = False
for col_d in self.get_columns(table_name, schema, **tblkw):
found_table = True
name = col_d['name']
if include_columns and name not in include_columns:
continue
coltype = col_d['type']
col_kw = {
'nullable':col_d['nullable'],
}
if 'autoincrement' in col_d:
col_kw['autoincrement'] = col_d['autoincrement']
if 'quote' in col_d:
col_kw['quote'] = col_d['quote']
colargs = []
if col_d.get('default') is not None:
# the "default" value is assumed to be a literal SQL expression,
# so is wrapped in text() so that no quoting occurs on re-issuance.
colargs.append(sa_schema.DefaultClause(sql.text(col_d['default'])))
if 'sequence' in col_d:
# TODO: mssql, maxdb and sybase are using this.
seq = col_d['sequence']
sequence = sa_schema.Sequence(seq['name'], 1, 1)
if 'start' in seq:
sequence.start = seq['start']
if 'increment' in seq:
sequence.increment = seq['increment']
colargs.append(sequence)
col = sa_schema.Column(name, coltype, *colargs, **col_kw)
table.append_column(col)
if not found_table:
raise exc.NoSuchTableError(table.name)
# Primary keys
primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[
table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw)
if pk in table.c
])
table.append_constraint(primary_key_constraint)
# Foreign keys
fkeys = self.get_foreign_keys(table_name, schema, **tblkw)
for fkey_d in fkeys:
conname = fkey_d['name']
constrained_columns = fkey_d['constrained_columns']
referred_schema = fkey_d['referred_schema']
referred_table = fkey_d['referred_table']
referred_columns = fkey_d['referred_columns']
refspec = []
if referred_schema is not None:
sa_schema.Table(referred_table, table.metadata,
autoload=True, schema=referred_schema,
autoload_with=self.conn,
**reflection_options
)
for column in referred_columns:
refspec.append(".".join(
[referred_schema, referred_table, column]))
else:
sa_schema.Table(referred_table, table.metadata, autoload=True,
autoload_with=self.conn,
**reflection_options
)
for column in referred_columns:
refspec.append(".".join([referred_table, column]))
table.append_constraint(
sa_schema.ForeignKeyConstraint(constrained_columns, refspec,
conname, link_to_name=True))
# Indexes
indexes = self.get_indexes(table_name, schema)
for index_d in indexes:
name = index_d['name']
columns = index_d['column_names']
unique = index_d['unique']
flavor = index_d.get('type', 'unknown type')
if include_columns and \
not set(columns).issubset(include_columns):
util.warn(
"Omitting %s KEY for (%s), key covers omitted columns." %
(flavor, ', '.join(columns)))
continue
sa_schema.Index(name, *[table.columns[c] for c in columns],
**dict(unique=unique))

View File

@@ -0,0 +1,227 @@
"""Strategies for creating new instances of Engine types.
These are semi-private implementation classes which provide the
underlying behavior for the "strategy" keyword argument available on
:func:`~sqlalchemy.engine.create_engine`. Current available options are
``plain``, ``threadlocal``, and ``mock``.
New strategies can be added via new ``EngineStrategy`` classes.
"""
from operator import attrgetter
from sqlalchemy.engine import base, threadlocal, url
from sqlalchemy import util, exc
from sqlalchemy import pool as poollib
strategies = {}
class EngineStrategy(object):
"""An adaptor that processes input arguements and produces an Engine.
Provides a ``create`` method that receives input arguments and
produces an instance of base.Engine or a subclass.
"""
def __init__(self):
strategies[self.name] = self
def create(self, *args, **kwargs):
"""Given arguments, returns a new Engine instance."""
raise NotImplementedError()
class DefaultEngineStrategy(EngineStrategy):
"""Base class for built-in stratgies."""
pool_threadlocal = False
def create(self, name_or_url, **kwargs):
# create url.URL object
u = url.make_url(name_or_url)
dialect_cls = u.get_dialect()
dialect_args = {}
# consume dialect arguments from kwargs
for k in util.get_cls_kwargs(dialect_cls):
if k in kwargs:
dialect_args[k] = kwargs.pop(k)
dbapi = kwargs.pop('module', None)
if dbapi is None:
dbapi_args = {}
for k in util.get_func_kwargs(dialect_cls.dbapi):
if k in kwargs:
dbapi_args[k] = kwargs.pop(k)
dbapi = dialect_cls.dbapi(**dbapi_args)
dialect_args['dbapi'] = dbapi
# create dialect
dialect = dialect_cls(**dialect_args)
# assemble connection arguments
(cargs, cparams) = dialect.create_connect_args(u)
cparams.update(kwargs.pop('connect_args', {}))
# look for existing pool or create
pool = kwargs.pop('pool', None)
if pool is None:
def connect():
try:
return dialect.connect(*cargs, **cparams)
except Exception, e:
# Py3K
#raise exc.DBAPIError.instance(None, None, e) from e
# Py2K
import sys
raise exc.DBAPIError.instance(None, None, e), None, sys.exc_info()[2]
# end Py2K
creator = kwargs.pop('creator', connect)
poolclass = (kwargs.pop('poolclass', None) or
getattr(dialect_cls, 'poolclass', poollib.QueuePool))
pool_args = {}
# consume pool arguments from kwargs, translating a few of
# the arguments
translate = {'logging_name': 'pool_logging_name',
'echo': 'echo_pool',
'timeout': 'pool_timeout',
'recycle': 'pool_recycle',
'use_threadlocal':'pool_threadlocal'}
for k in util.get_cls_kwargs(poolclass):
tk = translate.get(k, k)
if tk in kwargs:
pool_args[k] = kwargs.pop(tk)
pool_args.setdefault('use_threadlocal', self.pool_threadlocal)
pool = poolclass(creator, **pool_args)
else:
if isinstance(pool, poollib._DBProxy):
pool = pool.get_pool(*cargs, **cparams)
else:
pool = pool
# create engine.
engineclass = self.engine_cls
engine_args = {}
for k in util.get_cls_kwargs(engineclass):
if k in kwargs:
engine_args[k] = kwargs.pop(k)
_initialize = kwargs.pop('_initialize', True)
# all kwargs should be consumed
if kwargs:
raise TypeError(
"Invalid argument(s) %s sent to create_engine(), "
"using configuration %s/%s/%s. Please check that the "
"keyword arguments are appropriate for this combination "
"of components." % (','.join("'%s'" % k for k in kwargs),
dialect.__class__.__name__,
pool.__class__.__name__,
engineclass.__name__))
engine = engineclass(pool, dialect, u, **engine_args)
if _initialize:
do_on_connect = dialect.on_connect()
if do_on_connect:
def on_connect(conn, rec):
conn = getattr(conn, '_sqla_unwrap', conn)
if conn is None:
return
do_on_connect(conn)
pool.add_listener({'first_connect': on_connect, 'connect':on_connect})
def first_connect(conn, rec):
c = base.Connection(engine, connection=conn)
dialect.initialize(c)
pool.add_listener({'first_connect':first_connect})
return engine
class PlainEngineStrategy(DefaultEngineStrategy):
"""Strategy for configuring a regular Engine."""
name = 'plain'
engine_cls = base.Engine
PlainEngineStrategy()
class ThreadLocalEngineStrategy(DefaultEngineStrategy):
"""Strategy for configuring an Engine with thredlocal behavior."""
name = 'threadlocal'
pool_threadlocal = True
engine_cls = threadlocal.TLEngine
ThreadLocalEngineStrategy()
class MockEngineStrategy(EngineStrategy):
"""Strategy for configuring an Engine-like object with mocked execution.
Produces a single mock Connectable object which dispatches
statement execution to a passed-in function.
"""
name = 'mock'
def create(self, name_or_url, executor, **kwargs):
# create url.URL object
u = url.make_url(name_or_url)
dialect_cls = u.get_dialect()
dialect_args = {}
# consume dialect arguments from kwargs
for k in util.get_cls_kwargs(dialect_cls):
if k in kwargs:
dialect_args[k] = kwargs.pop(k)
# create dialect
dialect = dialect_cls(**dialect_args)
return MockEngineStrategy.MockConnection(dialect, executor)
class MockConnection(base.Connectable):
def __init__(self, dialect, execute):
self._dialect = dialect
self.execute = execute
engine = property(lambda s: s)
dialect = property(attrgetter('_dialect'))
name = property(lambda s: s._dialect.name)
def contextual_connect(self, **kwargs):
return self
def compiler(self, statement, parameters, **kwargs):
return self._dialect.compiler(
statement, parameters, engine=self, **kwargs)
def create(self, entity, **kwargs):
kwargs['checkfirst'] = False
from sqlalchemy.engine import ddl
ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity)
def drop(self, entity, **kwargs):
kwargs['checkfirst'] = False
from sqlalchemy.engine import ddl
ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity)
def execute(self, object, *multiparams, **params):
raise NotImplementedError()
MockEngineStrategy()

View File

@@ -0,0 +1,103 @@
"""Provides a thread-local transactional wrapper around the root Engine class.
The ``threadlocal`` module is invoked when using the ``strategy="threadlocal"`` flag
with :func:`~sqlalchemy.engine.create_engine`. This module is semi-private and is
invoked automatically when the threadlocal engine strategy is used.
"""
from sqlalchemy import util
from sqlalchemy.engine import base
import weakref
class TLConnection(base.Connection):
def __init__(self, *arg, **kw):
super(TLConnection, self).__init__(*arg, **kw)
self.__opencount = 0
def _increment_connect(self):
self.__opencount += 1
return self
def close(self):
if self.__opencount == 1:
base.Connection.close(self)
self.__opencount -= 1
def _force_close(self):
self.__opencount = 0
base.Connection.close(self)
class TLEngine(base.Engine):
"""An Engine that includes support for thread-local managed transactions."""
def __init__(self, *args, **kwargs):
super(TLEngine, self).__init__(*args, **kwargs)
self._connections = util.threading.local()
proxy = kwargs.get('proxy')
if proxy:
self.TLConnection = base._proxy_connection_cls(TLConnection, proxy)
else:
self.TLConnection = TLConnection
def contextual_connect(self, **kw):
if not hasattr(self._connections, 'conn'):
connection = None
else:
connection = self._connections.conn()
if connection is None or connection.closed:
# guards against pool-level reapers, if desired.
# or not connection.connection.is_valid:
connection = self.TLConnection(self, self.pool.connect(), **kw)
self._connections.conn = conn = weakref.ref(connection)
return connection._increment_connect()
def begin_twophase(self, xid=None):
if not hasattr(self._connections, 'trans'):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid))
def begin_nested(self):
if not hasattr(self._connections, 'trans'):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin_nested())
def begin(self):
if not hasattr(self._connections, 'trans'):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin())
def prepare(self):
self._connections.trans[-1].prepare()
def commit(self):
trans = self._connections.trans.pop(-1)
trans.commit()
def rollback(self):
trans = self._connections.trans.pop(-1)
trans.rollback()
def dispose(self):
self._connections = util.threading.local()
super(TLEngine, self).dispose()
@property
def closed(self):
return not hasattr(self._connections, 'conn') or \
self._connections.conn() is None or \
self._connections.conn().closed
def close(self):
if not self.closed:
self.contextual_connect().close()
connection = self._connections.conn()
connection._force_close()
del self._connections.conn
self._connections.trans = []
def __repr__(self):
return 'TLEngine(%s)' % str(self.url)

214
sqlalchemy/engine/url.py Normal file
View File

@@ -0,0 +1,214 @@
"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates
information about a database connection specification.
The URL object is created automatically when :func:`~sqlalchemy.engine.create_engine` is called
with a string argument; alternatively, the URL is a public-facing construct which can
be used directly and is also accepted directly by ``create_engine()``.
"""
import re, cgi, sys, urllib
from sqlalchemy import exc
class URL(object):
"""
Represent the components of a URL used to connect to a database.
This object is suitable to be passed directly to a
``create_engine()`` call. The fields of the URL are parsed from a
string by the ``module-level make_url()`` function. the string
format of the URL is an RFC-1738-style string.
All initialization parameters are available as public attributes.
:param drivername: the name of the database backend.
This name will correspond to a module in sqlalchemy/databases
or a third party plug-in.
:param username: The user name.
:param password: database password.
:param host: The name of the host.
:param port: The port number.
:param database: The database name.
:param query: A dictionary of options to be passed to the
dialect and/or the DBAPI upon connect.
"""
def __init__(self, drivername, username=None, password=None,
host=None, port=None, database=None, query=None):
self.drivername = drivername
self.username = username
self.password = password
self.host = host
if port is not None:
self.port = int(port)
else:
self.port = None
self.database = database
self.query = query or {}
def __str__(self):
s = self.drivername + "://"
if self.username is not None:
s += self.username
if self.password is not None:
s += ':' + urllib.quote_plus(self.password)
s += "@"
if self.host is not None:
s += self.host
if self.port is not None:
s += ':' + str(self.port)
if self.database is not None:
s += '/' + self.database
if self.query:
keys = self.query.keys()
keys.sort()
s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys)
return s
def __hash__(self):
return hash(str(self))
def __eq__(self, other):
return \
isinstance(other, URL) and \
self.drivername == other.drivername and \
self.username == other.username and \
self.password == other.password and \
self.host == other.host and \
self.database == other.database and \
self.query == other.query
def get_dialect(self):
"""Return the SQLAlchemy database dialect class corresponding
to this URL's driver name.
"""
try:
if '+' in self.drivername:
dialect, driver = self.drivername.split('+')
else:
dialect, driver = self.drivername, 'base'
module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
module = getattr(module, dialect)
module = getattr(module, driver)
return module.dialect
except ImportError:
module = self._load_entry_point()
if module is not None:
return module
else:
raise
def _load_entry_point(self):
"""attempt to load this url's dialect from entry points, or return None
if pkg_resources is not installed or there is no matching entry point.
Raise ImportError if the actual load fails.
"""
try:
import pkg_resources
except ImportError:
return None
for res in pkg_resources.iter_entry_points('sqlalchemy.dialects'):
if res.name == self.drivername:
return res.load()
else:
return None
def translate_connect_args(self, names=[], **kw):
"""Translate url attributes into a dictionary of connection arguments.
Returns attributes of this url (`host`, `database`, `username`,
`password`, `port`) as a plain dictionary. The attribute names are
used as the keys by default. Unset or false attributes are omitted
from the final dictionary.
:param \**kw: Optional, alternate key names for url attributes.
:param names: Deprecated. Same purpose as the keyword-based alternate names,
but correlates the name to the original positionally.
"""
translated = {}
attribute_names = ['host', 'database', 'username', 'password', 'port']
for sname in attribute_names:
if names:
name = names.pop(0)
elif sname in kw:
name = kw[sname]
else:
name = sname
if name is not None and getattr(self, sname, False):
translated[name] = getattr(self, sname)
return translated
def make_url(name_or_url):
"""Given a string or unicode instance, produce a new URL instance.
The given string is parsed according to the RFC 1738 spec. If an
existing URL object is passed, just returns the object.
"""
if isinstance(name_or_url, basestring):
return _parse_rfc1738_args(name_or_url)
else:
return name_or_url
def _parse_rfc1738_args(name):
pattern = re.compile(r'''
(?P<name>[\w\+]+)://
(?:
(?P<username>[^:/]*)
(?::(?P<password>[^/]*))?
@)?
(?:
(?P<host>[^/:]*)
(?::(?P<port>[^/]*))?
)?
(?:/(?P<database>.*))?
'''
, re.X)
m = pattern.match(name)
if m is not None:
components = m.groupdict()
if components['database'] is not None:
tokens = components['database'].split('?', 2)
components['database'] = tokens[0]
query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None
# Py2K
if query is not None:
query = dict((k.encode('ascii'), query[k]) for k in query)
# end Py2K
else:
query = None
components['query'] = query
if components['password'] is not None:
components['password'] = urllib.unquote_plus(components['password'])
name = components.pop('name')
return URL(name, **components)
else:
raise exc.ArgumentError(
"Could not parse rfc1738 URL from string '%s'" % name)
def _parse_keyvalue_args(name):
m = re.match( r'(\w+)://(.*)', name)
if m is not None:
(name, args) = m.group(1, 2)
opts = dict( cgi.parse_qsl( args ) )
return URL(name, *opts)
else:
return None