morro
This commit is contained in:
26
sqlalchemy/test/__init__.py
Normal file
26
sqlalchemy/test/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Testing environment and utilities.
|
||||
|
||||
This package contains base classes and routines used by
|
||||
the unit tests. Tests are based on Nose and bootstrapped
|
||||
by noseplugin.NoseSQLAlchemy.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.test import testing, engines, requires, profiling, pickleable, config
|
||||
from sqlalchemy.test.schema import Column, Table
|
||||
from sqlalchemy.test.testing import \
|
||||
AssertsCompiledSQL, \
|
||||
AssertsExecutionResults, \
|
||||
ComparesTables, \
|
||||
TestBase, \
|
||||
rowset
|
||||
|
||||
|
||||
__all__ = ('testing',
|
||||
'Column', 'Table',
|
||||
'rowset',
|
||||
'TestBase', 'AssertsExecutionResults',
|
||||
'AssertsCompiledSQL', 'ComparesTables',
|
||||
'engines', 'profiling', 'pickleable')
|
||||
|
||||
|
285
sqlalchemy/test/assertsql.py
Normal file
285
sqlalchemy/test/assertsql.py
Normal file
@@ -0,0 +1,285 @@
|
||||
|
||||
from sqlalchemy.interfaces import ConnectionProxy
|
||||
from sqlalchemy.engine.default import DefaultDialect
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy import util
|
||||
import re
|
||||
|
||||
class AssertRule(object):
|
||||
def process_execute(self, clauseelement, *multiparams, **params):
|
||||
pass
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context, executemany):
|
||||
pass
|
||||
|
||||
def is_consumed(self):
|
||||
"""Return True if this rule has been consumed, False if not.
|
||||
|
||||
Should raise an AssertionError if this rule's condition has definitely failed.
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def rule_passed(self):
|
||||
"""Return True if the last test of this rule passed, False if failed, None if no test was applied."""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def consume_final(self):
|
||||
"""Return True if this rule has been consumed.
|
||||
|
||||
Should raise an AssertionError if this rule's condition has not been consumed or has failed.
|
||||
|
||||
"""
|
||||
|
||||
if self._result is None:
|
||||
assert False, "Rule has not been consumed"
|
||||
|
||||
return self.is_consumed()
|
||||
|
||||
class SQLMatchRule(AssertRule):
|
||||
def __init__(self):
|
||||
self._result = None
|
||||
self._errmsg = ""
|
||||
|
||||
def rule_passed(self):
|
||||
return self._result
|
||||
|
||||
def is_consumed(self):
|
||||
if self._result is None:
|
||||
return False
|
||||
|
||||
assert self._result, self._errmsg
|
||||
|
||||
return True
|
||||
|
||||
class ExactSQL(SQLMatchRule):
|
||||
def __init__(self, sql, params=None):
|
||||
SQLMatchRule.__init__(self)
|
||||
self.sql = sql
|
||||
self.params = params
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context, executemany):
|
||||
if not context:
|
||||
return
|
||||
|
||||
_received_statement = _process_engine_statement(context.unicode_statement, context)
|
||||
_received_parameters = context.compiled_parameters
|
||||
|
||||
# TODO: remove this step once all unit tests
|
||||
# are migrated, as ExactSQL should really be *exact* SQL
|
||||
sql = _process_assertion_statement(self.sql, context)
|
||||
|
||||
equivalent = _received_statement == sql
|
||||
if self.params:
|
||||
if util.callable(self.params):
|
||||
params = self.params(context)
|
||||
else:
|
||||
params = self.params
|
||||
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
equivalent = equivalent and params == context.compiled_parameters
|
||||
else:
|
||||
params = {}
|
||||
|
||||
|
||||
self._result = equivalent
|
||||
if not self._result:
|
||||
self._errmsg = "Testing for exact statement %r exact params %r, " \
|
||||
"received %r with params %r" % (sql, params, _received_statement, _received_parameters)
|
||||
|
||||
|
||||
class RegexSQL(SQLMatchRule):
|
||||
def __init__(self, regex, params=None):
|
||||
SQLMatchRule.__init__(self)
|
||||
self.regex = re.compile(regex)
|
||||
self.orig_regex = regex
|
||||
self.params = params
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context, executemany):
|
||||
if not context:
|
||||
return
|
||||
|
||||
_received_statement = _process_engine_statement(context.unicode_statement, context)
|
||||
_received_parameters = context.compiled_parameters
|
||||
|
||||
equivalent = bool(self.regex.match(_received_statement))
|
||||
if self.params:
|
||||
if util.callable(self.params):
|
||||
params = self.params(context)
|
||||
else:
|
||||
params = self.params
|
||||
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
|
||||
# do a positive compare only
|
||||
for param, received in zip(params, _received_parameters):
|
||||
for k, v in param.iteritems():
|
||||
if k not in received or received[k] != v:
|
||||
equivalent = False
|
||||
break
|
||||
else:
|
||||
params = {}
|
||||
|
||||
self._result = equivalent
|
||||
if not self._result:
|
||||
self._errmsg = "Testing for regex %r partial params %r, "\
|
||||
"received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters)
|
||||
|
||||
class CompiledSQL(SQLMatchRule):
|
||||
def __init__(self, statement, params):
|
||||
SQLMatchRule.__init__(self)
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context, executemany):
|
||||
if not context:
|
||||
return
|
||||
|
||||
_received_parameters = context.compiled_parameters
|
||||
|
||||
# recompile from the context, using the default dialect
|
||||
compiled = context.compiled.statement.\
|
||||
compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys)
|
||||
|
||||
_received_statement = re.sub(r'\n', '', str(compiled))
|
||||
|
||||
equivalent = self.statement == _received_statement
|
||||
if self.params:
|
||||
if util.callable(self.params):
|
||||
params = self.params(context)
|
||||
else:
|
||||
params = self.params
|
||||
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
|
||||
# do a positive compare only
|
||||
for param, received in zip(params, _received_parameters):
|
||||
for k, v in param.iteritems():
|
||||
if k not in received or received[k] != v:
|
||||
equivalent = False
|
||||
break
|
||||
else:
|
||||
params = {}
|
||||
|
||||
self._result = equivalent
|
||||
if not self._result:
|
||||
self._errmsg = "Testing for compiled statement %r partial params %r, " \
|
||||
"received %r with params %r" % (self.statement, params, _received_statement, _received_parameters)
|
||||
|
||||
|
||||
class CountStatements(AssertRule):
|
||||
def __init__(self, count):
|
||||
self.count = count
|
||||
self._statement_count = 0
|
||||
|
||||
def process_execute(self, clauseelement, *multiparams, **params):
|
||||
self._statement_count += 1
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context, executemany):
|
||||
pass
|
||||
|
||||
def is_consumed(self):
|
||||
return False
|
||||
|
||||
def consume_final(self):
|
||||
assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count)
|
||||
return True
|
||||
|
||||
class AllOf(AssertRule):
|
||||
def __init__(self, *rules):
|
||||
self.rules = set(rules)
|
||||
|
||||
def process_execute(self, clauseelement, *multiparams, **params):
|
||||
for rule in self.rules:
|
||||
rule.process_execute(clauseelement, *multiparams, **params)
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context, executemany):
|
||||
for rule in self.rules:
|
||||
rule.process_cursor_execute(statement, parameters, context, executemany)
|
||||
|
||||
def is_consumed(self):
|
||||
if not self.rules:
|
||||
return True
|
||||
|
||||
for rule in list(self.rules):
|
||||
if rule.rule_passed(): # a rule passed, move on
|
||||
self.rules.remove(rule)
|
||||
return len(self.rules) == 0
|
||||
|
||||
assert False, "No assertion rules were satisfied for statement"
|
||||
|
||||
def consume_final(self):
|
||||
return len(self.rules) == 0
|
||||
|
||||
def _process_engine_statement(query, context):
|
||||
if util.jython:
|
||||
# oracle+zxjdbc passes a PyStatement when returning into
|
||||
query = unicode(query)
|
||||
if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'):
|
||||
query = query[:-25]
|
||||
|
||||
query = re.sub(r'\n', '', query)
|
||||
|
||||
return query
|
||||
|
||||
def _process_assertion_statement(query, context):
|
||||
paramstyle = context.dialect.paramstyle
|
||||
if paramstyle == 'named':
|
||||
pass
|
||||
elif paramstyle =='pyformat':
|
||||
query = re.sub(r':([\w_]+)', r"%(\1)s", query)
|
||||
else:
|
||||
# positional params
|
||||
repl = None
|
||||
if paramstyle=='qmark':
|
||||
repl = "?"
|
||||
elif paramstyle=='format':
|
||||
repl = r"%s"
|
||||
elif paramstyle=='numeric':
|
||||
repl = None
|
||||
query = re.sub(r':([\w_]+)', repl, query)
|
||||
|
||||
return query
|
||||
|
||||
class SQLAssert(ConnectionProxy):
|
||||
rules = None
|
||||
|
||||
def add_rules(self, rules):
|
||||
self.rules = list(rules)
|
||||
|
||||
def statement_complete(self):
|
||||
for rule in self.rules:
|
||||
if not rule.consume_final():
|
||||
assert False, "All statements are complete, but pending assertion rules remain"
|
||||
|
||||
def clear_rules(self):
|
||||
del self.rules
|
||||
|
||||
def execute(self, conn, execute, clauseelement, *multiparams, **params):
|
||||
result = execute(clauseelement, *multiparams, **params)
|
||||
|
||||
if self.rules is not None:
|
||||
if not self.rules:
|
||||
assert False, "All rules have been exhausted, but further statements remain"
|
||||
rule = self.rules[0]
|
||||
rule.process_execute(clauseelement, *multiparams, **params)
|
||||
if rule.is_consumed():
|
||||
self.rules.pop(0)
|
||||
|
||||
return result
|
||||
|
||||
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
|
||||
result = execute(cursor, statement, parameters, context)
|
||||
|
||||
if self.rules:
|
||||
rule = self.rules[0]
|
||||
rule.process_cursor_execute(statement, parameters, context, executemany)
|
||||
|
||||
return result
|
||||
|
||||
asserter = SQLAssert()
|
||||
|
180
sqlalchemy/test/config.py
Normal file
180
sqlalchemy/test/config.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import optparse, os, sys, re, ConfigParser, time, warnings
|
||||
|
||||
|
||||
# 2to3
|
||||
import StringIO
|
||||
|
||||
logging = None
|
||||
|
||||
__all__ = 'parser', 'configure', 'options',
|
||||
|
||||
db = None
|
||||
db_label, db_url, db_opts = None, None, {}
|
||||
|
||||
options = None
|
||||
file_config = None
|
||||
|
||||
base_config = """
|
||||
[db]
|
||||
sqlite=sqlite:///:memory:
|
||||
sqlite_file=sqlite:///querytest.db
|
||||
postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
|
||||
postgres=postgresql://scott:tiger@127.0.0.1:5432/test
|
||||
pg8000=postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
|
||||
postgresql_jython=postgresql+zxjdbc://scott:tiger@127.0.0.1:5432/test
|
||||
mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test
|
||||
mysql=mysql://scott:tiger@127.0.0.1:3306/test
|
||||
oracle=oracle://scott:tiger@127.0.0.1:1521
|
||||
oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
|
||||
mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
|
||||
firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb
|
||||
maxdb=maxdb://MONA:RED@/maxdb1
|
||||
"""
|
||||
|
||||
def _log(option, opt_str, value, parser):
|
||||
global logging
|
||||
if not logging:
|
||||
import logging
|
||||
logging.basicConfig()
|
||||
|
||||
if opt_str.endswith('-info'):
|
||||
logging.getLogger(value).setLevel(logging.INFO)
|
||||
elif opt_str.endswith('-debug'):
|
||||
logging.getLogger(value).setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def _list_dbs(*args):
|
||||
print "Available --db options (use --dburi to override)"
|
||||
for macro in sorted(file_config.options('db')):
|
||||
print "%20s\t%s" % (macro, file_config.get('db', macro))
|
||||
sys.exit(0)
|
||||
|
||||
def _server_side_cursors(options, opt_str, value, parser):
|
||||
db_opts['server_side_cursors'] = True
|
||||
|
||||
def _engine_strategy(options, opt_str, value, parser):
|
||||
if value:
|
||||
db_opts['strategy'] = value
|
||||
|
||||
class _ordered_map(object):
|
||||
def __init__(self):
|
||||
self._keys = list()
|
||||
self._data = dict()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key not in self._keys:
|
||||
self._keys.append(key)
|
||||
self._data[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
for key in self._keys:
|
||||
yield self._data[key]
|
||||
|
||||
# at one point in refactoring, modules were injecting into the config
|
||||
# process. this could probably just become a list now.
|
||||
post_configure = _ordered_map()
|
||||
|
||||
def _engine_uri(options, file_config):
|
||||
global db_label, db_url
|
||||
db_label = 'sqlite'
|
||||
if options.dburi:
|
||||
db_url = options.dburi
|
||||
db_label = db_url[:db_url.index(':')]
|
||||
elif options.db:
|
||||
db_label = options.db
|
||||
db_url = None
|
||||
|
||||
if db_url is None:
|
||||
if db_label not in file_config.options('db'):
|
||||
raise RuntimeError(
|
||||
"Unknown engine. Specify --dbs for known engines.")
|
||||
db_url = file_config.get('db', db_label)
|
||||
post_configure['engine_uri'] = _engine_uri
|
||||
|
||||
def _require(options, file_config):
|
||||
if not(options.require or
|
||||
(file_config.has_section('require') and
|
||||
file_config.items('require'))):
|
||||
return
|
||||
|
||||
try:
|
||||
import pkg_resources
|
||||
except ImportError:
|
||||
raise RuntimeError("setuptools is required for version requirements")
|
||||
|
||||
cmdline = []
|
||||
for requirement in options.require:
|
||||
pkg_resources.require(requirement)
|
||||
cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
|
||||
|
||||
if file_config.has_section('require'):
|
||||
for label, requirement in file_config.items('require'):
|
||||
if not label == db_label or label.startswith('%s.' % db_label):
|
||||
continue
|
||||
seen = [c for c in cmdline if requirement.startswith(c)]
|
||||
if seen:
|
||||
continue
|
||||
pkg_resources.require(requirement)
|
||||
post_configure['require'] = _require
|
||||
|
||||
def _engine_pool(options, file_config):
|
||||
if options.mockpool:
|
||||
from sqlalchemy import pool
|
||||
db_opts['poolclass'] = pool.AssertionPool
|
||||
post_configure['engine_pool'] = _engine_pool
|
||||
|
||||
def _create_testing_engine(options, file_config):
|
||||
from sqlalchemy.test import engines, testing
|
||||
global db
|
||||
db = engines.testing_engine(db_url, db_opts)
|
||||
testing.db = db
|
||||
post_configure['create_engine'] = _create_testing_engine
|
||||
|
||||
def _prep_testing_database(options, file_config):
|
||||
from sqlalchemy.test import engines
|
||||
from sqlalchemy import schema
|
||||
|
||||
# also create alt schemas etc. here?
|
||||
if options.dropfirst:
|
||||
e = engines.utf8_engine()
|
||||
existing = e.table_names()
|
||||
if existing:
|
||||
print "Dropping existing tables in database: " + db_url
|
||||
try:
|
||||
print "Tables: %s" % ', '.join(existing)
|
||||
except:
|
||||
pass
|
||||
print "Abort within 5 seconds..."
|
||||
time.sleep(5)
|
||||
md = schema.MetaData(e, reflect=True)
|
||||
md.drop_all()
|
||||
e.dispose()
|
||||
|
||||
post_configure['prep_db'] = _prep_testing_database
|
||||
|
||||
def _set_table_options(options, file_config):
|
||||
from sqlalchemy.test import schema
|
||||
|
||||
table_options = schema.table_options
|
||||
for spec in options.tableopts:
|
||||
key, value = spec.split('=')
|
||||
table_options[key] = value
|
||||
|
||||
if options.mysql_engine:
|
||||
table_options['mysql_engine'] = options.mysql_engine
|
||||
post_configure['table_options'] = _set_table_options
|
||||
|
||||
def _reverse_topological(options, file_config):
|
||||
if options.reversetop:
|
||||
from sqlalchemy.orm import unitofwork
|
||||
from sqlalchemy import topological
|
||||
class RevQueueDepSort(topological.QueueDependencySorter):
|
||||
def __init__(self, tuples, allitems):
|
||||
self.tuples = list(tuples)
|
||||
self.allitems = list(allitems)
|
||||
self.tuples.reverse()
|
||||
self.allitems.reverse()
|
||||
topological.QueueDependencySorter = RevQueueDepSort
|
||||
unitofwork.DependencySorter = RevQueueDepSort
|
||||
post_configure['topological'] = _reverse_topological
|
||||
|
300
sqlalchemy/test/engines.py
Normal file
300
sqlalchemy/test/engines.py
Normal file
@@ -0,0 +1,300 @@
|
||||
import sys, types, weakref
|
||||
from collections import deque
|
||||
import config
|
||||
from sqlalchemy.util import function_named, callable
|
||||
import re
|
||||
import warnings
|
||||
|
||||
class ConnectionKiller(object):
|
||||
def __init__(self):
|
||||
self.proxy_refs = weakref.WeakKeyDictionary()
|
||||
|
||||
def checkout(self, dbapi_con, con_record, con_proxy):
|
||||
self.proxy_refs[con_proxy] = True
|
||||
|
||||
def _apply_all(self, methods):
|
||||
# must copy keys atomically
|
||||
for rec in self.proxy_refs.keys():
|
||||
if rec is not None and rec.is_valid:
|
||||
try:
|
||||
for name in methods:
|
||||
if callable(name):
|
||||
name(rec)
|
||||
else:
|
||||
getattr(rec, name)()
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
raise
|
||||
except Exception, e:
|
||||
warnings.warn("testing_reaper couldn't close connection: %s" % e)
|
||||
|
||||
def rollback_all(self):
|
||||
self._apply_all(('rollback',))
|
||||
|
||||
def close_all(self):
|
||||
self._apply_all(('rollback', 'close'))
|
||||
|
||||
def assert_all_closed(self):
|
||||
for rec in self.proxy_refs:
|
||||
if rec.is_valid:
|
||||
assert False
|
||||
|
||||
testing_reaper = ConnectionKiller()
|
||||
|
||||
def drop_all_tables(metadata):
|
||||
testing_reaper.close_all()
|
||||
metadata.drop_all()
|
||||
|
||||
def assert_conns_closed(fn):
|
||||
def decorated(*args, **kw):
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.assert_all_closed()
|
||||
return function_named(decorated, fn.__name__)
|
||||
|
||||
def rollback_open_connections(fn):
|
||||
"""Decorator that rolls back all open connections after fn execution."""
|
||||
|
||||
def decorated(*args, **kw):
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.rollback_all()
|
||||
return function_named(decorated, fn.__name__)
|
||||
|
||||
def close_first(fn):
|
||||
"""Decorator that closes all connections before fn execution."""
|
||||
def decorated(*args, **kw):
|
||||
testing_reaper.close_all()
|
||||
fn(*args, **kw)
|
||||
return function_named(decorated, fn.__name__)
|
||||
|
||||
|
||||
def close_open_connections(fn):
|
||||
"""Decorator that closes all connections after fn execution."""
|
||||
|
||||
def decorated(*args, **kw):
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.close_all()
|
||||
return function_named(decorated, fn.__name__)
|
||||
|
||||
def all_dialects(exclude=None):
|
||||
import sqlalchemy.databases as d
|
||||
for name in d.__all__:
|
||||
# TEMPORARY
|
||||
if exclude and name in exclude:
|
||||
continue
|
||||
mod = getattr(d, name, None)
|
||||
if not mod:
|
||||
mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
|
||||
yield mod.dialect()
|
||||
|
||||
class ReconnectFixture(object):
|
||||
def __init__(self, dbapi):
|
||||
self.dbapi = dbapi
|
||||
self.connections = []
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.dbapi, key)
|
||||
|
||||
def connect(self, *args, **kwargs):
|
||||
conn = self.dbapi.connect(*args, **kwargs)
|
||||
self.connections.append(conn)
|
||||
return conn
|
||||
|
||||
def shutdown(self):
|
||||
for c in list(self.connections):
|
||||
c.close()
|
||||
self.connections = []
|
||||
|
||||
def reconnecting_engine(url=None, options=None):
|
||||
url = url or config.db_url
|
||||
dbapi = config.db.dialect.dbapi
|
||||
if not options:
|
||||
options = {}
|
||||
options['module'] = ReconnectFixture(dbapi)
|
||||
engine = testing_engine(url, options)
|
||||
engine.test_shutdown = engine.dialect.dbapi.shutdown
|
||||
return engine
|
||||
|
||||
def testing_engine(url=None, options=None):
|
||||
"""Produce an engine configured by --options with optional overrides."""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.test.assertsql import asserter
|
||||
|
||||
url = url or config.db_url
|
||||
options = options or config.db_opts
|
||||
|
||||
options.setdefault('proxy', asserter)
|
||||
|
||||
listeners = options.setdefault('listeners', [])
|
||||
listeners.append(testing_reaper)
|
||||
|
||||
engine = create_engine(url, **options)
|
||||
|
||||
# may want to call this, results
|
||||
# in first-connect initializers
|
||||
#engine.connect()
|
||||
|
||||
return engine
|
||||
|
||||
def utf8_engine(url=None, options=None):
|
||||
"""Hook for dialects or drivers that don't handle utf8 by default."""
|
||||
|
||||
from sqlalchemy.engine import url as engine_url
|
||||
|
||||
if config.db.driver == 'mysqldb':
|
||||
dbapi_ver = config.db.dialect.dbapi.version_info
|
||||
if (dbapi_ver < (1, 2, 1) or
|
||||
dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
|
||||
(1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))):
|
||||
raise RuntimeError('Character set support unavailable with this '
|
||||
'driver version: %s' % repr(dbapi_ver))
|
||||
else:
|
||||
url = url or config.db_url
|
||||
url = engine_url.make_url(url)
|
||||
url.query['charset'] = 'utf8'
|
||||
url.query['use_unicode'] = '0'
|
||||
url = str(url)
|
||||
|
||||
return testing_engine(url, options)
|
||||
|
||||
def mock_engine(dialect_name=None):
|
||||
"""Provides a mocking engine based on the current testing.db.
|
||||
|
||||
This is normally used to test DDL generation flow as emitted
|
||||
by an Engine.
|
||||
|
||||
It should not be used in other cases, as assert_compile() and
|
||||
assert_sql_execution() are much better choices with fewer
|
||||
moving parts.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
if not dialect_name:
|
||||
dialect_name = config.db.name
|
||||
|
||||
buffer = []
|
||||
def executor(sql, *a, **kw):
|
||||
buffer.append(sql)
|
||||
def assert_sql(stmts):
|
||||
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
|
||||
assert recv == stmts, recv
|
||||
|
||||
engine = create_engine(dialect_name + '://',
|
||||
strategy='mock', executor=executor)
|
||||
assert not hasattr(engine, 'mock')
|
||||
engine.mock = buffer
|
||||
engine.assert_sql = assert_sql
|
||||
return engine
|
||||
|
||||
class ReplayableSession(object):
|
||||
"""A simple record/playback tool.
|
||||
|
||||
This is *not* a mock testing class. It only records a session for later
|
||||
playback and makes no assertions on call consistency whatsoever. It's
|
||||
unlikely to be suitable for anything other than DB-API recording.
|
||||
|
||||
"""
|
||||
|
||||
Callable = object()
|
||||
NoAttribute = object()
|
||||
Natives = set([getattr(types, t)
|
||||
for t in dir(types) if not t.startswith('_')]). \
|
||||
difference([getattr(types, t)
|
||||
# Py3K
|
||||
#for t in ('FunctionType', 'BuiltinFunctionType',
|
||||
# 'MethodType', 'BuiltinMethodType',
|
||||
# 'LambdaType', )])
|
||||
|
||||
# Py2K
|
||||
for t in ('FunctionType', 'BuiltinFunctionType',
|
||||
'MethodType', 'BuiltinMethodType',
|
||||
'LambdaType', 'UnboundMethodType',)])
|
||||
# end Py2K
|
||||
def __init__(self):
|
||||
self.buffer = deque()
|
||||
|
||||
def recorder(self, base):
|
||||
return self.Recorder(self.buffer, base)
|
||||
|
||||
def player(self):
|
||||
return self.Player(self.buffer)
|
||||
|
||||
class Recorder(object):
|
||||
def __init__(self, buffer, subject):
|
||||
self._buffer = buffer
|
||||
self._subject = subject
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
subject, buffer = [object.__getattribute__(self, x)
|
||||
for x in ('_subject', '_buffer')]
|
||||
|
||||
result = subject(*args, **kw)
|
||||
if type(result) not in ReplayableSession.Natives:
|
||||
buffer.append(ReplayableSession.Callable)
|
||||
return type(self)(buffer, result)
|
||||
else:
|
||||
buffer.append(result)
|
||||
return result
|
||||
|
||||
@property
|
||||
def _sqla_unwrap(self):
|
||||
return self._subject
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return object.__getattribute__(self, key)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
subject, buffer = [object.__getattribute__(self, x)
|
||||
for x in ('_subject', '_buffer')]
|
||||
try:
|
||||
result = type(subject).__getattribute__(subject, key)
|
||||
except AttributeError:
|
||||
buffer.append(ReplayableSession.NoAttribute)
|
||||
raise
|
||||
else:
|
||||
if type(result) not in ReplayableSession.Natives:
|
||||
buffer.append(ReplayableSession.Callable)
|
||||
return type(self)(buffer, result)
|
||||
else:
|
||||
buffer.append(result)
|
||||
return result
|
||||
|
||||
class Player(object):
|
||||
def __init__(self, buffer):
|
||||
self._buffer = buffer
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
buffer = object.__getattribute__(self, '_buffer')
|
||||
result = buffer.popleft()
|
||||
if result is ReplayableSession.Callable:
|
||||
return self
|
||||
else:
|
||||
return result
|
||||
|
||||
@property
|
||||
def _sqla_unwrap(self):
|
||||
return None
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return object.__getattribute__(self, key)
|
||||
except AttributeError:
|
||||
pass
|
||||
buffer = object.__getattribute__(self, '_buffer')
|
||||
result = buffer.popleft()
|
||||
if result is ReplayableSession.Callable:
|
||||
return self
|
||||
elif result is ReplayableSession.NoAttribute:
|
||||
raise AttributeError(key)
|
||||
else:
|
||||
return result
|
||||
|
83
sqlalchemy/test/entities.py
Normal file
83
sqlalchemy/test/entities.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exc as sa_exc
|
||||
|
||||
_repr_stack = set()
|
||||
class BasicEntity(object):
|
||||
def __init__(self, **kw):
|
||||
for key, value in kw.iteritems():
|
||||
setattr(self, key, value)
|
||||
|
||||
def __repr__(self):
|
||||
if id(self) in _repr_stack:
|
||||
return object.__repr__(self)
|
||||
_repr_stack.add(id(self))
|
||||
try:
|
||||
return "%s(%s)" % (
|
||||
(self.__class__.__name__),
|
||||
', '.join(["%s=%r" % (key, getattr(self, key))
|
||||
for key in sorted(self.__dict__.keys())
|
||||
if not key.startswith('_')]))
|
||||
finally:
|
||||
_repr_stack.remove(id(self))
|
||||
|
||||
_recursion_stack = set()
|
||||
class ComparableEntity(BasicEntity):
|
||||
def __hash__(self):
|
||||
return hash(self.__class__)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""'Deep, sparse compare.
|
||||
|
||||
Deeply compare two entities, following the non-None attributes of the
|
||||
non-persisted object, if possible.
|
||||
|
||||
"""
|
||||
if other is self:
|
||||
return True
|
||||
elif not self.__class__ == other.__class__:
|
||||
return False
|
||||
|
||||
if id(self) in _recursion_stack:
|
||||
return True
|
||||
_recursion_stack.add(id(self))
|
||||
|
||||
try:
|
||||
# pick the entity thats not SA persisted as the source
|
||||
try:
|
||||
self_key = sa.orm.attributes.instance_state(self).key
|
||||
except sa.orm.exc.NO_STATE:
|
||||
self_key = None
|
||||
|
||||
if other is None:
|
||||
a = self
|
||||
b = other
|
||||
elif self_key is not None:
|
||||
a = other
|
||||
b = self
|
||||
else:
|
||||
a = self
|
||||
b = other
|
||||
|
||||
for attr in a.__dict__.keys():
|
||||
if attr.startswith('_'):
|
||||
continue
|
||||
value = getattr(a, attr)
|
||||
|
||||
try:
|
||||
# handle lazy loader errors
|
||||
battr = getattr(b, attr)
|
||||
except (AttributeError, sa_exc.UnboundExecutionError):
|
||||
return False
|
||||
|
||||
if hasattr(value, '__iter__'):
|
||||
if list(value) != list(battr):
|
||||
return False
|
||||
else:
|
||||
if value is not None and value != battr:
|
||||
return False
|
||||
return True
|
||||
finally:
|
||||
_recursion_stack.remove(id(self))
|
162
sqlalchemy/test/noseplugin.py
Normal file
162
sqlalchemy/test/noseplugin.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
import ConfigParser
|
||||
import StringIO
|
||||
|
||||
import nose.case
|
||||
from nose.plugins import Plugin
|
||||
|
||||
from sqlalchemy import util, log as sqla_log
|
||||
from sqlalchemy.test import testing, config, requires
|
||||
from sqlalchemy.test.config import (
|
||||
_create_testing_engine, _engine_pool, _engine_strategy, _engine_uri, _list_dbs, _log,
|
||||
_prep_testing_database, _require, _reverse_topological, _server_side_cursors,
|
||||
_set_table_options, base_config, db, db_label, db_url, file_config, post_configure)
|
||||
|
||||
log = logging.getLogger('nose.plugins.sqlalchemy')
|
||||
|
||||
class NoseSQLAlchemy(Plugin):
|
||||
"""
|
||||
Handles the setup and extra properties required for testing SQLAlchemy
|
||||
"""
|
||||
enabled = True
|
||||
name = 'sqlalchemy'
|
||||
score = 100
|
||||
|
||||
def options(self, parser, env=os.environ):
|
||||
Plugin.options(self, parser, env)
|
||||
opt = parser.add_option
|
||||
opt("--log-info", action="callback", type="string", callback=_log,
|
||||
help="turn on info logging for <LOG> (multiple OK)")
|
||||
opt("--log-debug", action="callback", type="string", callback=_log,
|
||||
help="turn on debug logging for <LOG> (multiple OK)")
|
||||
opt("--require", action="append", dest="require", default=[],
|
||||
help="require a particular driver or module version (multiple OK)")
|
||||
opt("--db", action="store", dest="db", default="sqlite",
|
||||
help="Use prefab database uri")
|
||||
opt('--dbs', action='callback', callback=_list_dbs,
|
||||
help="List available prefab dbs")
|
||||
opt("--dburi", action="store", dest="dburi",
|
||||
help="Database uri (overrides --db)")
|
||||
opt("--dropfirst", action="store_true", dest="dropfirst",
|
||||
help="Drop all tables in the target database first (use with caution on Oracle, "
|
||||
"MS-SQL)")
|
||||
opt("--mockpool", action="store_true", dest="mockpool",
|
||||
help="Use mock pool (asserts only one connection used)")
|
||||
opt("--enginestrategy", action="callback", type="string",
|
||||
callback=_engine_strategy,
|
||||
help="Engine strategy (plain or threadlocal, defaults to plain)")
|
||||
opt("--reversetop", action="store_true", dest="reversetop", default=False,
|
||||
help="Reverse the collection ordering for topological sorts (helps "
|
||||
"reveal dependency issues)")
|
||||
opt("--unhashable", action="store_true", dest="unhashable", default=False,
|
||||
help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
|
||||
opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
|
||||
help="Disallow SQLAlchemy from performing == on mapped test objects.")
|
||||
opt("--truthless", action="store_true", dest="truthless", default=False,
|
||||
help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
|
||||
opt("--serverside", action="callback", callback=_server_side_cursors,
|
||||
help="Turn on server side cursors for PG")
|
||||
opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
|
||||
help="Use the specified MySQL storage engine for all tables, default is "
|
||||
"a db-default/InnoDB combo.")
|
||||
opt("--table-option", action="append", dest="tableopts", default=[],
|
||||
help="Add a dialect-specific table option, key=value")
|
||||
|
||||
global file_config
|
||||
file_config = ConfigParser.ConfigParser()
|
||||
file_config.readfp(StringIO.StringIO(base_config))
|
||||
file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
|
||||
config.file_config = file_config
|
||||
|
||||
def configure(self, options, conf):
|
||||
Plugin.configure(self, options, conf)
|
||||
self.options = options
|
||||
|
||||
def begin(self):
|
||||
testing.db = db
|
||||
testing.requires = requires
|
||||
|
||||
# Lazy setup of other options (post coverage)
|
||||
for fn in post_configure:
|
||||
fn(self.options, file_config)
|
||||
|
||||
def describeTest(self, test):
|
||||
return ""
|
||||
|
||||
def wantClass(self, cls):
|
||||
"""Return true if you want the main test selector to collect
|
||||
tests from this class, false if you don't, and None if you don't
|
||||
care.
|
||||
|
||||
:Parameters:
|
||||
cls : class
|
||||
The class being examined by the selector
|
||||
|
||||
"""
|
||||
|
||||
if not issubclass(cls, testing.TestBase):
|
||||
return False
|
||||
else:
|
||||
if (hasattr(cls, '__whitelist__') and testing.db.name in cls.__whitelist__):
|
||||
return True
|
||||
else:
|
||||
return not self.__should_skip_for(cls)
|
||||
|
||||
def __should_skip_for(self, cls):
|
||||
if hasattr(cls, '__requires__'):
|
||||
def test_suite(): return 'ok'
|
||||
test_suite.__name__ = cls.__name__
|
||||
for requirement in cls.__requires__:
|
||||
check = getattr(requires, requirement)
|
||||
if check(test_suite)() != 'ok':
|
||||
# The requirement will perform messaging.
|
||||
return True
|
||||
|
||||
if cls.__unsupported_on__:
|
||||
spec = testing.db_spec(*cls.__unsupported_on__)
|
||||
if spec(testing.db):
|
||||
print "'%s' unsupported on DB implementation '%s'" % (
|
||||
cls.__class__.__name__, testing.db.name)
|
||||
return True
|
||||
|
||||
if getattr(cls, '__only_on__', None):
|
||||
spec = testing.db_spec(*util.to_list(cls.__only_on__))
|
||||
if not spec(testing.db):
|
||||
print "'%s' unsupported on DB implementation '%s'" % (
|
||||
cls.__class__.__name__, testing.db.name)
|
||||
return True
|
||||
|
||||
if getattr(cls, '__skip_if__', False):
|
||||
for c in getattr(cls, '__skip_if__'):
|
||||
if c():
|
||||
print "'%s' skipped by %s" % (
|
||||
cls.__class__.__name__, c.__name__)
|
||||
return True
|
||||
|
||||
for rule in getattr(cls, '__excluded_on__', ()):
|
||||
if testing._is_excluded(*rule):
|
||||
print "'%s' unsupported on DB %s version %s" % (
|
||||
cls.__class__.__name__, testing.db.name,
|
||||
_server_version())
|
||||
return True
|
||||
return False
|
||||
|
||||
def beforeTest(self, test):
|
||||
testing.resetwarnings()
|
||||
|
||||
def afterTest(self, test):
|
||||
testing.resetwarnings()
|
||||
|
||||
def afterContext(self):
|
||||
testing.global_cleanup_assertions()
|
||||
|
||||
#def handleError(self, test, err):
|
||||
#pass
|
||||
|
||||
#def finalize(self, result=None):
|
||||
#pass
|
111
sqlalchemy/test/orm.py
Normal file
111
sqlalchemy/test/orm.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import inspect, re
|
||||
import config, testing
|
||||
from sqlalchemy import orm
|
||||
|
||||
__all__ = 'mapper',
|
||||
|
||||
|
||||
_whitespace = re.compile(r'^(\s+)')
|
||||
|
||||
def _find_pragma(lines, current):
|
||||
m = _whitespace.match(lines[current])
|
||||
basis = m and m.group() or ''
|
||||
|
||||
for line in reversed(lines[0:current]):
|
||||
if 'testlib.pragma' in line:
|
||||
return line
|
||||
m = _whitespace.match(line)
|
||||
indent = m and m.group() or ''
|
||||
|
||||
# simplistic detection:
|
||||
|
||||
# >> # testlib.pragma foo
|
||||
# >> center_line()
|
||||
if indent == basis:
|
||||
break
|
||||
# >> # testlib.pragma foo
|
||||
# >> if fleem:
|
||||
# >> center_line()
|
||||
if line.endswith(':'):
|
||||
break
|
||||
return None
|
||||
|
||||
def _make_blocker(method_name, fallback):
|
||||
"""Creates tripwired variant of a method, raising when called.
|
||||
|
||||
To excempt an invocation from blockage, there are two options.
|
||||
|
||||
1) add a pragma in a comment::
|
||||
|
||||
# testlib.pragma exempt:methodname
|
||||
offending_line()
|
||||
|
||||
2) add a magic cookie to the function's namespace::
|
||||
__sa_baremethodname_exempt__ = True
|
||||
...
|
||||
offending_line()
|
||||
another_offending_lines()
|
||||
|
||||
The second is useful for testing and development.
|
||||
"""
|
||||
|
||||
if method_name.startswith('__') and method_name.endswith('__'):
|
||||
frame_marker = '__sa_%s_exempt__' % method_name[2:-2]
|
||||
else:
|
||||
frame_marker = '__sa_%s_exempt__' % method_name
|
||||
pragma_marker = 'exempt:' + method_name
|
||||
|
||||
def method(self, *args, **kw):
|
||||
frame_r = None
|
||||
try:
|
||||
frame = inspect.stack()[1][0]
|
||||
frame_r = inspect.getframeinfo(frame, 9)
|
||||
|
||||
module = frame.f_globals.get('__name__', '')
|
||||
|
||||
type_ = type(self)
|
||||
|
||||
pragma = _find_pragma(*frame_r[3:5])
|
||||
|
||||
exempt = (
|
||||
(not module.startswith('sqlalchemy')) or
|
||||
(pragma and pragma_marker in pragma) or
|
||||
(frame_marker in frame.f_locals) or
|
||||
('self' in frame.f_locals and
|
||||
getattr(frame.f_locals['self'], frame_marker, False)))
|
||||
|
||||
if exempt:
|
||||
supermeth = getattr(super(type_, self), method_name, None)
|
||||
if (supermeth is None or
|
||||
getattr(supermeth, 'im_func', None) is method):
|
||||
return fallback(self, *args, **kw)
|
||||
else:
|
||||
return supermeth(*args, **kw)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"%s.%s called in %s, line %s in %s" % (
|
||||
type_.__name__, method_name, module, frame_r[1], frame_r[2]))
|
||||
finally:
|
||||
del frame
|
||||
method.__name__ = method_name
|
||||
return method
|
||||
|
||||
def mapper(type_, *args, **kw):
|
||||
forbidden = [
|
||||
('__hash__', 'unhashable', lambda s: id(s)),
|
||||
('__eq__', 'noncomparable', lambda s, o: s is o),
|
||||
('__ne__', 'noncomparable', lambda s, o: s is not o),
|
||||
('__cmp__', 'noncomparable', lambda s, o: object.__cmp__(s, o)),
|
||||
('__le__', 'noncomparable', lambda s, o: object.__le__(s, o)),
|
||||
('__lt__', 'noncomparable', lambda s, o: object.__lt__(s, o)),
|
||||
('__ge__', 'noncomparable', lambda s, o: object.__ge__(s, o)),
|
||||
('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)),
|
||||
('__nonzero__', 'truthless', lambda s: 1), ]
|
||||
|
||||
if isinstance(type_, type) and type_.__bases__ == (object,):
|
||||
for method_name, option, fallback in forbidden:
|
||||
if (getattr(config.options, option, False) and
|
||||
method_name not in type_.__dict__):
|
||||
setattr(type_, method_name, _make_blocker(method_name, fallback))
|
||||
|
||||
return orm.mapper(type_, *args, **kw)
|
75
sqlalchemy/test/pickleable.py
Normal file
75
sqlalchemy/test/pickleable.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
|
||||
some objects used for pickle tests, declared in their own module so that they
|
||||
are easily pickleable.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class Foo(object):
|
||||
def __init__(self, moredata):
|
||||
self.data = 'im data'
|
||||
self.stuff = 'im stuff'
|
||||
self.moredata = moredata
|
||||
__hash__ = object.__hash__
|
||||
def __eq__(self, other):
|
||||
return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata
|
||||
|
||||
|
||||
class Bar(object):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
__hash__ = object.__hash__
|
||||
def __eq__(self, other):
|
||||
return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y
|
||||
def __str__(self):
|
||||
return "Bar(%d, %d)" % (self.x, self.y)
|
||||
|
||||
class OldSchool:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
def __eq__(self, other):
|
||||
return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y
|
||||
|
||||
class OldSchoolWithoutCompare:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
class BarWithoutCompare(object):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
def __str__(self):
|
||||
return "Bar(%d, %d)" % (self.x, self.y)
|
||||
|
||||
|
||||
class NotComparable(object):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class BrokenComparable(object):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
raise NotImplementedError
|
||||
|
||||
def __ne__(self, other):
|
||||
raise NotImplementedError
|
||||
|
222
sqlalchemy/test/profiling.py
Normal file
222
sqlalchemy/test/profiling.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Profiling support for unit and performance tests.
|
||||
|
||||
These are special purpose profiling methods which operate
|
||||
in a more fine-grained way than nose's profiling plugin.
|
||||
|
||||
"""
|
||||
|
||||
import os, sys
|
||||
from sqlalchemy.test import config
|
||||
from sqlalchemy.test.util import function_named, gc_collect
|
||||
from nose import SkipTest
|
||||
|
||||
__all__ = 'profiled', 'function_call_count', 'conditional_call_count'
|
||||
|
||||
all_targets = set()
|
||||
profile_config = { 'targets': set(),
|
||||
'report': True,
|
||||
'sort': ('time', 'calls'),
|
||||
'limit': None }
|
||||
profiler = None
|
||||
|
||||
def profiled(target=None, **target_opts):
|
||||
"""Optional function profiling.
|
||||
|
||||
@profiled('label')
|
||||
or
|
||||
@profiled('label', report=True, sort=('calls',), limit=20)
|
||||
|
||||
Enables profiling for a function when 'label' is targetted for
|
||||
profiling. Report options can be supplied, and override the global
|
||||
configuration and command-line options.
|
||||
"""
|
||||
|
||||
# manual or automatic namespacing by module would remove conflict issues
|
||||
if target is None:
|
||||
target = 'anonymous_target'
|
||||
elif target in all_targets:
|
||||
print "Warning: redefining profile target '%s'" % target
|
||||
all_targets.add(target)
|
||||
|
||||
filename = "%s.prof" % target
|
||||
|
||||
def decorator(fn):
|
||||
def profiled(*args, **kw):
|
||||
if (target not in profile_config['targets'] and
|
||||
not target_opts.get('always', None)):
|
||||
return fn(*args, **kw)
|
||||
|
||||
elapsed, load_stats, result = _profile(
|
||||
filename, fn, *args, **kw)
|
||||
|
||||
report = target_opts.get('report', profile_config['report'])
|
||||
if report:
|
||||
sort_ = target_opts.get('sort', profile_config['sort'])
|
||||
limit = target_opts.get('limit', profile_config['limit'])
|
||||
print "Profile report for target '%s' (%s)" % (
|
||||
target, filename)
|
||||
|
||||
stats = load_stats()
|
||||
stats.sort_stats(*sort_)
|
||||
if limit:
|
||||
stats.print_stats(limit)
|
||||
else:
|
||||
stats.print_stats()
|
||||
#stats.print_callers()
|
||||
os.unlink(filename)
|
||||
return result
|
||||
return function_named(profiled, fn.__name__)
|
||||
return decorator
|
||||
|
||||
def function_call_count(count=None, versions={}, variance=0.05):
|
||||
"""Assert a target for a test case's function call count.
|
||||
|
||||
count
|
||||
Optional, general target function call count.
|
||||
|
||||
versions
|
||||
Optional, a dictionary of Python version strings to counts,
|
||||
for example::
|
||||
|
||||
{ '2.5.1': 110,
|
||||
'2.5': 100,
|
||||
'2.4': 150 }
|
||||
|
||||
The best match for the current running python will be used.
|
||||
If none match, 'count' will be used as the fallback.
|
||||
|
||||
variance
|
||||
An +/- deviation percentage, defaults to 5%.
|
||||
"""
|
||||
|
||||
# this could easily dump the profile report if --verbose is in effect
|
||||
|
||||
version_info = list(sys.version_info)
|
||||
py_version = '.'.join([str(v) for v in sys.version_info])
|
||||
try:
|
||||
from sqlalchemy.cprocessors import to_float
|
||||
cextension = True
|
||||
except ImportError:
|
||||
cextension = False
|
||||
|
||||
while version_info:
|
||||
version = '.'.join([str(v) for v in version_info])
|
||||
if cextension:
|
||||
version += "+cextension"
|
||||
if version in versions:
|
||||
count = versions[version]
|
||||
break
|
||||
version_info.pop()
|
||||
|
||||
if count is None:
|
||||
return lambda fn: fn
|
||||
|
||||
def decorator(fn):
|
||||
def counted(*args, **kw):
|
||||
try:
|
||||
filename = "%s.prof" % fn.__name__
|
||||
|
||||
elapsed, stat_loader, result = _profile(
|
||||
filename, fn, *args, **kw)
|
||||
|
||||
stats = stat_loader()
|
||||
calls = stats.total_calls
|
||||
|
||||
stats.sort_stats('calls', 'cumulative')
|
||||
stats.print_stats()
|
||||
#stats.print_callers()
|
||||
deviance = int(count * variance)
|
||||
if (calls < (count - deviance) or
|
||||
calls > (count + deviance)):
|
||||
raise AssertionError(
|
||||
"Function call count %s not within %s%% "
|
||||
"of expected %s. (Python version %s)" % (
|
||||
calls, (variance * 100), count, py_version))
|
||||
|
||||
return result
|
||||
finally:
|
||||
if os.path.exists(filename):
|
||||
os.unlink(filename)
|
||||
return function_named(counted, fn.__name__)
|
||||
return decorator
|
||||
|
||||
def conditional_call_count(discriminator, categories):
|
||||
"""Apply a function call count conditionally at runtime.
|
||||
|
||||
Takes two arguments, a callable that returns a key value, and a dict
|
||||
mapping key values to a tuple of arguments to function_call_count.
|
||||
|
||||
The callable is not evaluated until the decorated function is actually
|
||||
invoked. If the `discriminator` returns a key not present in the
|
||||
`categories` dictionary, no call count assertion is applied.
|
||||
|
||||
Useful for integration tests, where running a named test in isolation may
|
||||
have a function count penalty not seen in the full suite, due to lazy
|
||||
initialization in the DB-API, SA, etc.
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
def at_runtime(*args, **kw):
|
||||
criteria = categories.get(discriminator(), None)
|
||||
if criteria is None:
|
||||
return fn(*args, **kw)
|
||||
|
||||
rewrapped = function_call_count(*criteria)(fn)
|
||||
return rewrapped(*args, **kw)
|
||||
return function_named(at_runtime, fn.__name__)
|
||||
return decorator
|
||||
|
||||
|
||||
def _profile(filename, fn, *args, **kw):
|
||||
global profiler
|
||||
if not profiler:
|
||||
if sys.version_info > (2, 5):
|
||||
try:
|
||||
import cProfile
|
||||
profiler = 'cProfile'
|
||||
except ImportError:
|
||||
pass
|
||||
if not profiler:
|
||||
try:
|
||||
import hotshot
|
||||
profiler = 'hotshot'
|
||||
except ImportError:
|
||||
profiler = 'skip'
|
||||
|
||||
if profiler == 'skip':
|
||||
raise SkipTest('Profiling not supported on this platform')
|
||||
elif profiler == 'cProfile':
|
||||
return _profile_cProfile(filename, fn, *args, **kw)
|
||||
else:
|
||||
return _profile_hotshot(filename, fn, *args, **kw)
|
||||
|
||||
def _profile_cProfile(filename, fn, *args, **kw):
|
||||
import cProfile, gc, pstats, time
|
||||
|
||||
load_stats = lambda: pstats.Stats(filename)
|
||||
gc_collect()
|
||||
|
||||
began = time.time()
|
||||
cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
|
||||
filename=filename)
|
||||
ended = time.time()
|
||||
|
||||
return ended - began, load_stats, locals()['result']
|
||||
|
||||
def _profile_hotshot(filename, fn, *args, **kw):
|
||||
import gc, hotshot, hotshot.stats, time
|
||||
load_stats = lambda: hotshot.stats.load(filename)
|
||||
|
||||
gc_collect()
|
||||
prof = hotshot.Profile(filename)
|
||||
began = time.time()
|
||||
prof.start()
|
||||
try:
|
||||
result = fn(*args, **kw)
|
||||
finally:
|
||||
prof.stop()
|
||||
ended = time.time()
|
||||
prof.close()
|
||||
|
||||
return ended - began, load_stats, result
|
||||
|
259
sqlalchemy/test/requires.py
Normal file
259
sqlalchemy/test/requires.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""Global database feature support policy.
|
||||
|
||||
Provides decorators to mark tests requiring specific feature support from the
|
||||
target database.
|
||||
|
||||
"""
|
||||
|
||||
from testing import \
|
||||
_block_unconditionally as no_support, \
|
||||
_chain_decorators_on, \
|
||||
exclude, \
|
||||
emits_warning_on,\
|
||||
skip_if,\
|
||||
fails_on
|
||||
|
||||
import testing
|
||||
import sys
|
||||
|
||||
def deferrable_constraints(fn):
|
||||
"""Target database must support derferable constraints."""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('firebird', 'not supported by database'),
|
||||
no_support('mysql', 'not supported by database'),
|
||||
no_support('mssql', 'not supported by database'),
|
||||
)
|
||||
|
||||
def foreign_keys(fn):
|
||||
"""Target database must support foreign keys."""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('sqlite', 'not supported by database'),
|
||||
)
|
||||
|
||||
|
||||
def unbounded_varchar(fn):
|
||||
"""Target database must support VARCHAR with no length"""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('firebird', 'not supported by database'),
|
||||
no_support('oracle', 'not supported by database'),
|
||||
no_support('mysql', 'not supported by database'),
|
||||
)
|
||||
|
||||
def boolean_col_expressions(fn):
|
||||
"""Target database must support boolean expressions as columns"""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('firebird', 'not supported by database'),
|
||||
no_support('oracle', 'not supported by database'),
|
||||
no_support('mssql', 'not supported by database'),
|
||||
no_support('sybase', 'not supported by database'),
|
||||
no_support('maxdb', 'FIXME: verify not supported by database'),
|
||||
)
|
||||
|
||||
def identity(fn):
|
||||
"""Target database must support GENERATED AS IDENTITY or a facsimile.
|
||||
|
||||
Includes GENERATED AS IDENTITY, AUTOINCREMENT, AUTO_INCREMENT, or other
|
||||
column DDL feature that fills in a DB-generated identifier at INSERT-time
|
||||
without requiring pre-execution of a SEQUENCE or other artifact.
|
||||
|
||||
"""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('firebird', 'not supported by database'),
|
||||
no_support('oracle', 'not supported by database'),
|
||||
no_support('postgresql', 'not supported by database'),
|
||||
no_support('sybase', 'not supported by database'),
|
||||
)
|
||||
|
||||
def independent_cursors(fn):
|
||||
"""Target must support simultaneous, independent database cursors on a single connection."""
|
||||
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('mssql+pyodbc', 'no driver support'),
|
||||
no_support('mssql+mxodbc', 'no driver support'),
|
||||
)
|
||||
|
||||
def independent_connections(fn):
|
||||
"""Target must support simultaneous, independent database connections."""
|
||||
|
||||
# This is also true of some configurations of UnixODBC and probably win32
|
||||
# ODBC as well.
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('sqlite', 'no driver support'),
|
||||
exclude('mssql', '<', (9, 0, 0),
|
||||
'SQL Server 2005+ is required for independent connections'),
|
||||
)
|
||||
|
||||
def row_triggers(fn):
|
||||
"""Target must support standard statement-running EACH ROW triggers."""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
# no access to same table
|
||||
no_support('mysql', 'requires SUPER priv'),
|
||||
exclude('mysql', '<', (5, 0, 10), 'not supported by database'),
|
||||
|
||||
# huh? TODO: implement triggers for PG tests, remove this
|
||||
no_support('postgresql', 'PG triggers need to be implemented for tests'),
|
||||
)
|
||||
|
||||
def correlated_outer_joins(fn):
|
||||
"""Target must support an outer join to a subquery which correlates to the parent."""
|
||||
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('oracle', 'Raises "ORA-01799: a column may not be outer-joined to a subquery"')
|
||||
)
|
||||
|
||||
def savepoints(fn):
|
||||
"""Target database must support savepoints."""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
emits_warning_on('mssql', 'Savepoint support in mssql is experimental and may lead to data loss.'),
|
||||
no_support('access', 'not supported by database'),
|
||||
no_support('sqlite', 'not supported by database'),
|
||||
no_support('sybase', 'FIXME: guessing, needs confirmation'),
|
||||
exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
|
||||
)
|
||||
|
||||
def denormalized_names(fn):
|
||||
"""Target database must have 'denormalized', i.e. UPPERCASE as case insensitive names."""
|
||||
|
||||
return skip_if(
|
||||
lambda: not testing.db.dialect.requires_name_normalize,
|
||||
"Backend does not require denomralized names."
|
||||
)(fn)
|
||||
|
||||
def schemas(fn):
|
||||
"""Target database must support external schemas, and have one named 'test_schema'."""
|
||||
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('sqlite', 'no schema support'),
|
||||
no_support('firebird', 'no schema support')
|
||||
)
|
||||
|
||||
def sequences(fn):
|
||||
"""Target database must support SEQUENCEs."""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('access', 'no SEQUENCE support'),
|
||||
no_support('mssql', 'no SEQUENCE support'),
|
||||
no_support('mysql', 'no SEQUENCE support'),
|
||||
no_support('sqlite', 'no SEQUENCE support'),
|
||||
no_support('sybase', 'no SEQUENCE support'),
|
||||
)
|
||||
|
||||
def subqueries(fn):
|
||||
"""Target database must support subqueries."""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
exclude('mysql', '<', (4, 1, 1), 'no subquery support'),
|
||||
)
|
||||
|
||||
def intersect(fn):
|
||||
"""Target database must support INTERSECT or equivlaent."""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
fails_on('firebird', 'no support for INTERSECT'),
|
||||
fails_on('mysql', 'no support for INTERSECT'),
|
||||
fails_on('sybase', 'no support for INTERSECT'),
|
||||
)
|
||||
|
||||
def except_(fn):
|
||||
"""Target database must support EXCEPT or equivlaent (i.e. MINUS)."""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
fails_on('firebird', 'no support for EXCEPT'),
|
||||
fails_on('mysql', 'no support for EXCEPT'),
|
||||
fails_on('sybase', 'no support for EXCEPT'),
|
||||
)
|
||||
|
||||
def offset(fn):
|
||||
"""Target database must support some method of adding OFFSET or equivalent to a result set."""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
fails_on('sybase', 'no support for OFFSET or equivalent'),
|
||||
)
|
||||
|
||||
def returning(fn):
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('access', 'not supported by database'),
|
||||
no_support('sqlite', 'not supported by database'),
|
||||
no_support('mysql', 'not supported by database'),
|
||||
no_support('maxdb', 'not supported by database'),
|
||||
no_support('sybase', 'not supported by database'),
|
||||
no_support('informix', 'not supported by database'),
|
||||
)
|
||||
|
||||
def two_phase_transactions(fn):
|
||||
"""Target database must support two-phase transactions."""
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('access', 'not supported by database'),
|
||||
no_support('firebird', 'no SA implementation'),
|
||||
no_support('maxdb', 'not supported by database'),
|
||||
no_support('mssql', 'FIXME: guessing, needs confirmation'),
|
||||
no_support('oracle', 'no SA implementation'),
|
||||
no_support('sqlite', 'not supported by database'),
|
||||
no_support('sybase', 'FIXME: guessing, needs confirmation'),
|
||||
no_support('postgresql+zxjdbc', 'FIXME: JDBC driver confuses the transaction state, may '
|
||||
'need separate XA implementation'),
|
||||
exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
|
||||
)
|
||||
|
||||
def unicode_connections(fn):
|
||||
"""Target driver must support some encoding of Unicode across the wire."""
|
||||
# TODO: expand to exclude MySQLdb versions w/ broken unicode
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
|
||||
)
|
||||
|
||||
def unicode_ddl(fn):
|
||||
"""Target driver must support some encoding of Unicode across the wire."""
|
||||
# TODO: expand to exclude MySQLdb versions w/ broken unicode
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
no_support('maxdb', 'database support flakey'),
|
||||
no_support('oracle', 'FIXME: no support in database?'),
|
||||
no_support('sybase', 'FIXME: guessing, needs confirmation'),
|
||||
no_support('mssql+pymssql', 'no FreeTDS support'),
|
||||
exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
|
||||
)
|
||||
|
||||
def sane_rowcount(fn):
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
skip_if(lambda: not testing.db.dialect.supports_sane_rowcount)
|
||||
)
|
||||
|
||||
def python2(fn):
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
skip_if(
|
||||
lambda: sys.version_info >= (3,),
|
||||
"Python version 2.xx is required."
|
||||
)
|
||||
)
|
||||
|
||||
def _has_sqlite():
|
||||
from sqlalchemy import create_engine
|
||||
try:
|
||||
e = create_engine('sqlite://')
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def sqlite(fn):
|
||||
return _chain_decorators_on(
|
||||
fn,
|
||||
skip_if(lambda: not _has_sqlite())
|
||||
)
|
||||
|
79
sqlalchemy/test/schema.py
Normal file
79
sqlalchemy/test/schema.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Enhanced versions of schema.Table and schema.Column which establish
|
||||
desired state for different backends.
|
||||
"""
|
||||
|
||||
from sqlalchemy.test import testing
|
||||
from sqlalchemy import schema
|
||||
|
||||
__all__ = 'Table', 'Column',
|
||||
|
||||
table_options = {}
|
||||
|
||||
def Table(*args, **kw):
|
||||
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
|
||||
if k.startswith('test_')])
|
||||
|
||||
kw.update(table_options)
|
||||
|
||||
if testing.against('mysql'):
|
||||
if 'mysql_engine' not in kw and 'mysql_type' not in kw:
|
||||
if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
|
||||
kw['mysql_engine'] = 'InnoDB'
|
||||
|
||||
# Apply some default cascading rules for self-referential foreign keys.
|
||||
# MySQL InnoDB has some issues around seleting self-refs too.
|
||||
if testing.against('firebird'):
|
||||
table_name = args[0]
|
||||
unpack = (testing.config.db.dialect.
|
||||
identifier_preparer.unformat_identifiers)
|
||||
|
||||
# Only going after ForeignKeys in Columns. May need to
|
||||
# expand to ForeignKeyConstraint too.
|
||||
fks = [fk
|
||||
for col in args if isinstance(col, schema.Column)
|
||||
for fk in col.foreign_keys]
|
||||
|
||||
for fk in fks:
|
||||
# root around in raw spec
|
||||
ref = fk._colspec
|
||||
if isinstance(ref, schema.Column):
|
||||
name = ref.table.name
|
||||
else:
|
||||
# take just the table name: on FB there cannot be
|
||||
# a schema, so the first element is always the
|
||||
# table name, possibly followed by the field name
|
||||
name = unpack(ref)[0]
|
||||
if name == table_name:
|
||||
if fk.ondelete is None:
|
||||
fk.ondelete = 'CASCADE'
|
||||
if fk.onupdate is None:
|
||||
fk.onupdate = 'CASCADE'
|
||||
|
||||
return schema.Table(*args, **kw)
|
||||
|
||||
|
||||
def Column(*args, **kw):
|
||||
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
|
||||
if k.startswith('test_')])
|
||||
|
||||
col = schema.Column(*args, **kw)
|
||||
if 'test_needs_autoincrement' in test_opts and \
|
||||
kw.get('primary_key', False) and \
|
||||
testing.against('firebird', 'oracle'):
|
||||
def add_seq(tbl, c):
|
||||
c._init_items(
|
||||
schema.Sequence(_truncate_name(testing.db.dialect, tbl.name + '_' + c.name + '_seq'), optional=True)
|
||||
)
|
||||
col._on_table_attach(add_seq)
|
||||
return col
|
||||
|
||||
def _truncate_name(dialect, name):
|
||||
if len(name) > dialect.max_identifier_length:
|
||||
return name[0:max(dialect.max_identifier_length - 6, 0)] + "_" + hex(hash(name) % 64)[2:]
|
||||
else:
|
||||
return name
|
||||
|
779
sqlalchemy/test/testing.py
Normal file
779
sqlalchemy/test/testing.py
Normal file
@@ -0,0 +1,779 @@
|
||||
"""TestCase and TestSuite artifacts and testing decorators."""
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import warnings
|
||||
from cStringIO import StringIO
|
||||
|
||||
from sqlalchemy.test import config, assertsql, util as testutil
|
||||
from sqlalchemy.util import function_named, py3k
|
||||
from engines import drop_all_tables
|
||||
|
||||
from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, pool, orm
|
||||
from sqlalchemy.engine import default
|
||||
from nose import SkipTest
|
||||
|
||||
|
||||
_ops = { '<': operator.lt,
|
||||
'>': operator.gt,
|
||||
'==': operator.eq,
|
||||
'!=': operator.ne,
|
||||
'<=': operator.le,
|
||||
'>=': operator.ge,
|
||||
'in': operator.contains,
|
||||
'between': lambda val, pair: val >= pair[0] and val <= pair[1],
|
||||
}
|
||||
|
||||
# sugar ('testing.db'); set here by config() at runtime
|
||||
db = None
|
||||
|
||||
# more sugar, installed by __init__
|
||||
requires = None
|
||||
|
||||
def fails_if(callable_, reason=None):
|
||||
"""Mark a test as expected to fail if callable_ returns True.
|
||||
|
||||
If the callable returns false, the test is run and reported as normal.
|
||||
However if the callable returns true, the test is expected to fail and the
|
||||
unit test logic is inverted: if the test fails, a success is reported. If
|
||||
the test succeeds, a failure is reported.
|
||||
"""
|
||||
|
||||
docstring = getattr(callable_, '__doc__', None) or callable_.__name__
|
||||
description = docstring.split('\n')[0]
|
||||
|
||||
def decorate(fn):
|
||||
fn_name = fn.__name__
|
||||
def maybe(*args, **kw):
|
||||
if not callable_():
|
||||
return fn(*args, **kw)
|
||||
else:
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
except Exception, ex:
|
||||
print ("'%s' failed as expected (condition: %s): %s " % (
|
||||
fn_name, description, str(ex)))
|
||||
return True
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Unexpected success for '%s' (condition: %s)" %
|
||||
(fn_name, description))
|
||||
return function_named(maybe, fn_name)
|
||||
return decorate
|
||||
|
||||
|
||||
def future(fn):
|
||||
"""Mark a test as expected to unconditionally fail.
|
||||
|
||||
Takes no arguments, omit parens when using as a decorator.
|
||||
"""
|
||||
|
||||
fn_name = fn.__name__
|
||||
def decorated(*args, **kw):
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
except Exception, ex:
|
||||
print ("Future test '%s' failed as expected: %s " % (
|
||||
fn_name, str(ex)))
|
||||
return True
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Unexpected success for future test '%s'" % fn_name)
|
||||
return function_named(decorated, fn_name)
|
||||
|
||||
def db_spec(*dbs):
|
||||
dialects = set([x for x in dbs if '+' not in x])
|
||||
drivers = set([x[1:] for x in dbs if x.startswith('+')])
|
||||
specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers])
|
||||
|
||||
def check(engine):
|
||||
return engine.name in dialects or \
|
||||
engine.driver in drivers or \
|
||||
(engine.name, engine.driver) in specs
|
||||
|
||||
return check
|
||||
|
||||
|
||||
def fails_on(dbs, reason):
|
||||
"""Mark a test as expected to fail on the specified database
|
||||
implementation.
|
||||
|
||||
Unlike ``crashes``, tests marked as ``fails_on`` will be run
|
||||
for the named databases. The test is expected to fail and the unit test
|
||||
logic is inverted: if the test fails, a success is reported. If the test
|
||||
succeeds, a failure is reported.
|
||||
"""
|
||||
|
||||
spec = db_spec(dbs)
|
||||
|
||||
def decorate(fn):
|
||||
fn_name = fn.__name__
|
||||
def maybe(*args, **kw):
|
||||
if not spec(config.db):
|
||||
return fn(*args, **kw)
|
||||
else:
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
except Exception, ex:
|
||||
print ("'%s' failed as expected on DB implementation "
|
||||
"'%s+%s': %s" % (
|
||||
fn_name, config.db.name, config.db.driver, reason))
|
||||
return True
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Unexpected success for '%s' on DB implementation '%s+%s'" %
|
||||
(fn_name, config.db.name, config.db.driver))
|
||||
return function_named(maybe, fn_name)
|
||||
return decorate
|
||||
|
||||
def fails_on_everything_except(*dbs):
|
||||
"""Mark a test as expected to fail on most database implementations.
|
||||
|
||||
Like ``fails_on``, except failure is the expected outcome on all
|
||||
databases except those listed.
|
||||
"""
|
||||
|
||||
spec = db_spec(*dbs)
|
||||
|
||||
def decorate(fn):
|
||||
fn_name = fn.__name__
|
||||
def maybe(*args, **kw):
|
||||
if spec(config.db):
|
||||
return fn(*args, **kw)
|
||||
else:
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
except Exception, ex:
|
||||
print ("'%s' failed as expected on DB implementation "
|
||||
"'%s+%s': %s" % (
|
||||
fn_name, config.db.name, config.db.driver, str(ex)))
|
||||
return True
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Unexpected success for '%s' on DB implementation '%s+%s'" %
|
||||
(fn_name, config.db.name, config.db.driver))
|
||||
return function_named(maybe, fn_name)
|
||||
return decorate
|
||||
|
||||
def crashes(db, reason):
|
||||
"""Mark a test as unsupported by a database implementation.
|
||||
|
||||
``crashes`` tests will be skipped unconditionally. Use for feature tests
|
||||
that cause deadlocks or other fatal problems.
|
||||
|
||||
"""
|
||||
carp = _should_carp_about_exclusion(reason)
|
||||
spec = db_spec(db)
|
||||
def decorate(fn):
|
||||
fn_name = fn.__name__
|
||||
def maybe(*args, **kw):
|
||||
if spec(config.db):
|
||||
msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
|
||||
fn_name, config.db.name, config.db.driver, reason)
|
||||
print msg
|
||||
if carp:
|
||||
print >> sys.stderr, msg
|
||||
return True
|
||||
else:
|
||||
return fn(*args, **kw)
|
||||
return function_named(maybe, fn_name)
|
||||
return decorate
|
||||
|
||||
def _block_unconditionally(db, reason):
|
||||
"""Mark a test as unsupported by a database implementation.
|
||||
|
||||
Will never run the test against any version of the given database, ever,
|
||||
no matter what. Use when your assumptions are infallible; past, present
|
||||
and future.
|
||||
|
||||
"""
|
||||
carp = _should_carp_about_exclusion(reason)
|
||||
spec = db_spec(db)
|
||||
def decorate(fn):
|
||||
fn_name = fn.__name__
|
||||
def maybe(*args, **kw):
|
||||
if spec(config.db):
|
||||
msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
|
||||
fn_name, config.db.name, config.db.driver, reason)
|
||||
print msg
|
||||
if carp:
|
||||
print >> sys.stderr, msg
|
||||
return True
|
||||
else:
|
||||
return fn(*args, **kw)
|
||||
return function_named(maybe, fn_name)
|
||||
return decorate
|
||||
|
||||
def only_on(db, reason):
|
||||
carp = _should_carp_about_exclusion(reason)
|
||||
spec = db_spec(db)
|
||||
def decorate(fn):
|
||||
fn_name = fn.__name__
|
||||
def maybe(*args, **kw):
|
||||
if spec(config.db):
|
||||
return fn(*args, **kw)
|
||||
else:
|
||||
msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
|
||||
fn_name, config.db.name, config.db.driver, reason)
|
||||
print msg
|
||||
if carp:
|
||||
print >> sys.stderr, msg
|
||||
return True
|
||||
return function_named(maybe, fn_name)
|
||||
return decorate
|
||||
|
||||
def exclude(db, op, spec, reason):
|
||||
"""Mark a test as unsupported by specific database server versions.
|
||||
|
||||
Stackable, both with other excludes and other decorators. Examples::
|
||||
|
||||
# Not supported by mydb versions less than 1, 0
|
||||
@exclude('mydb', '<', (1,0))
|
||||
# Other operators work too
|
||||
@exclude('bigdb', '==', (9,0,9))
|
||||
@exclude('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
|
||||
|
||||
"""
|
||||
carp = _should_carp_about_exclusion(reason)
|
||||
|
||||
def decorate(fn):
|
||||
fn_name = fn.__name__
|
||||
def maybe(*args, **kw):
|
||||
if _is_excluded(db, op, spec):
|
||||
msg = "'%s' unsupported on DB %s version '%s': %s" % (
|
||||
fn_name, config.db.name, _server_version(), reason)
|
||||
print msg
|
||||
if carp:
|
||||
print >> sys.stderr, msg
|
||||
return True
|
||||
else:
|
||||
return fn(*args, **kw)
|
||||
return function_named(maybe, fn_name)
|
||||
return decorate
|
||||
|
||||
def _should_carp_about_exclusion(reason):
|
||||
"""Guard against forgotten exclusions."""
|
||||
assert reason
|
||||
for _ in ('todo', 'fixme', 'xxx'):
|
||||
if _ in reason.lower():
|
||||
return True
|
||||
else:
|
||||
if len(reason) < 4:
|
||||
return True
|
||||
|
||||
def _is_excluded(db, op, spec):
|
||||
"""Return True if the configured db matches an exclusion specification.
|
||||
|
||||
db:
|
||||
A dialect name
|
||||
op:
|
||||
An operator or stringified operator, such as '=='
|
||||
spec:
|
||||
A value that will be compared to the dialect's server_version_info
|
||||
using the supplied operator.
|
||||
|
||||
Examples::
|
||||
# Not supported by mydb versions less than 1, 0
|
||||
_is_excluded('mydb', '<', (1,0))
|
||||
# Other operators work too
|
||||
_is_excluded('bigdb', '==', (9,0,9))
|
||||
_is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
|
||||
"""
|
||||
|
||||
vendor_spec = db_spec(db)
|
||||
|
||||
if not vendor_spec(config.db):
|
||||
return False
|
||||
|
||||
version = _server_version()
|
||||
|
||||
oper = hasattr(op, '__call__') and op or _ops[op]
|
||||
return oper(version, spec)
|
||||
|
||||
def _server_version(bind=None):
|
||||
"""Return a server_version_info tuple."""
|
||||
|
||||
if bind is None:
|
||||
bind = config.db
|
||||
|
||||
# force metadata to be retrieved
|
||||
conn = bind.connect()
|
||||
version = getattr(bind.dialect, 'server_version_info', ())
|
||||
conn.close()
|
||||
return version
|
||||
|
||||
def skip_if(predicate, reason=None):
|
||||
"""Skip a test if predicate is true."""
|
||||
reason = reason or predicate.__name__
|
||||
carp = _should_carp_about_exclusion(reason)
|
||||
|
||||
def decorate(fn):
|
||||
fn_name = fn.__name__
|
||||
def maybe(*args, **kw):
|
||||
if predicate():
|
||||
msg = "'%s' skipped on DB %s version '%s': %s" % (
|
||||
fn_name, config.db.name, _server_version(), reason)
|
||||
print msg
|
||||
if carp:
|
||||
print >> sys.stderr, msg
|
||||
return True
|
||||
else:
|
||||
return fn(*args, **kw)
|
||||
return function_named(maybe, fn_name)
|
||||
return decorate
|
||||
|
||||
def emits_warning(*messages):
|
||||
"""Mark a test as emitting a warning.
|
||||
|
||||
With no arguments, squelches all SAWarning failures. Or pass one or more
|
||||
strings; these will be matched to the root of the warning description by
|
||||
warnings.filterwarnings().
|
||||
"""
|
||||
|
||||
# TODO: it would be nice to assert that a named warning was
|
||||
# emitted. should work with some monkeypatching of warnings,
|
||||
# and may work on non-CPython if they keep to the spirit of
|
||||
# warnings.showwarning's docstring.
|
||||
# - update: jython looks ok, it uses cpython's module
|
||||
def decorate(fn):
|
||||
def safe(*args, **kw):
|
||||
# todo: should probably be strict about this, too
|
||||
filters = [dict(action='ignore',
|
||||
category=sa_exc.SAPendingDeprecationWarning)]
|
||||
if not messages:
|
||||
filters.append(dict(action='ignore',
|
||||
category=sa_exc.SAWarning))
|
||||
else:
|
||||
filters.extend(dict(action='ignore',
|
||||
message=message,
|
||||
category=sa_exc.SAWarning)
|
||||
for message in messages)
|
||||
for f in filters:
|
||||
warnings.filterwarnings(**f)
|
||||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
resetwarnings()
|
||||
return function_named(safe, fn.__name__)
|
||||
return decorate
|
||||
|
||||
def emits_warning_on(db, *warnings):
|
||||
"""Mark a test as emitting a warning on a specific dialect.
|
||||
|
||||
With no arguments, squelches all SAWarning failures. Or pass one or more
|
||||
strings; these will be matched to the root of the warning description by
|
||||
warnings.filterwarnings().
|
||||
"""
|
||||
spec = db_spec(db)
|
||||
|
||||
def decorate(fn):
|
||||
def maybe(*args, **kw):
|
||||
if isinstance(db, basestring):
|
||||
if not spec(config.db):
|
||||
return fn(*args, **kw)
|
||||
else:
|
||||
wrapped = emits_warning(*warnings)(fn)
|
||||
return wrapped(*args, **kw)
|
||||
else:
|
||||
if not _is_excluded(*db):
|
||||
return fn(*args, **kw)
|
||||
else:
|
||||
wrapped = emits_warning(*warnings)(fn)
|
||||
return wrapped(*args, **kw)
|
||||
return function_named(maybe, fn.__name__)
|
||||
return decorate
|
||||
|
||||
def uses_deprecated(*messages):
|
||||
"""Mark a test as immune from fatal deprecation warnings.
|
||||
|
||||
With no arguments, squelches all SADeprecationWarning failures.
|
||||
Or pass one or more strings; these will be matched to the root
|
||||
of the warning description by warnings.filterwarnings().
|
||||
|
||||
As a special case, you may pass a function name prefixed with //
|
||||
and it will be re-written as needed to match the standard warning
|
||||
verbiage emitted by the sqlalchemy.util.deprecated decorator.
|
||||
"""
|
||||
|
||||
def decorate(fn):
|
||||
def safe(*args, **kw):
|
||||
# todo: should probably be strict about this, too
|
||||
filters = [dict(action='ignore',
|
||||
category=sa_exc.SAPendingDeprecationWarning)]
|
||||
if not messages:
|
||||
filters.append(dict(action='ignore',
|
||||
category=sa_exc.SADeprecationWarning))
|
||||
else:
|
||||
filters.extend(
|
||||
[dict(action='ignore',
|
||||
message=message,
|
||||
category=sa_exc.SADeprecationWarning)
|
||||
for message in
|
||||
[ (m.startswith('//') and
|
||||
('Call to deprecated function ' + m[2:]) or m)
|
||||
for m in messages] ])
|
||||
|
||||
for f in filters:
|
||||
warnings.filterwarnings(**f)
|
||||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
resetwarnings()
|
||||
return function_named(safe, fn.__name__)
|
||||
return decorate
|
||||
|
||||
def resetwarnings():
|
||||
"""Reset warning behavior to testing defaults."""
|
||||
|
||||
warnings.filterwarnings('ignore',
|
||||
category=sa_exc.SAPendingDeprecationWarning)
|
||||
warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
|
||||
warnings.filterwarnings('error', category=sa_exc.SAWarning)
|
||||
|
||||
# warnings.simplefilter('error')
|
||||
|
||||
if sys.version_info < (2, 4):
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
|
||||
def global_cleanup_assertions():
|
||||
"""Check things that have to be finalized at the end of a test suite.
|
||||
|
||||
Hardcoded at the moment, a modular system can be built here
|
||||
to support things like PG prepared transactions, tables all
|
||||
dropped, etc.
|
||||
|
||||
"""
|
||||
|
||||
testutil.lazy_gc()
|
||||
assert not pool._refs
|
||||
|
||||
|
||||
|
||||
def against(*queries):
|
||||
"""Boolean predicate, compares to testing database configuration.
|
||||
|
||||
Given one or more dialect names, returns True if one is the configured
|
||||
database engine.
|
||||
|
||||
Also supports comparison to database version when provided with one or
|
||||
more 3-tuples of dialect name, operator, and version specification::
|
||||
|
||||
testing.against('mysql', 'postgresql')
|
||||
testing.against(('mysql', '>=', (5, 0, 0))
|
||||
"""
|
||||
|
||||
for query in queries:
|
||||
if isinstance(query, basestring):
|
||||
if db_spec(query)(config.db):
|
||||
return True
|
||||
else:
|
||||
name, op, spec = query
|
||||
if not db_spec(name)(config.db):
|
||||
continue
|
||||
|
||||
have = _server_version()
|
||||
|
||||
oper = hasattr(op, '__call__') and op or _ops[op]
|
||||
if oper(have, spec):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _chain_decorators_on(fn, *decorators):
|
||||
"""Apply a series of decorators to fn, returning a decorated function."""
|
||||
for decorator in reversed(decorators):
|
||||
fn = decorator(fn)
|
||||
return fn
|
||||
|
||||
def rowset(results):
|
||||
"""Converts the results of sql execution into a plain set of column tuples.
|
||||
|
||||
Useful for asserting the results of an unordered query.
|
||||
"""
|
||||
|
||||
return set([tuple(row) for row in results])
|
||||
|
||||
|
||||
def eq_(a, b, msg=None):
|
||||
"""Assert a == b, with repr messaging on failure."""
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
def ne_(a, b, msg=None):
|
||||
"""Assert a != b, with repr messaging on failure."""
|
||||
assert a != b, msg or "%r == %r" % (a, b)
|
||||
|
||||
def is_(a, b, msg=None):
|
||||
"""Assert a is b, with repr messaging on failure."""
|
||||
assert a is b, msg or "%r is not %r" % (a, b)
|
||||
|
||||
def is_not_(a, b, msg=None):
|
||||
"""Assert a is not b, with repr messaging on failure."""
|
||||
assert a is not b, msg or "%r is %r" % (a, b)
|
||||
|
||||
def startswith_(a, fragment, msg=None):
|
||||
"""Assert a.startswith(fragment), with repr messaging on failure."""
|
||||
assert a.startswith(fragment), msg or "%r does not start with %r" % (
|
||||
a, fragment)
|
||||
|
||||
def assert_raises(except_cls, callable_, *args, **kw):
|
||||
try:
|
||||
callable_(*args, **kw)
|
||||
success = False
|
||||
except except_cls, e:
|
||||
success = True
|
||||
|
||||
# assert outside the block so it works for AssertionError too !
|
||||
assert success, "Callable did not raise an exception"
|
||||
|
||||
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
|
||||
try:
|
||||
callable_(*args, **kwargs)
|
||||
assert False, "Callable did not raise an exception"
|
||||
except except_cls, e:
|
||||
assert re.search(msg, str(e)), "%r !~ %s" % (msg, e)
|
||||
|
||||
def fail(msg):
|
||||
assert False, msg
|
||||
|
||||
def fixture(table, columns, *rows):
|
||||
"""Insert data into table after creation."""
|
||||
def onload(event, schema_item, connection):
|
||||
insert = table.insert()
|
||||
column_names = [col.key for col in columns]
|
||||
connection.execute(insert, [dict(zip(column_names, column_values))
|
||||
for column_values in rows])
|
||||
table.append_ddl_listener('after-create', onload)
|
||||
|
||||
def resolve_artifact_names(fn):
|
||||
"""Decorator, augment function globals with tables and classes.
|
||||
|
||||
Swaps out the function's globals at execution time. The 'global' statement
|
||||
will not work as expected inside a decorated function.
|
||||
|
||||
"""
|
||||
# This could be automatically applied to framework and test_ methods in
|
||||
# the MappedTest-derived test suites but... *some* explicitness for this
|
||||
# magic is probably good. Especially as 'global' won't work- these
|
||||
# rebound functions aren't regular Python..
|
||||
#
|
||||
# Also: it's lame that CPython accepts a dict-subclass for globals, but
|
||||
# only calls dict methods. That would allow 'global' to pass through to
|
||||
# the func_globals.
|
||||
def resolved(*args, **kwargs):
|
||||
self = args[0]
|
||||
context = dict(fn.func_globals)
|
||||
for source in self._artifact_registries:
|
||||
context.update(getattr(self, source))
|
||||
# jython bug #1034
|
||||
rebound = types.FunctionType(
|
||||
fn.func_code, context, fn.func_name, fn.func_defaults,
|
||||
fn.func_closure)
|
||||
return rebound(*args, **kwargs)
|
||||
return function_named(resolved, fn.func_name)
|
||||
|
||||
class adict(dict):
|
||||
"""Dict keys available as attributes. Shadows."""
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return dict.__getattribute__(self, key)
|
||||
|
||||
def get_all(self, *keys):
|
||||
return tuple([self[key] for key in keys])
|
||||
|
||||
|
||||
class TestBase(object):
|
||||
# A sequence of database names to always run, regardless of the
|
||||
# constraints below.
|
||||
__whitelist__ = ()
|
||||
|
||||
# A sequence of requirement names matching testing.requires decorators
|
||||
__requires__ = ()
|
||||
|
||||
# A sequence of dialect names to exclude from the test class.
|
||||
__unsupported_on__ = ()
|
||||
|
||||
# If present, test class is only runnable for the *single* specified
|
||||
# dialect. If you need multiple, use __unsupported_on__ and invert.
|
||||
__only_on__ = None
|
||||
|
||||
# A sequence of no-arg callables. If any are True, the entire testcase is
|
||||
# skipped.
|
||||
__skip_if__ = None
|
||||
|
||||
_artifact_registries = ()
|
||||
|
||||
def assert_(self, val, msg=None):
|
||||
assert val, msg
|
||||
|
||||
class AssertsCompiledSQL(object):
|
||||
def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, use_default_dialect=False):
|
||||
if use_default_dialect:
|
||||
dialect = default.DefaultDialect()
|
||||
|
||||
if dialect is None:
|
||||
dialect = getattr(self, '__dialect__', None)
|
||||
|
||||
kw = {}
|
||||
if params is not None:
|
||||
kw['column_keys'] = params.keys()
|
||||
|
||||
if isinstance(clause, orm.Query):
|
||||
context = clause._compile_context()
|
||||
context.statement.use_labels = True
|
||||
clause = context.statement
|
||||
|
||||
c = clause.compile(dialect=dialect, **kw)
|
||||
|
||||
param_str = repr(getattr(c, 'params', {}))
|
||||
# Py3K
|
||||
#param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
|
||||
|
||||
print "\nSQL String:\n" + str(c) + param_str
|
||||
|
||||
cc = re.sub(r'[\n\t]', '', str(c))
|
||||
|
||||
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
|
||||
|
||||
if checkparams is not None:
|
||||
eq_(c.construct_params(params), checkparams)
|
||||
|
||||
class ComparesTables(object):
|
||||
def assert_tables_equal(self, table, reflected_table, strict_types=False):
|
||||
assert len(table.c) == len(reflected_table.c)
|
||||
for c, reflected_c in zip(table.c, reflected_table.c):
|
||||
eq_(c.name, reflected_c.name)
|
||||
assert reflected_c is reflected_table.c[c.name]
|
||||
eq_(c.primary_key, reflected_c.primary_key)
|
||||
eq_(c.nullable, reflected_c.nullable)
|
||||
|
||||
if strict_types:
|
||||
assert type(reflected_c.type) is type(c.type), \
|
||||
"Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
|
||||
else:
|
||||
self.assert_types_base(reflected_c, c)
|
||||
|
||||
if isinstance(c.type, sqltypes.String):
|
||||
eq_(c.type.length, reflected_c.type.length)
|
||||
|
||||
eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
|
||||
if c.server_default:
|
||||
assert isinstance(reflected_c.server_default,
|
||||
schema.FetchedValue)
|
||||
|
||||
assert len(table.primary_key) == len(reflected_table.primary_key)
|
||||
for c in table.primary_key:
|
||||
assert reflected_table.primary_key.columns[c.name] is not None
|
||||
|
||||
def assert_types_base(self, c1, c2):
|
||||
assert c1.type._compare_type_affinity(c2.type),\
|
||||
"On column %r, type '%s' doesn't correspond to type '%s'" % \
|
||||
(c1.name, c1.type, c2.type)
|
||||
|
||||
class AssertsExecutionResults(object):
|
||||
def assert_result(self, result, class_, *objects):
|
||||
result = list(result)
|
||||
print repr(result)
|
||||
self.assert_list(result, class_, objects)
|
||||
|
||||
def assert_list(self, result, class_, list):
|
||||
self.assert_(len(result) == len(list),
|
||||
"result list is not the same size as test list, " +
|
||||
"for class " + class_.__name__)
|
||||
for i in range(0, len(list)):
|
||||
self.assert_row(class_, result[i], list[i])
|
||||
|
||||
def assert_row(self, class_, rowobj, desc):
|
||||
self.assert_(rowobj.__class__ is class_,
|
||||
"item class is not " + repr(class_))
|
||||
for key, value in desc.iteritems():
|
||||
if isinstance(value, tuple):
|
||||
if isinstance(value[1], list):
|
||||
self.assert_list(getattr(rowobj, key), value[0], value[1])
|
||||
else:
|
||||
self.assert_row(value[0], getattr(rowobj, key), value[1])
|
||||
else:
|
||||
self.assert_(getattr(rowobj, key) == value,
|
||||
"attribute %s value %s does not match %s" % (
|
||||
key, getattr(rowobj, key), value))
|
||||
|
||||
def assert_unordered_result(self, result, cls, *expected):
|
||||
"""As assert_result, but the order of objects is not considered.
|
||||
|
||||
The algorithm is very expensive but not a big deal for the small
|
||||
numbers of rows that the test suite manipulates.
|
||||
"""
|
||||
|
||||
class frozendict(dict):
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
found = util.IdentitySet(result)
|
||||
expected = set([frozendict(e) for e in expected])
|
||||
|
||||
for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
|
||||
fail('Unexpected type "%s", expected "%s"' % (
|
||||
type(wrong).__name__, cls.__name__))
|
||||
|
||||
if len(found) != len(expected):
|
||||
fail('Unexpected object count "%s", expected "%s"' % (
|
||||
len(found), len(expected)))
|
||||
|
||||
NOVALUE = object()
|
||||
def _compare_item(obj, spec):
|
||||
for key, value in spec.iteritems():
|
||||
if isinstance(value, tuple):
|
||||
try:
|
||||
self.assert_unordered_result(
|
||||
getattr(obj, key), value[0], *value[1])
|
||||
except AssertionError:
|
||||
return False
|
||||
else:
|
||||
if getattr(obj, key, NOVALUE) != value:
|
||||
return False
|
||||
return True
|
||||
|
||||
for expected_item in expected:
|
||||
for found_item in found:
|
||||
if _compare_item(found_item, expected_item):
|
||||
found.remove(found_item)
|
||||
break
|
||||
else:
|
||||
fail(
|
||||
"Expected %s instance with attributes %s not found." % (
|
||||
cls.__name__, repr(expected_item)))
|
||||
return True
|
||||
|
||||
def assert_sql_execution(self, db, callable_, *rules):
|
||||
assertsql.asserter.add_rules(rules)
|
||||
try:
|
||||
callable_()
|
||||
assertsql.asserter.statement_complete()
|
||||
finally:
|
||||
assertsql.asserter.clear_rules()
|
||||
|
||||
def assert_sql(self, db, callable_, list_, with_sequences=None):
|
||||
if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'):
|
||||
rules = with_sequences
|
||||
else:
|
||||
rules = list_
|
||||
|
||||
newrules = []
|
||||
for rule in rules:
|
||||
if isinstance(rule, dict):
|
||||
newrule = assertsql.AllOf(*[
|
||||
assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
|
||||
])
|
||||
else:
|
||||
newrule = assertsql.ExactSQL(*rule)
|
||||
newrules.append(newrule)
|
||||
|
||||
self.assert_sql_execution(db, callable_, *newrules)
|
||||
|
||||
def assert_sql_count(self, db, callable_, count):
|
||||
self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
|
||||
|
||||
|
53
sqlalchemy/test/util.py
Normal file
53
sqlalchemy/test/util.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from sqlalchemy.util import jython, function_named
|
||||
|
||||
import gc
|
||||
import time
|
||||
|
||||
if jython:
|
||||
def gc_collect(*args):
|
||||
"""aggressive gc.collect for tests."""
|
||||
gc.collect()
|
||||
time.sleep(0.1)
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
return 0
|
||||
|
||||
# "lazy" gc, for VM's that don't GC on refcount == 0
|
||||
lazy_gc = gc_collect
|
||||
|
||||
else:
|
||||
# assume CPython - straight gc.collect, lazy_gc() is a pass
|
||||
gc_collect = gc.collect
|
||||
def lazy_gc():
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def picklers():
|
||||
picklers = set()
|
||||
# Py2K
|
||||
try:
|
||||
import cPickle
|
||||
picklers.add(cPickle)
|
||||
except ImportError:
|
||||
pass
|
||||
# end Py2K
|
||||
import pickle
|
||||
picklers.add(pickle)
|
||||
|
||||
# yes, this thing needs this much testing
|
||||
for pickle in picklers:
|
||||
for protocol in -1, 0, 1, 2:
|
||||
yield pickle.loads, lambda d:pickle.dumps(d, protocol)
|
||||
|
||||
|
||||
def round_decimal(value, prec):
|
||||
if isinstance(value, float):
|
||||
return round(value, prec)
|
||||
|
||||
import decimal
|
||||
|
||||
# can also use shift() here but that is 2.6 only
|
||||
return (value * decimal.Decimal("1" + "0" * prec)).to_integral(decimal.ROUND_FLOOR) / \
|
||||
pow(10, prec)
|
||||
|
Reference in New Issue
Block a user