This commit is contained in:
2010-05-07 17:33:49 +00:00
parent c7c7498f19
commit cae7e001e3
127 changed files with 57530 additions and 0 deletions

View File

@@ -0,0 +1,26 @@
"""Testing environment and utilities.
This package contains base classes and routines used by
the unit tests. Tests are based on Nose and bootstrapped
by noseplugin.NoseSQLAlchemy.
"""
from sqlalchemy.test import testing, engines, requires, profiling, pickleable, config
from sqlalchemy.test.schema import Column, Table
from sqlalchemy.test.testing import \
AssertsCompiledSQL, \
AssertsExecutionResults, \
ComparesTables, \
TestBase, \
rowset
__all__ = ('testing',
'Column', 'Table',
'rowset',
'TestBase', 'AssertsExecutionResults',
'AssertsCompiledSQL', 'ComparesTables',
'engines', 'profiling', 'pickleable')

View File

@@ -0,0 +1,285 @@
from sqlalchemy.interfaces import ConnectionProxy
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.base import Connection
from sqlalchemy import util
import re
class AssertRule(object):
def process_execute(self, clauseelement, *multiparams, **params):
pass
def process_cursor_execute(self, statement, parameters, context, executemany):
pass
def is_consumed(self):
"""Return True if this rule has been consumed, False if not.
Should raise an AssertionError if this rule's condition has definitely failed.
"""
raise NotImplementedError()
def rule_passed(self):
"""Return True if the last test of this rule passed, False if failed, None if no test was applied."""
raise NotImplementedError()
def consume_final(self):
"""Return True if this rule has been consumed.
Should raise an AssertionError if this rule's condition has not been consumed or has failed.
"""
if self._result is None:
assert False, "Rule has not been consumed"
return self.is_consumed()
class SQLMatchRule(AssertRule):
def __init__(self):
self._result = None
self._errmsg = ""
def rule_passed(self):
return self._result
def is_consumed(self):
if self._result is None:
return False
assert self._result, self._errmsg
return True
class ExactSQL(SQLMatchRule):
def __init__(self, sql, params=None):
SQLMatchRule.__init__(self)
self.sql = sql
self.params = params
def process_cursor_execute(self, statement, parameters, context, executemany):
if not context:
return
_received_statement = _process_engine_statement(context.unicode_statement, context)
_received_parameters = context.compiled_parameters
# TODO: remove this step once all unit tests
# are migrated, as ExactSQL should really be *exact* SQL
sql = _process_assertion_statement(self.sql, context)
equivalent = _received_statement == sql
if self.params:
if util.callable(self.params):
params = self.params(context)
else:
params = self.params
if not isinstance(params, list):
params = [params]
equivalent = equivalent and params == context.compiled_parameters
else:
params = {}
self._result = equivalent
if not self._result:
self._errmsg = "Testing for exact statement %r exact params %r, " \
"received %r with params %r" % (sql, params, _received_statement, _received_parameters)
class RegexSQL(SQLMatchRule):
def __init__(self, regex, params=None):
SQLMatchRule.__init__(self)
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
def process_cursor_execute(self, statement, parameters, context, executemany):
if not context:
return
_received_statement = _process_engine_statement(context.unicode_statement, context)
_received_parameters = context.compiled_parameters
equivalent = bool(self.regex.match(_received_statement))
if self.params:
if util.callable(self.params):
params = self.params(context)
else:
params = self.params
if not isinstance(params, list):
params = [params]
# do a positive compare only
for param, received in zip(params, _received_parameters):
for k, v in param.iteritems():
if k not in received or received[k] != v:
equivalent = False
break
else:
params = {}
self._result = equivalent
if not self._result:
self._errmsg = "Testing for regex %r partial params %r, "\
"received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters)
class CompiledSQL(SQLMatchRule):
def __init__(self, statement, params):
SQLMatchRule.__init__(self)
self.statement = statement
self.params = params
def process_cursor_execute(self, statement, parameters, context, executemany):
if not context:
return
_received_parameters = context.compiled_parameters
# recompile from the context, using the default dialect
compiled = context.compiled.statement.\
compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys)
_received_statement = re.sub(r'\n', '', str(compiled))
equivalent = self.statement == _received_statement
if self.params:
if util.callable(self.params):
params = self.params(context)
else:
params = self.params
if not isinstance(params, list):
params = [params]
# do a positive compare only
for param, received in zip(params, _received_parameters):
for k, v in param.iteritems():
if k not in received or received[k] != v:
equivalent = False
break
else:
params = {}
self._result = equivalent
if not self._result:
self._errmsg = "Testing for compiled statement %r partial params %r, " \
"received %r with params %r" % (self.statement, params, _received_statement, _received_parameters)
class CountStatements(AssertRule):
def __init__(self, count):
self.count = count
self._statement_count = 0
def process_execute(self, clauseelement, *multiparams, **params):
self._statement_count += 1
def process_cursor_execute(self, statement, parameters, context, executemany):
pass
def is_consumed(self):
return False
def consume_final(self):
assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count)
return True
class AllOf(AssertRule):
def __init__(self, *rules):
self.rules = set(rules)
def process_execute(self, clauseelement, *multiparams, **params):
for rule in self.rules:
rule.process_execute(clauseelement, *multiparams, **params)
def process_cursor_execute(self, statement, parameters, context, executemany):
for rule in self.rules:
rule.process_cursor_execute(statement, parameters, context, executemany)
def is_consumed(self):
if not self.rules:
return True
for rule in list(self.rules):
if rule.rule_passed(): # a rule passed, move on
self.rules.remove(rule)
return len(self.rules) == 0
assert False, "No assertion rules were satisfied for statement"
def consume_final(self):
return len(self.rules) == 0
def _process_engine_statement(query, context):
if util.jython:
# oracle+zxjdbc passes a PyStatement when returning into
query = unicode(query)
if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'):
query = query[:-25]
query = re.sub(r'\n', '', query)
return query
def _process_assertion_statement(query, context):
paramstyle = context.dialect.paramstyle
if paramstyle == 'named':
pass
elif paramstyle =='pyformat':
query = re.sub(r':([\w_]+)', r"%(\1)s", query)
else:
# positional params
repl = None
if paramstyle=='qmark':
repl = "?"
elif paramstyle=='format':
repl = r"%s"
elif paramstyle=='numeric':
repl = None
query = re.sub(r':([\w_]+)', repl, query)
return query
class SQLAssert(ConnectionProxy):
rules = None
def add_rules(self, rules):
self.rules = list(rules)
def statement_complete(self):
for rule in self.rules:
if not rule.consume_final():
assert False, "All statements are complete, but pending assertion rules remain"
def clear_rules(self):
del self.rules
def execute(self, conn, execute, clauseelement, *multiparams, **params):
result = execute(clauseelement, *multiparams, **params)
if self.rules is not None:
if not self.rules:
assert False, "All rules have been exhausted, but further statements remain"
rule = self.rules[0]
rule.process_execute(clauseelement, *multiparams, **params)
if rule.is_consumed():
self.rules.pop(0)
return result
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
result = execute(cursor, statement, parameters, context)
if self.rules:
rule = self.rules[0]
rule.process_cursor_execute(statement, parameters, context, executemany)
return result
asserter = SQLAssert()

180
sqlalchemy/test/config.py Normal file
View File

@@ -0,0 +1,180 @@
import optparse, os, sys, re, ConfigParser, time, warnings
# 2to3
import StringIO
logging = None
__all__ = 'parser', 'configure', 'options',
db = None
db_label, db_url, db_opts = None, None, {}
options = None
file_config = None
base_config = """
[db]
sqlite=sqlite:///:memory:
sqlite_file=sqlite:///querytest.db
postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
postgres=postgresql://scott:tiger@127.0.0.1:5432/test
pg8000=postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
postgresql_jython=postgresql+zxjdbc://scott:tiger@127.0.0.1:5432/test
mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test
mysql=mysql://scott:tiger@127.0.0.1:3306/test
oracle=oracle://scott:tiger@127.0.0.1:1521
oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb
maxdb=maxdb://MONA:RED@/maxdb1
"""
def _log(option, opt_str, value, parser):
global logging
if not logging:
import logging
logging.basicConfig()
if opt_str.endswith('-info'):
logging.getLogger(value).setLevel(logging.INFO)
elif opt_str.endswith('-debug'):
logging.getLogger(value).setLevel(logging.DEBUG)
def _list_dbs(*args):
print "Available --db options (use --dburi to override)"
for macro in sorted(file_config.options('db')):
print "%20s\t%s" % (macro, file_config.get('db', macro))
sys.exit(0)
def _server_side_cursors(options, opt_str, value, parser):
db_opts['server_side_cursors'] = True
def _engine_strategy(options, opt_str, value, parser):
if value:
db_opts['strategy'] = value
class _ordered_map(object):
def __init__(self):
self._keys = list()
self._data = dict()
def __setitem__(self, key, value):
if key not in self._keys:
self._keys.append(key)
self._data[key] = value
def __iter__(self):
for key in self._keys:
yield self._data[key]
# at one point in refactoring, modules were injecting into the config
# process. this could probably just become a list now.
post_configure = _ordered_map()
def _engine_uri(options, file_config):
global db_label, db_url
db_label = 'sqlite'
if options.dburi:
db_url = options.dburi
db_label = db_url[:db_url.index(':')]
elif options.db:
db_label = options.db
db_url = None
if db_url is None:
if db_label not in file_config.options('db'):
raise RuntimeError(
"Unknown engine. Specify --dbs for known engines.")
db_url = file_config.get('db', db_label)
post_configure['engine_uri'] = _engine_uri
def _require(options, file_config):
if not(options.require or
(file_config.has_section('require') and
file_config.items('require'))):
return
try:
import pkg_resources
except ImportError:
raise RuntimeError("setuptools is required for version requirements")
cmdline = []
for requirement in options.require:
pkg_resources.require(requirement)
cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
if file_config.has_section('require'):
for label, requirement in file_config.items('require'):
if not label == db_label or label.startswith('%s.' % db_label):
continue
seen = [c for c in cmdline if requirement.startswith(c)]
if seen:
continue
pkg_resources.require(requirement)
post_configure['require'] = _require
def _engine_pool(options, file_config):
if options.mockpool:
from sqlalchemy import pool
db_opts['poolclass'] = pool.AssertionPool
post_configure['engine_pool'] = _engine_pool
def _create_testing_engine(options, file_config):
from sqlalchemy.test import engines, testing
global db
db = engines.testing_engine(db_url, db_opts)
testing.db = db
post_configure['create_engine'] = _create_testing_engine
def _prep_testing_database(options, file_config):
from sqlalchemy.test import engines
from sqlalchemy import schema
# also create alt schemas etc. here?
if options.dropfirst:
e = engines.utf8_engine()
existing = e.table_names()
if existing:
print "Dropping existing tables in database: " + db_url
try:
print "Tables: %s" % ', '.join(existing)
except:
pass
print "Abort within 5 seconds..."
time.sleep(5)
md = schema.MetaData(e, reflect=True)
md.drop_all()
e.dispose()
post_configure['prep_db'] = _prep_testing_database
def _set_table_options(options, file_config):
from sqlalchemy.test import schema
table_options = schema.table_options
for spec in options.tableopts:
key, value = spec.split('=')
table_options[key] = value
if options.mysql_engine:
table_options['mysql_engine'] = options.mysql_engine
post_configure['table_options'] = _set_table_options
def _reverse_topological(options, file_config):
if options.reversetop:
from sqlalchemy.orm import unitofwork
from sqlalchemy import topological
class RevQueueDepSort(topological.QueueDependencySorter):
def __init__(self, tuples, allitems):
self.tuples = list(tuples)
self.allitems = list(allitems)
self.tuples.reverse()
self.allitems.reverse()
topological.QueueDependencySorter = RevQueueDepSort
unitofwork.DependencySorter = RevQueueDepSort
post_configure['topological'] = _reverse_topological

300
sqlalchemy/test/engines.py Normal file
View File

@@ -0,0 +1,300 @@
import sys, types, weakref
from collections import deque
import config
from sqlalchemy.util import function_named, callable
import re
import warnings
class ConnectionKiller(object):
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
def checkout(self, dbapi_con, con_record, con_proxy):
self.proxy_refs[con_proxy] = True
def _apply_all(self, methods):
# must copy keys atomically
for rec in self.proxy_refs.keys():
if rec is not None and rec.is_valid:
try:
for name in methods:
if callable(name):
name(rec)
else:
getattr(rec, name)()
except (SystemExit, KeyboardInterrupt):
raise
except Exception, e:
warnings.warn("testing_reaper couldn't close connection: %s" % e)
def rollback_all(self):
self._apply_all(('rollback',))
def close_all(self):
self._apply_all(('rollback', 'close'))
def assert_all_closed(self):
for rec in self.proxy_refs:
if rec.is_valid:
assert False
testing_reaper = ConnectionKiller()
def drop_all_tables(metadata):
testing_reaper.close_all()
metadata.drop_all()
def assert_conns_closed(fn):
def decorated(*args, **kw):
try:
fn(*args, **kw)
finally:
testing_reaper.assert_all_closed()
return function_named(decorated, fn.__name__)
def rollback_open_connections(fn):
"""Decorator that rolls back all open connections after fn execution."""
def decorated(*args, **kw):
try:
fn(*args, **kw)
finally:
testing_reaper.rollback_all()
return function_named(decorated, fn.__name__)
def close_first(fn):
"""Decorator that closes all connections before fn execution."""
def decorated(*args, **kw):
testing_reaper.close_all()
fn(*args, **kw)
return function_named(decorated, fn.__name__)
def close_open_connections(fn):
"""Decorator that closes all connections after fn execution."""
def decorated(*args, **kw):
try:
fn(*args, **kw)
finally:
testing_reaper.close_all()
return function_named(decorated, fn.__name__)
def all_dialects(exclude=None):
import sqlalchemy.databases as d
for name in d.__all__:
# TEMPORARY
if exclude and name in exclude:
continue
mod = getattr(d, name, None)
if not mod:
mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
yield mod.dialect()
class ReconnectFixture(object):
def __init__(self, dbapi):
self.dbapi = dbapi
self.connections = []
def __getattr__(self, key):
return getattr(self.dbapi, key)
def connect(self, *args, **kwargs):
conn = self.dbapi.connect(*args, **kwargs)
self.connections.append(conn)
return conn
def shutdown(self):
for c in list(self.connections):
c.close()
self.connections = []
def reconnecting_engine(url=None, options=None):
url = url or config.db_url
dbapi = config.db.dialect.dbapi
if not options:
options = {}
options['module'] = ReconnectFixture(dbapi)
engine = testing_engine(url, options)
engine.test_shutdown = engine.dialect.dbapi.shutdown
return engine
def testing_engine(url=None, options=None):
"""Produce an engine configured by --options with optional overrides."""
from sqlalchemy import create_engine
from sqlalchemy.test.assertsql import asserter
url = url or config.db_url
options = options or config.db_opts
options.setdefault('proxy', asserter)
listeners = options.setdefault('listeners', [])
listeners.append(testing_reaper)
engine = create_engine(url, **options)
# may want to call this, results
# in first-connect initializers
#engine.connect()
return engine
def utf8_engine(url=None, options=None):
"""Hook for dialects or drivers that don't handle utf8 by default."""
from sqlalchemy.engine import url as engine_url
if config.db.driver == 'mysqldb':
dbapi_ver = config.db.dialect.dbapi.version_info
if (dbapi_ver < (1, 2, 1) or
dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
(1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))):
raise RuntimeError('Character set support unavailable with this '
'driver version: %s' % repr(dbapi_ver))
else:
url = url or config.db_url
url = engine_url.make_url(url)
url.query['charset'] = 'utf8'
url.query['use_unicode'] = '0'
url = str(url)
return testing_engine(url, options)
def mock_engine(dialect_name=None):
"""Provides a mocking engine based on the current testing.db.
This is normally used to test DDL generation flow as emitted
by an Engine.
It should not be used in other cases, as assert_compile() and
assert_sql_execution() are much better choices with fewer
moving parts.
"""
from sqlalchemy import create_engine
if not dialect_name:
dialect_name = config.db.name
buffer = []
def executor(sql, *a, **kw):
buffer.append(sql)
def assert_sql(stmts):
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
assert recv == stmts, recv
engine = create_engine(dialect_name + '://',
strategy='mock', executor=executor)
assert not hasattr(engine, 'mock')
engine.mock = buffer
engine.assert_sql = assert_sql
return engine
class ReplayableSession(object):
"""A simple record/playback tool.
This is *not* a mock testing class. It only records a session for later
playback and makes no assertions on call consistency whatsoever. It's
unlikely to be suitable for anything other than DB-API recording.
"""
Callable = object()
NoAttribute = object()
Natives = set([getattr(types, t)
for t in dir(types) if not t.startswith('_')]). \
difference([getattr(types, t)
# Py3K
#for t in ('FunctionType', 'BuiltinFunctionType',
# 'MethodType', 'BuiltinMethodType',
# 'LambdaType', )])
# Py2K
for t in ('FunctionType', 'BuiltinFunctionType',
'MethodType', 'BuiltinMethodType',
'LambdaType', 'UnboundMethodType',)])
# end Py2K
def __init__(self):
self.buffer = deque()
def recorder(self, base):
return self.Recorder(self.buffer, base)
def player(self):
return self.Player(self.buffer)
class Recorder(object):
def __init__(self, buffer, subject):
self._buffer = buffer
self._subject = subject
def __call__(self, *args, **kw):
subject, buffer = [object.__getattribute__(self, x)
for x in ('_subject', '_buffer')]
result = subject(*args, **kw)
if type(result) not in ReplayableSession.Natives:
buffer.append(ReplayableSession.Callable)
return type(self)(buffer, result)
else:
buffer.append(result)
return result
@property
def _sqla_unwrap(self):
return self._subject
def __getattribute__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError:
pass
subject, buffer = [object.__getattribute__(self, x)
for x in ('_subject', '_buffer')]
try:
result = type(subject).__getattribute__(subject, key)
except AttributeError:
buffer.append(ReplayableSession.NoAttribute)
raise
else:
if type(result) not in ReplayableSession.Natives:
buffer.append(ReplayableSession.Callable)
return type(self)(buffer, result)
else:
buffer.append(result)
return result
class Player(object):
def __init__(self, buffer):
self._buffer = buffer
def __call__(self, *args, **kw):
buffer = object.__getattribute__(self, '_buffer')
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
else:
return result
@property
def _sqla_unwrap(self):
return None
def __getattribute__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError:
pass
buffer = object.__getattribute__(self, '_buffer')
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
elif result is ReplayableSession.NoAttribute:
raise AttributeError(key)
else:
return result

View File

@@ -0,0 +1,83 @@
import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
_repr_stack = set()
class BasicEntity(object):
def __init__(self, **kw):
for key, value in kw.iteritems():
setattr(self, key, value)
def __repr__(self):
if id(self) in _repr_stack:
return object.__repr__(self)
_repr_stack.add(id(self))
try:
return "%s(%s)" % (
(self.__class__.__name__),
', '.join(["%s=%r" % (key, getattr(self, key))
for key in sorted(self.__dict__.keys())
if not key.startswith('_')]))
finally:
_repr_stack.remove(id(self))
_recursion_stack = set()
class ComparableEntity(BasicEntity):
def __hash__(self):
return hash(self.__class__)
def __ne__(self, other):
return not self.__eq__(other)
def __eq__(self, other):
"""'Deep, sparse compare.
Deeply compare two entities, following the non-None attributes of the
non-persisted object, if possible.
"""
if other is self:
return True
elif not self.__class__ == other.__class__:
return False
if id(self) in _recursion_stack:
return True
_recursion_stack.add(id(self))
try:
# pick the entity thats not SA persisted as the source
try:
self_key = sa.orm.attributes.instance_state(self).key
except sa.orm.exc.NO_STATE:
self_key = None
if other is None:
a = self
b = other
elif self_key is not None:
a = other
b = self
else:
a = self
b = other
for attr in a.__dict__.keys():
if attr.startswith('_'):
continue
value = getattr(a, attr)
try:
# handle lazy loader errors
battr = getattr(b, attr)
except (AttributeError, sa_exc.UnboundExecutionError):
return False
if hasattr(value, '__iter__'):
if list(value) != list(battr):
return False
else:
if value is not None and value != battr:
return False
return True
finally:
_recursion_stack.remove(id(self))

View File

@@ -0,0 +1,162 @@
import logging
import os
import re
import sys
import time
import warnings
import ConfigParser
import StringIO
import nose.case
from nose.plugins import Plugin
from sqlalchemy import util, log as sqla_log
from sqlalchemy.test import testing, config, requires
from sqlalchemy.test.config import (
_create_testing_engine, _engine_pool, _engine_strategy, _engine_uri, _list_dbs, _log,
_prep_testing_database, _require, _reverse_topological, _server_side_cursors,
_set_table_options, base_config, db, db_label, db_url, file_config, post_configure)
log = logging.getLogger('nose.plugins.sqlalchemy')
class NoseSQLAlchemy(Plugin):
"""
Handles the setup and extra properties required for testing SQLAlchemy
"""
enabled = True
name = 'sqlalchemy'
score = 100
def options(self, parser, env=os.environ):
Plugin.options(self, parser, env)
opt = parser.add_option
opt("--log-info", action="callback", type="string", callback=_log,
help="turn on info logging for <LOG> (multiple OK)")
opt("--log-debug", action="callback", type="string", callback=_log,
help="turn on debug logging for <LOG> (multiple OK)")
opt("--require", action="append", dest="require", default=[],
help="require a particular driver or module version (multiple OK)")
opt("--db", action="store", dest="db", default="sqlite",
help="Use prefab database uri")
opt('--dbs', action='callback', callback=_list_dbs,
help="List available prefab dbs")
opt("--dburi", action="store", dest="dburi",
help="Database uri (overrides --db)")
opt("--dropfirst", action="store_true", dest="dropfirst",
help="Drop all tables in the target database first (use with caution on Oracle, "
"MS-SQL)")
opt("--mockpool", action="store_true", dest="mockpool",
help="Use mock pool (asserts only one connection used)")
opt("--enginestrategy", action="callback", type="string",
callback=_engine_strategy,
help="Engine strategy (plain or threadlocal, defaults to plain)")
opt("--reversetop", action="store_true", dest="reversetop", default=False,
help="Reverse the collection ordering for topological sorts (helps "
"reveal dependency issues)")
opt("--unhashable", action="store_true", dest="unhashable", default=False,
help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
help="Disallow SQLAlchemy from performing == on mapped test objects.")
opt("--truthless", action="store_true", dest="truthless", default=False,
help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
opt("--serverside", action="callback", callback=_server_side_cursors,
help="Turn on server side cursors for PG")
opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
help="Use the specified MySQL storage engine for all tables, default is "
"a db-default/InnoDB combo.")
opt("--table-option", action="append", dest="tableopts", default=[],
help="Add a dialect-specific table option, key=value")
global file_config
file_config = ConfigParser.ConfigParser()
file_config.readfp(StringIO.StringIO(base_config))
file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
config.file_config = file_config
def configure(self, options, conf):
Plugin.configure(self, options, conf)
self.options = options
def begin(self):
testing.db = db
testing.requires = requires
# Lazy setup of other options (post coverage)
for fn in post_configure:
fn(self.options, file_config)
def describeTest(self, test):
return ""
def wantClass(self, cls):
"""Return true if you want the main test selector to collect
tests from this class, false if you don't, and None if you don't
care.
:Parameters:
cls : class
The class being examined by the selector
"""
if not issubclass(cls, testing.TestBase):
return False
else:
if (hasattr(cls, '__whitelist__') and testing.db.name in cls.__whitelist__):
return True
else:
return not self.__should_skip_for(cls)
def __should_skip_for(self, cls):
if hasattr(cls, '__requires__'):
def test_suite(): return 'ok'
test_suite.__name__ = cls.__name__
for requirement in cls.__requires__:
check = getattr(requires, requirement)
if check(test_suite)() != 'ok':
# The requirement will perform messaging.
return True
if cls.__unsupported_on__:
spec = testing.db_spec(*cls.__unsupported_on__)
if spec(testing.db):
print "'%s' unsupported on DB implementation '%s'" % (
cls.__class__.__name__, testing.db.name)
return True
if getattr(cls, '__only_on__', None):
spec = testing.db_spec(*util.to_list(cls.__only_on__))
if not spec(testing.db):
print "'%s' unsupported on DB implementation '%s'" % (
cls.__class__.__name__, testing.db.name)
return True
if getattr(cls, '__skip_if__', False):
for c in getattr(cls, '__skip_if__'):
if c():
print "'%s' skipped by %s" % (
cls.__class__.__name__, c.__name__)
return True
for rule in getattr(cls, '__excluded_on__', ()):
if testing._is_excluded(*rule):
print "'%s' unsupported on DB %s version %s" % (
cls.__class__.__name__, testing.db.name,
_server_version())
return True
return False
def beforeTest(self, test):
testing.resetwarnings()
def afterTest(self, test):
testing.resetwarnings()
def afterContext(self):
testing.global_cleanup_assertions()
#def handleError(self, test, err):
#pass
#def finalize(self, result=None):
#pass

111
sqlalchemy/test/orm.py Normal file
View File

@@ -0,0 +1,111 @@
import inspect, re
import config, testing
from sqlalchemy import orm
__all__ = 'mapper',
_whitespace = re.compile(r'^(\s+)')
def _find_pragma(lines, current):
m = _whitespace.match(lines[current])
basis = m and m.group() or ''
for line in reversed(lines[0:current]):
if 'testlib.pragma' in line:
return line
m = _whitespace.match(line)
indent = m and m.group() or ''
# simplistic detection:
# >> # testlib.pragma foo
# >> center_line()
if indent == basis:
break
# >> # testlib.pragma foo
# >> if fleem:
# >> center_line()
if line.endswith(':'):
break
return None
def _make_blocker(method_name, fallback):
"""Creates tripwired variant of a method, raising when called.
To excempt an invocation from blockage, there are two options.
1) add a pragma in a comment::
# testlib.pragma exempt:methodname
offending_line()
2) add a magic cookie to the function's namespace::
__sa_baremethodname_exempt__ = True
...
offending_line()
another_offending_lines()
The second is useful for testing and development.
"""
if method_name.startswith('__') and method_name.endswith('__'):
frame_marker = '__sa_%s_exempt__' % method_name[2:-2]
else:
frame_marker = '__sa_%s_exempt__' % method_name
pragma_marker = 'exempt:' + method_name
def method(self, *args, **kw):
frame_r = None
try:
frame = inspect.stack()[1][0]
frame_r = inspect.getframeinfo(frame, 9)
module = frame.f_globals.get('__name__', '')
type_ = type(self)
pragma = _find_pragma(*frame_r[3:5])
exempt = (
(not module.startswith('sqlalchemy')) or
(pragma and pragma_marker in pragma) or
(frame_marker in frame.f_locals) or
('self' in frame.f_locals and
getattr(frame.f_locals['self'], frame_marker, False)))
if exempt:
supermeth = getattr(super(type_, self), method_name, None)
if (supermeth is None or
getattr(supermeth, 'im_func', None) is method):
return fallback(self, *args, **kw)
else:
return supermeth(*args, **kw)
else:
raise AssertionError(
"%s.%s called in %s, line %s in %s" % (
type_.__name__, method_name, module, frame_r[1], frame_r[2]))
finally:
del frame
method.__name__ = method_name
return method
def mapper(type_, *args, **kw):
forbidden = [
('__hash__', 'unhashable', lambda s: id(s)),
('__eq__', 'noncomparable', lambda s, o: s is o),
('__ne__', 'noncomparable', lambda s, o: s is not o),
('__cmp__', 'noncomparable', lambda s, o: object.__cmp__(s, o)),
('__le__', 'noncomparable', lambda s, o: object.__le__(s, o)),
('__lt__', 'noncomparable', lambda s, o: object.__lt__(s, o)),
('__ge__', 'noncomparable', lambda s, o: object.__ge__(s, o)),
('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)),
('__nonzero__', 'truthless', lambda s: 1), ]
if isinstance(type_, type) and type_.__bases__ == (object,):
for method_name, option, fallback in forbidden:
if (getattr(config.options, option, False) and
method_name not in type_.__dict__):
setattr(type_, method_name, _make_blocker(method_name, fallback))
return orm.mapper(type_, *args, **kw)

View File

@@ -0,0 +1,75 @@
"""
some objects used for pickle tests, declared in their own module so that they
are easily pickleable.
"""
class Foo(object):
def __init__(self, moredata):
self.data = 'im data'
self.stuff = 'im stuff'
self.moredata = moredata
__hash__ = object.__hash__
def __eq__(self, other):
return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata
class Bar(object):
def __init__(self, x, y):
self.x = x
self.y = y
__hash__ = object.__hash__
def __eq__(self, other):
return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
class OldSchool:
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y
class OldSchoolWithoutCompare:
def __init__(self, x, y):
self.x = x
self.y = y
class BarWithoutCompare(object):
def __init__(self, x, y):
self.x = x
self.y = y
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
class NotComparable(object):
def __init__(self, data):
self.data = data
def __hash__(self):
return id(self)
def __eq__(self, other):
return NotImplemented
def __ne__(self, other):
return NotImplemented
class BrokenComparable(object):
def __init__(self, data):
self.data = data
def __hash__(self):
return id(self)
def __eq__(self, other):
raise NotImplementedError
def __ne__(self, other):
raise NotImplementedError

View File

@@ -0,0 +1,222 @@
"""Profiling support for unit and performance tests.
These are special purpose profiling methods which operate
in a more fine-grained way than nose's profiling plugin.
"""
import os, sys
from sqlalchemy.test import config
from sqlalchemy.test.util import function_named, gc_collect
from nose import SkipTest
__all__ = 'profiled', 'function_call_count', 'conditional_call_count'
all_targets = set()
profile_config = { 'targets': set(),
'report': True,
'sort': ('time', 'calls'),
'limit': None }
profiler = None
def profiled(target=None, **target_opts):
"""Optional function profiling.
@profiled('label')
or
@profiled('label', report=True, sort=('calls',), limit=20)
Enables profiling for a function when 'label' is targetted for
profiling. Report options can be supplied, and override the global
configuration and command-line options.
"""
# manual or automatic namespacing by module would remove conflict issues
if target is None:
target = 'anonymous_target'
elif target in all_targets:
print "Warning: redefining profile target '%s'" % target
all_targets.add(target)
filename = "%s.prof" % target
def decorator(fn):
def profiled(*args, **kw):
if (target not in profile_config['targets'] and
not target_opts.get('always', None)):
return fn(*args, **kw)
elapsed, load_stats, result = _profile(
filename, fn, *args, **kw)
report = target_opts.get('report', profile_config['report'])
if report:
sort_ = target_opts.get('sort', profile_config['sort'])
limit = target_opts.get('limit', profile_config['limit'])
print "Profile report for target '%s' (%s)" % (
target, filename)
stats = load_stats()
stats.sort_stats(*sort_)
if limit:
stats.print_stats(limit)
else:
stats.print_stats()
#stats.print_callers()
os.unlink(filename)
return result
return function_named(profiled, fn.__name__)
return decorator
def function_call_count(count=None, versions={}, variance=0.05):
"""Assert a target for a test case's function call count.
count
Optional, general target function call count.
versions
Optional, a dictionary of Python version strings to counts,
for example::
{ '2.5.1': 110,
'2.5': 100,
'2.4': 150 }
The best match for the current running python will be used.
If none match, 'count' will be used as the fallback.
variance
An +/- deviation percentage, defaults to 5%.
"""
# this could easily dump the profile report if --verbose is in effect
version_info = list(sys.version_info)
py_version = '.'.join([str(v) for v in sys.version_info])
try:
from sqlalchemy.cprocessors import to_float
cextension = True
except ImportError:
cextension = False
while version_info:
version = '.'.join([str(v) for v in version_info])
if cextension:
version += "+cextension"
if version in versions:
count = versions[version]
break
version_info.pop()
if count is None:
return lambda fn: fn
def decorator(fn):
def counted(*args, **kw):
try:
filename = "%s.prof" % fn.__name__
elapsed, stat_loader, result = _profile(
filename, fn, *args, **kw)
stats = stat_loader()
calls = stats.total_calls
stats.sort_stats('calls', 'cumulative')
stats.print_stats()
#stats.print_callers()
deviance = int(count * variance)
if (calls < (count - deviance) or
calls > (count + deviance)):
raise AssertionError(
"Function call count %s not within %s%% "
"of expected %s. (Python version %s)" % (
calls, (variance * 100), count, py_version))
return result
finally:
if os.path.exists(filename):
os.unlink(filename)
return function_named(counted, fn.__name__)
return decorator
def conditional_call_count(discriminator, categories):
"""Apply a function call count conditionally at runtime.
Takes two arguments, a callable that returns a key value, and a dict
mapping key values to a tuple of arguments to function_call_count.
The callable is not evaluated until the decorated function is actually
invoked. If the `discriminator` returns a key not present in the
`categories` dictionary, no call count assertion is applied.
Useful for integration tests, where running a named test in isolation may
have a function count penalty not seen in the full suite, due to lazy
initialization in the DB-API, SA, etc.
"""
def decorator(fn):
def at_runtime(*args, **kw):
criteria = categories.get(discriminator(), None)
if criteria is None:
return fn(*args, **kw)
rewrapped = function_call_count(*criteria)(fn)
return rewrapped(*args, **kw)
return function_named(at_runtime, fn.__name__)
return decorator
def _profile(filename, fn, *args, **kw):
global profiler
if not profiler:
if sys.version_info > (2, 5):
try:
import cProfile
profiler = 'cProfile'
except ImportError:
pass
if not profiler:
try:
import hotshot
profiler = 'hotshot'
except ImportError:
profiler = 'skip'
if profiler == 'skip':
raise SkipTest('Profiling not supported on this platform')
elif profiler == 'cProfile':
return _profile_cProfile(filename, fn, *args, **kw)
else:
return _profile_hotshot(filename, fn, *args, **kw)
def _profile_cProfile(filename, fn, *args, **kw):
import cProfile, gc, pstats, time
load_stats = lambda: pstats.Stats(filename)
gc_collect()
began = time.time()
cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
filename=filename)
ended = time.time()
return ended - began, load_stats, locals()['result']
def _profile_hotshot(filename, fn, *args, **kw):
import gc, hotshot, hotshot.stats, time
load_stats = lambda: hotshot.stats.load(filename)
gc_collect()
prof = hotshot.Profile(filename)
began = time.time()
prof.start()
try:
result = fn(*args, **kw)
finally:
prof.stop()
ended = time.time()
prof.close()
return ended - began, load_stats, result

259
sqlalchemy/test/requires.py Normal file
View File

@@ -0,0 +1,259 @@
"""Global database feature support policy.
Provides decorators to mark tests requiring specific feature support from the
target database.
"""
from testing import \
_block_unconditionally as no_support, \
_chain_decorators_on, \
exclude, \
emits_warning_on,\
skip_if,\
fails_on
import testing
import sys
def deferrable_constraints(fn):
"""Target database must support derferable constraints."""
return _chain_decorators_on(
fn,
no_support('firebird', 'not supported by database'),
no_support('mysql', 'not supported by database'),
no_support('mssql', 'not supported by database'),
)
def foreign_keys(fn):
"""Target database must support foreign keys."""
return _chain_decorators_on(
fn,
no_support('sqlite', 'not supported by database'),
)
def unbounded_varchar(fn):
"""Target database must support VARCHAR with no length"""
return _chain_decorators_on(
fn,
no_support('firebird', 'not supported by database'),
no_support('oracle', 'not supported by database'),
no_support('mysql', 'not supported by database'),
)
def boolean_col_expressions(fn):
"""Target database must support boolean expressions as columns"""
return _chain_decorators_on(
fn,
no_support('firebird', 'not supported by database'),
no_support('oracle', 'not supported by database'),
no_support('mssql', 'not supported by database'),
no_support('sybase', 'not supported by database'),
no_support('maxdb', 'FIXME: verify not supported by database'),
)
def identity(fn):
"""Target database must support GENERATED AS IDENTITY or a facsimile.
Includes GENERATED AS IDENTITY, AUTOINCREMENT, AUTO_INCREMENT, or other
column DDL feature that fills in a DB-generated identifier at INSERT-time
without requiring pre-execution of a SEQUENCE or other artifact.
"""
return _chain_decorators_on(
fn,
no_support('firebird', 'not supported by database'),
no_support('oracle', 'not supported by database'),
no_support('postgresql', 'not supported by database'),
no_support('sybase', 'not supported by database'),
)
def independent_cursors(fn):
"""Target must support simultaneous, independent database cursors on a single connection."""
return _chain_decorators_on(
fn,
no_support('mssql+pyodbc', 'no driver support'),
no_support('mssql+mxodbc', 'no driver support'),
)
def independent_connections(fn):
"""Target must support simultaneous, independent database connections."""
# This is also true of some configurations of UnixODBC and probably win32
# ODBC as well.
return _chain_decorators_on(
fn,
no_support('sqlite', 'no driver support'),
exclude('mssql', '<', (9, 0, 0),
'SQL Server 2005+ is required for independent connections'),
)
def row_triggers(fn):
"""Target must support standard statement-running EACH ROW triggers."""
return _chain_decorators_on(
fn,
# no access to same table
no_support('mysql', 'requires SUPER priv'),
exclude('mysql', '<', (5, 0, 10), 'not supported by database'),
# huh? TODO: implement triggers for PG tests, remove this
no_support('postgresql', 'PG triggers need to be implemented for tests'),
)
def correlated_outer_joins(fn):
"""Target must support an outer join to a subquery which correlates to the parent."""
return _chain_decorators_on(
fn,
no_support('oracle', 'Raises "ORA-01799: a column may not be outer-joined to a subquery"')
)
def savepoints(fn):
"""Target database must support savepoints."""
return _chain_decorators_on(
fn,
emits_warning_on('mssql', 'Savepoint support in mssql is experimental and may lead to data loss.'),
no_support('access', 'not supported by database'),
no_support('sqlite', 'not supported by database'),
no_support('sybase', 'FIXME: guessing, needs confirmation'),
exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
)
def denormalized_names(fn):
"""Target database must have 'denormalized', i.e. UPPERCASE as case insensitive names."""
return skip_if(
lambda: not testing.db.dialect.requires_name_normalize,
"Backend does not require denomralized names."
)(fn)
def schemas(fn):
"""Target database must support external schemas, and have one named 'test_schema'."""
return _chain_decorators_on(
fn,
no_support('sqlite', 'no schema support'),
no_support('firebird', 'no schema support')
)
def sequences(fn):
"""Target database must support SEQUENCEs."""
return _chain_decorators_on(
fn,
no_support('access', 'no SEQUENCE support'),
no_support('mssql', 'no SEQUENCE support'),
no_support('mysql', 'no SEQUENCE support'),
no_support('sqlite', 'no SEQUENCE support'),
no_support('sybase', 'no SEQUENCE support'),
)
def subqueries(fn):
"""Target database must support subqueries."""
return _chain_decorators_on(
fn,
exclude('mysql', '<', (4, 1, 1), 'no subquery support'),
)
def intersect(fn):
"""Target database must support INTERSECT or equivlaent."""
return _chain_decorators_on(
fn,
fails_on('firebird', 'no support for INTERSECT'),
fails_on('mysql', 'no support for INTERSECT'),
fails_on('sybase', 'no support for INTERSECT'),
)
def except_(fn):
"""Target database must support EXCEPT or equivlaent (i.e. MINUS)."""
return _chain_decorators_on(
fn,
fails_on('firebird', 'no support for EXCEPT'),
fails_on('mysql', 'no support for EXCEPT'),
fails_on('sybase', 'no support for EXCEPT'),
)
def offset(fn):
"""Target database must support some method of adding OFFSET or equivalent to a result set."""
return _chain_decorators_on(
fn,
fails_on('sybase', 'no support for OFFSET or equivalent'),
)
def returning(fn):
return _chain_decorators_on(
fn,
no_support('access', 'not supported by database'),
no_support('sqlite', 'not supported by database'),
no_support('mysql', 'not supported by database'),
no_support('maxdb', 'not supported by database'),
no_support('sybase', 'not supported by database'),
no_support('informix', 'not supported by database'),
)
def two_phase_transactions(fn):
"""Target database must support two-phase transactions."""
return _chain_decorators_on(
fn,
no_support('access', 'not supported by database'),
no_support('firebird', 'no SA implementation'),
no_support('maxdb', 'not supported by database'),
no_support('mssql', 'FIXME: guessing, needs confirmation'),
no_support('oracle', 'no SA implementation'),
no_support('sqlite', 'not supported by database'),
no_support('sybase', 'FIXME: guessing, needs confirmation'),
no_support('postgresql+zxjdbc', 'FIXME: JDBC driver confuses the transaction state, may '
'need separate XA implementation'),
exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
)
def unicode_connections(fn):
"""Target driver must support some encoding of Unicode across the wire."""
# TODO: expand to exclude MySQLdb versions w/ broken unicode
return _chain_decorators_on(
fn,
exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
)
def unicode_ddl(fn):
"""Target driver must support some encoding of Unicode across the wire."""
# TODO: expand to exclude MySQLdb versions w/ broken unicode
return _chain_decorators_on(
fn,
no_support('maxdb', 'database support flakey'),
no_support('oracle', 'FIXME: no support in database?'),
no_support('sybase', 'FIXME: guessing, needs confirmation'),
no_support('mssql+pymssql', 'no FreeTDS support'),
exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
)
def sane_rowcount(fn):
return _chain_decorators_on(
fn,
skip_if(lambda: not testing.db.dialect.supports_sane_rowcount)
)
def python2(fn):
return _chain_decorators_on(
fn,
skip_if(
lambda: sys.version_info >= (3,),
"Python version 2.xx is required."
)
)
def _has_sqlite():
from sqlalchemy import create_engine
try:
e = create_engine('sqlite://')
return True
except ImportError:
return False
def sqlite(fn):
return _chain_decorators_on(
fn,
skip_if(lambda: not _has_sqlite())
)

79
sqlalchemy/test/schema.py Normal file
View File

@@ -0,0 +1,79 @@
"""Enhanced versions of schema.Table and schema.Column which establish
desired state for different backends.
"""
from sqlalchemy.test import testing
from sqlalchemy import schema
__all__ = 'Table', 'Column',
table_options = {}
def Table(*args, **kw):
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
if k.startswith('test_')])
kw.update(table_options)
if testing.against('mysql'):
if 'mysql_engine' not in kw and 'mysql_type' not in kw:
if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
kw['mysql_engine'] = 'InnoDB'
# Apply some default cascading rules for self-referential foreign keys.
# MySQL InnoDB has some issues around seleting self-refs too.
if testing.against('firebird'):
table_name = args[0]
unpack = (testing.config.db.dialect.
identifier_preparer.unformat_identifiers)
# Only going after ForeignKeys in Columns. May need to
# expand to ForeignKeyConstraint too.
fks = [fk
for col in args if isinstance(col, schema.Column)
for fk in col.foreign_keys]
for fk in fks:
# root around in raw spec
ref = fk._colspec
if isinstance(ref, schema.Column):
name = ref.table.name
else:
# take just the table name: on FB there cannot be
# a schema, so the first element is always the
# table name, possibly followed by the field name
name = unpack(ref)[0]
if name == table_name:
if fk.ondelete is None:
fk.ondelete = 'CASCADE'
if fk.onupdate is None:
fk.onupdate = 'CASCADE'
return schema.Table(*args, **kw)
def Column(*args, **kw):
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
if k.startswith('test_')])
col = schema.Column(*args, **kw)
if 'test_needs_autoincrement' in test_opts and \
kw.get('primary_key', False) and \
testing.against('firebird', 'oracle'):
def add_seq(tbl, c):
c._init_items(
schema.Sequence(_truncate_name(testing.db.dialect, tbl.name + '_' + c.name + '_seq'), optional=True)
)
col._on_table_attach(add_seq)
return col
def _truncate_name(dialect, name):
if len(name) > dialect.max_identifier_length:
return name[0:max(dialect.max_identifier_length - 6, 0)] + "_" + hex(hash(name) % 64)[2:]
else:
return name

779
sqlalchemy/test/testing.py Normal file
View File

@@ -0,0 +1,779 @@
"""TestCase and TestSuite artifacts and testing decorators."""
import itertools
import operator
import re
import sys
import types
import warnings
from cStringIO import StringIO
from sqlalchemy.test import config, assertsql, util as testutil
from sqlalchemy.util import function_named, py3k
from engines import drop_all_tables
from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, pool, orm
from sqlalchemy.engine import default
from nose import SkipTest
_ops = { '<': operator.lt,
'>': operator.gt,
'==': operator.eq,
'!=': operator.ne,
'<=': operator.le,
'>=': operator.ge,
'in': operator.contains,
'between': lambda val, pair: val >= pair[0] and val <= pair[1],
}
# sugar ('testing.db'); set here by config() at runtime
db = None
# more sugar, installed by __init__
requires = None
def fails_if(callable_, reason=None):
"""Mark a test as expected to fail if callable_ returns True.
If the callable returns false, the test is run and reported as normal.
However if the callable returns true, the test is expected to fail and the
unit test logic is inverted: if the test fails, a success is reported. If
the test succeeds, a failure is reported.
"""
docstring = getattr(callable_, '__doc__', None) or callable_.__name__
description = docstring.split('\n')[0]
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
if not callable_():
return fn(*args, **kw)
else:
try:
fn(*args, **kw)
except Exception, ex:
print ("'%s' failed as expected (condition: %s): %s " % (
fn_name, description, str(ex)))
return True
else:
raise AssertionError(
"Unexpected success for '%s' (condition: %s)" %
(fn_name, description))
return function_named(maybe, fn_name)
return decorate
def future(fn):
"""Mark a test as expected to unconditionally fail.
Takes no arguments, omit parens when using as a decorator.
"""
fn_name = fn.__name__
def decorated(*args, **kw):
try:
fn(*args, **kw)
except Exception, ex:
print ("Future test '%s' failed as expected: %s " % (
fn_name, str(ex)))
return True
else:
raise AssertionError(
"Unexpected success for future test '%s'" % fn_name)
return function_named(decorated, fn_name)
def db_spec(*dbs):
dialects = set([x for x in dbs if '+' not in x])
drivers = set([x[1:] for x in dbs if x.startswith('+')])
specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers])
def check(engine):
return engine.name in dialects or \
engine.driver in drivers or \
(engine.name, engine.driver) in specs
return check
def fails_on(dbs, reason):
"""Mark a test as expected to fail on the specified database
implementation.
Unlike ``crashes``, tests marked as ``fails_on`` will be run
for the named databases. The test is expected to fail and the unit test
logic is inverted: if the test fails, a success is reported. If the test
succeeds, a failure is reported.
"""
spec = db_spec(dbs)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
if not spec(config.db):
return fn(*args, **kw)
else:
try:
fn(*args, **kw)
except Exception, ex:
print ("'%s' failed as expected on DB implementation "
"'%s+%s': %s" % (
fn_name, config.db.name, config.db.driver, reason))
return True
else:
raise AssertionError(
"Unexpected success for '%s' on DB implementation '%s+%s'" %
(fn_name, config.db.name, config.db.driver))
return function_named(maybe, fn_name)
return decorate
def fails_on_everything_except(*dbs):
"""Mark a test as expected to fail on most database implementations.
Like ``fails_on``, except failure is the expected outcome on all
databases except those listed.
"""
spec = db_spec(*dbs)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
if spec(config.db):
return fn(*args, **kw)
else:
try:
fn(*args, **kw)
except Exception, ex:
print ("'%s' failed as expected on DB implementation "
"'%s+%s': %s" % (
fn_name, config.db.name, config.db.driver, str(ex)))
return True
else:
raise AssertionError(
"Unexpected success for '%s' on DB implementation '%s+%s'" %
(fn_name, config.db.name, config.db.driver))
return function_named(maybe, fn_name)
return decorate
def crashes(db, reason):
"""Mark a test as unsupported by a database implementation.
``crashes`` tests will be skipped unconditionally. Use for feature tests
that cause deadlocks or other fatal problems.
"""
carp = _should_carp_about_exclusion(reason)
spec = db_spec(db)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
if spec(config.db):
msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
fn_name, config.db.name, config.db.driver, reason)
print msg
if carp:
print >> sys.stderr, msg
return True
else:
return fn(*args, **kw)
return function_named(maybe, fn_name)
return decorate
def _block_unconditionally(db, reason):
"""Mark a test as unsupported by a database implementation.
Will never run the test against any version of the given database, ever,
no matter what. Use when your assumptions are infallible; past, present
and future.
"""
carp = _should_carp_about_exclusion(reason)
spec = db_spec(db)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
if spec(config.db):
msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
fn_name, config.db.name, config.db.driver, reason)
print msg
if carp:
print >> sys.stderr, msg
return True
else:
return fn(*args, **kw)
return function_named(maybe, fn_name)
return decorate
def only_on(db, reason):
carp = _should_carp_about_exclusion(reason)
spec = db_spec(db)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
if spec(config.db):
return fn(*args, **kw)
else:
msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
fn_name, config.db.name, config.db.driver, reason)
print msg
if carp:
print >> sys.stderr, msg
return True
return function_named(maybe, fn_name)
return decorate
def exclude(db, op, spec, reason):
"""Mark a test as unsupported by specific database server versions.
Stackable, both with other excludes and other decorators. Examples::
# Not supported by mydb versions less than 1, 0
@exclude('mydb', '<', (1,0))
# Other operators work too
@exclude('bigdb', '==', (9,0,9))
@exclude('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
"""
carp = _should_carp_about_exclusion(reason)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
if _is_excluded(db, op, spec):
msg = "'%s' unsupported on DB %s version '%s': %s" % (
fn_name, config.db.name, _server_version(), reason)
print msg
if carp:
print >> sys.stderr, msg
return True
else:
return fn(*args, **kw)
return function_named(maybe, fn_name)
return decorate
def _should_carp_about_exclusion(reason):
"""Guard against forgotten exclusions."""
assert reason
for _ in ('todo', 'fixme', 'xxx'):
if _ in reason.lower():
return True
else:
if len(reason) < 4:
return True
def _is_excluded(db, op, spec):
"""Return True if the configured db matches an exclusion specification.
db:
A dialect name
op:
An operator or stringified operator, such as '=='
spec:
A value that will be compared to the dialect's server_version_info
using the supplied operator.
Examples::
# Not supported by mydb versions less than 1, 0
_is_excluded('mydb', '<', (1,0))
# Other operators work too
_is_excluded('bigdb', '==', (9,0,9))
_is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
"""
vendor_spec = db_spec(db)
if not vendor_spec(config.db):
return False
version = _server_version()
oper = hasattr(op, '__call__') and op or _ops[op]
return oper(version, spec)
def _server_version(bind=None):
"""Return a server_version_info tuple."""
if bind is None:
bind = config.db
# force metadata to be retrieved
conn = bind.connect()
version = getattr(bind.dialect, 'server_version_info', ())
conn.close()
return version
def skip_if(predicate, reason=None):
"""Skip a test if predicate is true."""
reason = reason or predicate.__name__
carp = _should_carp_about_exclusion(reason)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
if predicate():
msg = "'%s' skipped on DB %s version '%s': %s" % (
fn_name, config.db.name, _server_version(), reason)
print msg
if carp:
print >> sys.stderr, msg
return True
else:
return fn(*args, **kw)
return function_named(maybe, fn_name)
return decorate
def emits_warning(*messages):
"""Mark a test as emitting a warning.
With no arguments, squelches all SAWarning failures. Or pass one or more
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
"""
# TODO: it would be nice to assert that a named warning was
# emitted. should work with some monkeypatching of warnings,
# and may work on non-CPython if they keep to the spirit of
# warnings.showwarning's docstring.
# - update: jython looks ok, it uses cpython's module
def decorate(fn):
def safe(*args, **kw):
# todo: should probably be strict about this, too
filters = [dict(action='ignore',
category=sa_exc.SAPendingDeprecationWarning)]
if not messages:
filters.append(dict(action='ignore',
category=sa_exc.SAWarning))
else:
filters.extend(dict(action='ignore',
message=message,
category=sa_exc.SAWarning)
for message in messages)
for f in filters:
warnings.filterwarnings(**f)
try:
return fn(*args, **kw)
finally:
resetwarnings()
return function_named(safe, fn.__name__)
return decorate
def emits_warning_on(db, *warnings):
"""Mark a test as emitting a warning on a specific dialect.
With no arguments, squelches all SAWarning failures. Or pass one or more
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
"""
spec = db_spec(db)
def decorate(fn):
def maybe(*args, **kw):
if isinstance(db, basestring):
if not spec(config.db):
return fn(*args, **kw)
else:
wrapped = emits_warning(*warnings)(fn)
return wrapped(*args, **kw)
else:
if not _is_excluded(*db):
return fn(*args, **kw)
else:
wrapped = emits_warning(*warnings)(fn)
return wrapped(*args, **kw)
return function_named(maybe, fn.__name__)
return decorate
def uses_deprecated(*messages):
"""Mark a test as immune from fatal deprecation warnings.
With no arguments, squelches all SADeprecationWarning failures.
Or pass one or more strings; these will be matched to the root
of the warning description by warnings.filterwarnings().
As a special case, you may pass a function name prefixed with //
and it will be re-written as needed to match the standard warning
verbiage emitted by the sqlalchemy.util.deprecated decorator.
"""
def decorate(fn):
def safe(*args, **kw):
# todo: should probably be strict about this, too
filters = [dict(action='ignore',
category=sa_exc.SAPendingDeprecationWarning)]
if not messages:
filters.append(dict(action='ignore',
category=sa_exc.SADeprecationWarning))
else:
filters.extend(
[dict(action='ignore',
message=message,
category=sa_exc.SADeprecationWarning)
for message in
[ (m.startswith('//') and
('Call to deprecated function ' + m[2:]) or m)
for m in messages] ])
for f in filters:
warnings.filterwarnings(**f)
try:
return fn(*args, **kw)
finally:
resetwarnings()
return function_named(safe, fn.__name__)
return decorate
def resetwarnings():
"""Reset warning behavior to testing defaults."""
warnings.filterwarnings('ignore',
category=sa_exc.SAPendingDeprecationWarning)
warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
warnings.filterwarnings('error', category=sa_exc.SAWarning)
# warnings.simplefilter('error')
if sys.version_info < (2, 4):
warnings.filterwarnings('ignore', category=FutureWarning)
def global_cleanup_assertions():
"""Check things that have to be finalized at the end of a test suite.
Hardcoded at the moment, a modular system can be built here
to support things like PG prepared transactions, tables all
dropped, etc.
"""
testutil.lazy_gc()
assert not pool._refs
def against(*queries):
"""Boolean predicate, compares to testing database configuration.
Given one or more dialect names, returns True if one is the configured
database engine.
Also supports comparison to database version when provided with one or
more 3-tuples of dialect name, operator, and version specification::
testing.against('mysql', 'postgresql')
testing.against(('mysql', '>=', (5, 0, 0))
"""
for query in queries:
if isinstance(query, basestring):
if db_spec(query)(config.db):
return True
else:
name, op, spec = query
if not db_spec(name)(config.db):
continue
have = _server_version()
oper = hasattr(op, '__call__') and op or _ops[op]
if oper(have, spec):
return True
return False
def _chain_decorators_on(fn, *decorators):
"""Apply a series of decorators to fn, returning a decorated function."""
for decorator in reversed(decorators):
fn = decorator(fn)
return fn
def rowset(results):
"""Converts the results of sql execution into a plain set of column tuples.
Useful for asserting the results of an unordered query.
"""
return set([tuple(row) for row in results])
def eq_(a, b, msg=None):
"""Assert a == b, with repr messaging on failure."""
assert a == b, msg or "%r != %r" % (a, b)
def ne_(a, b, msg=None):
"""Assert a != b, with repr messaging on failure."""
assert a != b, msg or "%r == %r" % (a, b)
def is_(a, b, msg=None):
"""Assert a is b, with repr messaging on failure."""
assert a is b, msg or "%r is not %r" % (a, b)
def is_not_(a, b, msg=None):
"""Assert a is not b, with repr messaging on failure."""
assert a is not b, msg or "%r is %r" % (a, b)
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
a, fragment)
def assert_raises(except_cls, callable_, *args, **kw):
try:
callable_(*args, **kw)
success = False
except except_cls, e:
success = True
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
try:
callable_(*args, **kwargs)
assert False, "Callable did not raise an exception"
except except_cls, e:
assert re.search(msg, str(e)), "%r !~ %s" % (msg, e)
def fail(msg):
assert False, msg
def fixture(table, columns, *rows):
"""Insert data into table after creation."""
def onload(event, schema_item, connection):
insert = table.insert()
column_names = [col.key for col in columns]
connection.execute(insert, [dict(zip(column_names, column_values))
for column_values in rows])
table.append_ddl_listener('after-create', onload)
def resolve_artifact_names(fn):
"""Decorator, augment function globals with tables and classes.
Swaps out the function's globals at execution time. The 'global' statement
will not work as expected inside a decorated function.
"""
# This could be automatically applied to framework and test_ methods in
# the MappedTest-derived test suites but... *some* explicitness for this
# magic is probably good. Especially as 'global' won't work- these
# rebound functions aren't regular Python..
#
# Also: it's lame that CPython accepts a dict-subclass for globals, but
# only calls dict methods. That would allow 'global' to pass through to
# the func_globals.
def resolved(*args, **kwargs):
self = args[0]
context = dict(fn.func_globals)
for source in self._artifact_registries:
context.update(getattr(self, source))
# jython bug #1034
rebound = types.FunctionType(
fn.func_code, context, fn.func_name, fn.func_defaults,
fn.func_closure)
return rebound(*args, **kwargs)
return function_named(resolved, fn.func_name)
class adict(dict):
"""Dict keys available as attributes. Shadows."""
def __getattribute__(self, key):
try:
return self[key]
except KeyError:
return dict.__getattribute__(self, key)
def get_all(self, *keys):
return tuple([self[key] for key in keys])
class TestBase(object):
# A sequence of database names to always run, regardless of the
# constraints below.
__whitelist__ = ()
# A sequence of requirement names matching testing.requires decorators
__requires__ = ()
# A sequence of dialect names to exclude from the test class.
__unsupported_on__ = ()
# If present, test class is only runnable for the *single* specified
# dialect. If you need multiple, use __unsupported_on__ and invert.
__only_on__ = None
# A sequence of no-arg callables. If any are True, the entire testcase is
# skipped.
__skip_if__ = None
_artifact_registries = ()
def assert_(self, val, msg=None):
assert val, msg
class AssertsCompiledSQL(object):
def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, use_default_dialect=False):
if use_default_dialect:
dialect = default.DefaultDialect()
if dialect is None:
dialect = getattr(self, '__dialect__', None)
kw = {}
if params is not None:
kw['column_keys'] = params.keys()
if isinstance(clause, orm.Query):
context = clause._compile_context()
context.statement.use_labels = True
clause = context.statement
c = clause.compile(dialect=dialect, **kw)
param_str = repr(getattr(c, 'params', {}))
# Py3K
#param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
print "\nSQL String:\n" + str(c) + param_str
cc = re.sub(r'[\n\t]', '', str(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
if checkparams is not None:
eq_(c.construct_params(params), checkparams)
class ComparesTables(object):
def assert_tables_equal(self, table, reflected_table, strict_types=False):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
eq_(c.name, reflected_c.name)
assert reflected_c is reflected_table.c[c.name]
eq_(c.primary_key, reflected_c.primary_key)
eq_(c.nullable, reflected_c.nullable)
if strict_types:
assert type(reflected_c.type) is type(c.type), \
"Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
else:
self.assert_types_base(reflected_c, c)
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
if c.server_default:
assert isinstance(reflected_c.server_default,
schema.FetchedValue)
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
assert c1.type._compare_type_affinity(c2.type),\
"On column %r, type '%s' doesn't correspond to type '%s'" % \
(c1.name, c1.type, c2.type)
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
result = list(result)
print repr(result)
self.assert_list(result, class_, objects)
def assert_list(self, result, class_, list):
self.assert_(len(result) == len(list),
"result list is not the same size as test list, " +
"for class " + class_.__name__)
for i in range(0, len(list)):
self.assert_row(class_, result[i], list[i])
def assert_row(self, class_, rowobj, desc):
self.assert_(rowobj.__class__ is class_,
"item class is not " + repr(class_))
for key, value in desc.iteritems():
if isinstance(value, tuple):
if isinstance(value[1], list):
self.assert_list(getattr(rowobj, key), value[0], value[1])
else:
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
self.assert_(getattr(rowobj, key) == value,
"attribute %s value %s does not match %s" % (
key, getattr(rowobj, key), value))
def assert_unordered_result(self, result, cls, *expected):
"""As assert_result, but the order of objects is not considered.
The algorithm is very expensive but not a big deal for the small
numbers of rows that the test suite manipulates.
"""
class frozendict(dict):
def __hash__(self):
return id(self)
found = util.IdentitySet(result)
expected = set([frozendict(e) for e in expected])
for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
fail('Unexpected type "%s", expected "%s"' % (
type(wrong).__name__, cls.__name__))
if len(found) != len(expected):
fail('Unexpected object count "%s", expected "%s"' % (
len(found), len(expected)))
NOVALUE = object()
def _compare_item(obj, spec):
for key, value in spec.iteritems():
if isinstance(value, tuple):
try:
self.assert_unordered_result(
getattr(obj, key), value[0], *value[1])
except AssertionError:
return False
else:
if getattr(obj, key, NOVALUE) != value:
return False
return True
for expected_item in expected:
for found_item in found:
if _compare_item(found_item, expected_item):
found.remove(found_item)
break
else:
fail(
"Expected %s instance with attributes %s not found." % (
cls.__name__, repr(expected_item)))
return True
def assert_sql_execution(self, db, callable_, *rules):
assertsql.asserter.add_rules(rules)
try:
callable_()
assertsql.asserter.statement_complete()
finally:
assertsql.asserter.clear_rules()
def assert_sql(self, db, callable_, list_, with_sequences=None):
if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'):
rules = with_sequences
else:
rules = list_
newrules = []
for rule in rules:
if isinstance(rule, dict):
newrule = assertsql.AllOf(*[
assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
])
else:
newrule = assertsql.ExactSQL(*rule)
newrules.append(newrule)
self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))

53
sqlalchemy/test/util.py Normal file
View File

@@ -0,0 +1,53 @@
from sqlalchemy.util import jython, function_named
import gc
import time
if jython:
def gc_collect(*args):
"""aggressive gc.collect for tests."""
gc.collect()
time.sleep(0.1)
gc.collect()
gc.collect()
return 0
# "lazy" gc, for VM's that don't GC on refcount == 0
lazy_gc = gc_collect
else:
# assume CPython - straight gc.collect, lazy_gc() is a pass
gc_collect = gc.collect
def lazy_gc():
pass
def picklers():
picklers = set()
# Py2K
try:
import cPickle
picklers.add(cPickle)
except ImportError:
pass
# end Py2K
import pickle
picklers.add(pickle)
# yes, this thing needs this much testing
for pickle in picklers:
for protocol in -1, 0, 1, 2:
yield pickle.loads, lambda d:pickle.dumps(d, protocol)
def round_decimal(value, prec):
if isinstance(value, float):
return round(value, prec)
import decimal
# can also use shift() here but that is 2.6 only
return (value * decimal.Decimal("1" + "0" * prec)).to_integral(decimal.ROUND_FLOOR) / \
pow(10, prec)