Updated SqlAlchemy + the new files
This commit is contained in:
36
sqlalchemy/testing/__init__.py
Normal file
36
sqlalchemy/testing/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# testing/__init__.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
from .warnings import assert_warnings
|
||||
|
||||
from . import config
|
||||
|
||||
from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
|
||||
fails_on, fails_on_everything_except, skip, only_on, exclude, \
|
||||
against as _against, _server_version, only_if, fails
|
||||
|
||||
|
||||
def against(*queries):
|
||||
return _against(config._current, *queries)
|
||||
|
||||
from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
|
||||
eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \
|
||||
assert_raises_message, AssertsCompiledSQL, ComparesTables, \
|
||||
AssertsExecutionResults, expect_deprecated, expect_warnings, \
|
||||
in_, not_in_, eq_ignore_whitespace, eq_regex, is_true, is_false
|
||||
|
||||
from .util import run_as_contextmanager, rowset, fail, \
|
||||
provide_metadata, adict, force_drop_names, \
|
||||
teardown_events
|
||||
|
||||
crashes = skip
|
||||
|
||||
from .config import db
|
||||
from .config import requirements as requires
|
||||
|
||||
from . import mock
|
520
sqlalchemy/testing/assertions.py
Normal file
520
sqlalchemy/testing/assertions.py
Normal file
@@ -0,0 +1,520 @@
|
||||
# testing/assertions.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from . import util as testutil
|
||||
from sqlalchemy import pool, orm, util
|
||||
from sqlalchemy.engine import default, url
|
||||
from sqlalchemy.util import decorator, compat
|
||||
from sqlalchemy import types as sqltypes, schema, exc as sa_exc
|
||||
import warnings
|
||||
import re
|
||||
from .exclusions import db_spec
|
||||
from . import assertsql
|
||||
from . import config
|
||||
from .util import fail
|
||||
import contextlib
|
||||
from . import mock
|
||||
|
||||
|
||||
def expect_warnings(*messages, **kw):
|
||||
"""Context manager which expects one or more warnings.
|
||||
|
||||
With no arguments, squelches all SAWarnings emitted via
|
||||
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
|
||||
pass string expressions that will match selected warnings via regex;
|
||||
all non-matching warnings are sent through.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
Note that the test suite sets SAWarning warnings to raise exceptions.
|
||||
|
||||
"""
|
||||
return _expect_warnings(sa_exc.SAWarning, messages, **kw)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def expect_warnings_on(db, *messages, **kw):
|
||||
"""Context manager which expects one or more warnings on specific
|
||||
dialects.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
"""
|
||||
spec = db_spec(db)
|
||||
|
||||
if isinstance(db, util.string_types) and not spec(config._current):
|
||||
yield
|
||||
else:
|
||||
with expect_warnings(*messages, **kw):
|
||||
yield
|
||||
|
||||
|
||||
def emits_warning(*messages):
|
||||
"""Decorator form of expect_warnings().
|
||||
|
||||
Note that emits_warning does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_warnings(assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def expect_deprecated(*messages, **kw):
|
||||
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
|
||||
|
||||
|
||||
def emits_warning_on(db, *messages):
|
||||
"""Mark a test as emitting a warning on a specific dialect.
|
||||
|
||||
With no arguments, squelches all SAWarning failures. Or pass one or more
|
||||
strings; these will be matched to the root of the warning description by
|
||||
warnings.filterwarnings().
|
||||
|
||||
Note that emits_warning_on does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_warnings_on(db, assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def uses_deprecated(*messages):
|
||||
"""Mark a test as immune from fatal deprecation warnings.
|
||||
|
||||
With no arguments, squelches all SADeprecationWarning failures.
|
||||
Or pass one or more strings; these will be matched to the root
|
||||
of the warning description by warnings.filterwarnings().
|
||||
|
||||
As a special case, you may pass a function name prefixed with //
|
||||
and it will be re-written as needed to match the standard warning
|
||||
verbiage emitted by the sqlalchemy.util.deprecated decorator.
|
||||
|
||||
Note that uses_deprecated does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_deprecated(*messages, assert_=False):
|
||||
return fn(*args, **kw)
|
||||
return decorate
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
|
||||
py2konly=False):
|
||||
|
||||
if regex:
|
||||
filters = [re.compile(msg, re.I | re.S) for msg in messages]
|
||||
else:
|
||||
filters = messages
|
||||
|
||||
seen = set(filters)
|
||||
|
||||
real_warn = warnings.warn
|
||||
|
||||
def our_warn(msg, exception, *arg, **kw):
|
||||
if not issubclass(exception, exc_cls):
|
||||
return real_warn(msg, exception, *arg, **kw)
|
||||
|
||||
if not filters:
|
||||
return
|
||||
|
||||
for filter_ in filters:
|
||||
if (regex and filter_.match(msg)) or \
|
||||
(not regex and filter_ == msg):
|
||||
seen.discard(filter_)
|
||||
break
|
||||
else:
|
||||
real_warn(msg, exception, *arg, **kw)
|
||||
|
||||
with mock.patch("warnings.warn", our_warn):
|
||||
yield
|
||||
|
||||
if assert_ and (not py2konly or not compat.py3k):
|
||||
assert not seen, "Warnings were not seen: %s" % \
|
||||
", ".join("%r" % (s.pattern if regex else s) for s in seen)
|
||||
|
||||
|
||||
def global_cleanup_assertions():
|
||||
"""Check things that have to be finalized at the end of a test suite.
|
||||
|
||||
Hardcoded at the moment, a modular system can be built here
|
||||
to support things like PG prepared transactions, tables all
|
||||
dropped, etc.
|
||||
|
||||
"""
|
||||
_assert_no_stray_pool_connections()
|
||||
|
||||
_STRAY_CONNECTION_FAILURES = 0
|
||||
|
||||
|
||||
def _assert_no_stray_pool_connections():
|
||||
global _STRAY_CONNECTION_FAILURES
|
||||
|
||||
# lazy gc on cPython means "do nothing." pool connections
|
||||
# shouldn't be in cycles, should go away.
|
||||
testutil.lazy_gc()
|
||||
|
||||
# however, once in awhile, on an EC2 machine usually,
|
||||
# there's a ref in there. usually just one.
|
||||
if pool._refs:
|
||||
|
||||
# OK, let's be somewhat forgiving.
|
||||
_STRAY_CONNECTION_FAILURES += 1
|
||||
|
||||
print("Encountered a stray connection in test cleanup: %s"
|
||||
% str(pool._refs))
|
||||
# then do a real GC sweep. We shouldn't even be here
|
||||
# so a single sweep should really be doing it, otherwise
|
||||
# there's probably a real unreachable cycle somewhere.
|
||||
testutil.gc_collect()
|
||||
|
||||
# if we've already had two of these occurrences, or
|
||||
# after a hard gc sweep we still have pool._refs?!
|
||||
# now we have to raise.
|
||||
if pool._refs:
|
||||
err = str(pool._refs)
|
||||
|
||||
# but clean out the pool refs collection directly,
|
||||
# reset the counter,
|
||||
# so the error doesn't at least keep happening.
|
||||
pool._refs.clear()
|
||||
_STRAY_CONNECTION_FAILURES = 0
|
||||
warnings.warn(
|
||||
"Stray connection refused to leave "
|
||||
"after gc.collect(): %s" % err)
|
||||
elif _STRAY_CONNECTION_FAILURES > 10:
|
||||
assert False, "Encountered more than 10 stray connections"
|
||||
_STRAY_CONNECTION_FAILURES = 0
|
||||
|
||||
|
||||
def eq_regex(a, b, msg=None):
|
||||
assert re.match(b, a), msg or "%r !~ %r" % (a, b)
|
||||
|
||||
|
||||
def eq_(a, b, msg=None):
|
||||
"""Assert a == b, with repr messaging on failure."""
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def ne_(a, b, msg=None):
|
||||
"""Assert a != b, with repr messaging on failure."""
|
||||
assert a != b, msg or "%r == %r" % (a, b)
|
||||
|
||||
|
||||
def le_(a, b, msg=None):
|
||||
"""Assert a <= b, with repr messaging on failure."""
|
||||
assert a <= b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def is_true(a, msg=None):
|
||||
is_(a, True, msg=msg)
|
||||
|
||||
|
||||
def is_false(a, msg=None):
|
||||
is_(a, False, msg=msg)
|
||||
|
||||
|
||||
def is_(a, b, msg=None):
|
||||
"""Assert a is b, with repr messaging on failure."""
|
||||
assert a is b, msg or "%r is not %r" % (a, b)
|
||||
|
||||
|
||||
def is_not_(a, b, msg=None):
|
||||
"""Assert a is not b, with repr messaging on failure."""
|
||||
assert a is not b, msg or "%r is %r" % (a, b)
|
||||
|
||||
|
||||
def in_(a, b, msg=None):
|
||||
"""Assert a in b, with repr messaging on failure."""
|
||||
assert a in b, msg or "%r not in %r" % (a, b)
|
||||
|
||||
|
||||
def not_in_(a, b, msg=None):
|
||||
"""Assert a in not b, with repr messaging on failure."""
|
||||
assert a not in b, msg or "%r is in %r" % (a, b)
|
||||
|
||||
|
||||
def startswith_(a, fragment, msg=None):
|
||||
"""Assert a.startswith(fragment), with repr messaging on failure."""
|
||||
assert a.startswith(fragment), msg or "%r does not start with %r" % (
|
||||
a, fragment)
|
||||
|
||||
|
||||
def eq_ignore_whitespace(a, b, msg=None):
|
||||
a = re.sub(r'^\s+?|\n', "", a)
|
||||
a = re.sub(r' {2,}', " ", a)
|
||||
b = re.sub(r'^\s+?|\n', "", b)
|
||||
b = re.sub(r' {2,}', " ", b)
|
||||
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def assert_raises(except_cls, callable_, *args, **kw):
|
||||
try:
|
||||
callable_(*args, **kw)
|
||||
success = False
|
||||
except except_cls:
|
||||
success = True
|
||||
|
||||
# assert outside the block so it works for AssertionError too !
|
||||
assert success, "Callable did not raise an exception"
|
||||
|
||||
|
||||
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
|
||||
try:
|
||||
callable_(*args, **kwargs)
|
||||
assert False, "Callable did not raise an exception"
|
||||
except except_cls as e:
|
||||
assert re.search(
|
||||
msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
|
||||
print(util.text_type(e).encode('utf-8'))
|
||||
|
||||
|
||||
class AssertsCompiledSQL(object):
|
||||
def assert_compile(self, clause, result, params=None,
|
||||
checkparams=None, dialect=None,
|
||||
checkpositional=None,
|
||||
check_prefetch=None,
|
||||
use_default_dialect=False,
|
||||
allow_dialect_select=False,
|
||||
literal_binds=False,
|
||||
schema_translate_map=None):
|
||||
if use_default_dialect:
|
||||
dialect = default.DefaultDialect()
|
||||
elif allow_dialect_select:
|
||||
dialect = None
|
||||
else:
|
||||
if dialect is None:
|
||||
dialect = getattr(self, '__dialect__', None)
|
||||
|
||||
if dialect is None:
|
||||
dialect = config.db.dialect
|
||||
elif dialect == 'default':
|
||||
dialect = default.DefaultDialect()
|
||||
elif dialect == 'default_enhanced':
|
||||
dialect = default.StrCompileDialect()
|
||||
elif isinstance(dialect, util.string_types):
|
||||
dialect = url.URL(dialect).get_dialect()()
|
||||
|
||||
kw = {}
|
||||
compile_kwargs = {}
|
||||
|
||||
if schema_translate_map:
|
||||
kw['schema_translate_map'] = schema_translate_map
|
||||
|
||||
if params is not None:
|
||||
kw['column_keys'] = list(params)
|
||||
|
||||
if literal_binds:
|
||||
compile_kwargs['literal_binds'] = True
|
||||
|
||||
if isinstance(clause, orm.Query):
|
||||
context = clause._compile_context()
|
||||
context.statement.use_labels = True
|
||||
clause = context.statement
|
||||
|
||||
if compile_kwargs:
|
||||
kw['compile_kwargs'] = compile_kwargs
|
||||
|
||||
c = clause.compile(dialect=dialect, **kw)
|
||||
|
||||
param_str = repr(getattr(c, 'params', {}))
|
||||
|
||||
if util.py3k:
|
||||
param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
|
||||
print(
|
||||
("\nSQL String:\n" +
|
||||
util.text_type(c) +
|
||||
param_str).encode('utf-8'))
|
||||
else:
|
||||
print(
|
||||
"\nSQL String:\n" +
|
||||
util.text_type(c).encode('utf-8') +
|
||||
param_str)
|
||||
|
||||
cc = re.sub(r'[\n\t]', '', util.text_type(c))
|
||||
|
||||
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
|
||||
|
||||
if checkparams is not None:
|
||||
eq_(c.construct_params(params), checkparams)
|
||||
if checkpositional is not None:
|
||||
p = c.construct_params(params)
|
||||
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
|
||||
if check_prefetch is not None:
|
||||
eq_(c.prefetch, check_prefetch)
|
||||
|
||||
|
||||
class ComparesTables(object):
|
||||
|
||||
def assert_tables_equal(self, table, reflected_table, strict_types=False):
|
||||
assert len(table.c) == len(reflected_table.c)
|
||||
for c, reflected_c in zip(table.c, reflected_table.c):
|
||||
eq_(c.name, reflected_c.name)
|
||||
assert reflected_c is reflected_table.c[c.name]
|
||||
eq_(c.primary_key, reflected_c.primary_key)
|
||||
eq_(c.nullable, reflected_c.nullable)
|
||||
|
||||
if strict_types:
|
||||
msg = "Type '%s' doesn't correspond to type '%s'"
|
||||
assert isinstance(reflected_c.type, type(c.type)), \
|
||||
msg % (reflected_c.type, c.type)
|
||||
else:
|
||||
self.assert_types_base(reflected_c, c)
|
||||
|
||||
if isinstance(c.type, sqltypes.String):
|
||||
eq_(c.type.length, reflected_c.type.length)
|
||||
|
||||
eq_(
|
||||
set([f.column.name for f in c.foreign_keys]),
|
||||
set([f.column.name for f in reflected_c.foreign_keys])
|
||||
)
|
||||
if c.server_default:
|
||||
assert isinstance(reflected_c.server_default,
|
||||
schema.FetchedValue)
|
||||
|
||||
assert len(table.primary_key) == len(reflected_table.primary_key)
|
||||
for c in table.primary_key:
|
||||
assert reflected_table.primary_key.columns[c.name] is not None
|
||||
|
||||
def assert_types_base(self, c1, c2):
|
||||
assert c1.type._compare_type_affinity(c2.type),\
|
||||
"On column %r, type '%s' doesn't correspond to type '%s'" % \
|
||||
(c1.name, c1.type, c2.type)
|
||||
|
||||
|
||||
class AssertsExecutionResults(object):
|
||||
def assert_result(self, result, class_, *objects):
|
||||
result = list(result)
|
||||
print(repr(result))
|
||||
self.assert_list(result, class_, objects)
|
||||
|
||||
def assert_list(self, result, class_, list):
|
||||
self.assert_(len(result) == len(list),
|
||||
"result list is not the same size as test list, " +
|
||||
"for class " + class_.__name__)
|
||||
for i in range(0, len(list)):
|
||||
self.assert_row(class_, result[i], list[i])
|
||||
|
||||
def assert_row(self, class_, rowobj, desc):
|
||||
self.assert_(rowobj.__class__ is class_,
|
||||
"item class is not " + repr(class_))
|
||||
for key, value in desc.items():
|
||||
if isinstance(value, tuple):
|
||||
if isinstance(value[1], list):
|
||||
self.assert_list(getattr(rowobj, key), value[0], value[1])
|
||||
else:
|
||||
self.assert_row(value[0], getattr(rowobj, key), value[1])
|
||||
else:
|
||||
self.assert_(getattr(rowobj, key) == value,
|
||||
"attribute %s value %s does not match %s" % (
|
||||
key, getattr(rowobj, key), value))
|
||||
|
||||
def assert_unordered_result(self, result, cls, *expected):
|
||||
"""As assert_result, but the order of objects is not considered.
|
||||
|
||||
The algorithm is very expensive but not a big deal for the small
|
||||
numbers of rows that the test suite manipulates.
|
||||
"""
|
||||
|
||||
class immutabledict(dict):
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
found = util.IdentitySet(result)
|
||||
expected = set([immutabledict(e) for e in expected])
|
||||
|
||||
for wrong in util.itertools_filterfalse(lambda o:
|
||||
isinstance(o, cls), found):
|
||||
fail('Unexpected type "%s", expected "%s"' % (
|
||||
type(wrong).__name__, cls.__name__))
|
||||
|
||||
if len(found) != len(expected):
|
||||
fail('Unexpected object count "%s", expected "%s"' % (
|
||||
len(found), len(expected)))
|
||||
|
||||
NOVALUE = object()
|
||||
|
||||
def _compare_item(obj, spec):
|
||||
for key, value in spec.items():
|
||||
if isinstance(value, tuple):
|
||||
try:
|
||||
self.assert_unordered_result(
|
||||
getattr(obj, key), value[0], *value[1])
|
||||
except AssertionError:
|
||||
return False
|
||||
else:
|
||||
if getattr(obj, key, NOVALUE) != value:
|
||||
return False
|
||||
return True
|
||||
|
||||
for expected_item in expected:
|
||||
for found_item in found:
|
||||
if _compare_item(found_item, expected_item):
|
||||
found.remove(found_item)
|
||||
break
|
||||
else:
|
||||
fail(
|
||||
"Expected %s instance with attributes %s not found." % (
|
||||
cls.__name__, repr(expected_item)))
|
||||
return True
|
||||
|
||||
def sql_execution_asserter(self, db=None):
|
||||
if db is None:
|
||||
from . import db as db
|
||||
|
||||
return assertsql.assert_engine(db)
|
||||
|
||||
def assert_sql_execution(self, db, callable_, *rules):
|
||||
with self.sql_execution_asserter(db) as asserter:
|
||||
callable_()
|
||||
asserter.assert_(*rules)
|
||||
|
||||
def assert_sql(self, db, callable_, rules):
|
||||
|
||||
newrules = []
|
||||
for rule in rules:
|
||||
if isinstance(rule, dict):
|
||||
newrule = assertsql.AllOf(*[
|
||||
assertsql.CompiledSQL(k, v) for k, v in rule.items()
|
||||
])
|
||||
else:
|
||||
newrule = assertsql.CompiledSQL(*rule)
|
||||
newrules.append(newrule)
|
||||
|
||||
self.assert_sql_execution(db, callable_, *newrules)
|
||||
|
||||
def assert_sql_count(self, db, callable_, count):
|
||||
self.assert_sql_execution(
|
||||
db, callable_, assertsql.CountStatements(count))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_execution(self, *rules):
|
||||
assertsql.asserter.add_rules(rules)
|
||||
try:
|
||||
yield
|
||||
assertsql.asserter.statement_complete()
|
||||
finally:
|
||||
assertsql.asserter.clear_rules()
|
||||
|
||||
def assert_statement_count(self, count):
|
||||
return self.assert_execution(assertsql.CountStatements(count))
|
377
sqlalchemy/testing/assertsql.py
Normal file
377
sqlalchemy/testing/assertsql.py
Normal file
@@ -0,0 +1,377 @@
|
||||
# testing/assertsql.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from ..engine.default import DefaultDialect
|
||||
from .. import util
|
||||
import re
|
||||
import collections
|
||||
import contextlib
|
||||
from .. import event
|
||||
from sqlalchemy.schema import _DDLCompiles
|
||||
from sqlalchemy.engine.util import _distill_params
|
||||
from sqlalchemy.engine import url
|
||||
|
||||
|
||||
class AssertRule(object):
|
||||
|
||||
is_consumed = False
|
||||
errormessage = None
|
||||
consume_statement = True
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
pass
|
||||
|
||||
def no_more_statements(self):
|
||||
assert False, 'All statements are complete, but pending '\
|
||||
'assertion rules remain'
|
||||
|
||||
|
||||
class SQLMatchRule(AssertRule):
|
||||
pass
|
||||
|
||||
|
||||
class CursorSQL(SQLMatchRule):
|
||||
consume_statement = False
|
||||
|
||||
def __init__(self, statement, params=None):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
stmt = execute_observed.statements[0]
|
||||
if self.statement != stmt.statement or (
|
||||
self.params is not None and self.params != stmt.parameters):
|
||||
self.errormessage = \
|
||||
"Testing for exact SQL %s parameters %s received %s %s" % (
|
||||
self.statement, self.params,
|
||||
stmt.statement, stmt.parameters
|
||||
)
|
||||
else:
|
||||
execute_observed.statements.pop(0)
|
||||
self.is_consumed = True
|
||||
if not execute_observed.statements:
|
||||
self.consume_statement = True
|
||||
|
||||
|
||||
class CompiledSQL(SQLMatchRule):
|
||||
|
||||
def __init__(self, statement, params=None, dialect='default'):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
self.dialect = dialect
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = re.sub(r'[\n\t]', '', self.statement)
|
||||
return received_statement == stmt
|
||||
|
||||
def _compile_dialect(self, execute_observed):
|
||||
if self.dialect == 'default':
|
||||
return DefaultDialect()
|
||||
else:
|
||||
# ugh
|
||||
if self.dialect == 'postgresql':
|
||||
params = {'implicit_returning': True}
|
||||
else:
|
||||
params = {}
|
||||
return url.URL(self.dialect).get_dialect()(**params)
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
"""reconstruct the statement and params in terms
|
||||
of a target dialect, which for CompiledSQL is just DefaultDialect."""
|
||||
|
||||
context = execute_observed.context
|
||||
compare_dialect = self._compile_dialect(execute_observed)
|
||||
if isinstance(context.compiled.statement, _DDLCompiles):
|
||||
compiled = \
|
||||
context.compiled.statement.compile(
|
||||
dialect=compare_dialect,
|
||||
schema_translate_map=context.
|
||||
execution_options.get('schema_translate_map'))
|
||||
else:
|
||||
compiled = (
|
||||
context.compiled.statement.compile(
|
||||
dialect=compare_dialect,
|
||||
column_keys=context.compiled.column_keys,
|
||||
inline=context.compiled.inline,
|
||||
schema_translate_map=context.
|
||||
execution_options.get('schema_translate_map'))
|
||||
)
|
||||
_received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
|
||||
parameters = execute_observed.parameters
|
||||
|
||||
if not parameters:
|
||||
_received_parameters = [compiled.construct_params()]
|
||||
else:
|
||||
_received_parameters = [
|
||||
compiled.construct_params(m) for m in parameters]
|
||||
|
||||
return _received_statement, _received_parameters
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
context = execute_observed.context
|
||||
|
||||
_received_statement, _received_parameters = \
|
||||
self._received_statement(execute_observed)
|
||||
params = self._all_params(context)
|
||||
|
||||
equivalent = self._compare_sql(execute_observed, _received_statement)
|
||||
|
||||
if equivalent:
|
||||
if params is not None:
|
||||
all_params = list(params)
|
||||
all_received = list(_received_parameters)
|
||||
while all_params and all_received:
|
||||
param = dict(all_params.pop(0))
|
||||
|
||||
for idx, received in enumerate(list(all_received)):
|
||||
# do a positive compare only
|
||||
for param_key in param:
|
||||
# a key in param did not match current
|
||||
# 'received'
|
||||
if param_key not in received or \
|
||||
received[param_key] != param[param_key]:
|
||||
break
|
||||
else:
|
||||
# all keys in param matched 'received';
|
||||
# onto next param
|
||||
del all_received[idx]
|
||||
break
|
||||
else:
|
||||
# param did not match any entry
|
||||
# in all_received
|
||||
equivalent = False
|
||||
break
|
||||
if all_params or all_received:
|
||||
equivalent = False
|
||||
|
||||
if equivalent:
|
||||
self.is_consumed = True
|
||||
self.errormessage = None
|
||||
else:
|
||||
self.errormessage = self._failure_message(params) % {
|
||||
'received_statement': _received_statement,
|
||||
'received_parameters': _received_parameters
|
||||
}
|
||||
|
||||
def _all_params(self, context):
|
||||
if self.params:
|
||||
if util.callable(self.params):
|
||||
params = self.params(context)
|
||||
else:
|
||||
params = self.params
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
return params
|
||||
else:
|
||||
return None
|
||||
|
||||
def _failure_message(self, expected_params):
|
||||
return (
|
||||
'Testing for compiled statement %r partial params %r, '
|
||||
'received %%(received_statement)r with params '
|
||||
'%%(received_parameters)r' % (
|
||||
self.statement.replace('%', '%%'), expected_params
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RegexSQL(CompiledSQL):
|
||||
def __init__(self, regex, params=None):
|
||||
SQLMatchRule.__init__(self)
|
||||
self.regex = re.compile(regex)
|
||||
self.orig_regex = regex
|
||||
self.params = params
|
||||
self.dialect = 'default'
|
||||
|
||||
def _failure_message(self, expected_params):
|
||||
return (
|
||||
'Testing for compiled statement ~%r partial params %r, '
|
||||
'received %%(received_statement)r with params '
|
||||
'%%(received_parameters)r' % (
|
||||
self.orig_regex, expected_params
|
||||
)
|
||||
)
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
return bool(self.regex.match(received_statement))
|
||||
|
||||
|
||||
class DialectSQL(CompiledSQL):
|
||||
def _compile_dialect(self, execute_observed):
|
||||
return execute_observed.context.dialect
|
||||
|
||||
def _compare_no_space(self, real_stmt, received_stmt):
|
||||
stmt = re.sub(r'[\n\t]', '', real_stmt)
|
||||
return received_stmt == stmt
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
received_stmt, received_params = super(DialectSQL, self).\
|
||||
_received_statement(execute_observed)
|
||||
|
||||
# TODO: why do we need this part?
|
||||
for real_stmt in execute_observed.statements:
|
||||
if self._compare_no_space(real_stmt.statement, received_stmt):
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Can't locate compiled statement %r in list of "
|
||||
"statements actually invoked" % received_stmt)
|
||||
|
||||
return received_stmt, execute_observed.context.compiled_parameters
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = re.sub(r'[\n\t]', '', self.statement)
|
||||
# convert our comparison statement to have the
|
||||
# paramstyle of the received
|
||||
paramstyle = execute_observed.context.dialect.paramstyle
|
||||
if paramstyle == 'pyformat':
|
||||
stmt = re.sub(
|
||||
r':([\w_]+)', r"%(\1)s", stmt)
|
||||
else:
|
||||
# positional params
|
||||
repl = None
|
||||
if paramstyle == 'qmark':
|
||||
repl = "?"
|
||||
elif paramstyle == 'format':
|
||||
repl = r"%s"
|
||||
elif paramstyle == 'numeric':
|
||||
repl = None
|
||||
stmt = re.sub(r':([\w_]+)', repl, stmt)
|
||||
|
||||
return received_statement == stmt
|
||||
|
||||
|
||||
class CountStatements(AssertRule):
|
||||
|
||||
def __init__(self, count):
|
||||
self.count = count
|
||||
self._statement_count = 0
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
self._statement_count += 1
|
||||
|
||||
def no_more_statements(self):
|
||||
if self.count != self._statement_count:
|
||||
assert False, 'desired statement count %d does not match %d' \
|
||||
% (self.count, self._statement_count)
|
||||
|
||||
|
||||
class AllOf(AssertRule):
|
||||
|
||||
def __init__(self, *rules):
|
||||
self.rules = set(rules)
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in list(self.rules):
|
||||
rule.errormessage = None
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.rules.discard(rule)
|
||||
if not self.rules:
|
||||
self.is_consumed = True
|
||||
break
|
||||
elif not rule.errormessage:
|
||||
# rule is not done yet
|
||||
self.errormessage = None
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
|
||||
class Or(AllOf):
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in self.rules:
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.is_consumed = True
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
|
||||
class SQLExecuteObserved(object):
|
||||
def __init__(self, context, clauseelement, multiparams, params):
|
||||
self.context = context
|
||||
self.clauseelement = clauseelement
|
||||
self.parameters = _distill_params(multiparams, params)
|
||||
self.statements = []
|
||||
|
||||
|
||||
class SQLCursorExecuteObserved(
|
||||
collections.namedtuple(
|
||||
"SQLCursorExecuteObserved",
|
||||
["statement", "parameters", "context", "executemany"])
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAsserter(object):
|
||||
def __init__(self):
|
||||
self.accumulated = []
|
||||
|
||||
def _close(self):
|
||||
self._final = self.accumulated
|
||||
del self.accumulated
|
||||
|
||||
def assert_(self, *rules):
|
||||
rules = list(rules)
|
||||
observed = list(self._final)
|
||||
|
||||
while observed and rules:
|
||||
rule = rules[0]
|
||||
rule.process_statement(observed[0])
|
||||
if rule.is_consumed:
|
||||
rules.pop(0)
|
||||
elif rule.errormessage:
|
||||
assert False, rule.errormessage
|
||||
|
||||
if rule.consume_statement:
|
||||
observed.pop(0)
|
||||
|
||||
if not observed and rules:
|
||||
rules[0].no_more_statements()
|
||||
elif not rules and observed:
|
||||
assert False, "Additional SQL statements remain"
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_engine(engine):
|
||||
asserter = SQLAsserter()
|
||||
|
||||
orig = []
|
||||
|
||||
@event.listens_for(engine, "before_execute")
|
||||
def connection_execute(conn, clauseelement, multiparams, params):
|
||||
# grab the original statement + params before any cursor
|
||||
# execution
|
||||
orig[:] = clauseelement, multiparams, params
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def cursor_execute(conn, cursor, statement, parameters,
|
||||
context, executemany):
|
||||
if not context:
|
||||
return
|
||||
# then grab real cursor statements and associate them all
|
||||
# around a single context
|
||||
if asserter.accumulated and \
|
||||
asserter.accumulated[-1].context is context:
|
||||
obs = asserter.accumulated[-1]
|
||||
else:
|
||||
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
|
||||
asserter.accumulated.append(obs)
|
||||
obs.statements.append(
|
||||
SQLCursorExecuteObserved(
|
||||
statement, parameters, context, executemany)
|
||||
)
|
||||
|
||||
try:
|
||||
yield asserter
|
||||
finally:
|
||||
event.remove(engine, "after_cursor_execute", cursor_execute)
|
||||
event.remove(engine, "before_execute", connection_execute)
|
||||
asserter._close()
|
97
sqlalchemy/testing/config.py
Normal file
97
sqlalchemy/testing/config.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# testing/config.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import collections
|
||||
|
||||
requirements = None
|
||||
db = None
|
||||
db_url = None
|
||||
db_opts = None
|
||||
file_config = None
|
||||
test_schema = None
|
||||
test_schema_2 = None
|
||||
_current = None
|
||||
|
||||
try:
|
||||
from unittest import SkipTest as _skip_test_exception
|
||||
except ImportError:
|
||||
_skip_test_exception = None
|
||||
|
||||
|
||||
class Config(object):
|
||||
def __init__(self, db, db_opts, options, file_config):
|
||||
self.db = db
|
||||
self.db_opts = db_opts
|
||||
self.options = options
|
||||
self.file_config = file_config
|
||||
self.test_schema = "test_schema"
|
||||
self.test_schema_2 = "test_schema_2"
|
||||
|
||||
_stack = collections.deque()
|
||||
_configs = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, db, db_opts, options, file_config):
|
||||
"""add a config as one of the global configs.
|
||||
|
||||
If there are no configs set up yet, this config also
|
||||
gets set as the "_current".
|
||||
"""
|
||||
cfg = Config(db, db_opts, options, file_config)
|
||||
|
||||
cls._configs[cfg.db.name] = cfg
|
||||
cls._configs[(cfg.db.name, cfg.db.dialect)] = cfg
|
||||
cls._configs[cfg.db] = cfg
|
||||
return cfg
|
||||
|
||||
@classmethod
|
||||
def set_as_current(cls, config, namespace):
|
||||
global db, _current, db_url, test_schema, test_schema_2, db_opts
|
||||
_current = config
|
||||
db_url = config.db.url
|
||||
db_opts = config.db_opts
|
||||
test_schema = config.test_schema
|
||||
test_schema_2 = config.test_schema_2
|
||||
namespace.db = db = config.db
|
||||
|
||||
@classmethod
|
||||
def push_engine(cls, db, namespace):
|
||||
assert _current, "Can't push without a default Config set up"
|
||||
cls.push(
|
||||
Config(
|
||||
db, _current.db_opts, _current.options, _current.file_config),
|
||||
namespace
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def push(cls, config, namespace):
|
||||
cls._stack.append(_current)
|
||||
cls.set_as_current(config, namespace)
|
||||
|
||||
@classmethod
|
||||
def reset(cls, namespace):
|
||||
if cls._stack:
|
||||
cls.set_as_current(cls._stack[0], namespace)
|
||||
cls._stack.clear()
|
||||
|
||||
@classmethod
|
||||
def all_configs(cls):
|
||||
for cfg in set(cls._configs.values()):
|
||||
yield cfg
|
||||
|
||||
@classmethod
|
||||
def all_dbs(cls):
|
||||
for cfg in cls.all_configs():
|
||||
yield cfg.db
|
||||
|
||||
def skip_test(self, msg):
|
||||
skip_test(msg)
|
||||
|
||||
|
||||
def skip_test(msg):
|
||||
raise _skip_test_exception(msg)
|
||||
|
349
sqlalchemy/testing/engines.py
Normal file
349
sqlalchemy/testing/engines.py
Normal file
@@ -0,0 +1,349 @@
|
||||
# testing/engines.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import weakref
|
||||
from . import config
|
||||
from .util import decorator
|
||||
from .. import event, pool
|
||||
import re
|
||||
import warnings
|
||||
|
||||
|
||||
class ConnectionKiller(object):
|
||||
|
||||
def __init__(self):
|
||||
self.proxy_refs = weakref.WeakKeyDictionary()
|
||||
self.testing_engines = weakref.WeakKeyDictionary()
|
||||
self.conns = set()
|
||||
|
||||
def add_engine(self, engine):
|
||||
self.testing_engines[engine] = True
|
||||
|
||||
def connect(self, dbapi_conn, con_record):
|
||||
self.conns.add((dbapi_conn, con_record))
|
||||
|
||||
def checkout(self, dbapi_con, con_record, con_proxy):
|
||||
self.proxy_refs[con_proxy] = True
|
||||
|
||||
def invalidate(self, dbapi_con, con_record, exception):
|
||||
self.conns.discard((dbapi_con, con_record))
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"testing_reaper couldn't "
|
||||
"rollback/close connection: %s" % e)
|
||||
|
||||
def rollback_all(self):
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self._safe(rec.rollback)
|
||||
|
||||
def close_all(self):
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self._safe(rec._close)
|
||||
|
||||
def _after_test_ctx(self):
|
||||
# this can cause a deadlock with pg8000 - pg8000 acquires
|
||||
# prepared statement lock inside of rollback() - if async gc
|
||||
# is collecting in finalize_fairy, deadlock.
|
||||
# not sure if this should be if pypy/jython only.
|
||||
# note that firebird/fdb definitely needs this though
|
||||
for conn, rec in list(self.conns):
|
||||
self._safe(conn.rollback)
|
||||
|
||||
def _stop_test_ctx(self):
|
||||
if config.options.low_connections:
|
||||
self._stop_test_ctx_minimal()
|
||||
else:
|
||||
self._stop_test_ctx_aggressive()
|
||||
|
||||
def _stop_test_ctx_minimal(self):
|
||||
self.close_all()
|
||||
|
||||
self.conns = set()
|
||||
|
||||
for rec in list(self.testing_engines):
|
||||
if rec is not config.db:
|
||||
rec.dispose()
|
||||
|
||||
def _stop_test_ctx_aggressive(self):
|
||||
self.close_all()
|
||||
for conn, rec in list(self.conns):
|
||||
self._safe(conn.close)
|
||||
rec.connection = None
|
||||
|
||||
self.conns = set()
|
||||
for rec in list(self.testing_engines):
|
||||
rec.dispose()
|
||||
|
||||
def assert_all_closed(self):
|
||||
for rec in self.proxy_refs:
|
||||
if rec.is_valid:
|
||||
assert False
|
||||
|
||||
testing_reaper = ConnectionKiller()
|
||||
|
||||
|
||||
def drop_all_tables(metadata, bind):
|
||||
testing_reaper.close_all()
|
||||
if hasattr(bind, 'close'):
|
||||
bind.close()
|
||||
|
||||
if not config.db.dialect.supports_alter:
|
||||
from . import assertions
|
||||
with assertions.expect_warnings(
|
||||
"Can't sort tables", assert_=False):
|
||||
metadata.drop_all(bind)
|
||||
else:
|
||||
metadata.drop_all(bind)
|
||||
|
||||
|
||||
@decorator
|
||||
def assert_conns_closed(fn, *args, **kw):
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.assert_all_closed()
|
||||
|
||||
|
||||
@decorator
|
||||
def rollback_open_connections(fn, *args, **kw):
|
||||
"""Decorator that rolls back all open connections after fn execution."""
|
||||
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.rollback_all()
|
||||
|
||||
|
||||
@decorator
|
||||
def close_first(fn, *args, **kw):
|
||||
"""Decorator that closes all connections before fn execution."""
|
||||
|
||||
testing_reaper.close_all()
|
||||
fn(*args, **kw)
|
||||
|
||||
|
||||
@decorator
|
||||
def close_open_connections(fn, *args, **kw):
|
||||
"""Decorator that closes all connections after fn execution."""
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.close_all()
|
||||
|
||||
|
||||
def all_dialects(exclude=None):
|
||||
import sqlalchemy.databases as d
|
||||
for name in d.__all__:
|
||||
# TEMPORARY
|
||||
if exclude and name in exclude:
|
||||
continue
|
||||
mod = getattr(d, name, None)
|
||||
if not mod:
|
||||
mod = getattr(__import__(
|
||||
'sqlalchemy.databases.%s' % name).databases, name)
|
||||
yield mod.dialect()
|
||||
|
||||
|
||||
class ReconnectFixture(object):
|
||||
|
||||
def __init__(self, dbapi):
|
||||
self.dbapi = dbapi
|
||||
self.connections = []
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.dbapi, key)
|
||||
|
||||
def connect(self, *args, **kwargs):
|
||||
conn = self.dbapi.connect(*args, **kwargs)
|
||||
self.connections.append(conn)
|
||||
return conn
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"ReconnectFixture couldn't "
|
||||
"close connection: %s" % e)
|
||||
|
||||
def shutdown(self):
|
||||
# TODO: this doesn't cover all cases
|
||||
# as nicely as we'd like, namely MySQLdb.
|
||||
# would need to implement R. Brewer's
|
||||
# proxy server idea to get better
|
||||
# coverage.
|
||||
for c in list(self.connections):
|
||||
self._safe(c.close)
|
||||
self.connections = []
|
||||
|
||||
|
||||
def reconnecting_engine(url=None, options=None):
|
||||
url = url or config.db.url
|
||||
dbapi = config.db.dialect.dbapi
|
||||
if not options:
|
||||
options = {}
|
||||
options['module'] = ReconnectFixture(dbapi)
|
||||
engine = testing_engine(url, options)
|
||||
_dispose = engine.dispose
|
||||
|
||||
def dispose():
|
||||
engine.dialect.dbapi.shutdown()
|
||||
_dispose()
|
||||
|
||||
engine.test_shutdown = engine.dialect.dbapi.shutdown
|
||||
engine.dispose = dispose
|
||||
return engine
|
||||
|
||||
|
||||
def testing_engine(url=None, options=None):
|
||||
"""Produce an engine configured by --options with optional overrides."""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
if not options:
|
||||
use_reaper = True
|
||||
else:
|
||||
use_reaper = options.pop('use_reaper', True)
|
||||
|
||||
url = url or config.db.url
|
||||
|
||||
url = make_url(url)
|
||||
if options is None:
|
||||
if config.db is None or url.drivername == config.db.url.drivername:
|
||||
options = config.db_opts
|
||||
else:
|
||||
options = {}
|
||||
elif config.db is not None and url.drivername == config.db.url.drivername:
|
||||
default_opt = config.db_opts.copy()
|
||||
default_opt.update(options)
|
||||
|
||||
engine = create_engine(url, **options)
|
||||
engine._has_events = True # enable event blocks, helps with profiling
|
||||
|
||||
if isinstance(engine.pool, pool.QueuePool):
|
||||
engine.pool._timeout = 0
|
||||
engine.pool._max_overflow = 0
|
||||
if use_reaper:
|
||||
event.listen(engine.pool, 'connect', testing_reaper.connect)
|
||||
event.listen(engine.pool, 'checkout', testing_reaper.checkout)
|
||||
event.listen(engine.pool, 'invalidate', testing_reaper.invalidate)
|
||||
testing_reaper.add_engine(engine)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def mock_engine(dialect_name=None):
|
||||
"""Provides a mocking engine based on the current testing.db.
|
||||
|
||||
This is normally used to test DDL generation flow as emitted
|
||||
by an Engine.
|
||||
|
||||
It should not be used in other cases, as assert_compile() and
|
||||
assert_sql_execution() are much better choices with fewer
|
||||
moving parts.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
if not dialect_name:
|
||||
dialect_name = config.db.name
|
||||
|
||||
buffer = []
|
||||
|
||||
def executor(sql, *a, **kw):
|
||||
buffer.append(sql)
|
||||
|
||||
def assert_sql(stmts):
|
||||
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
|
||||
assert recv == stmts, recv
|
||||
|
||||
def print_sql():
|
||||
d = engine.dialect
|
||||
return "\n".join(
|
||||
str(s.compile(dialect=d))
|
||||
for s in engine.mock
|
||||
)
|
||||
|
||||
engine = create_engine(dialect_name + '://',
|
||||
strategy='mock', executor=executor)
|
||||
assert not hasattr(engine, 'mock')
|
||||
engine.mock = buffer
|
||||
engine.assert_sql = assert_sql
|
||||
engine.print_sql = print_sql
|
||||
return engine
|
||||
|
||||
|
||||
class DBAPIProxyCursor(object):
|
||||
"""Proxy a DBAPI cursor.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level cursor operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, conn, *args, **kwargs):
|
||||
self.engine = engine
|
||||
self.connection = conn
|
||||
self.cursor = conn.cursor(*args, **kwargs)
|
||||
|
||||
def execute(self, stmt, parameters=None, **kw):
|
||||
if parameters:
|
||||
return self.cursor.execute(stmt, parameters, **kw)
|
||||
else:
|
||||
return self.cursor.execute(stmt, **kw)
|
||||
|
||||
def executemany(self, stmt, params, **kw):
|
||||
return self.cursor.executemany(stmt, params, **kw)
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.cursor, key)
|
||||
|
||||
|
||||
class DBAPIProxyConnection(object):
|
||||
"""Proxy a DBAPI connection.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level connection operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, cursor_cls):
|
||||
self.conn = self._sqla_unwrap = engine.pool._creator()
|
||||
self.engine = engine
|
||||
self.cursor_cls = cursor_cls
|
||||
|
||||
def cursor(self, *args, **kwargs):
|
||||
return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.conn, key)
|
||||
|
||||
|
||||
def proxying_engine(conn_cls=DBAPIProxyConnection,
|
||||
cursor_cls=DBAPIProxyCursor):
|
||||
"""Produce an engine that provides proxy hooks for
|
||||
common methods.
|
||||
|
||||
"""
|
||||
def mock_conn():
|
||||
return conn_cls(config.db, cursor_cls)
|
||||
return testing_engine(options={'creator': mock_conn})
|
||||
|
||||
|
101
sqlalchemy/testing/entities.py
Normal file
101
sqlalchemy/testing/entities.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# testing/entities.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exc as sa_exc
|
||||
|
||||
_repr_stack = set()
|
||||
|
||||
|
||||
class BasicEntity(object):
|
||||
|
||||
def __init__(self, **kw):
|
||||
for key, value in kw.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def __repr__(self):
|
||||
if id(self) in _repr_stack:
|
||||
return object.__repr__(self)
|
||||
_repr_stack.add(id(self))
|
||||
try:
|
||||
return "%s(%s)" % (
|
||||
(self.__class__.__name__),
|
||||
', '.join(["%s=%r" % (key, getattr(self, key))
|
||||
for key in sorted(self.__dict__.keys())
|
||||
if not key.startswith('_')]))
|
||||
finally:
|
||||
_repr_stack.remove(id(self))
|
||||
|
||||
_recursion_stack = set()
|
||||
|
||||
|
||||
class ComparableEntity(BasicEntity):
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.__class__)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""'Deep, sparse compare.
|
||||
|
||||
Deeply compare two entities, following the non-None attributes of the
|
||||
non-persisted object, if possible.
|
||||
|
||||
"""
|
||||
if other is self:
|
||||
return True
|
||||
elif not self.__class__ == other.__class__:
|
||||
return False
|
||||
|
||||
if id(self) in _recursion_stack:
|
||||
return True
|
||||
_recursion_stack.add(id(self))
|
||||
|
||||
try:
|
||||
# pick the entity that's not SA persisted as the source
|
||||
try:
|
||||
self_key = sa.orm.attributes.instance_state(self).key
|
||||
except sa.orm.exc.NO_STATE:
|
||||
self_key = None
|
||||
|
||||
if other is None:
|
||||
a = self
|
||||
b = other
|
||||
elif self_key is not None:
|
||||
a = other
|
||||
b = self
|
||||
else:
|
||||
a = self
|
||||
b = other
|
||||
|
||||
for attr in list(a.__dict__):
|
||||
if attr.startswith('_'):
|
||||
continue
|
||||
value = getattr(a, attr)
|
||||
|
||||
try:
|
||||
# handle lazy loader errors
|
||||
battr = getattr(b, attr)
|
||||
except (AttributeError, sa_exc.UnboundExecutionError):
|
||||
return False
|
||||
|
||||
if hasattr(value, '__iter__'):
|
||||
if hasattr(value, '__getitem__') and not hasattr(
|
||||
value, 'keys'):
|
||||
if list(value) != list(battr):
|
||||
return False
|
||||
else:
|
||||
if set(value) != set(battr):
|
||||
return False
|
||||
else:
|
||||
if value is not None and value != battr:
|
||||
return False
|
||||
return True
|
||||
finally:
|
||||
_recursion_stack.remove(id(self))
|
443
sqlalchemy/testing/exclusions.py
Normal file
443
sqlalchemy/testing/exclusions.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# testing/exclusions.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
import operator
|
||||
from ..util import decorator
|
||||
from . import config
|
||||
from .. import util
|
||||
import inspect
|
||||
import contextlib
|
||||
from sqlalchemy.util.compat import inspect_getargspec
|
||||
|
||||
|
||||
def skip_if(predicate, reason=None):
|
||||
rule = compound()
|
||||
pred = _as_predicate(predicate, reason)
|
||||
rule.skips.add(pred)
|
||||
return rule
|
||||
|
||||
|
||||
def fails_if(predicate, reason=None):
|
||||
rule = compound()
|
||||
pred = _as_predicate(predicate, reason)
|
||||
rule.fails.add(pred)
|
||||
return rule
|
||||
|
||||
|
||||
class compound(object):
|
||||
def __init__(self):
|
||||
self.fails = set()
|
||||
self.skips = set()
|
||||
self.tags = set()
|
||||
|
||||
def __add__(self, other):
|
||||
return self.add(other)
|
||||
|
||||
def add(self, *others):
|
||||
copy = compound()
|
||||
copy.fails.update(self.fails)
|
||||
copy.skips.update(self.skips)
|
||||
copy.tags.update(self.tags)
|
||||
for other in others:
|
||||
copy.fails.update(other.fails)
|
||||
copy.skips.update(other.skips)
|
||||
copy.tags.update(other.tags)
|
||||
return copy
|
||||
|
||||
def not_(self):
|
||||
copy = compound()
|
||||
copy.fails.update(NotPredicate(fail) for fail in self.fails)
|
||||
copy.skips.update(NotPredicate(skip) for skip in self.skips)
|
||||
copy.tags.update(self.tags)
|
||||
return copy
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.enabled_for_config(config._current)
|
||||
|
||||
def enabled_for_config(self, config):
|
||||
for predicate in self.skips.union(self.fails):
|
||||
if predicate(config):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def matching_config_reasons(self, config):
|
||||
return [
|
||||
predicate._as_string(config) for predicate
|
||||
in self.skips.union(self.fails)
|
||||
if predicate(config)
|
||||
]
|
||||
|
||||
def include_test(self, include_tags, exclude_tags):
|
||||
return bool(
|
||||
not self.tags.intersection(exclude_tags) and
|
||||
(not include_tags or self.tags.intersection(include_tags))
|
||||
)
|
||||
|
||||
def _extend(self, other):
|
||||
self.skips.update(other.skips)
|
||||
self.fails.update(other.fails)
|
||||
self.tags.update(other.tags)
|
||||
|
||||
def __call__(self, fn):
|
||||
if hasattr(fn, '_sa_exclusion_extend'):
|
||||
fn._sa_exclusion_extend._extend(self)
|
||||
return fn
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
return self._do(config._current, fn, *args, **kw)
|
||||
decorated = decorate(fn)
|
||||
decorated._sa_exclusion_extend = self
|
||||
return decorated
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fail_if(self):
|
||||
all_fails = compound()
|
||||
all_fails.fails.update(self.skips.union(self.fails))
|
||||
|
||||
try:
|
||||
yield
|
||||
except Exception as ex:
|
||||
all_fails._expect_failure(config._current, ex)
|
||||
else:
|
||||
all_fails._expect_success(config._current)
|
||||
|
||||
def _do(self, cfg, fn, *args, **kw):
|
||||
for skip in self.skips:
|
||||
if skip(cfg):
|
||||
msg = "'%s' : %s" % (
|
||||
fn.__name__,
|
||||
skip._as_string(cfg)
|
||||
)
|
||||
config.skip_test(msg)
|
||||
|
||||
try:
|
||||
return_value = fn(*args, **kw)
|
||||
except Exception as ex:
|
||||
self._expect_failure(cfg, ex, name=fn.__name__)
|
||||
else:
|
||||
self._expect_success(cfg, name=fn.__name__)
|
||||
return return_value
|
||||
|
||||
def _expect_failure(self, config, ex, name='block'):
|
||||
for fail in self.fails:
|
||||
if fail(config):
|
||||
print(("%s failed as expected (%s): %s " % (
|
||||
name, fail._as_string(config), str(ex))))
|
||||
break
|
||||
else:
|
||||
util.raise_from_cause(ex)
|
||||
|
||||
def _expect_success(self, config, name='block'):
|
||||
if not self.fails:
|
||||
return
|
||||
for fail in self.fails:
|
||||
if not fail(config):
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Unexpected success for '%s' (%s)" %
|
||||
(
|
||||
name,
|
||||
" and ".join(
|
||||
fail._as_string(config)
|
||||
for fail in self.fails
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def requires_tag(tagname):
|
||||
return tags([tagname])
|
||||
|
||||
|
||||
def tags(tagnames):
|
||||
comp = compound()
|
||||
comp.tags.update(tagnames)
|
||||
return comp
|
||||
|
||||
|
||||
def only_if(predicate, reason=None):
|
||||
predicate = _as_predicate(predicate)
|
||||
return skip_if(NotPredicate(predicate), reason)
|
||||
|
||||
|
||||
def succeeds_if(predicate, reason=None):
|
||||
predicate = _as_predicate(predicate)
|
||||
return fails_if(NotPredicate(predicate), reason)
|
||||
|
||||
|
||||
class Predicate(object):
|
||||
@classmethod
|
||||
def as_predicate(cls, predicate, description=None):
|
||||
if isinstance(predicate, compound):
|
||||
return cls.as_predicate(predicate.enabled_for_config, description)
|
||||
elif isinstance(predicate, Predicate):
|
||||
if description and predicate.description is None:
|
||||
predicate.description = description
|
||||
return predicate
|
||||
elif isinstance(predicate, (list, set)):
|
||||
return OrPredicate(
|
||||
[cls.as_predicate(pred) for pred in predicate],
|
||||
description)
|
||||
elif isinstance(predicate, tuple):
|
||||
return SpecPredicate(*predicate)
|
||||
elif isinstance(predicate, util.string_types):
|
||||
tokens = predicate.split(" ", 2)
|
||||
op = spec = None
|
||||
db = tokens.pop(0)
|
||||
if tokens:
|
||||
op = tokens.pop(0)
|
||||
if tokens:
|
||||
spec = tuple(int(d) for d in tokens.pop(0).split("."))
|
||||
return SpecPredicate(db, op, spec, description=description)
|
||||
elif util.callable(predicate):
|
||||
return LambdaPredicate(predicate, description)
|
||||
else:
|
||||
assert False, "unknown predicate type: %s" % predicate
|
||||
|
||||
def _format_description(self, config, negate=False):
|
||||
bool_ = self(config)
|
||||
if negate:
|
||||
bool_ = not negate
|
||||
return self.description % {
|
||||
"driver": config.db.url.get_driver_name()
|
||||
if config else "<no driver>",
|
||||
"database": config.db.url.get_backend_name()
|
||||
if config else "<no database>",
|
||||
"doesnt_support": "doesn't support" if bool_ else "does support",
|
||||
"does_support": "does support" if bool_ else "doesn't support"
|
||||
}
|
||||
|
||||
def _as_string(self, config=None, negate=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BooleanPredicate(Predicate):
|
||||
def __init__(self, value, description=None):
|
||||
self.value = value
|
||||
self.description = description or "boolean %s" % value
|
||||
|
||||
def __call__(self, config):
|
||||
return self.value
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
return self._format_description(config, negate=negate)
|
||||
|
||||
|
||||
class SpecPredicate(Predicate):
|
||||
def __init__(self, db, op=None, spec=None, description=None):
|
||||
self.db = db
|
||||
self.op = op
|
||||
self.spec = spec
|
||||
self.description = description
|
||||
|
||||
_ops = {
|
||||
'<': operator.lt,
|
||||
'>': operator.gt,
|
||||
'==': operator.eq,
|
||||
'!=': operator.ne,
|
||||
'<=': operator.le,
|
||||
'>=': operator.ge,
|
||||
'in': operator.contains,
|
||||
'between': lambda val, pair: val >= pair[0] and val <= pair[1],
|
||||
}
|
||||
|
||||
def __call__(self, config):
|
||||
engine = config.db
|
||||
|
||||
if "+" in self.db:
|
||||
dialect, driver = self.db.split('+')
|
||||
else:
|
||||
dialect, driver = self.db, None
|
||||
|
||||
if dialect and engine.name != dialect:
|
||||
return False
|
||||
if driver is not None and engine.driver != driver:
|
||||
return False
|
||||
|
||||
if self.op is not None:
|
||||
assert driver is None, "DBAPI version specs not supported yet"
|
||||
|
||||
version = _server_version(engine)
|
||||
oper = hasattr(self.op, '__call__') and self.op \
|
||||
or self._ops[self.op]
|
||||
return oper(version, self.spec)
|
||||
else:
|
||||
return True
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if self.description is not None:
|
||||
return self._format_description(config)
|
||||
elif self.op is None:
|
||||
if negate:
|
||||
return "not %s" % self.db
|
||||
else:
|
||||
return "%s" % self.db
|
||||
else:
|
||||
if negate:
|
||||
return "not %s %s %s" % (
|
||||
self.db,
|
||||
self.op,
|
||||
self.spec
|
||||
)
|
||||
else:
|
||||
return "%s %s %s" % (
|
||||
self.db,
|
||||
self.op,
|
||||
self.spec
|
||||
)
|
||||
|
||||
|
||||
class LambdaPredicate(Predicate):
|
||||
def __init__(self, lambda_, description=None, args=None, kw=None):
|
||||
spec = inspect_getargspec(lambda_)
|
||||
if not spec[0]:
|
||||
self.lambda_ = lambda db: lambda_()
|
||||
else:
|
||||
self.lambda_ = lambda_
|
||||
self.args = args or ()
|
||||
self.kw = kw or {}
|
||||
if description:
|
||||
self.description = description
|
||||
elif lambda_.__doc__:
|
||||
self.description = lambda_.__doc__
|
||||
else:
|
||||
self.description = "custom function"
|
||||
|
||||
def __call__(self, config):
|
||||
return self.lambda_(config)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
return self._format_description(config)
|
||||
|
||||
|
||||
class NotPredicate(Predicate):
|
||||
def __init__(self, predicate, description=None):
|
||||
self.predicate = predicate
|
||||
self.description = description
|
||||
|
||||
def __call__(self, config):
|
||||
return not self.predicate(config)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if self.description:
|
||||
return self._format_description(config, not negate)
|
||||
else:
|
||||
return self.predicate._as_string(config, not negate)
|
||||
|
||||
|
||||
class OrPredicate(Predicate):
|
||||
def __init__(self, predicates, description=None):
|
||||
self.predicates = predicates
|
||||
self.description = description
|
||||
|
||||
def __call__(self, config):
|
||||
for pred in self.predicates:
|
||||
if pred(config):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _eval_str(self, config, negate=False):
|
||||
if negate:
|
||||
conjunction = " and "
|
||||
else:
|
||||
conjunction = " or "
|
||||
return conjunction.join(p._as_string(config, negate=negate)
|
||||
for p in self.predicates)
|
||||
|
||||
def _negation_str(self, config):
|
||||
if self.description is not None:
|
||||
return "Not " + self._format_description(config)
|
||||
else:
|
||||
return self._eval_str(config, negate=True)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if negate:
|
||||
return self._negation_str(config)
|
||||
else:
|
||||
if self.description is not None:
|
||||
return self._format_description(config)
|
||||
else:
|
||||
return self._eval_str(config)
|
||||
|
||||
|
||||
_as_predicate = Predicate.as_predicate
|
||||
|
||||
|
||||
def _is_excluded(db, op, spec):
|
||||
return SpecPredicate(db, op, spec)(config._current)
|
||||
|
||||
|
||||
def _server_version(engine):
|
||||
"""Return a server_version_info tuple."""
|
||||
|
||||
# force metadata to be retrieved
|
||||
conn = engine.connect()
|
||||
version = getattr(engine.dialect, 'server_version_info', ())
|
||||
conn.close()
|
||||
return version
|
||||
|
||||
|
||||
def db_spec(*dbs):
|
||||
return OrPredicate(
|
||||
[Predicate.as_predicate(db) for db in dbs]
|
||||
)
|
||||
|
||||
|
||||
def open():
|
||||
return skip_if(BooleanPredicate(False, "mark as execute"))
|
||||
|
||||
|
||||
def closed():
|
||||
return skip_if(BooleanPredicate(True, "marked as skip"))
|
||||
|
||||
|
||||
def fails(reason=None):
|
||||
return fails_if(BooleanPredicate(True, reason or "expected to fail"))
|
||||
|
||||
|
||||
@decorator
|
||||
def future(fn, *arg):
|
||||
return fails_if(LambdaPredicate(fn), "Future feature")
|
||||
|
||||
|
||||
def fails_on(db, reason=None):
|
||||
return fails_if(Predicate.as_predicate(db), reason)
|
||||
|
||||
|
||||
def fails_on_everything_except(*dbs):
|
||||
return succeeds_if(
|
||||
OrPredicate([
|
||||
Predicate.as_predicate(db) for db in dbs
|
||||
])
|
||||
)
|
||||
|
||||
|
||||
def skip(db, reason=None):
|
||||
return skip_if(Predicate.as_predicate(db), reason)
|
||||
|
||||
|
||||
def only_on(dbs, reason=None):
|
||||
return only_if(
|
||||
OrPredicate([Predicate.as_predicate(db) for db in util.to_list(dbs)])
|
||||
)
|
||||
|
||||
|
||||
def exclude(db, op, spec, reason=None):
|
||||
return skip_if(SpecPredicate(db, op, spec), reason)
|
||||
|
||||
|
||||
def against(config, *queries):
|
||||
assert queries, "no queries sent!"
|
||||
return OrPredicate([
|
||||
Predicate.as_predicate(query)
|
||||
for query in queries
|
||||
])(config)
|
386
sqlalchemy/testing/fixtures.py
Normal file
386
sqlalchemy/testing/fixtures.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# testing/fixtures.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from . import config
|
||||
from . import assertions, schema
|
||||
from .util import adict
|
||||
from .. import util
|
||||
from .engines import drop_all_tables
|
||||
from .entities import BasicEntity, ComparableEntity
|
||||
import sys
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
|
||||
|
||||
# whether or not we use unittest changes things dramatically,
|
||||
# as far as how py.test collection works.
|
||||
|
||||
|
||||
class TestBase(object):
|
||||
# A sequence of database names to always run, regardless of the
|
||||
# constraints below.
|
||||
__whitelist__ = ()
|
||||
|
||||
# A sequence of requirement names matching testing.requires decorators
|
||||
__requires__ = ()
|
||||
|
||||
# A sequence of dialect names to exclude from the test class.
|
||||
__unsupported_on__ = ()
|
||||
|
||||
# If present, test class is only runnable for the *single* specified
|
||||
# dialect. If you need multiple, use __unsupported_on__ and invert.
|
||||
__only_on__ = None
|
||||
|
||||
# A sequence of no-arg callables. If any are True, the entire testcase is
|
||||
# skipped.
|
||||
__skip_if__ = None
|
||||
|
||||
def assert_(self, val, msg=None):
|
||||
assert val, msg
|
||||
|
||||
# apparently a handful of tests are doing this....OK
|
||||
def setup(self):
|
||||
if hasattr(self, "setUp"):
|
||||
self.setUp()
|
||||
|
||||
def teardown(self):
|
||||
if hasattr(self, "tearDown"):
|
||||
self.tearDown()
|
||||
|
||||
|
||||
class TablesTest(TestBase):
|
||||
|
||||
# 'once', None
|
||||
run_setup_bind = 'once'
|
||||
|
||||
# 'once', 'each', None
|
||||
run_define_tables = 'once'
|
||||
|
||||
# 'once', 'each', None
|
||||
run_create_tables = 'once'
|
||||
|
||||
# 'once', 'each', None
|
||||
run_inserts = 'each'
|
||||
|
||||
# 'each', None
|
||||
run_deletes = 'each'
|
||||
|
||||
# 'once', None
|
||||
run_dispose_bind = None
|
||||
|
||||
bind = None
|
||||
metadata = None
|
||||
tables = None
|
||||
other = None
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls._init_class()
|
||||
|
||||
cls._setup_once_tables()
|
||||
|
||||
cls._setup_once_inserts()
|
||||
|
||||
@classmethod
|
||||
def _init_class(cls):
|
||||
if cls.run_define_tables == 'each':
|
||||
if cls.run_create_tables == 'once':
|
||||
cls.run_create_tables = 'each'
|
||||
assert cls.run_inserts in ('each', None)
|
||||
|
||||
cls.other = adict()
|
||||
cls.tables = adict()
|
||||
|
||||
cls.bind = cls.setup_bind()
|
||||
cls.metadata = sa.MetaData()
|
||||
cls.metadata.bind = cls.bind
|
||||
|
||||
@classmethod
|
||||
def _setup_once_inserts(cls):
|
||||
if cls.run_inserts == 'once':
|
||||
cls._load_fixtures()
|
||||
cls.insert_data()
|
||||
|
||||
@classmethod
|
||||
def _setup_once_tables(cls):
|
||||
if cls.run_define_tables == 'once':
|
||||
cls.define_tables(cls.metadata)
|
||||
if cls.run_create_tables == 'once':
|
||||
cls.metadata.create_all(cls.bind)
|
||||
cls.tables.update(cls.metadata.tables)
|
||||
|
||||
def _setup_each_tables(self):
|
||||
if self.run_define_tables == 'each':
|
||||
self.tables.clear()
|
||||
if self.run_create_tables == 'each':
|
||||
drop_all_tables(self.metadata, self.bind)
|
||||
self.metadata.clear()
|
||||
self.define_tables(self.metadata)
|
||||
if self.run_create_tables == 'each':
|
||||
self.metadata.create_all(self.bind)
|
||||
self.tables.update(self.metadata.tables)
|
||||
elif self.run_create_tables == 'each':
|
||||
drop_all_tables(self.metadata, self.bind)
|
||||
self.metadata.create_all(self.bind)
|
||||
|
||||
def _setup_each_inserts(self):
|
||||
if self.run_inserts == 'each':
|
||||
self._load_fixtures()
|
||||
self.insert_data()
|
||||
|
||||
def _teardown_each_tables(self):
|
||||
# no need to run deletes if tables are recreated on setup
|
||||
if self.run_define_tables != 'each' and self.run_deletes == 'each':
|
||||
with self.bind.connect() as conn:
|
||||
for table in reversed(self.metadata.sorted_tables):
|
||||
try:
|
||||
conn.execute(table.delete())
|
||||
except sa.exc.DBAPIError as ex:
|
||||
util.print_(
|
||||
("Error emptying table %s: %r" % (table, ex)),
|
||||
file=sys.stderr)
|
||||
|
||||
def setup(self):
|
||||
self._setup_each_tables()
|
||||
self._setup_each_inserts()
|
||||
|
||||
def teardown(self):
|
||||
self._teardown_each_tables()
|
||||
|
||||
@classmethod
|
||||
def _teardown_once_metadata_bind(cls):
|
||||
if cls.run_create_tables:
|
||||
drop_all_tables(cls.metadata, cls.bind)
|
||||
|
||||
if cls.run_dispose_bind == 'once':
|
||||
cls.dispose_bind(cls.bind)
|
||||
|
||||
cls.metadata.bind = None
|
||||
|
||||
if cls.run_setup_bind is not None:
|
||||
cls.bind = None
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
cls._teardown_once_metadata_bind()
|
||||
|
||||
@classmethod
|
||||
def setup_bind(cls):
|
||||
return config.db
|
||||
|
||||
@classmethod
|
||||
def dispose_bind(cls, bind):
|
||||
if hasattr(bind, 'dispose'):
|
||||
bind.dispose()
|
||||
elif hasattr(bind, 'close'):
|
||||
bind.close()
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def fixtures(cls):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls):
|
||||
pass
|
||||
|
||||
def sql_count_(self, count, fn):
|
||||
self.assert_sql_count(self.bind, fn, count)
|
||||
|
||||
def sql_eq_(self, callable_, statements):
|
||||
self.assert_sql(self.bind, callable_, statements)
|
||||
|
||||
@classmethod
|
||||
def _load_fixtures(cls):
|
||||
"""Insert rows as represented by the fixtures() method."""
|
||||
headers, rows = {}, {}
|
||||
for table, data in cls.fixtures().items():
|
||||
if len(data) < 2:
|
||||
continue
|
||||
if isinstance(table, util.string_types):
|
||||
table = cls.tables[table]
|
||||
headers[table] = data[0]
|
||||
rows[table] = data[1:]
|
||||
for table in cls.metadata.sorted_tables:
|
||||
if table not in headers:
|
||||
continue
|
||||
cls.bind.execute(
|
||||
table.insert(),
|
||||
[dict(zip(headers[table], column_values))
|
||||
for column_values in rows[table]])
|
||||
|
||||
from sqlalchemy import event
|
||||
|
||||
|
||||
class RemovesEvents(object):
|
||||
@util.memoized_property
|
||||
def _event_fns(self):
|
||||
return set()
|
||||
|
||||
def event_listen(self, target, name, fn):
|
||||
self._event_fns.add((target, name, fn))
|
||||
event.listen(target, name, fn)
|
||||
|
||||
def teardown(self):
|
||||
for key in self._event_fns:
|
||||
event.remove(*key)
|
||||
super_ = super(RemovesEvents, self)
|
||||
if hasattr(super_, "teardown"):
|
||||
super_.teardown()
|
||||
|
||||
|
||||
class _ORMTest(object):
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
sa.orm.session.Session.close_all()
|
||||
sa.orm.clear_mappers()
|
||||
|
||||
|
||||
class ORMTest(_ORMTest, TestBase):
|
||||
pass
|
||||
|
||||
|
||||
class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
|
||||
# 'once', 'each', None
|
||||
run_setup_classes = 'once'
|
||||
|
||||
# 'once', 'each', None
|
||||
run_setup_mappers = 'each'
|
||||
|
||||
classes = None
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls._init_class()
|
||||
|
||||
if cls.classes is None:
|
||||
cls.classes = adict()
|
||||
|
||||
cls._setup_once_tables()
|
||||
cls._setup_once_classes()
|
||||
cls._setup_once_mappers()
|
||||
cls._setup_once_inserts()
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
cls._teardown_once_class()
|
||||
cls._teardown_once_metadata_bind()
|
||||
|
||||
def setup(self):
|
||||
self._setup_each_tables()
|
||||
self._setup_each_classes()
|
||||
self._setup_each_mappers()
|
||||
self._setup_each_inserts()
|
||||
|
||||
def teardown(self):
|
||||
sa.orm.session.Session.close_all()
|
||||
self._teardown_each_mappers()
|
||||
self._teardown_each_classes()
|
||||
self._teardown_each_tables()
|
||||
|
||||
@classmethod
|
||||
def _teardown_once_class(cls):
|
||||
cls.classes.clear()
|
||||
_ORMTest.teardown_class()
|
||||
|
||||
@classmethod
|
||||
def _setup_once_classes(cls):
|
||||
if cls.run_setup_classes == 'once':
|
||||
cls._with_register_classes(cls.setup_classes)
|
||||
|
||||
@classmethod
|
||||
def _setup_once_mappers(cls):
|
||||
if cls.run_setup_mappers == 'once':
|
||||
cls._with_register_classes(cls.setup_mappers)
|
||||
|
||||
def _setup_each_mappers(self):
|
||||
if self.run_setup_mappers == 'each':
|
||||
self._with_register_classes(self.setup_mappers)
|
||||
|
||||
def _setup_each_classes(self):
|
||||
if self.run_setup_classes == 'each':
|
||||
self._with_register_classes(self.setup_classes)
|
||||
|
||||
@classmethod
|
||||
def _with_register_classes(cls, fn):
|
||||
"""Run a setup method, framing the operation with a Base class
|
||||
that will catch new subclasses to be established within
|
||||
the "classes" registry.
|
||||
|
||||
"""
|
||||
cls_registry = cls.classes
|
||||
|
||||
class FindFixture(type):
|
||||
def __init__(cls, classname, bases, dict_):
|
||||
cls_registry[classname] = cls
|
||||
return type.__init__(cls, classname, bases, dict_)
|
||||
|
||||
class _Base(util.with_metaclass(FindFixture, object)):
|
||||
pass
|
||||
|
||||
class Basic(BasicEntity, _Base):
|
||||
pass
|
||||
|
||||
class Comparable(ComparableEntity, _Base):
|
||||
pass
|
||||
|
||||
cls.Basic = Basic
|
||||
cls.Comparable = Comparable
|
||||
fn()
|
||||
|
||||
def _teardown_each_mappers(self):
|
||||
# some tests create mappers in the test bodies
|
||||
# and will define setup_mappers as None -
|
||||
# clear mappers in any case
|
||||
if self.run_setup_mappers != 'once':
|
||||
sa.orm.clear_mappers()
|
||||
|
||||
def _teardown_each_classes(self):
|
||||
if self.run_setup_classes != 'once':
|
||||
self.classes.clear()
|
||||
|
||||
@classmethod
|
||||
def setup_classes(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def setup_mappers(cls):
|
||||
pass
|
||||
|
||||
|
||||
class DeclarativeMappedTest(MappedTest):
|
||||
run_setup_classes = 'once'
|
||||
run_setup_mappers = 'once'
|
||||
|
||||
@classmethod
|
||||
def _setup_once_tables(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _with_register_classes(cls, fn):
|
||||
cls_registry = cls.classes
|
||||
|
||||
class FindFixtureDeclarative(DeclarativeMeta):
|
||||
def __init__(cls, classname, bases, dict_):
|
||||
cls_registry[classname] = cls
|
||||
return DeclarativeMeta.__init__(
|
||||
cls, classname, bases, dict_)
|
||||
|
||||
class DeclarativeBasic(object):
|
||||
__table_cls__ = schema.Table
|
||||
|
||||
_DeclBase = declarative_base(metadata=cls.metadata,
|
||||
metaclass=FindFixtureDeclarative,
|
||||
cls=DeclarativeBasic)
|
||||
cls.DeclarativeBasic = _DeclBase
|
||||
fn()
|
||||
|
||||
if cls.metadata.tables and cls.run_create_tables:
|
||||
cls.metadata.create_all(config.db)
|
21
sqlalchemy/testing/mock.py
Normal file
21
sqlalchemy/testing/mock.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# testing/mock.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Import stub for mock library.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from ..util import py33
|
||||
|
||||
if py33:
|
||||
from unittest.mock import MagicMock, Mock, call, patch, ANY
|
||||
else:
|
||||
try:
|
||||
from mock import MagicMock, Mock, call, patch, ANY
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"SQLAlchemy's test suite requires the "
|
||||
"'mock' library as of 0.8.2.")
|
143
sqlalchemy/testing/pickleable.py
Normal file
143
sqlalchemy/testing/pickleable.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# testing/pickleable.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Classes used in pickling tests, need to be at the module level for
|
||||
unpickling.
|
||||
"""
|
||||
|
||||
from . import fixtures
|
||||
|
||||
|
||||
class User(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Order(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Dingaling(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class EmailUser(User):
|
||||
pass
|
||||
|
||||
|
||||
class Address(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: these are kind of arbitrary....
|
||||
class Child1(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Child2(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Parent(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Screen(object):
|
||||
|
||||
def __init__(self, obj, parent=None):
|
||||
self.obj = obj
|
||||
self.parent = parent
|
||||
|
||||
|
||||
class Foo(object):
|
||||
|
||||
def __init__(self, moredata):
|
||||
self.data = 'im data'
|
||||
self.stuff = 'im stuff'
|
||||
self.moredata = moredata
|
||||
|
||||
__hash__ = object.__hash__
|
||||
|
||||
def __eq__(self, other):
|
||||
return other.data == self.data and \
|
||||
other.stuff == self.stuff and \
|
||||
other.moredata == self.moredata
|
||||
|
||||
|
||||
class Bar(object):
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
__hash__ = object.__hash__
|
||||
|
||||
def __eq__(self, other):
|
||||
return other.__class__ is self.__class__ and \
|
||||
other.x == self.x and \
|
||||
other.y == self.y
|
||||
|
||||
def __str__(self):
|
||||
return "Bar(%d, %d)" % (self.x, self.y)
|
||||
|
||||
|
||||
class OldSchool:
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __eq__(self, other):
|
||||
return other.__class__ is self.__class__ and \
|
||||
other.x == self.x and \
|
||||
other.y == self.y
|
||||
|
||||
|
||||
class OldSchoolWithoutCompare:
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
|
||||
class BarWithoutCompare(object):
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __str__(self):
|
||||
return "Bar(%d, %d)" % (self.x, self.y)
|
||||
|
||||
|
||||
class NotComparable(object):
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class BrokenComparable(object):
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
raise NotImplementedError
|
||||
|
||||
def __ne__(self, other):
|
||||
raise NotImplementedError
|
0
sqlalchemy/testing/plugin/__init__.py
Normal file
0
sqlalchemy/testing/plugin/__init__.py
Normal file
44
sqlalchemy/testing/plugin/bootstrap.py
Normal file
44
sqlalchemy/testing/plugin/bootstrap.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Bootstrapper for nose/pytest plugins.
|
||||
|
||||
The entire rationale for this system is to get the modules in plugin/
|
||||
imported without importing all of the supporting library, so that we can
|
||||
set up things for testing before coverage starts.
|
||||
|
||||
The rationale for all of plugin/ being *in* the supporting library in the
|
||||
first place is so that the testing and plugin suite is available to other
|
||||
libraries, mainly external SQLAlchemy and Alembic dialects, to make use
|
||||
of the same test environment and standard suites available to
|
||||
SQLAlchemy/Alembic themselves without the need to ship/install a separate
|
||||
package outside of SQLAlchemy.
|
||||
|
||||
NOTE: copied/adapted from SQLAlchemy master for backwards compatibility;
|
||||
this should be removable when Alembic targets SQLAlchemy 1.0.0.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
bootstrap_file = locals()['bootstrap_file']
|
||||
to_bootstrap = locals()['to_bootstrap']
|
||||
|
||||
|
||||
def load_file_as_module(name):
|
||||
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
|
||||
if sys.version_info >= (3, 3):
|
||||
from importlib import machinery
|
||||
mod = machinery.SourceFileLoader(name, path).load_module()
|
||||
else:
|
||||
import imp
|
||||
mod = imp.load_source(name, path)
|
||||
return mod
|
||||
|
||||
if to_bootstrap == "pytest":
|
||||
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
|
||||
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
|
||||
elif to_bootstrap == "nose":
|
||||
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
|
||||
sys.modules["sqla_noseplugin"] = load_file_as_module("noseplugin")
|
||||
else:
|
||||
raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
|
107
sqlalchemy/testing/plugin/noseplugin.py
Normal file
107
sqlalchemy/testing/plugin/noseplugin.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# plugin/noseplugin.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Enhance nose with extra options and behaviors for running SQLAlchemy tests.
|
||||
|
||||
Must be run via ./sqla_nose.py so that it is imported in the expected
|
||||
way (e.g. as a package-less import).
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
# installed by bootstrap.py
|
||||
import sqla_plugin_base as plugin_base
|
||||
except ImportError:
|
||||
# assume we're a package, use traditional import
|
||||
from . import plugin_base
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from nose.plugins import Plugin
|
||||
import nose
|
||||
fixtures = None
|
||||
|
||||
py3k = sys.version_info >= (3, 0)
|
||||
|
||||
|
||||
class NoseSQLAlchemy(Plugin):
|
||||
enabled = True
|
||||
|
||||
name = 'sqla_testing'
|
||||
score = 100
|
||||
|
||||
def options(self, parser, env=os.environ):
|
||||
Plugin.options(self, parser, env)
|
||||
opt = parser.add_option
|
||||
|
||||
def make_option(name, **kw):
|
||||
callback_ = kw.pop("callback", None)
|
||||
if callback_:
|
||||
def wrap_(option, opt_str, value, parser):
|
||||
callback_(opt_str, value, parser)
|
||||
kw["callback"] = wrap_
|
||||
opt(name, **kw)
|
||||
|
||||
plugin_base.setup_options(make_option)
|
||||
plugin_base.read_config()
|
||||
|
||||
def configure(self, options, conf):
|
||||
super(NoseSQLAlchemy, self).configure(options, conf)
|
||||
plugin_base.pre_begin(options)
|
||||
|
||||
plugin_base.set_coverage_flag(options.enable_plugin_coverage)
|
||||
|
||||
plugin_base.set_skip_test(nose.SkipTest)
|
||||
|
||||
def begin(self):
|
||||
global fixtures
|
||||
from sqlalchemy.testing import fixtures # noqa
|
||||
|
||||
plugin_base.post_begin()
|
||||
|
||||
def describeTest(self, test):
|
||||
return ""
|
||||
|
||||
def wantFunction(self, fn):
|
||||
return False
|
||||
|
||||
def wantMethod(self, fn):
|
||||
if py3k:
|
||||
if not hasattr(fn.__self__, 'cls'):
|
||||
return False
|
||||
cls = fn.__self__.cls
|
||||
else:
|
||||
cls = fn.im_class
|
||||
return plugin_base.want_method(cls, fn)
|
||||
|
||||
def wantClass(self, cls):
|
||||
return plugin_base.want_class(cls)
|
||||
|
||||
def beforeTest(self, test):
|
||||
if not hasattr(test.test, 'cls'):
|
||||
return
|
||||
plugin_base.before_test(
|
||||
test,
|
||||
test.test.cls.__module__,
|
||||
test.test.cls, test.test.method.__name__)
|
||||
|
||||
def afterTest(self, test):
|
||||
plugin_base.after_test(test)
|
||||
|
||||
def startContext(self, ctx):
|
||||
if not isinstance(ctx, type) \
|
||||
or not issubclass(ctx, fixtures.TestBase):
|
||||
return
|
||||
plugin_base.start_test_class(ctx)
|
||||
|
||||
def stopContext(self, ctx):
|
||||
if not isinstance(ctx, type) \
|
||||
or not issubclass(ctx, fixtures.TestBase):
|
||||
return
|
||||
plugin_base.stop_test_class(ctx)
|
565
sqlalchemy/testing/plugin/plugin_base.py
Normal file
565
sqlalchemy/testing/plugin/plugin_base.py
Normal file
@@ -0,0 +1,565 @@
|
||||
# plugin/plugin_base.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Testing extensions.
|
||||
|
||||
this module is designed to work as a testing-framework-agnostic library,
|
||||
so that we can continue to support nose and also begin adding new
|
||||
functionality via py.test.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import sys
|
||||
import re
|
||||
|
||||
py3k = sys.version_info >= (3, 0)
|
||||
|
||||
if py3k:
|
||||
import configparser
|
||||
else:
|
||||
import ConfigParser as configparser
|
||||
|
||||
# late imports
|
||||
fixtures = None
|
||||
engines = None
|
||||
exclusions = None
|
||||
warnings = None
|
||||
profiling = None
|
||||
assertions = None
|
||||
requirements = None
|
||||
config = None
|
||||
testing = None
|
||||
util = None
|
||||
file_config = None
|
||||
|
||||
|
||||
logging = None
|
||||
include_tags = set()
|
||||
exclude_tags = set()
|
||||
options = None
|
||||
|
||||
|
||||
def setup_options(make_option):
|
||||
make_option("--log-info", action="callback", type="string", callback=_log,
|
||||
help="turn on info logging for <LOG> (multiple OK)")
|
||||
make_option("--log-debug", action="callback",
|
||||
type="string", callback=_log,
|
||||
help="turn on debug logging for <LOG> (multiple OK)")
|
||||
make_option("--db", action="append", type="string", dest="db",
|
||||
help="Use prefab database uri. Multiple OK, "
|
||||
"first one is run by default.")
|
||||
make_option('--dbs', action='callback', callback=_list_dbs,
|
||||
help="List available prefab dbs")
|
||||
make_option("--dburi", action="append", type="string", dest="dburi",
|
||||
help="Database uri. Multiple OK, "
|
||||
"first one is run by default.")
|
||||
make_option("--dropfirst", action="store_true", dest="dropfirst",
|
||||
help="Drop all tables in the target database first")
|
||||
make_option("--backend-only", action="store_true", dest="backend_only",
|
||||
help="Run only tests marked with __backend__")
|
||||
make_option("--low-connections", action="store_true",
|
||||
dest="low_connections",
|
||||
help="Use a low number of distinct connections - "
|
||||
"i.e. for Oracle TNS")
|
||||
make_option("--write-idents", type="string", dest="write_idents",
|
||||
help="write out generated follower idents to <file>, "
|
||||
"when -n<num> is used")
|
||||
make_option("--reversetop", action="store_true",
|
||||
dest="reversetop", default=False,
|
||||
help="Use a random-ordering set implementation in the ORM "
|
||||
"(helps reveal dependency issues)")
|
||||
make_option("--requirements", action="callback", type="string",
|
||||
callback=_requirements_opt,
|
||||
help="requirements class for testing, overrides setup.cfg")
|
||||
make_option("--with-cdecimal", action="store_true",
|
||||
dest="cdecimal", default=False,
|
||||
help="Monkeypatch the cdecimal library into Python 'decimal' "
|
||||
"for all tests")
|
||||
make_option("--include-tag", action="callback", callback=_include_tag,
|
||||
type="string",
|
||||
help="Include tests with tag <tag>")
|
||||
make_option("--exclude-tag", action="callback", callback=_exclude_tag,
|
||||
type="string",
|
||||
help="Exclude tests with tag <tag>")
|
||||
make_option("--write-profiles", action="store_true",
|
||||
dest="write_profiles", default=False,
|
||||
help="Write/update failing profiling data.")
|
||||
make_option("--force-write-profiles", action="store_true",
|
||||
dest="force_write_profiles", default=False,
|
||||
help="Unconditionally write/update profiling data.")
|
||||
|
||||
|
||||
def configure_follower(follower_ident):
|
||||
"""Configure required state for a follower.
|
||||
|
||||
This invokes in the parent process and typically includes
|
||||
database creation.
|
||||
|
||||
"""
|
||||
from sqlalchemy.testing import provision
|
||||
provision.FOLLOWER_IDENT = follower_ident
|
||||
|
||||
|
||||
def memoize_important_follower_config(dict_):
|
||||
"""Store important configuration we will need to send to a follower.
|
||||
|
||||
This invokes in the parent process after normal config is set up.
|
||||
|
||||
This is necessary as py.test seems to not be using forking, so we
|
||||
start with nothing in memory, *but* it isn't running our argparse
|
||||
callables, so we have to just copy all of that over.
|
||||
|
||||
"""
|
||||
dict_['memoized_config'] = {
|
||||
'include_tags': include_tags,
|
||||
'exclude_tags': exclude_tags
|
||||
}
|
||||
|
||||
|
||||
def restore_important_follower_config(dict_):
|
||||
"""Restore important configuration needed by a follower.
|
||||
|
||||
This invokes in the follower process.
|
||||
|
||||
"""
|
||||
global include_tags, exclude_tags
|
||||
include_tags.update(dict_['memoized_config']['include_tags'])
|
||||
exclude_tags.update(dict_['memoized_config']['exclude_tags'])
|
||||
|
||||
|
||||
def read_config():
|
||||
global file_config
|
||||
file_config = configparser.ConfigParser()
|
||||
file_config.read(['setup.cfg', 'test.cfg'])
|
||||
|
||||
|
||||
def pre_begin(opt):
|
||||
"""things to set up early, before coverage might be setup."""
|
||||
global options
|
||||
options = opt
|
||||
for fn in pre_configure:
|
||||
fn(options, file_config)
|
||||
|
||||
|
||||
def set_coverage_flag(value):
|
||||
options.has_coverage = value
|
||||
|
||||
_skip_test_exception = None
|
||||
|
||||
|
||||
def set_skip_test(exc):
|
||||
global _skip_test_exception
|
||||
_skip_test_exception = exc
|
||||
|
||||
|
||||
def post_begin():
|
||||
"""things to set up later, once we know coverage is running."""
|
||||
# Lazy setup of other options (post coverage)
|
||||
for fn in post_configure:
|
||||
fn(options, file_config)
|
||||
|
||||
# late imports, has to happen after config as well
|
||||
# as nose plugins like coverage
|
||||
global util, fixtures, engines, exclusions, \
|
||||
assertions, warnings, profiling,\
|
||||
config, testing
|
||||
from sqlalchemy import testing # noqa
|
||||
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
|
||||
from sqlalchemy.testing import assertions, warnings, profiling # noqa
|
||||
from sqlalchemy.testing import config # noqa
|
||||
from sqlalchemy import util # noqa
|
||||
warnings.setup_filters()
|
||||
|
||||
|
||||
|
||||
def _log(opt_str, value, parser):
|
||||
global logging
|
||||
if not logging:
|
||||
import logging
|
||||
logging.basicConfig()
|
||||
|
||||
if opt_str.endswith('-info'):
|
||||
logging.getLogger(value).setLevel(logging.INFO)
|
||||
elif opt_str.endswith('-debug'):
|
||||
logging.getLogger(value).setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def _list_dbs(*args):
|
||||
print("Available --db options (use --dburi to override)")
|
||||
for macro in sorted(file_config.options('db')):
|
||||
print("%20s\t%s" % (macro, file_config.get('db', macro)))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _requirements_opt(opt_str, value, parser):
|
||||
_setup_requirements(value)
|
||||
|
||||
|
||||
def _exclude_tag(opt_str, value, parser):
|
||||
exclude_tags.add(value.replace('-', '_'))
|
||||
|
||||
|
||||
def _include_tag(opt_str, value, parser):
|
||||
include_tags.add(value.replace('-', '_'))
|
||||
|
||||
pre_configure = []
|
||||
post_configure = []
|
||||
|
||||
|
||||
def pre(fn):
|
||||
pre_configure.append(fn)
|
||||
return fn
|
||||
|
||||
|
||||
def post(fn):
|
||||
post_configure.append(fn)
|
||||
return fn
|
||||
|
||||
|
||||
@pre
|
||||
def _setup_options(opt, file_config):
|
||||
global options
|
||||
options = opt
|
||||
|
||||
|
||||
@pre
|
||||
def _monkeypatch_cdecimal(options, file_config):
|
||||
if options.cdecimal:
|
||||
import cdecimal
|
||||
sys.modules['decimal'] = cdecimal
|
||||
|
||||
|
||||
@post
|
||||
def _init_skiptest(options, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
config._skip_test_exception = _skip_test_exception
|
||||
|
||||
|
||||
@post
|
||||
def _engine_uri(options, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy.testing import provision
|
||||
|
||||
if options.dburi:
|
||||
db_urls = list(options.dburi)
|
||||
else:
|
||||
db_urls = []
|
||||
|
||||
if options.db:
|
||||
for db_token in options.db:
|
||||
for db in re.split(r'[,\s]+', db_token):
|
||||
if db not in file_config.options('db'):
|
||||
raise RuntimeError(
|
||||
"Unknown URI specifier '%s'. "
|
||||
"Specify --dbs for known uris."
|
||||
% db)
|
||||
else:
|
||||
db_urls.append(file_config.get('db', db))
|
||||
|
||||
if not db_urls:
|
||||
db_urls.append(file_config.get('db', 'default'))
|
||||
|
||||
config._current = None
|
||||
for db_url in db_urls:
|
||||
cfg = provision.setup_config(
|
||||
db_url, options, file_config, provision.FOLLOWER_IDENT)
|
||||
|
||||
if not config._current:
|
||||
cfg.set_as_current(cfg, testing)
|
||||
|
||||
|
||||
@post
|
||||
def _requirements(options, file_config):
|
||||
|
||||
requirement_cls = file_config.get('sqla_testing', "requirement_cls")
|
||||
_setup_requirements(requirement_cls)
|
||||
|
||||
|
||||
def _setup_requirements(argument):
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy import testing
|
||||
|
||||
if config.requirements is not None:
|
||||
return
|
||||
|
||||
modname, clsname = argument.split(":")
|
||||
|
||||
# importlib.import_module() only introduced in 2.7, a little
|
||||
# late
|
||||
mod = __import__(modname)
|
||||
for component in modname.split(".")[1:]:
|
||||
mod = getattr(mod, component)
|
||||
req_cls = getattr(mod, clsname)
|
||||
|
||||
config.requirements = testing.requires = req_cls()
|
||||
|
||||
|
||||
@post
|
||||
def _prep_testing_database(options, file_config):
|
||||
from sqlalchemy.testing import config, util
|
||||
from sqlalchemy.testing.exclusions import against
|
||||
from sqlalchemy import schema, inspect
|
||||
|
||||
if options.dropfirst:
|
||||
for cfg in config.Config.all_configs():
|
||||
e = cfg.db
|
||||
inspector = inspect(e)
|
||||
try:
|
||||
view_names = inspector.get_view_names()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
for vname in view_names:
|
||||
e.execute(schema._DropView(
|
||||
schema.Table(vname, schema.MetaData())
|
||||
))
|
||||
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
try:
|
||||
view_names = inspector.get_view_names(
|
||||
schema="test_schema")
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
for vname in view_names:
|
||||
e.execute(schema._DropView(
|
||||
schema.Table(vname, schema.MetaData(),
|
||||
schema="test_schema")
|
||||
))
|
||||
|
||||
util.drop_all_tables(e, inspector)
|
||||
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
util.drop_all_tables(e, inspector, schema=cfg.test_schema)
|
||||
|
||||
if against(cfg, "postgresql"):
|
||||
from sqlalchemy.dialects import postgresql
|
||||
for enum in inspector.get_enums("*"):
|
||||
e.execute(postgresql.DropEnumType(
|
||||
postgresql.ENUM(
|
||||
name=enum['name'],
|
||||
schema=enum['schema'])))
|
||||
|
||||
|
||||
@post
|
||||
def _reverse_topological(options, file_config):
|
||||
if options.reversetop:
|
||||
from sqlalchemy.orm.util import randomize_unitofwork
|
||||
randomize_unitofwork()
|
||||
|
||||
|
||||
@post
|
||||
def _post_setup_options(opt, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
config.options = options
|
||||
config.file_config = file_config
|
||||
|
||||
|
||||
@post
|
||||
def _setup_profiling(options, file_config):
|
||||
from sqlalchemy.testing import profiling
|
||||
profiling._profile_stats = profiling.ProfileStatsFile(
|
||||
file_config.get('sqla_testing', 'profile_file'))
|
||||
|
||||
|
||||
def want_class(cls):
|
||||
if not issubclass(cls, fixtures.TestBase):
|
||||
return False
|
||||
elif cls.__name__.startswith('_'):
|
||||
return False
|
||||
elif config.options.backend_only and not getattr(cls, '__backend__',
|
||||
False):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def want_method(cls, fn):
|
||||
if not fn.__name__.startswith("test_"):
|
||||
return False
|
||||
elif fn.__module__ is None:
|
||||
return False
|
||||
elif include_tags:
|
||||
return (
|
||||
hasattr(cls, '__tags__') and
|
||||
exclusions.tags(cls.__tags__).include_test(
|
||||
include_tags, exclude_tags)
|
||||
) or (
|
||||
hasattr(fn, '_sa_exclusion_extend') and
|
||||
fn._sa_exclusion_extend.include_test(
|
||||
include_tags, exclude_tags)
|
||||
)
|
||||
elif exclude_tags and hasattr(cls, '__tags__'):
|
||||
return exclusions.tags(cls.__tags__).include_test(
|
||||
include_tags, exclude_tags)
|
||||
elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
|
||||
return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def generate_sub_tests(cls, module):
|
||||
if getattr(cls, '__backend__', False):
|
||||
for cfg in _possible_configs_for_cls(cls):
|
||||
name = "%s_%s_%s" % (cls.__name__, cfg.db.name, cfg.db.driver)
|
||||
subcls = type(
|
||||
name,
|
||||
(cls, ),
|
||||
{
|
||||
"__only_on__": ("%s+%s" % (cfg.db.name, cfg.db.driver)),
|
||||
}
|
||||
)
|
||||
setattr(module, name, subcls)
|
||||
yield subcls
|
||||
else:
|
||||
yield cls
|
||||
|
||||
|
||||
def start_test_class(cls):
|
||||
_do_skips(cls)
|
||||
_setup_engine(cls)
|
||||
|
||||
|
||||
def stop_test_class(cls):
|
||||
#from sqlalchemy import inspect
|
||||
#assert not inspect(testing.db).get_table_names()
|
||||
engines.testing_reaper._stop_test_ctx()
|
||||
try:
|
||||
if not options.low_connections:
|
||||
assertions.global_cleanup_assertions()
|
||||
finally:
|
||||
_restore_engine()
|
||||
|
||||
|
||||
def _restore_engine():
|
||||
config._current.reset(testing)
|
||||
|
||||
|
||||
def final_process_cleanup():
|
||||
engines.testing_reaper._stop_test_ctx_aggressive()
|
||||
assertions.global_cleanup_assertions()
|
||||
_restore_engine()
|
||||
|
||||
|
||||
def _setup_engine(cls):
|
||||
if getattr(cls, '__engine_options__', None):
|
||||
eng = engines.testing_engine(options=cls.__engine_options__)
|
||||
config._current.push_engine(eng, testing)
|
||||
|
||||
|
||||
def before_test(test, test_module_name, test_class, test_name):
|
||||
|
||||
# like a nose id, e.g.:
|
||||
# "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
|
||||
name = test_class.__name__
|
||||
|
||||
suffix = "_%s_%s" % (config.db.name, config.db.driver)
|
||||
if name.endswith(suffix):
|
||||
name = name[0:-(len(suffix))]
|
||||
|
||||
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
|
||||
|
||||
profiling._current_test = id_
|
||||
|
||||
|
||||
def after_test(test):
|
||||
engines.testing_reaper._after_test_ctx()
|
||||
|
||||
|
||||
def _possible_configs_for_cls(cls, reasons=None):
|
||||
all_configs = set(config.Config.all_configs())
|
||||
|
||||
if cls.__unsupported_on__:
|
||||
spec = exclusions.db_spec(*cls.__unsupported_on__)
|
||||
for config_obj in list(all_configs):
|
||||
if spec(config_obj):
|
||||
all_configs.remove(config_obj)
|
||||
|
||||
if getattr(cls, '__only_on__', None):
|
||||
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
|
||||
for config_obj in list(all_configs):
|
||||
if not spec(config_obj):
|
||||
all_configs.remove(config_obj)
|
||||
|
||||
if hasattr(cls, '__requires__'):
|
||||
requirements = config.requirements
|
||||
for config_obj in list(all_configs):
|
||||
for requirement in cls.__requires__:
|
||||
check = getattr(requirements, requirement)
|
||||
|
||||
skip_reasons = check.matching_config_reasons(config_obj)
|
||||
if skip_reasons:
|
||||
all_configs.remove(config_obj)
|
||||
if reasons is not None:
|
||||
reasons.extend(skip_reasons)
|
||||
break
|
||||
|
||||
if hasattr(cls, '__prefer_requires__'):
|
||||
non_preferred = set()
|
||||
requirements = config.requirements
|
||||
for config_obj in list(all_configs):
|
||||
for requirement in cls.__prefer_requires__:
|
||||
check = getattr(requirements, requirement)
|
||||
|
||||
if not check.enabled_for_config(config_obj):
|
||||
non_preferred.add(config_obj)
|
||||
if all_configs.difference(non_preferred):
|
||||
all_configs.difference_update(non_preferred)
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def _do_skips(cls):
|
||||
reasons = []
|
||||
all_configs = _possible_configs_for_cls(cls, reasons)
|
||||
|
||||
if getattr(cls, '__skip_if__', False):
|
||||
for c in getattr(cls, '__skip_if__'):
|
||||
if c():
|
||||
config.skip_test("'%s' skipped by %s" % (
|
||||
cls.__name__, c.__name__)
|
||||
)
|
||||
|
||||
if not all_configs:
|
||||
if getattr(cls, '__backend__', False):
|
||||
msg = "'%s' unsupported for implementation '%s'" % (
|
||||
cls.__name__, cls.__only_on__)
|
||||
else:
|
||||
msg = "'%s' unsupported on any DB implementation %s%s" % (
|
||||
cls.__name__,
|
||||
", ".join(
|
||||
"'%s(%s)+%s'" % (
|
||||
config_obj.db.name,
|
||||
".".join(
|
||||
str(dig) for dig in
|
||||
config_obj.db.dialect.server_version_info),
|
||||
config_obj.db.driver
|
||||
)
|
||||
for config_obj in config.Config.all_configs()
|
||||
),
|
||||
", ".join(reasons)
|
||||
)
|
||||
config.skip_test(msg)
|
||||
elif hasattr(cls, '__prefer_backends__'):
|
||||
non_preferred = set()
|
||||
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
|
||||
for config_obj in all_configs:
|
||||
if not spec(config_obj):
|
||||
non_preferred.add(config_obj)
|
||||
if all_configs.difference(non_preferred):
|
||||
all_configs.difference_update(non_preferred)
|
||||
|
||||
if config._current not in all_configs:
|
||||
_setup_config(all_configs.pop(), cls)
|
||||
|
||||
|
||||
def _setup_config(config_obj, ctx):
|
||||
config._current.push(config_obj, testing)
|
194
sqlalchemy/testing/plugin/pytestplugin.py
Normal file
194
sqlalchemy/testing/plugin/pytestplugin.py
Normal file
@@ -0,0 +1,194 @@
|
||||
try:
|
||||
# installed by bootstrap.py
|
||||
import sqla_plugin_base as plugin_base
|
||||
except ImportError:
|
||||
# assume we're a package, use traditional import
|
||||
from . import plugin_base
|
||||
|
||||
import pytest
|
||||
import argparse
|
||||
import inspect
|
||||
import collections
|
||||
import os
|
||||
|
||||
try:
|
||||
import xdist # noqa
|
||||
has_xdist = True
|
||||
except ImportError:
|
||||
has_xdist = False
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
group = parser.getgroup("sqlalchemy")
|
||||
|
||||
def make_option(name, **kw):
|
||||
callback_ = kw.pop("callback", None)
|
||||
if callback_:
|
||||
class CallableAction(argparse.Action):
|
||||
def __call__(self, parser, namespace,
|
||||
values, option_string=None):
|
||||
callback_(option_string, values, parser)
|
||||
kw["action"] = CallableAction
|
||||
|
||||
group.addoption(name, **kw)
|
||||
|
||||
plugin_base.setup_options(make_option)
|
||||
plugin_base.read_config()
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
if hasattr(config, "slaveinput"):
|
||||
plugin_base.restore_important_follower_config(config.slaveinput)
|
||||
plugin_base.configure_follower(
|
||||
config.slaveinput["follower_ident"]
|
||||
)
|
||||
|
||||
if config.option.write_idents:
|
||||
with open(config.option.write_idents, "a") as file_:
|
||||
file_.write(config.slaveinput["follower_ident"] + "\n")
|
||||
else:
|
||||
if config.option.write_idents and \
|
||||
os.path.exists(config.option.write_idents):
|
||||
os.remove(config.option.write_idents)
|
||||
|
||||
plugin_base.pre_begin(config.option)
|
||||
|
||||
plugin_base.set_coverage_flag(bool(getattr(config.option,
|
||||
"cov_source", False)))
|
||||
|
||||
plugin_base.set_skip_test(pytest.skip.Exception)
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
plugin_base.post_begin()
|
||||
|
||||
|
||||
def pytest_sessionfinish(session):
|
||||
plugin_base.final_process_cleanup()
|
||||
|
||||
|
||||
if has_xdist:
|
||||
import uuid
|
||||
|
||||
def pytest_configure_node(node):
|
||||
# the master for each node fills slaveinput dictionary
|
||||
# which pytest-xdist will transfer to the subprocess
|
||||
|
||||
plugin_base.memoize_important_follower_config(node.slaveinput)
|
||||
|
||||
node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
|
||||
from sqlalchemy.testing import provision
|
||||
provision.create_follower_db(node.slaveinput["follower_ident"])
|
||||
|
||||
def pytest_testnodedown(node, error):
|
||||
from sqlalchemy.testing import provision
|
||||
provision.drop_follower_db(node.slaveinput["follower_ident"])
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(session, config, items):
|
||||
# look for all those classes that specify __backend__ and
|
||||
# expand them out into per-database test cases.
|
||||
|
||||
# this is much easier to do within pytest_pycollect_makeitem, however
|
||||
# pytest is iterating through cls.__dict__ as makeitem is
|
||||
# called which causes a "dictionary changed size" error on py3k.
|
||||
# I'd submit a pullreq for them to turn it into a list first, but
|
||||
# it's to suit the rather odd use case here which is that we are adding
|
||||
# new classes to a module on the fly.
|
||||
|
||||
rebuilt_items = collections.defaultdict(list)
|
||||
items[:] = [
|
||||
item for item in
|
||||
items if isinstance(item.parent, pytest.Instance)
|
||||
and not item.parent.parent.name.startswith("_")]
|
||||
test_classes = set(item.parent for item in items)
|
||||
for test_class in test_classes:
|
||||
for sub_cls in plugin_base.generate_sub_tests(
|
||||
test_class.cls, test_class.parent.module):
|
||||
if sub_cls is not test_class.cls:
|
||||
list_ = rebuilt_items[test_class.cls]
|
||||
|
||||
for inst in pytest.Class(
|
||||
sub_cls.__name__,
|
||||
parent=test_class.parent.parent).collect():
|
||||
list_.extend(inst.collect())
|
||||
|
||||
newitems = []
|
||||
for item in items:
|
||||
if item.parent.cls in rebuilt_items:
|
||||
newitems.extend(rebuilt_items[item.parent.cls])
|
||||
rebuilt_items[item.parent.cls][:] = []
|
||||
else:
|
||||
newitems.append(item)
|
||||
|
||||
# seems like the functions attached to a test class aren't sorted already?
|
||||
# is that true and why's that? (when using unittest, they're sorted)
|
||||
items[:] = sorted(newitems, key=lambda item: (
|
||||
item.parent.parent.parent.name,
|
||||
item.parent.parent.name,
|
||||
item.name
|
||||
))
|
||||
|
||||
|
||||
def pytest_pycollect_makeitem(collector, name, obj):
|
||||
if inspect.isclass(obj) and plugin_base.want_class(obj):
|
||||
return pytest.Class(name, parent=collector)
|
||||
elif inspect.isfunction(obj) and \
|
||||
isinstance(collector, pytest.Instance) and \
|
||||
plugin_base.want_method(collector.cls, obj):
|
||||
return pytest.Function(name, parent=collector)
|
||||
else:
|
||||
return []
|
||||
|
||||
_current_class = None
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
# here we seem to get called only based on what we collected
|
||||
# in pytest_collection_modifyitems. So to do class-based stuff
|
||||
# we have to tear that out.
|
||||
global _current_class
|
||||
|
||||
if not isinstance(item, pytest.Function):
|
||||
return
|
||||
|
||||
# ... so we're doing a little dance here to figure it out...
|
||||
if _current_class is None:
|
||||
class_setup(item.parent.parent)
|
||||
_current_class = item.parent.parent
|
||||
|
||||
# this is needed for the class-level, to ensure that the
|
||||
# teardown runs after the class is completed with its own
|
||||
# class-level teardown...
|
||||
def finalize():
|
||||
global _current_class
|
||||
class_teardown(item.parent.parent)
|
||||
_current_class = None
|
||||
item.parent.parent.addfinalizer(finalize)
|
||||
|
||||
test_setup(item)
|
||||
|
||||
|
||||
def pytest_runtest_teardown(item):
|
||||
# ...but this works better as the hook here rather than
|
||||
# using a finalizer, as the finalizer seems to get in the way
|
||||
# of the test reporting failures correctly (you get a bunch of
|
||||
# py.test assertion stuff instead)
|
||||
test_teardown(item)
|
||||
|
||||
|
||||
def test_setup(item):
|
||||
plugin_base.before_test(item, item.parent.module.__name__,
|
||||
item.parent.cls, item.name)
|
||||
|
||||
|
||||
def test_teardown(item):
|
||||
plugin_base.after_test(item)
|
||||
|
||||
|
||||
def class_setup(item):
|
||||
plugin_base.start_test_class(item.cls)
|
||||
|
||||
|
||||
def class_teardown(item):
|
||||
plugin_base.stop_test_class(item.cls)
|
265
sqlalchemy/testing/profiling.py
Normal file
265
sqlalchemy/testing/profiling.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# testing/profiling.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Profiling support for unit and performance tests.
|
||||
|
||||
These are special purpose profiling methods which operate
|
||||
in a more fine-grained way than nose's profiling plugin.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from .util import gc_collect
|
||||
from . import config
|
||||
import pstats
|
||||
import collections
|
||||
import contextlib
|
||||
|
||||
try:
|
||||
import cProfile
|
||||
except ImportError:
|
||||
cProfile = None
|
||||
from ..util import jython, pypy, win32, update_wrapper
|
||||
|
||||
_current_test = None
|
||||
|
||||
# ProfileStatsFile instance, set up in plugin_base
|
||||
_profile_stats = None
|
||||
|
||||
|
||||
class ProfileStatsFile(object):
|
||||
""""Store per-platform/fn profiling results in a file.
|
||||
|
||||
We're still targeting Py2.5, 2.4 on 0.7 with no dependencies,
|
||||
so no json lib :( need to roll something silly
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, filename):
|
||||
self.force_write = (
|
||||
config.options is not None and
|
||||
config.options.force_write_profiles
|
||||
)
|
||||
self.write = self.force_write or (
|
||||
config.options is not None and
|
||||
config.options.write_profiles
|
||||
)
|
||||
self.fname = os.path.abspath(filename)
|
||||
self.short_fname = os.path.split(self.fname)[-1]
|
||||
self.data = collections.defaultdict(
|
||||
lambda: collections.defaultdict(dict))
|
||||
self._read()
|
||||
if self.write:
|
||||
# rewrite for the case where features changed,
|
||||
# etc.
|
||||
self._write()
|
||||
|
||||
@property
|
||||
def platform_key(self):
|
||||
|
||||
dbapi_key = config.db.name + "_" + config.db.driver
|
||||
|
||||
# keep it at 2.7, 3.1, 3.2, etc. for now.
|
||||
py_version = '.'.join([str(v) for v in sys.version_info[0:2]])
|
||||
|
||||
platform_tokens = [py_version]
|
||||
platform_tokens.append(dbapi_key)
|
||||
if jython:
|
||||
platform_tokens.append("jython")
|
||||
if pypy:
|
||||
platform_tokens.append("pypy")
|
||||
if win32:
|
||||
platform_tokens.append("win")
|
||||
platform_tokens.append(
|
||||
"nativeunicode"
|
||||
if config.db.dialect.convert_unicode
|
||||
else "dbapiunicode"
|
||||
)
|
||||
_has_cext = config.requirements._has_cextensions()
|
||||
platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
|
||||
return "_".join(platform_tokens)
|
||||
|
||||
def has_stats(self):
|
||||
test_key = _current_test
|
||||
return (
|
||||
test_key in self.data and
|
||||
self.platform_key in self.data[test_key]
|
||||
)
|
||||
|
||||
def result(self, callcount):
|
||||
test_key = _current_test
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[self.platform_key]
|
||||
|
||||
if 'counts' not in per_platform:
|
||||
per_platform['counts'] = counts = []
|
||||
else:
|
||||
counts = per_platform['counts']
|
||||
|
||||
if 'current_count' not in per_platform:
|
||||
per_platform['current_count'] = current_count = 0
|
||||
else:
|
||||
current_count = per_platform['current_count']
|
||||
|
||||
has_count = len(counts) > current_count
|
||||
|
||||
if not has_count:
|
||||
counts.append(callcount)
|
||||
if self.write:
|
||||
self._write()
|
||||
result = None
|
||||
else:
|
||||
result = per_platform['lineno'], counts[current_count]
|
||||
per_platform['current_count'] += 1
|
||||
return result
|
||||
|
||||
def replace(self, callcount):
|
||||
test_key = _current_test
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[self.platform_key]
|
||||
counts = per_platform['counts']
|
||||
current_count = per_platform['current_count']
|
||||
if current_count < len(counts):
|
||||
counts[current_count - 1] = callcount
|
||||
else:
|
||||
counts[-1] = callcount
|
||||
if self.write:
|
||||
self._write()
|
||||
|
||||
def _header(self):
|
||||
return (
|
||||
"# %s\n"
|
||||
"# This file is written out on a per-environment basis.\n"
|
||||
"# For each test in aaa_profiling, the corresponding "
|
||||
"function and \n"
|
||||
"# environment is located within this file. "
|
||||
"If it doesn't exist,\n"
|
||||
"# the test is skipped.\n"
|
||||
"# If a callcount does exist, it is compared "
|
||||
"to what we received. \n"
|
||||
"# assertions are raised if the counts do not match.\n"
|
||||
"# \n"
|
||||
"# To add a new callcount test, apply the function_call_count \n"
|
||||
"# decorator and re-run the tests using the --write-profiles \n"
|
||||
"# option - this file will be rewritten including the new count.\n"
|
||||
"# \n"
|
||||
) % (self.fname)
|
||||
|
||||
def _read(self):
|
||||
try:
|
||||
profile_f = open(self.fname)
|
||||
except IOError:
|
||||
return
|
||||
for lineno, line in enumerate(profile_f):
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
test_key, platform_key, counts = line.split()
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[platform_key]
|
||||
c = [int(count) for count in counts.split(",")]
|
||||
per_platform['counts'] = c
|
||||
per_platform['lineno'] = lineno + 1
|
||||
per_platform['current_count'] = 0
|
||||
profile_f.close()
|
||||
|
||||
def _write(self):
|
||||
print(("Writing profile file %s" % self.fname))
|
||||
profile_f = open(self.fname, "w")
|
||||
profile_f.write(self._header())
|
||||
for test_key in sorted(self.data):
|
||||
|
||||
per_fn = self.data[test_key]
|
||||
profile_f.write("\n# TEST: %s\n\n" % test_key)
|
||||
for platform_key in sorted(per_fn):
|
||||
per_platform = per_fn[platform_key]
|
||||
c = ",".join(str(count) for count in per_platform['counts'])
|
||||
profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
|
||||
profile_f.close()
|
||||
|
||||
|
||||
def function_call_count(variance=0.05):
|
||||
"""Assert a target for a test case's function call count.
|
||||
|
||||
The main purpose of this assertion is to detect changes in
|
||||
callcounts for various functions - the actual number is not as important.
|
||||
Callcounts are stored in a file keyed to Python version and OS platform
|
||||
information. This file is generated automatically for new tests,
|
||||
and versioned so that unexpected changes in callcounts will be detected.
|
||||
|
||||
"""
|
||||
|
||||
def decorate(fn):
|
||||
def wrap(*args, **kw):
|
||||
with count_functions(variance=variance):
|
||||
return fn(*args, **kw)
|
||||
return update_wrapper(wrap, fn)
|
||||
return decorate
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def count_functions(variance=0.05):
|
||||
if cProfile is None:
|
||||
raise SkipTest("cProfile is not installed")
|
||||
|
||||
if not _profile_stats.has_stats() and not _profile_stats.write:
|
||||
config.skip_test(
|
||||
"No profiling stats available on this "
|
||||
"platform for this function. Run tests with "
|
||||
"--write-profiles to add statistics to %s for "
|
||||
"this platform." % _profile_stats.short_fname)
|
||||
|
||||
gc_collect()
|
||||
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
#began = time.time()
|
||||
yield
|
||||
#ended = time.time()
|
||||
pr.disable()
|
||||
|
||||
#s = compat.StringIO()
|
||||
stats = pstats.Stats(pr, stream=sys.stdout)
|
||||
|
||||
#timespent = ended - began
|
||||
callcount = stats.total_calls
|
||||
|
||||
expected = _profile_stats.result(callcount)
|
||||
|
||||
if expected is None:
|
||||
expected_count = None
|
||||
else:
|
||||
line_no, expected_count = expected
|
||||
|
||||
print(("Pstats calls: %d Expected %s" % (
|
||||
callcount,
|
||||
expected_count
|
||||
)
|
||||
))
|
||||
stats.sort_stats("cumulative")
|
||||
stats.print_stats()
|
||||
|
||||
if expected_count:
|
||||
deviance = int(callcount * variance)
|
||||
failed = abs(callcount - expected_count) > deviance
|
||||
|
||||
if failed or _profile_stats.force_write:
|
||||
if _profile_stats.write:
|
||||
_profile_stats.replace(callcount)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Adjusted function call count %s not within %s%% "
|
||||
"of expected %s, platform %s. Rerun with "
|
||||
"--write-profiles to "
|
||||
"regenerate this callcount."
|
||||
% (
|
||||
callcount, (variance * 100),
|
||||
expected_count, _profile_stats.platform_key))
|
||||
|
||||
|
318
sqlalchemy/testing/provision.py
Normal file
318
sqlalchemy/testing/provision.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from sqlalchemy.engine import url as sa_url
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy.util import compat
|
||||
from . import config, engines
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
FOLLOWER_IDENT = None
|
||||
|
||||
|
||||
class register(object):
|
||||
def __init__(self):
|
||||
self.fns = {}
|
||||
|
||||
@classmethod
|
||||
def init(cls, fn):
|
||||
return register().for_db("*")(fn)
|
||||
|
||||
def for_db(self, dbname):
|
||||
def decorate(fn):
|
||||
self.fns[dbname] = fn
|
||||
return self
|
||||
return decorate
|
||||
|
||||
def __call__(self, cfg, *arg):
|
||||
if isinstance(cfg, compat.string_types):
|
||||
url = sa_url.make_url(cfg)
|
||||
elif isinstance(cfg, sa_url.URL):
|
||||
url = cfg
|
||||
else:
|
||||
url = cfg.db.url
|
||||
backend = url.get_backend_name()
|
||||
if backend in self.fns:
|
||||
return self.fns[backend](cfg, *arg)
|
||||
else:
|
||||
return self.fns['*'](cfg, *arg)
|
||||
|
||||
|
||||
def create_follower_db(follower_ident):
|
||||
for cfg in _configs_for_db_operation():
|
||||
_create_db(cfg, cfg.db, follower_ident)
|
||||
|
||||
|
||||
def configure_follower(follower_ident):
|
||||
for cfg in config.Config.all_configs():
|
||||
_configure_follower(cfg, follower_ident)
|
||||
|
||||
|
||||
def setup_config(db_url, options, file_config, follower_ident):
|
||||
if follower_ident:
|
||||
db_url = _follower_url_from_main(db_url, follower_ident)
|
||||
db_opts = {}
|
||||
_update_db_opts(db_url, db_opts)
|
||||
eng = engines.testing_engine(db_url, db_opts)
|
||||
_post_configure_engine(db_url, eng, follower_ident)
|
||||
eng.connect().close()
|
||||
cfg = config.Config.register(eng, db_opts, options, file_config)
|
||||
if follower_ident:
|
||||
_configure_follower(cfg, follower_ident)
|
||||
return cfg
|
||||
|
||||
|
||||
def drop_follower_db(follower_ident):
|
||||
for cfg in _configs_for_db_operation():
|
||||
_drop_db(cfg, cfg.db, follower_ident)
|
||||
|
||||
|
||||
def _configs_for_db_operation():
|
||||
hosts = set()
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
cfg.db.dispose()
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
url = cfg.db.url
|
||||
backend = url.get_backend_name()
|
||||
host_conf = (
|
||||
backend,
|
||||
url.username, url.host, url.database)
|
||||
|
||||
if host_conf not in hosts:
|
||||
yield cfg
|
||||
hosts.add(host_conf)
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
cfg.db.dispose()
|
||||
|
||||
|
||||
@register.init
|
||||
def _create_db(cfg, eng, ident):
|
||||
raise NotImplementedError("no DB creation routine for cfg: %s" % eng.url)
|
||||
|
||||
|
||||
@register.init
|
||||
def _drop_db(cfg, eng, ident):
|
||||
raise NotImplementedError("no DB drop routine for cfg: %s" % eng.url)
|
||||
|
||||
|
||||
@register.init
|
||||
def _update_db_opts(db_url, db_opts):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def _configure_follower(cfg, ident):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def _post_configure_engine(url, engine, follower_ident):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def _follower_url_from_main(url, ident):
|
||||
url = sa_url.make_url(url)
|
||||
url.database = ident
|
||||
return url
|
||||
|
||||
|
||||
@_update_db_opts.for_db("mssql")
|
||||
def _mssql_update_db_opts(db_url, db_opts):
|
||||
db_opts['legacy_schema_aliasing'] = False
|
||||
|
||||
|
||||
|
||||
@_follower_url_from_main.for_db("sqlite")
|
||||
def _sqlite_follower_url_from_main(url, ident):
|
||||
url = sa_url.make_url(url)
|
||||
if not url.database or url.database == ':memory:':
|
||||
return url
|
||||
else:
|
||||
return sa_url.make_url("sqlite:///%s.db" % ident)
|
||||
|
||||
|
||||
@_post_configure_engine.for_db("sqlite")
|
||||
def _sqlite_post_configure_engine(url, engine, follower_ident):
|
||||
from sqlalchemy import event
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def connect(dbapi_connection, connection_record):
|
||||
# use file DBs in all cases, memory acts kind of strangely
|
||||
# as an attached
|
||||
if not follower_ident:
|
||||
dbapi_connection.execute(
|
||||
'ATTACH DATABASE "test_schema.db" AS test_schema')
|
||||
else:
|
||||
dbapi_connection.execute(
|
||||
'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
|
||||
% follower_ident)
|
||||
|
||||
|
||||
@_create_db.for_db("postgresql")
|
||||
def _pg_create_db(cfg, eng, ident):
|
||||
with eng.connect().execution_options(
|
||||
isolation_level="AUTOCOMMIT") as conn:
|
||||
try:
|
||||
_pg_drop_db(cfg, conn, ident)
|
||||
except Exception:
|
||||
pass
|
||||
currentdb = conn.scalar("select current_database()")
|
||||
for attempt in range(3):
|
||||
try:
|
||||
conn.execute(
|
||||
"CREATE DATABASE %s TEMPLATE %s" % (ident, currentdb))
|
||||
except exc.OperationalError as err:
|
||||
if attempt != 2 and "accessed by other users" in str(err):
|
||||
time.sleep(.2)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
@_create_db.for_db("mysql")
|
||||
def _mysql_create_db(cfg, eng, ident):
|
||||
with eng.connect() as conn:
|
||||
try:
|
||||
_mysql_drop_db(cfg, conn, ident)
|
||||
except Exception:
|
||||
pass
|
||||
conn.execute("CREATE DATABASE %s" % ident)
|
||||
conn.execute("CREATE DATABASE %s_test_schema" % ident)
|
||||
conn.execute("CREATE DATABASE %s_test_schema_2" % ident)
|
||||
|
||||
|
||||
@_configure_follower.for_db("mysql")
|
||||
def _mysql_configure_follower(config, ident):
|
||||
config.test_schema = "%s_test_schema" % ident
|
||||
config.test_schema_2 = "%s_test_schema_2" % ident
|
||||
|
||||
|
||||
@_create_db.for_db("sqlite")
|
||||
def _sqlite_create_db(cfg, eng, ident):
|
||||
pass
|
||||
|
||||
|
||||
@_drop_db.for_db("postgresql")
|
||||
def _pg_drop_db(cfg, eng, ident):
|
||||
with eng.connect().execution_options(
|
||||
isolation_level="AUTOCOMMIT") as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"select pg_terminate_backend(pid) from pg_stat_activity "
|
||||
"where usename=current_user and pid != pg_backend_pid() "
|
||||
"and datname=:dname"
|
||||
), dname=ident)
|
||||
conn.execute("DROP DATABASE %s" % ident)
|
||||
|
||||
|
||||
@_drop_db.for_db("sqlite")
|
||||
def _sqlite_drop_db(cfg, eng, ident):
|
||||
if ident:
|
||||
os.remove("%s_test_schema.db" % ident)
|
||||
else:
|
||||
os.remove("%s.db" % ident)
|
||||
|
||||
|
||||
@_drop_db.for_db("mysql")
|
||||
def _mysql_drop_db(cfg, eng, ident):
|
||||
with eng.connect() as conn:
|
||||
conn.execute("DROP DATABASE %s_test_schema" % ident)
|
||||
conn.execute("DROP DATABASE %s_test_schema_2" % ident)
|
||||
conn.execute("DROP DATABASE %s" % ident)
|
||||
|
||||
|
||||
@_create_db.for_db("oracle")
|
||||
def _oracle_create_db(cfg, eng, ident):
|
||||
# NOTE: make sure you've run "ALTER DATABASE default tablespace users" or
|
||||
# similar, so that the default tablespace is not "system"; reflection will
|
||||
# fail otherwise
|
||||
with eng.connect() as conn:
|
||||
conn.execute("create user %s identified by xe" % ident)
|
||||
conn.execute("create user %s_ts1 identified by xe" % ident)
|
||||
conn.execute("create user %s_ts2 identified by xe" % ident)
|
||||
conn.execute("grant dba to %s" % (ident, ))
|
||||
conn.execute("grant unlimited tablespace to %s" % ident)
|
||||
conn.execute("grant unlimited tablespace to %s_ts1" % ident)
|
||||
conn.execute("grant unlimited tablespace to %s_ts2" % ident)
|
||||
|
||||
@_configure_follower.for_db("oracle")
|
||||
def _oracle_configure_follower(config, ident):
|
||||
config.test_schema = "%s_ts1" % ident
|
||||
config.test_schema_2 = "%s_ts2" % ident
|
||||
|
||||
|
||||
def _ora_drop_ignore(conn, dbname):
|
||||
try:
|
||||
conn.execute("drop user %s cascade" % dbname)
|
||||
log.info("Reaped db: %s", dbname)
|
||||
return True
|
||||
except exc.DatabaseError as err:
|
||||
log.warning("couldn't drop db: %s", err)
|
||||
return False
|
||||
|
||||
|
||||
@_drop_db.for_db("oracle")
|
||||
def _oracle_drop_db(cfg, eng, ident):
|
||||
with eng.connect() as conn:
|
||||
# cx_Oracle seems to occasionally leak open connections when a large
|
||||
# suite it run, even if we confirm we have zero references to
|
||||
# connection objects.
|
||||
# while there is a "kill session" command in Oracle,
|
||||
# it unfortunately does not release the connection sufficiently.
|
||||
_ora_drop_ignore(conn, ident)
|
||||
_ora_drop_ignore(conn, "%s_ts1" % ident)
|
||||
_ora_drop_ignore(conn, "%s_ts2" % ident)
|
||||
|
||||
|
||||
@_update_db_opts.for_db("oracle")
|
||||
def _oracle_update_db_opts(db_url, db_opts):
|
||||
db_opts['_retry_on_12516'] = True
|
||||
|
||||
|
||||
def reap_oracle_dbs(eng, idents_file):
|
||||
log.info("Reaping Oracle dbs...")
|
||||
with eng.connect() as conn:
|
||||
with open(idents_file) as file_:
|
||||
idents = set(line.strip() for line in file_)
|
||||
|
||||
log.info("identifiers in file: %s", ", ".join(idents))
|
||||
|
||||
to_reap = conn.execute(
|
||||
"select u.username from all_users u where username "
|
||||
"like 'TEST_%' and not exists (select username "
|
||||
"from v$session where username=u.username)")
|
||||
all_names = set([username.lower() for (username, ) in to_reap])
|
||||
to_drop = set()
|
||||
for name in all_names:
|
||||
if name.endswith("_ts1") or name.endswith("_ts2"):
|
||||
continue
|
||||
elif name in idents:
|
||||
to_drop.add(name)
|
||||
if "%s_ts1" % name in all_names:
|
||||
to_drop.add("%s_ts1" % name)
|
||||
if "%s_ts2" % name in all_names:
|
||||
to_drop.add("%s_ts2" % name)
|
||||
|
||||
dropped = total = 0
|
||||
for total, username in enumerate(to_drop, 1):
|
||||
if _ora_drop_ignore(conn, username):
|
||||
dropped += 1
|
||||
log.info(
|
||||
"Dropped %d out of %d stale databases detected", dropped, total)
|
||||
|
||||
|
||||
@_follower_url_from_main.for_db("oracle")
|
||||
def _oracle_follower_url_from_main(url, ident):
|
||||
url = sa_url.make_url(url)
|
||||
url.username = ident
|
||||
url.password = 'xe'
|
||||
return url
|
||||
|
||||
|
172
sqlalchemy/testing/replay_fixture.py
Normal file
172
sqlalchemy/testing/replay_fixture.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from . import fixtures
|
||||
from . import profiling
|
||||
from .. import util
|
||||
import types
|
||||
from collections import deque
|
||||
import contextlib
|
||||
from . import config
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class ReplayFixtureTest(fixtures.TestBase):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _dummy_ctx(self, *arg, **kw):
|
||||
yield
|
||||
|
||||
def test_invocation(self):
|
||||
|
||||
dbapi_session = ReplayableSession()
|
||||
creator = config.db.pool._creator
|
||||
recorder = lambda: dbapi_session.recorder(creator())
|
||||
engine = create_engine(
|
||||
config.db.url, creator=recorder,
|
||||
use_native_hstore=False)
|
||||
self.metadata = MetaData(engine)
|
||||
self.engine = engine
|
||||
self.session = Session(engine)
|
||||
|
||||
self.setup_engine()
|
||||
try:
|
||||
self._run_steps(ctx=self._dummy_ctx)
|
||||
finally:
|
||||
self.teardown_engine()
|
||||
engine.dispose()
|
||||
|
||||
player = lambda: dbapi_session.player()
|
||||
engine = create_engine(
|
||||
config.db.url, creator=player,
|
||||
use_native_hstore=False)
|
||||
|
||||
self.metadata = MetaData(engine)
|
||||
self.engine = engine
|
||||
self.session = Session(engine)
|
||||
|
||||
self.setup_engine()
|
||||
try:
|
||||
self._run_steps(ctx=profiling.count_functions)
|
||||
finally:
|
||||
self.session.close()
|
||||
engine.dispose()
|
||||
|
||||
def setup_engine(self):
|
||||
pass
|
||||
|
||||
def teardown_engine(self):
|
||||
pass
|
||||
|
||||
def _run_steps(self, ctx):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ReplayableSession(object):
|
||||
"""A simple record/playback tool.
|
||||
|
||||
This is *not* a mock testing class. It only records a session for later
|
||||
playback and makes no assertions on call consistency whatsoever. It's
|
||||
unlikely to be suitable for anything other than DB-API recording.
|
||||
|
||||
"""
|
||||
|
||||
Callable = object()
|
||||
NoAttribute = object()
|
||||
|
||||
if util.py2k:
|
||||
Natives = set([getattr(types, t)
|
||||
for t in dir(types) if not t.startswith('_')]).\
|
||||
difference([getattr(types, t)
|
||||
for t in ('FunctionType', 'BuiltinFunctionType',
|
||||
'MethodType', 'BuiltinMethodType',
|
||||
'LambdaType', 'UnboundMethodType',)])
|
||||
else:
|
||||
Natives = set([getattr(types, t)
|
||||
for t in dir(types) if not t.startswith('_')]).\
|
||||
union([type(t) if not isinstance(t, type)
|
||||
else t for t in __builtins__.values()]).\
|
||||
difference([getattr(types, t)
|
||||
for t in ('FunctionType', 'BuiltinFunctionType',
|
||||
'MethodType', 'BuiltinMethodType',
|
||||
'LambdaType', )])
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = deque()
|
||||
|
||||
def recorder(self, base):
|
||||
return self.Recorder(self.buffer, base)
|
||||
|
||||
def player(self):
|
||||
return self.Player(self.buffer)
|
||||
|
||||
class Recorder(object):
|
||||
def __init__(self, buffer, subject):
|
||||
self._buffer = buffer
|
||||
self._subject = subject
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
subject, buffer = [object.__getattribute__(self, x)
|
||||
for x in ('_subject', '_buffer')]
|
||||
|
||||
result = subject(*args, **kw)
|
||||
if type(result) not in ReplayableSession.Natives:
|
||||
buffer.append(ReplayableSession.Callable)
|
||||
return type(self)(buffer, result)
|
||||
else:
|
||||
buffer.append(result)
|
||||
return result
|
||||
|
||||
@property
|
||||
def _sqla_unwrap(self):
|
||||
return self._subject
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return object.__getattribute__(self, key)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
subject, buffer = [object.__getattribute__(self, x)
|
||||
for x in ('_subject', '_buffer')]
|
||||
try:
|
||||
result = type(subject).__getattribute__(subject, key)
|
||||
except AttributeError:
|
||||
buffer.append(ReplayableSession.NoAttribute)
|
||||
raise
|
||||
else:
|
||||
if type(result) not in ReplayableSession.Natives:
|
||||
buffer.append(ReplayableSession.Callable)
|
||||
return type(self)(buffer, result)
|
||||
else:
|
||||
buffer.append(result)
|
||||
return result
|
||||
|
||||
class Player(object):
|
||||
def __init__(self, buffer):
|
||||
self._buffer = buffer
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
buffer = object.__getattribute__(self, '_buffer')
|
||||
result = buffer.popleft()
|
||||
if result is ReplayableSession.Callable:
|
||||
return self
|
||||
else:
|
||||
return result
|
||||
|
||||
@property
|
||||
def _sqla_unwrap(self):
|
||||
return None
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return object.__getattribute__(self, key)
|
||||
except AttributeError:
|
||||
pass
|
||||
buffer = object.__getattribute__(self, '_buffer')
|
||||
result = buffer.popleft()
|
||||
if result is ReplayableSession.Callable:
|
||||
return self
|
||||
elif result is ReplayableSession.NoAttribute:
|
||||
raise AttributeError(key)
|
||||
else:
|
||||
return result
|
800
sqlalchemy/testing/requirements.py
Normal file
800
sqlalchemy/testing/requirements.py
Normal file
@@ -0,0 +1,800 @@
|
||||
# testing/requirements.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Global database feature support policy.
|
||||
|
||||
Provides decorators to mark tests requiring specific feature support from the
|
||||
target database.
|
||||
|
||||
External dialect test suites should subclass SuiteRequirements
|
||||
to provide specific inclusion/exclusions.
|
||||
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
from . import exclusions
|
||||
from .. import util
|
||||
|
||||
|
||||
class Requirements(object):
|
||||
pass
|
||||
|
||||
|
||||
class SuiteRequirements(Requirements):
|
||||
|
||||
@property
|
||||
def create_table(self):
|
||||
"""target platform can emit basic CreateTable DDL."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def drop_table(self):
|
||||
"""target platform can emit basic DropTable DDL."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def foreign_keys(self):
|
||||
"""Target database must support foreign keys."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def on_update_cascade(self):
|
||||
""""target database must support ON UPDATE..CASCADE behavior in
|
||||
foreign keys."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def non_updating_cascade(self):
|
||||
"""target database must *not* support ON UPDATE..CASCADE behavior in
|
||||
foreign keys."""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def deferrable_fks(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def on_update_or_deferrable_fks(self):
|
||||
# TODO: exclusions should be composable,
|
||||
# somehow only_if([x, y]) isn't working here, negation/conjunctions
|
||||
# getting confused.
|
||||
return exclusions.only_if(
|
||||
lambda: self.on_update_cascade.enabled or
|
||||
self.deferrable_fks.enabled
|
||||
)
|
||||
|
||||
@property
|
||||
def self_referential_foreign_keys(self):
|
||||
"""Target database must support self-referential foreign keys."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def foreign_key_ddl(self):
|
||||
"""Target database must support the DDL phrases for FOREIGN KEY."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def named_constraints(self):
|
||||
"""target database must support names for constraints."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def subqueries(self):
|
||||
"""Target database must support subqueries."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def offset(self):
|
||||
"""target database can render OFFSET, or an equivalent, in a
|
||||
SELECT.
|
||||
"""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def bound_limit_offset(self):
|
||||
"""target database can render LIMIT and/or OFFSET using a bound
|
||||
parameter
|
||||
"""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def parens_in_union_contained_select_w_limit_offset(self):
|
||||
"""Target database must support parenthesized SELECT in UNION
|
||||
when LIMIT/OFFSET is specifically present.
|
||||
|
||||
E.g. (SELECT ...) UNION (SELECT ..)
|
||||
|
||||
This is known to fail on SQLite.
|
||||
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def parens_in_union_contained_select_wo_limit_offset(self):
|
||||
"""Target database must support parenthesized SELECT in UNION
|
||||
when OFFSET/LIMIT is specifically not present.
|
||||
|
||||
E.g. (SELECT ... LIMIT ..) UNION (SELECT .. OFFSET ..)
|
||||
|
||||
This is known to fail on SQLite. It also fails on Oracle
|
||||
because without LIMIT/OFFSET, there is currently no step that
|
||||
creates an additional subquery.
|
||||
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def boolean_col_expressions(self):
|
||||
"""Target database must support boolean expressions as columns"""
|
||||
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def nullsordering(self):
|
||||
"""Target backends that support nulls ordering."""
|
||||
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def standalone_binds(self):
|
||||
"""target database/driver supports bound parameters as column expressions
|
||||
without being in the context of a typed column.
|
||||
|
||||
"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def intersect(self):
|
||||
"""Target database must support INTERSECT or equivalent."""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def except_(self):
|
||||
"""Target database must support EXCEPT or equivalent (i.e. MINUS)."""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def window_functions(self):
|
||||
"""Target database must support window functions."""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def autoincrement_insert(self):
|
||||
"""target platform generates new surrogate integer primary key values
|
||||
when insert() is executed, excluding the pk column."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def fetch_rows_post_commit(self):
|
||||
"""target platform will allow cursor.fetchone() to proceed after a
|
||||
COMMIT.
|
||||
|
||||
Typically this refers to an INSERT statement with RETURNING which
|
||||
is invoked within "autocommit". If the row can be returned
|
||||
after the autocommit, then this rule can be open.
|
||||
|
||||
"""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def empty_inserts(self):
|
||||
"""target platform supports INSERT with no values, i.e.
|
||||
INSERT DEFAULT VALUES or equivalent."""
|
||||
|
||||
return exclusions.only_if(
|
||||
lambda config: config.db.dialect.supports_empty_insert or
|
||||
config.db.dialect.supports_default_values,
|
||||
"empty inserts not supported"
|
||||
)
|
||||
|
||||
@property
|
||||
def insert_from_select(self):
|
||||
"""target platform supports INSERT from a SELECT."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def returning(self):
|
||||
"""target platform supports RETURNING."""
|
||||
|
||||
return exclusions.only_if(
|
||||
lambda config: config.db.dialect.implicit_returning,
|
||||
"%(database)s %(does_support)s 'returning'"
|
||||
)
|
||||
|
||||
@property
|
||||
def duplicate_names_in_cursor_description(self):
|
||||
"""target platform supports a SELECT statement that has
|
||||
the same name repeated more than once in the columns list."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def denormalized_names(self):
|
||||
"""Target database must have 'denormalized', i.e.
|
||||
UPPERCASE as case insensitive names."""
|
||||
|
||||
return exclusions.skip_if(
|
||||
lambda config: not config.db.dialect.requires_name_normalize,
|
||||
"Backend does not require denormalized names."
|
||||
)
|
||||
|
||||
@property
|
||||
def multivalues_inserts(self):
|
||||
"""target database must support multiple VALUES clauses in an
|
||||
INSERT statement."""
|
||||
|
||||
return exclusions.skip_if(
|
||||
lambda config: not config.db.dialect.supports_multivalues_insert,
|
||||
"Backend does not support multirow inserts."
|
||||
)
|
||||
|
||||
@property
|
||||
def implements_get_lastrowid(self):
|
||||
""""target dialect implements the executioncontext.get_lastrowid()
|
||||
method without reliance on RETURNING.
|
||||
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def emulated_lastrowid(self):
|
||||
""""target dialect retrieves cursor.lastrowid, or fetches
|
||||
from a database-side function after an insert() construct executes,
|
||||
within the get_lastrowid() method.
|
||||
|
||||
Only dialects that "pre-execute", or need RETURNING to get last
|
||||
inserted id, would return closed/fail/skip for this.
|
||||
|
||||
"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def dbapi_lastrowid(self):
|
||||
""""target platform includes a 'lastrowid' accessor on the DBAPI
|
||||
cursor object.
|
||||
|
||||
"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def views(self):
|
||||
"""Target database must support VIEWs."""
|
||||
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def schemas(self):
|
||||
"""Target database must support external schemas, and have one
|
||||
named 'test_schema'."""
|
||||
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def server_side_cursors(self):
|
||||
"""Target dialect must support server side cursors."""
|
||||
|
||||
return exclusions.only_if([
|
||||
lambda config: config.db.dialect.supports_server_side_cursors
|
||||
], "no server side cursors support")
|
||||
|
||||
@property
|
||||
def sequences(self):
|
||||
"""Target database must support SEQUENCEs."""
|
||||
|
||||
return exclusions.only_if([
|
||||
lambda config: config.db.dialect.supports_sequences
|
||||
], "no sequence support")
|
||||
|
||||
@property
|
||||
def sequences_optional(self):
|
||||
"""Target database supports sequences, but also optionally
|
||||
as a means of generating new PK values."""
|
||||
|
||||
return exclusions.only_if([
|
||||
lambda config: config.db.dialect.supports_sequences and
|
||||
config.db.dialect.sequences_optional
|
||||
], "no sequence support, or sequences not optional")
|
||||
|
||||
@property
|
||||
def reflects_pk_names(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def table_reflection(self):
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def view_column_reflection(self):
|
||||
"""target database must support retrieval of the columns in a view,
|
||||
similarly to how a table is inspected.
|
||||
|
||||
This does not include the full CREATE VIEW definition.
|
||||
|
||||
"""
|
||||
return self.views
|
||||
|
||||
@property
|
||||
def view_reflection(self):
|
||||
"""target database must support inspection of the full CREATE VIEW definition.
|
||||
"""
|
||||
return self.views
|
||||
|
||||
@property
|
||||
def schema_reflection(self):
|
||||
return self.schemas
|
||||
|
||||
@property
|
||||
def primary_key_constraint_reflection(self):
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def foreign_key_constraint_reflection(self):
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def foreign_key_constraint_option_reflection(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def temp_table_reflection(self):
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def temp_table_names(self):
|
||||
"""target dialect supports listing of temporary table names"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def temporary_tables(self):
|
||||
"""target database supports temporary tables"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def temporary_views(self):
|
||||
"""target database supports temporary views"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def index_reflection(self):
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def unique_constraint_reflection(self):
|
||||
"""target dialect supports reflection of unique constraints"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def duplicate_key_raises_integrity_error(self):
|
||||
"""target dialect raises IntegrityError when reporting an INSERT
|
||||
with a primary key violation. (hint: it should)
|
||||
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def unbounded_varchar(self):
|
||||
"""Target database must support VARCHAR with no length"""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def unicode_data(self):
|
||||
"""Target database/dialect must support Python unicode objects with
|
||||
non-ASCII characters represented, delivered as bound parameters
|
||||
as well as in result rows.
|
||||
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def unicode_ddl(self):
|
||||
"""Target driver must support some degree of non-ascii symbol
|
||||
names.
|
||||
"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def datetime_literals(self):
|
||||
"""target dialect supports rendering of a date, time, or datetime as a
|
||||
literal string, e.g. via the TypeEngine.literal_processor() method.
|
||||
|
||||
"""
|
||||
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def datetime(self):
|
||||
"""target dialect supports representation of Python
|
||||
datetime.datetime() objects."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def datetime_microseconds(self):
|
||||
"""target dialect supports representation of Python
|
||||
datetime.datetime() with microsecond objects."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def datetime_historic(self):
|
||||
"""target dialect supports representation of Python
|
||||
datetime.datetime() objects with historic (pre 1970) values."""
|
||||
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def date(self):
|
||||
"""target dialect supports representation of Python
|
||||
datetime.date() objects."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def date_coerces_from_datetime(self):
|
||||
"""target dialect accepts a datetime object as the target
|
||||
of a date column."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def date_historic(self):
|
||||
"""target dialect supports representation of Python
|
||||
datetime.datetime() objects with historic (pre 1970) values."""
|
||||
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def time(self):
|
||||
"""target dialect supports representation of Python
|
||||
datetime.time() objects."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def time_microseconds(self):
|
||||
"""target dialect supports representation of Python
|
||||
datetime.time() with microsecond objects."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def binary_comparisons(self):
|
||||
"""target database/driver can allow BLOB/BINARY fields to be compared
|
||||
against a bound parameter value.
|
||||
"""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def binary_literals(self):
|
||||
"""target backend supports simple binary literals, e.g. an
|
||||
expression like::
|
||||
|
||||
SELECT CAST('foo' AS BINARY)
|
||||
|
||||
Where ``BINARY`` is the type emitted from :class:`.LargeBinary`,
|
||||
e.g. it could be ``BLOB`` or similar.
|
||||
|
||||
Basically fails on Oracle.
|
||||
|
||||
"""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def json_type(self):
|
||||
"""target platform implements a native JSON type."""
|
||||
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def json_array_indexes(self):
|
||||
""""target platform supports numeric array indexes
|
||||
within a JSON structure"""
|
||||
|
||||
return self.json_type
|
||||
|
||||
@property
|
||||
def precision_numerics_general(self):
|
||||
"""target backend has general support for moderately high-precision
|
||||
numerics."""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def precision_numerics_enotation_small(self):
|
||||
"""target backend supports Decimal() objects using E notation
|
||||
to represent very small values."""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def precision_numerics_enotation_large(self):
|
||||
"""target backend supports Decimal() objects using E notation
|
||||
to represent very large values."""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def precision_numerics_many_significant_digits(self):
|
||||
"""target backend supports values with many digits on both sides,
|
||||
such as 319438950232418390.273596, 87673.594069654243
|
||||
|
||||
"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def precision_numerics_retains_significant_digits(self):
|
||||
"""A precision numeric type will return empty significant digits,
|
||||
i.e. a value such as 10.000 will come back in Decimal form with
|
||||
the .000 maintained."""
|
||||
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def precision_generic_float_type(self):
|
||||
"""target backend will return native floating point numbers with at
|
||||
least seven decimal places when using the generic Float type.
|
||||
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def floats_to_four_decimals(self):
|
||||
"""target backend can return a floating-point number with four
|
||||
significant digits (such as 15.7563) accurately
|
||||
(i.e. without FP inaccuracies, such as 15.75629997253418).
|
||||
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def fetch_null_from_numeric(self):
|
||||
"""target backend doesn't crash when you try to select a NUMERIC
|
||||
value that has a value of NULL.
|
||||
|
||||
Added to support Pyodbc bug #351.
|
||||
"""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def text_type(self):
|
||||
"""Target database must support an unbounded Text() "
|
||||
"type such as TEXT or CLOB"""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def empty_strings_varchar(self):
|
||||
"""target database can persist/return an empty string with a
|
||||
varchar.
|
||||
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def empty_strings_text(self):
|
||||
"""target database can persist/return an empty string with an
|
||||
unbounded text."""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def selectone(self):
|
||||
"""target driver must support the literal statement 'select 1'"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def savepoints(self):
|
||||
"""Target database must support savepoints."""
|
||||
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def two_phase_transactions(self):
|
||||
"""Target database must support two-phase transactions."""
|
||||
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def update_from(self):
|
||||
"""Target must support UPDATE..FROM syntax"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def update_where_target_in_subquery(self):
|
||||
"""Target must support UPDATE where the same table is present in a
|
||||
subquery in the WHERE clause.
|
||||
|
||||
This is an ANSI-standard syntax that apparently MySQL can't handle,
|
||||
such as:
|
||||
|
||||
UPDATE documents SET flag=1 WHERE documents.title IN
|
||||
(SELECT max(documents.title) AS title
|
||||
FROM documents GROUP BY documents.user_id
|
||||
)
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def mod_operator_as_percent_sign(self):
|
||||
"""target database must use a plain percent '%' as the 'modulus'
|
||||
operator."""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def percent_schema_names(self):
|
||||
"""target backend supports weird identifiers with percent signs
|
||||
in them, e.g. 'some % column'.
|
||||
|
||||
this is a very weird use case but often has problems because of
|
||||
DBAPIs that use python formatting. It's not a critical use
|
||||
case either.
|
||||
|
||||
"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def order_by_label_with_expression(self):
|
||||
"""target backend supports ORDER BY a column label within an
|
||||
expression.
|
||||
|
||||
Basically this::
|
||||
|
||||
select data as foo from test order by foo || 'bar'
|
||||
|
||||
Lots of databases including PostgreSQL don't support this,
|
||||
so this is off by default.
|
||||
|
||||
"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def unicode_connections(self):
|
||||
"""Target driver must support non-ASCII characters being passed at
|
||||
all.
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def graceful_disconnects(self):
|
||||
"""Target driver must raise a DBAPI-level exception, such as
|
||||
InterfaceError, when the underlying connection has been closed
|
||||
and the execute() method is called.
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def skip_mysql_on_windows(self):
|
||||
"""Catchall for a large variety of MySQL on Windows failures"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def ad_hoc_engines(self):
|
||||
"""Test environment must allow ad-hoc engine/connection creation.
|
||||
|
||||
DBs that scale poorly for many connections, even when closed, i.e.
|
||||
Oracle, may use the "--low-connections" option which flags this
|
||||
requirement as not present.
|
||||
|
||||
"""
|
||||
return exclusions.skip_if(
|
||||
lambda config: config.options.low_connections)
|
||||
|
||||
@property
|
||||
def timing_intensive(self):
|
||||
return exclusions.requires_tag("timing_intensive")
|
||||
|
||||
@property
|
||||
def memory_intensive(self):
|
||||
return exclusions.requires_tag("memory_intensive")
|
||||
|
||||
@property
|
||||
def threading_with_mock(self):
|
||||
"""Mark tests that use threading and mock at the same time - stability
|
||||
issues have been observed with coverage + python 3.3
|
||||
|
||||
"""
|
||||
return exclusions.skip_if(
|
||||
lambda config: util.py3k and config.options.has_coverage,
|
||||
"Stability issues with coverage + py3k"
|
||||
)
|
||||
|
||||
@property
|
||||
def python2(self):
|
||||
return exclusions.skip_if(
|
||||
lambda: sys.version_info >= (3,),
|
||||
"Python version 2.xx is required."
|
||||
)
|
||||
|
||||
@property
|
||||
def python3(self):
|
||||
return exclusions.skip_if(
|
||||
lambda: sys.version_info < (3,),
|
||||
"Python version 3.xx is required."
|
||||
)
|
||||
|
||||
@property
|
||||
def cpython(self):
|
||||
return exclusions.only_if(
|
||||
lambda: util.cpython,
|
||||
"cPython interpreter needed"
|
||||
)
|
||||
|
||||
@property
|
||||
def non_broken_pickle(self):
|
||||
from sqlalchemy.util import pickle
|
||||
return exclusions.only_if(
|
||||
lambda: not util.pypy and pickle.__name__ == 'cPickle'
|
||||
or sys.version_info >= (3, 2),
|
||||
"Needs cPickle+cPython or newer Python 3 pickle"
|
||||
)
|
||||
|
||||
@property
|
||||
def predictable_gc(self):
|
||||
"""target platform must remove all cycles unconditionally when
|
||||
gc.collect() is called, as well as clean out unreferenced subclasses.
|
||||
|
||||
"""
|
||||
return self.cpython
|
||||
|
||||
@property
|
||||
def no_coverage(self):
|
||||
"""Test should be skipped if coverage is enabled.
|
||||
|
||||
This is to block tests that exercise libraries that seem to be
|
||||
sensitive to coverage, such as PostgreSQL notice logging.
|
||||
|
||||
"""
|
||||
return exclusions.skip_if(
|
||||
lambda config: config.options.has_coverage,
|
||||
"Issues observed when coverage is enabled"
|
||||
)
|
||||
|
||||
def _has_mysql_on_windows(self, config):
|
||||
return False
|
||||
|
||||
def _has_mysql_fully_case_sensitive(self, config):
|
||||
return False
|
||||
|
||||
@property
|
||||
def sqlite(self):
|
||||
return exclusions.skip_if(lambda: not self._has_sqlite())
|
||||
|
||||
@property
|
||||
def cextensions(self):
|
||||
return exclusions.skip_if(
|
||||
lambda: not self._has_cextensions(), "C extensions not installed"
|
||||
)
|
||||
|
||||
def _has_sqlite(self):
|
||||
from sqlalchemy import create_engine
|
||||
try:
|
||||
create_engine('sqlite://')
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def _has_cextensions(self):
|
||||
try:
|
||||
from sqlalchemy import cresultproxy, cprocessors
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
50
sqlalchemy/testing/runner.py
Normal file
50
sqlalchemy/testing/runner.py
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
# testing/runner.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""
|
||||
Nose test runner module.
|
||||
|
||||
This script is a front-end to "nosetests" which
|
||||
installs SQLAlchemy's testing plugin into the local environment.
|
||||
|
||||
The script is intended to be used by third-party dialects and extensions
|
||||
that run within SQLAlchemy's testing framework. The runner can
|
||||
be invoked via::
|
||||
|
||||
python -m sqlalchemy.testing.runner
|
||||
|
||||
The script is then essentially the same as the "nosetests" script, including
|
||||
all of the usual Nose options. The test environment requires that a
|
||||
setup.cfg is locally present including various required options.
|
||||
|
||||
Note that when using this runner, Nose's "coverage" plugin will not be
|
||||
able to provide coverage for SQLAlchemy itself, since SQLAlchemy is
|
||||
imported into sys.modules before coverage is started. The special
|
||||
script sqla_nose.py is provided as a top-level script which loads the
|
||||
plugin in a special (somewhat hacky) way so that coverage against
|
||||
SQLAlchemy itself is possible.
|
||||
|
||||
"""
|
||||
|
||||
from .plugin.noseplugin import NoseSQLAlchemy
|
||||
|
||||
import nose
|
||||
|
||||
|
||||
def main():
|
||||
nose.main(addplugins=[NoseSQLAlchemy()])
|
||||
|
||||
|
||||
def setup_py_test():
|
||||
"""Runner to use for the 'test_suite' entry of your setup.py.
|
||||
|
||||
Prevents any name clash shenanigans from the command line
|
||||
argument "test" that the "setup.py test" command sends
|
||||
to nose.
|
||||
|
||||
"""
|
||||
nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner'])
|
101
sqlalchemy/testing/schema.py
Normal file
101
sqlalchemy/testing/schema.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# testing/schema.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from . import exclusions
|
||||
from .. import schema, event
|
||||
from . import config
|
||||
|
||||
__all__ = 'Table', 'Column',
|
||||
|
||||
table_options = {}
|
||||
|
||||
|
||||
def Table(*args, **kw):
|
||||
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
test_opts = dict([(k, kw.pop(k)) for k in list(kw)
|
||||
if k.startswith('test_')])
|
||||
|
||||
kw.update(table_options)
|
||||
|
||||
if exclusions.against(config._current, 'mysql'):
|
||||
if 'mysql_engine' not in kw and 'mysql_type' not in kw:
|
||||
if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
|
||||
kw['mysql_engine'] = 'InnoDB'
|
||||
else:
|
||||
kw['mysql_engine'] = 'MyISAM'
|
||||
|
||||
# Apply some default cascading rules for self-referential foreign keys.
|
||||
# MySQL InnoDB has some issues around seleting self-refs too.
|
||||
if exclusions.against(config._current, 'firebird'):
|
||||
table_name = args[0]
|
||||
unpack = (config.db.dialect.
|
||||
identifier_preparer.unformat_identifiers)
|
||||
|
||||
# Only going after ForeignKeys in Columns. May need to
|
||||
# expand to ForeignKeyConstraint too.
|
||||
fks = [fk
|
||||
for col in args if isinstance(col, schema.Column)
|
||||
for fk in col.foreign_keys]
|
||||
|
||||
for fk in fks:
|
||||
# root around in raw spec
|
||||
ref = fk._colspec
|
||||
if isinstance(ref, schema.Column):
|
||||
name = ref.table.name
|
||||
else:
|
||||
# take just the table name: on FB there cannot be
|
||||
# a schema, so the first element is always the
|
||||
# table name, possibly followed by the field name
|
||||
name = unpack(ref)[0]
|
||||
if name == table_name:
|
||||
if fk.ondelete is None:
|
||||
fk.ondelete = 'CASCADE'
|
||||
if fk.onupdate is None:
|
||||
fk.onupdate = 'CASCADE'
|
||||
|
||||
return schema.Table(*args, **kw)
|
||||
|
||||
|
||||
def Column(*args, **kw):
|
||||
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
test_opts = dict([(k, kw.pop(k)) for k in list(kw)
|
||||
if k.startswith('test_')])
|
||||
|
||||
if not config.requirements.foreign_key_ddl.enabled_for_config(config):
|
||||
args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
|
||||
|
||||
col = schema.Column(*args, **kw)
|
||||
if test_opts.get('test_needs_autoincrement', False) and \
|
||||
kw.get('primary_key', False):
|
||||
|
||||
if col.default is None and col.server_default is None:
|
||||
col.autoincrement = True
|
||||
|
||||
# allow any test suite to pick up on this
|
||||
col.info['test_needs_autoincrement'] = True
|
||||
|
||||
# hardcoded rule for firebird, oracle; this should
|
||||
# be moved out
|
||||
if exclusions.against(config._current, 'firebird', 'oracle'):
|
||||
def add_seq(c, tbl):
|
||||
c._init_items(
|
||||
schema.Sequence(_truncate_name(
|
||||
config.db.dialect, tbl.name + '_' + c.name + '_seq'),
|
||||
optional=True)
|
||||
)
|
||||
event.listen(col, 'after_parent_attach', add_seq, propagate=True)
|
||||
return col
|
||||
|
||||
|
||||
def _truncate_name(dialect, name):
|
||||
if len(name) > dialect.max_identifier_length:
|
||||
return name[0:max(dialect.max_identifier_length - 6, 0)] + \
|
||||
"_" + hex(hash(name) % 64)[2:]
|
||||
else:
|
||||
return name
|
10
sqlalchemy/testing/suite/__init__.py
Normal file
10
sqlalchemy/testing/suite/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
|
||||
from sqlalchemy.testing.suite.test_dialect import *
|
||||
from sqlalchemy.testing.suite.test_ddl import *
|
||||
from sqlalchemy.testing.suite.test_insert import *
|
||||
from sqlalchemy.testing.suite.test_sequence import *
|
||||
from sqlalchemy.testing.suite.test_select import *
|
||||
from sqlalchemy.testing.suite.test_results import *
|
||||
from sqlalchemy.testing.suite.test_update_delete import *
|
||||
from sqlalchemy.testing.suite.test_reflection import *
|
||||
from sqlalchemy.testing.suite.test_types import *
|
65
sqlalchemy/testing/suite/test_ddl.py
Normal file
65
sqlalchemy/testing/suite/test_ddl.py
Normal file
@@ -0,0 +1,65 @@
|
||||
|
||||
|
||||
from .. import fixtures, config, util
|
||||
from ..config import requirements
|
||||
from ..assertions import eq_
|
||||
|
||||
from sqlalchemy import Table, Column, Integer, String
|
||||
|
||||
|
||||
class TableDDLTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
def _simple_fixture(self):
|
||||
return Table('test_table', self.metadata,
|
||||
Column('id', Integer, primary_key=True,
|
||||
autoincrement=False),
|
||||
Column('data', String(50))
|
||||
)
|
||||
|
||||
def _underscore_fixture(self):
|
||||
return Table('_test_table', self.metadata,
|
||||
Column('id', Integer, primary_key=True,
|
||||
autoincrement=False),
|
||||
Column('_data', String(50))
|
||||
)
|
||||
|
||||
def _simple_roundtrip(self, table):
|
||||
with config.db.begin() as conn:
|
||||
conn.execute(table.insert().values((1, 'some data')))
|
||||
result = conn.execute(table.select())
|
||||
eq_(
|
||||
result.first(),
|
||||
(1, 'some data')
|
||||
)
|
||||
|
||||
@requirements.create_table
|
||||
@util.provide_metadata
|
||||
def test_create_table(self):
|
||||
table = self._simple_fixture()
|
||||
table.create(
|
||||
config.db, checkfirst=False
|
||||
)
|
||||
self._simple_roundtrip(table)
|
||||
|
||||
@requirements.drop_table
|
||||
@util.provide_metadata
|
||||
def test_drop_table(self):
|
||||
table = self._simple_fixture()
|
||||
table.create(
|
||||
config.db, checkfirst=False
|
||||
)
|
||||
table.drop(
|
||||
config.db, checkfirst=False
|
||||
)
|
||||
|
||||
@requirements.create_table
|
||||
@util.provide_metadata
|
||||
def test_underscore_names(self):
|
||||
table = self._underscore_fixture()
|
||||
table.create(
|
||||
config.db, checkfirst=False
|
||||
)
|
||||
self._simple_roundtrip(table)
|
||||
|
||||
__all__ = ('TableDDLTest', )
|
41
sqlalchemy/testing/suite/test_dialect.py
Normal file
41
sqlalchemy/testing/suite/test_dialect.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from .. import fixtures, config
|
||||
from ..config import requirements
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy import Integer, String
|
||||
from .. import assert_raises
|
||||
from ..schema import Table, Column
|
||||
|
||||
|
||||
class ExceptionTest(fixtures.TablesTest):
|
||||
"""Test basic exception wrapping.
|
||||
|
||||
DBAPIs vary a lot in exception behavior so to actually anticipate
|
||||
specific exceptions from real round trips, we need to be conservative.
|
||||
|
||||
"""
|
||||
run_deletes = 'each'
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('manual_pk', metadata,
|
||||
Column('id', Integer, primary_key=True, autoincrement=False),
|
||||
Column('data', String(50))
|
||||
)
|
||||
|
||||
@requirements.duplicate_key_raises_integrity_error
|
||||
def test_integrity_error(self):
|
||||
|
||||
with config.db.begin() as conn:
|
||||
conn.execute(
|
||||
self.tables.manual_pk.insert(),
|
||||
{'id': 1, 'data': 'd1'}
|
||||
)
|
||||
|
||||
assert_raises(
|
||||
exc.IntegrityError,
|
||||
conn.execute,
|
||||
self.tables.manual_pk.insert(),
|
||||
{'id': 1, 'data': 'd1'}
|
||||
)
|
319
sqlalchemy/testing/suite/test_insert.py
Normal file
319
sqlalchemy/testing/suite/test_insert.py
Normal file
@@ -0,0 +1,319 @@
|
||||
from .. import fixtures, config
|
||||
from ..config import requirements
|
||||
from .. import exclusions
|
||||
from ..assertions import eq_
|
||||
from .. import engines
|
||||
|
||||
from sqlalchemy import Integer, String, select, literal_column, literal
|
||||
|
||||
from ..schema import Table, Column
|
||||
|
||||
|
||||
class LastrowidTest(fixtures.TablesTest):
|
||||
run_deletes = 'each'
|
||||
|
||||
__backend__ = True
|
||||
|
||||
__requires__ = 'implements_get_lastrowid', 'autoincrement_insert'
|
||||
|
||||
__engine_options__ = {"implicit_returning": False}
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('autoinc_pk', metadata,
|
||||
Column('id', Integer, primary_key=True,
|
||||
test_needs_autoincrement=True),
|
||||
Column('data', String(50))
|
||||
)
|
||||
|
||||
Table('manual_pk', metadata,
|
||||
Column('id', Integer, primary_key=True, autoincrement=False),
|
||||
Column('data', String(50))
|
||||
)
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(
|
||||
row,
|
||||
(config.db.dialect.default_sequence_base, "some data")
|
||||
)
|
||||
|
||||
def test_autoincrement_on_insert(self):
|
||||
|
||||
config.db.execute(
|
||||
self.tables.autoinc_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
self._assert_round_trip(self.tables.autoinc_pk, config.db)
|
||||
|
||||
def test_last_inserted_id(self):
|
||||
|
||||
r = config.db.execute(
|
||||
self.tables.autoinc_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
|
||||
eq_(
|
||||
r.inserted_primary_key,
|
||||
[pk]
|
||||
)
|
||||
|
||||
# failed on pypy1.9 but seems to be OK on pypy 2.1
|
||||
# @exclusions.fails_if(lambda: util.pypy,
|
||||
# "lastrowid not maintained after "
|
||||
# "connection close")
|
||||
@requirements.dbapi_lastrowid
|
||||
def test_native_lastrowid_autoinc(self):
|
||||
r = config.db.execute(
|
||||
self.tables.autoinc_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
lastrowid = r.lastrowid
|
||||
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
|
||||
eq_(
|
||||
lastrowid, pk
|
||||
)
|
||||
|
||||
|
||||
class InsertBehaviorTest(fixtures.TablesTest):
|
||||
run_deletes = 'each'
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('autoinc_pk', metadata,
|
||||
Column('id', Integer, primary_key=True,
|
||||
test_needs_autoincrement=True),
|
||||
Column('data', String(50))
|
||||
)
|
||||
Table('manual_pk', metadata,
|
||||
Column('id', Integer, primary_key=True, autoincrement=False),
|
||||
Column('data', String(50))
|
||||
)
|
||||
Table('includes_defaults', metadata,
|
||||
Column('id', Integer, primary_key=True,
|
||||
test_needs_autoincrement=True),
|
||||
Column('data', String(50)),
|
||||
Column('x', Integer, default=5),
|
||||
Column('y', Integer,
|
||||
default=literal_column("2", type_=Integer) + literal(2)))
|
||||
|
||||
def test_autoclose_on_insert(self):
|
||||
if requirements.returning.enabled:
|
||||
engine = engines.testing_engine(
|
||||
options={'implicit_returning': False})
|
||||
else:
|
||||
engine = config.db
|
||||
|
||||
r = engine.execute(
|
||||
self.tables.autoinc_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
assert r.is_insert
|
||||
assert not r.returns_rows
|
||||
|
||||
@requirements.returning
|
||||
def test_autoclose_on_insert_implicit_returning(self):
|
||||
r = config.db.execute(
|
||||
self.tables.autoinc_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
assert r.is_insert
|
||||
assert not r.returns_rows
|
||||
|
||||
@requirements.empty_inserts
|
||||
def test_empty_insert(self):
|
||||
r = config.db.execute(
|
||||
self.tables.autoinc_pk.insert(),
|
||||
)
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
|
||||
r = config.db.execute(
|
||||
self.tables.autoinc_pk.select().
|
||||
where(self.tables.autoinc_pk.c.id != None)
|
||||
)
|
||||
|
||||
assert len(r.fetchall())
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_autoinc(self):
|
||||
src_table = self.tables.manual_pk
|
||||
dest_table = self.tables.autoinc_pk
|
||||
config.db.execute(
|
||||
src_table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
]
|
||||
)
|
||||
|
||||
result = config.db.execute(
|
||||
dest_table.insert().
|
||||
from_select(
|
||||
("data",),
|
||||
select([src_table.c.data]).
|
||||
where(src_table.c.data.in_(["data2", "data3"]))
|
||||
)
|
||||
)
|
||||
|
||||
eq_(result.inserted_primary_key, [None])
|
||||
|
||||
result = config.db.execute(
|
||||
select([dest_table.c.data]).order_by(dest_table.c.data)
|
||||
)
|
||||
eq_(result.fetchall(), [("data2", ), ("data3", )])
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_autoinc_no_rows(self):
|
||||
src_table = self.tables.manual_pk
|
||||
dest_table = self.tables.autoinc_pk
|
||||
|
||||
result = config.db.execute(
|
||||
dest_table.insert().
|
||||
from_select(
|
||||
("data",),
|
||||
select([src_table.c.data]).
|
||||
where(src_table.c.data.in_(["data2", "data3"]))
|
||||
)
|
||||
)
|
||||
eq_(result.inserted_primary_key, [None])
|
||||
|
||||
result = config.db.execute(
|
||||
select([dest_table.c.data]).order_by(dest_table.c.data)
|
||||
)
|
||||
|
||||
eq_(result.fetchall(), [])
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select(self):
|
||||
table = self.tables.manual_pk
|
||||
config.db.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
]
|
||||
)
|
||||
|
||||
config.db.execute(
|
||||
table.insert(inline=True).
|
||||
from_select(("id", "data",),
|
||||
select([table.c.id + 5, table.c.data]).
|
||||
where(table.c.data.in_(["data2", "data3"]))
|
||||
),
|
||||
)
|
||||
|
||||
eq_(
|
||||
config.db.execute(
|
||||
select([table.c.data]).order_by(table.c.data)
|
||||
).fetchall(),
|
||||
[("data1", ), ("data2", ), ("data2", ),
|
||||
("data3", ), ("data3", )]
|
||||
)
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_with_defaults(self):
|
||||
table = self.tables.includes_defaults
|
||||
config.db.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
]
|
||||
)
|
||||
|
||||
config.db.execute(
|
||||
table.insert(inline=True).
|
||||
from_select(("id", "data",),
|
||||
select([table.c.id + 5, table.c.data]).
|
||||
where(table.c.data.in_(["data2", "data3"]))
|
||||
),
|
||||
)
|
||||
|
||||
eq_(
|
||||
config.db.execute(
|
||||
select([table]).order_by(table.c.data, table.c.id)
|
||||
).fetchall(),
|
||||
[(1, 'data1', 5, 4), (2, 'data2', 5, 4),
|
||||
(7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)]
|
||||
)
|
||||
|
||||
|
||||
class ReturningTest(fixtures.TablesTest):
|
||||
run_create_tables = 'each'
|
||||
__requires__ = 'returning', 'autoincrement_insert'
|
||||
__backend__ = True
|
||||
|
||||
__engine_options__ = {"implicit_returning": True}
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(
|
||||
row,
|
||||
(config.db.dialect.default_sequence_base, "some data")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('autoinc_pk', metadata,
|
||||
Column('id', Integer, primary_key=True,
|
||||
test_needs_autoincrement=True),
|
||||
Column('data', String(50))
|
||||
)
|
||||
|
||||
@requirements.fetch_rows_post_commit
|
||||
def test_explicit_returning_pk_autocommit(self):
|
||||
engine = config.db
|
||||
table = self.tables.autoinc_pk
|
||||
r = engine.execute(
|
||||
table.insert().returning(
|
||||
table.c.id),
|
||||
data="some data"
|
||||
)
|
||||
pk = r.first()[0]
|
||||
fetched_pk = config.db.scalar(select([table.c.id]))
|
||||
eq_(fetched_pk, pk)
|
||||
|
||||
def test_explicit_returning_pk_no_autocommit(self):
|
||||
engine = config.db
|
||||
table = self.tables.autoinc_pk
|
||||
with engine.begin() as conn:
|
||||
r = conn.execute(
|
||||
table.insert().returning(
|
||||
table.c.id),
|
||||
data="some data"
|
||||
)
|
||||
pk = r.first()[0]
|
||||
fetched_pk = config.db.scalar(select([table.c.id]))
|
||||
eq_(fetched_pk, pk)
|
||||
|
||||
def test_autoincrement_on_insert_implcit_returning(self):
|
||||
|
||||
config.db.execute(
|
||||
self.tables.autoinc_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
self._assert_round_trip(self.tables.autoinc_pk, config.db)
|
||||
|
||||
def test_last_inserted_id_implicit_returning(self):
|
||||
|
||||
r = config.db.execute(
|
||||
self.tables.autoinc_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
|
||||
eq_(
|
||||
r.inserted_primary_key,
|
||||
[pk]
|
||||
)
|
||||
|
||||
|
||||
__all__ = ('LastrowidTest', 'InsertBehaviorTest', 'ReturningTest')
|
746
sqlalchemy/testing/suite/test_reflection.py
Normal file
746
sqlalchemy/testing/suite/test_reflection.py
Normal file
@@ -0,0 +1,746 @@
|
||||
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy import types as sql_types
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import MetaData, Integer, String
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.testing import engines, fixtures
|
||||
from sqlalchemy.testing.schema import Table, Column
|
||||
from sqlalchemy.testing import eq_, assert_raises_message
|
||||
from sqlalchemy import testing
|
||||
from .. import config
|
||||
import operator
|
||||
from sqlalchemy.schema import DDL, Index
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.sql.elements import quoted_name
|
||||
from sqlalchemy import ForeignKey
|
||||
|
||||
metadata, users = None, None
|
||||
|
||||
|
||||
class HasTableTest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('test_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', String(50))
|
||||
)
|
||||
|
||||
def test_has_table(self):
|
||||
with config.db.begin() as conn:
|
||||
assert config.db.dialect.has_table(conn, "test_table")
|
||||
assert not config.db.dialect.has_table(conn, "nonexistent_table")
|
||||
|
||||
|
||||
class ComponentReflectionTest(fixtures.TablesTest):
|
||||
run_inserts = run_deletes = None
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def setup_bind(cls):
|
||||
if config.requirements.independent_connections.enabled:
|
||||
from sqlalchemy import pool
|
||||
return engines.testing_engine(
|
||||
options=dict(poolclass=pool.StaticPool))
|
||||
else:
|
||||
return config.db
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
cls.define_reflected_tables(metadata, None)
|
||||
if testing.requires.schemas.enabled:
|
||||
cls.define_reflected_tables(metadata, testing.config.test_schema)
|
||||
|
||||
@classmethod
|
||||
def define_reflected_tables(cls, metadata, schema):
|
||||
if schema:
|
||||
schema_prefix = schema + "."
|
||||
else:
|
||||
schema_prefix = ""
|
||||
|
||||
if testing.requires.self_referential_foreign_keys.enabled:
|
||||
users = Table('users', metadata,
|
||||
Column('user_id', sa.INT, primary_key=True),
|
||||
Column('test1', sa.CHAR(5), nullable=False),
|
||||
Column('test2', sa.Float(5), nullable=False),
|
||||
Column('parent_user_id', sa.Integer,
|
||||
sa.ForeignKey('%susers.user_id' %
|
||||
schema_prefix)),
|
||||
schema=schema,
|
||||
test_needs_fk=True,
|
||||
)
|
||||
else:
|
||||
users = Table('users', metadata,
|
||||
Column('user_id', sa.INT, primary_key=True),
|
||||
Column('test1', sa.CHAR(5), nullable=False),
|
||||
Column('test2', sa.Float(5), nullable=False),
|
||||
schema=schema,
|
||||
test_needs_fk=True,
|
||||
)
|
||||
|
||||
Table("dingalings", metadata,
|
||||
Column('dingaling_id', sa.Integer, primary_key=True),
|
||||
Column('address_id', sa.Integer,
|
||||
sa.ForeignKey('%semail_addresses.address_id' %
|
||||
schema_prefix)),
|
||||
Column('data', sa.String(30)),
|
||||
schema=schema,
|
||||
test_needs_fk=True,
|
||||
)
|
||||
Table('email_addresses', metadata,
|
||||
Column('address_id', sa.Integer),
|
||||
Column('remote_user_id', sa.Integer,
|
||||
sa.ForeignKey(users.c.user_id)),
|
||||
Column('email_address', sa.String(20)),
|
||||
sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'),
|
||||
schema=schema,
|
||||
test_needs_fk=True,
|
||||
)
|
||||
|
||||
if testing.requires.index_reflection.enabled:
|
||||
cls.define_index(metadata, users)
|
||||
if testing.requires.view_column_reflection.enabled:
|
||||
cls.define_views(metadata, schema)
|
||||
if not schema and testing.requires.temp_table_reflection.enabled:
|
||||
cls.define_temp_tables(metadata)
|
||||
|
||||
@classmethod
|
||||
def define_temp_tables(cls, metadata):
|
||||
# cheat a bit, we should fix this with some dialect-level
|
||||
# temp table fixture
|
||||
if testing.against("oracle"):
|
||||
kw = {
|
||||
'prefixes': ["GLOBAL TEMPORARY"],
|
||||
'oracle_on_commit': 'PRESERVE ROWS'
|
||||
}
|
||||
else:
|
||||
kw = {
|
||||
'prefixes': ["TEMPORARY"],
|
||||
}
|
||||
|
||||
user_tmp = Table(
|
||||
"user_tmp", metadata,
|
||||
Column("id", sa.INT, primary_key=True),
|
||||
Column('name', sa.VARCHAR(50)),
|
||||
Column('foo', sa.INT),
|
||||
sa.UniqueConstraint('name', name='user_tmp_uq'),
|
||||
sa.Index("user_tmp_ix", "foo"),
|
||||
**kw
|
||||
)
|
||||
if testing.requires.view_reflection.enabled and \
|
||||
testing.requires.temporary_views.enabled:
|
||||
event.listen(
|
||||
user_tmp, "after_create",
|
||||
DDL("create temporary view user_tmp_v as "
|
||||
"select * from user_tmp")
|
||||
)
|
||||
event.listen(
|
||||
user_tmp, "before_drop",
|
||||
DDL("drop view user_tmp_v")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def define_index(cls, metadata, users):
|
||||
Index("users_t_idx", users.c.test1, users.c.test2)
|
||||
Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1)
|
||||
|
||||
@classmethod
|
||||
def define_views(cls, metadata, schema):
|
||||
for table_name in ('users', 'email_addresses'):
|
||||
fullname = table_name
|
||||
if schema:
|
||||
fullname = "%s.%s" % (schema, table_name)
|
||||
view_name = fullname + '_v'
|
||||
query = "CREATE VIEW %s AS SELECT * FROM %s" % (
|
||||
view_name, fullname)
|
||||
|
||||
event.listen(
|
||||
metadata,
|
||||
"after_create",
|
||||
DDL(query)
|
||||
)
|
||||
event.listen(
|
||||
metadata,
|
||||
"before_drop",
|
||||
DDL("DROP VIEW %s" % view_name)
|
||||
)
|
||||
|
||||
@testing.requires.schema_reflection
|
||||
def test_get_schema_names(self):
|
||||
insp = inspect(testing.db)
|
||||
|
||||
self.assert_(testing.config.test_schema in insp.get_schema_names())
|
||||
|
||||
@testing.requires.schema_reflection
|
||||
def test_dialect_initialize(self):
|
||||
engine = engines.testing_engine()
|
||||
assert not hasattr(engine.dialect, 'default_schema_name')
|
||||
inspect(engine)
|
||||
assert hasattr(engine.dialect, 'default_schema_name')
|
||||
|
||||
@testing.requires.schema_reflection
|
||||
def test_get_default_schema_name(self):
|
||||
insp = inspect(testing.db)
|
||||
eq_(insp.default_schema_name, testing.db.dialect.default_schema_name)
|
||||
|
||||
@testing.provide_metadata
|
||||
def _test_get_table_names(self, schema=None, table_type='table',
|
||||
order_by=None):
|
||||
meta = self.metadata
|
||||
users, addresses, dingalings = self.tables.users, \
|
||||
self.tables.email_addresses, self.tables.dingalings
|
||||
insp = inspect(meta.bind)
|
||||
|
||||
if table_type == 'view':
|
||||
table_names = insp.get_view_names(schema)
|
||||
table_names.sort()
|
||||
answer = ['email_addresses_v', 'users_v']
|
||||
eq_(sorted(table_names), answer)
|
||||
else:
|
||||
table_names = insp.get_table_names(schema,
|
||||
order_by=order_by)
|
||||
if order_by == 'foreign_key':
|
||||
answer = ['users', 'email_addresses', 'dingalings']
|
||||
eq_(table_names, answer)
|
||||
else:
|
||||
answer = ['dingalings', 'email_addresses', 'users']
|
||||
eq_(sorted(table_names), answer)
|
||||
|
||||
@testing.requires.temp_table_names
|
||||
def test_get_temp_table_names(self):
|
||||
insp = inspect(self.bind)
|
||||
temp_table_names = insp.get_temp_table_names()
|
||||
eq_(sorted(temp_table_names), ['user_tmp'])
|
||||
|
||||
@testing.requires.view_reflection
|
||||
@testing.requires.temp_table_names
|
||||
@testing.requires.temporary_views
|
||||
def test_get_temp_view_names(self):
|
||||
insp = inspect(self.bind)
|
||||
temp_table_names = insp.get_temp_view_names()
|
||||
eq_(sorted(temp_table_names), ['user_tmp_v'])
|
||||
|
||||
@testing.requires.table_reflection
|
||||
def test_get_table_names(self):
|
||||
self._test_get_table_names()
|
||||
|
||||
@testing.requires.table_reflection
|
||||
@testing.requires.foreign_key_constraint_reflection
|
||||
def test_get_table_names_fks(self):
|
||||
self._test_get_table_names(order_by='foreign_key')
|
||||
|
||||
@testing.requires.table_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_table_names_with_schema(self):
|
||||
self._test_get_table_names(testing.config.test_schema)
|
||||
|
||||
@testing.requires.view_column_reflection
|
||||
def test_get_view_names(self):
|
||||
self._test_get_table_names(table_type='view')
|
||||
|
||||
@testing.requires.view_column_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_view_names_with_schema(self):
|
||||
self._test_get_table_names(
|
||||
testing.config.test_schema, table_type='view')
|
||||
|
||||
@testing.requires.table_reflection
|
||||
@testing.requires.view_column_reflection
|
||||
def test_get_tables_and_views(self):
|
||||
self._test_get_table_names()
|
||||
self._test_get_table_names(table_type='view')
|
||||
|
||||
def _test_get_columns(self, schema=None, table_type='table'):
|
||||
meta = MetaData(testing.db)
|
||||
users, addresses, dingalings = self.tables.users, \
|
||||
self.tables.email_addresses, self.tables.dingalings
|
||||
table_names = ['users', 'email_addresses']
|
||||
if table_type == 'view':
|
||||
table_names = ['users_v', 'email_addresses_v']
|
||||
insp = inspect(meta.bind)
|
||||
for table_name, table in zip(table_names, (users,
|
||||
addresses)):
|
||||
schema_name = schema
|
||||
cols = insp.get_columns(table_name, schema=schema_name)
|
||||
self.assert_(len(cols) > 0, len(cols))
|
||||
|
||||
# should be in order
|
||||
|
||||
for i, col in enumerate(table.columns):
|
||||
eq_(col.name, cols[i]['name'])
|
||||
ctype = cols[i]['type'].__class__
|
||||
ctype_def = col.type
|
||||
if isinstance(ctype_def, sa.types.TypeEngine):
|
||||
ctype_def = ctype_def.__class__
|
||||
|
||||
# Oracle returns Date for DateTime.
|
||||
|
||||
if testing.against('oracle') and ctype_def \
|
||||
in (sql_types.Date, sql_types.DateTime):
|
||||
ctype_def = sql_types.Date
|
||||
|
||||
# assert that the desired type and return type share
|
||||
# a base within one of the generic types.
|
||||
|
||||
self.assert_(len(set(ctype.__mro__).
|
||||
intersection(ctype_def.__mro__).
|
||||
intersection([
|
||||
sql_types.Integer,
|
||||
sql_types.Numeric,
|
||||
sql_types.DateTime,
|
||||
sql_types.Date,
|
||||
sql_types.Time,
|
||||
sql_types.String,
|
||||
sql_types._Binary,
|
||||
])) > 0, '%s(%s), %s(%s)' %
|
||||
(col.name, col.type, cols[i]['name'], ctype))
|
||||
|
||||
if not col.primary_key:
|
||||
assert cols[i]['default'] is None
|
||||
|
||||
@testing.requires.table_reflection
|
||||
def test_get_columns(self):
|
||||
self._test_get_columns()
|
||||
|
||||
@testing.provide_metadata
|
||||
def _type_round_trip(self, *types):
|
||||
t = Table('t', self.metadata,
|
||||
*[
|
||||
Column('t%d' % i, type_)
|
||||
for i, type_ in enumerate(types)
|
||||
]
|
||||
)
|
||||
t.create()
|
||||
|
||||
return [
|
||||
c['type'] for c in
|
||||
inspect(self.metadata.bind).get_columns('t')
|
||||
]
|
||||
|
||||
@testing.requires.table_reflection
|
||||
def test_numeric_reflection(self):
|
||||
for typ in self._type_round_trip(
|
||||
sql_types.Numeric(18, 5),
|
||||
):
|
||||
assert isinstance(typ, sql_types.Numeric)
|
||||
eq_(typ.precision, 18)
|
||||
eq_(typ.scale, 5)
|
||||
|
||||
@testing.requires.table_reflection
|
||||
def test_varchar_reflection(self):
|
||||
typ = self._type_round_trip(sql_types.String(52))[0]
|
||||
assert isinstance(typ, sql_types.String)
|
||||
eq_(typ.length, 52)
|
||||
|
||||
@testing.requires.table_reflection
|
||||
@testing.provide_metadata
|
||||
def test_nullable_reflection(self):
|
||||
t = Table('t', self.metadata,
|
||||
Column('a', Integer, nullable=True),
|
||||
Column('b', Integer, nullable=False))
|
||||
t.create()
|
||||
eq_(
|
||||
dict(
|
||||
(col['name'], col['nullable'])
|
||||
for col in inspect(self.metadata.bind).get_columns('t')
|
||||
),
|
||||
{"a": True, "b": False}
|
||||
)
|
||||
|
||||
@testing.requires.table_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_columns_with_schema(self):
|
||||
self._test_get_columns(schema=testing.config.test_schema)
|
||||
|
||||
@testing.requires.temp_table_reflection
|
||||
def test_get_temp_table_columns(self):
|
||||
meta = MetaData(self.bind)
|
||||
user_tmp = self.tables.user_tmp
|
||||
insp = inspect(meta.bind)
|
||||
cols = insp.get_columns('user_tmp')
|
||||
self.assert_(len(cols) > 0, len(cols))
|
||||
|
||||
for i, col in enumerate(user_tmp.columns):
|
||||
eq_(col.name, cols[i]['name'])
|
||||
|
||||
@testing.requires.temp_table_reflection
|
||||
@testing.requires.view_column_reflection
|
||||
@testing.requires.temporary_views
|
||||
def test_get_temp_view_columns(self):
|
||||
insp = inspect(self.bind)
|
||||
cols = insp.get_columns('user_tmp_v')
|
||||
eq_(
|
||||
[col['name'] for col in cols],
|
||||
['id', 'name', 'foo']
|
||||
)
|
||||
|
||||
@testing.requires.view_column_reflection
|
||||
def test_get_view_columns(self):
|
||||
self._test_get_columns(table_type='view')
|
||||
|
||||
@testing.requires.view_column_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_view_columns_with_schema(self):
|
||||
self._test_get_columns(
|
||||
schema=testing.config.test_schema, table_type='view')
|
||||
|
||||
@testing.provide_metadata
|
||||
def _test_get_pk_constraint(self, schema=None):
|
||||
meta = self.metadata
|
||||
users, addresses = self.tables.users, self.tables.email_addresses
|
||||
insp = inspect(meta.bind)
|
||||
|
||||
users_cons = insp.get_pk_constraint(users.name, schema=schema)
|
||||
users_pkeys = users_cons['constrained_columns']
|
||||
eq_(users_pkeys, ['user_id'])
|
||||
|
||||
addr_cons = insp.get_pk_constraint(addresses.name, schema=schema)
|
||||
addr_pkeys = addr_cons['constrained_columns']
|
||||
eq_(addr_pkeys, ['address_id'])
|
||||
|
||||
with testing.requires.reflects_pk_names.fail_if():
|
||||
eq_(addr_cons['name'], 'email_ad_pk')
|
||||
|
||||
@testing.requires.primary_key_constraint_reflection
|
||||
def test_get_pk_constraint(self):
|
||||
self._test_get_pk_constraint()
|
||||
|
||||
@testing.requires.table_reflection
|
||||
@testing.requires.primary_key_constraint_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_pk_constraint_with_schema(self):
|
||||
self._test_get_pk_constraint(schema=testing.config.test_schema)
|
||||
|
||||
@testing.requires.table_reflection
|
||||
@testing.provide_metadata
|
||||
def test_deprecated_get_primary_keys(self):
|
||||
meta = self.metadata
|
||||
users = self.tables.users
|
||||
insp = Inspector(meta.bind)
|
||||
assert_raises_message(
|
||||
sa_exc.SADeprecationWarning,
|
||||
"Call to deprecated method get_primary_keys."
|
||||
" Use get_pk_constraint instead.",
|
||||
insp.get_primary_keys, users.name
|
||||
)
|
||||
|
||||
@testing.provide_metadata
|
||||
def _test_get_foreign_keys(self, schema=None):
|
||||
meta = self.metadata
|
||||
users, addresses, dingalings = self.tables.users, \
|
||||
self.tables.email_addresses, self.tables.dingalings
|
||||
insp = inspect(meta.bind)
|
||||
expected_schema = schema
|
||||
# users
|
||||
|
||||
if testing.requires.self_referential_foreign_keys.enabled:
|
||||
users_fkeys = insp.get_foreign_keys(users.name,
|
||||
schema=schema)
|
||||
fkey1 = users_fkeys[0]
|
||||
|
||||
with testing.requires.named_constraints.fail_if():
|
||||
self.assert_(fkey1['name'] is not None)
|
||||
|
||||
eq_(fkey1['referred_schema'], expected_schema)
|
||||
eq_(fkey1['referred_table'], users.name)
|
||||
eq_(fkey1['referred_columns'], ['user_id', ])
|
||||
if testing.requires.self_referential_foreign_keys.enabled:
|
||||
eq_(fkey1['constrained_columns'], ['parent_user_id'])
|
||||
|
||||
# addresses
|
||||
addr_fkeys = insp.get_foreign_keys(addresses.name,
|
||||
schema=schema)
|
||||
fkey1 = addr_fkeys[0]
|
||||
|
||||
with testing.requires.named_constraints.fail_if():
|
||||
self.assert_(fkey1['name'] is not None)
|
||||
|
||||
eq_(fkey1['referred_schema'], expected_schema)
|
||||
eq_(fkey1['referred_table'], users.name)
|
||||
eq_(fkey1['referred_columns'], ['user_id', ])
|
||||
eq_(fkey1['constrained_columns'], ['remote_user_id'])
|
||||
|
||||
@testing.requires.foreign_key_constraint_reflection
|
||||
def test_get_foreign_keys(self):
|
||||
self._test_get_foreign_keys()
|
||||
|
||||
@testing.requires.foreign_key_constraint_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_foreign_keys_with_schema(self):
|
||||
self._test_get_foreign_keys(schema=testing.config.test_schema)
|
||||
|
||||
@testing.requires.foreign_key_constraint_option_reflection
|
||||
@testing.provide_metadata
|
||||
def test_get_foreign_key_options(self):
|
||||
meta = self.metadata
|
||||
|
||||
Table(
|
||||
'x', meta,
|
||||
Column('id', Integer, primary_key=True),
|
||||
test_needs_fk=True
|
||||
)
|
||||
|
||||
Table('table', meta,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('x_id', Integer, sa.ForeignKey('x.id', name='xid')),
|
||||
Column('test', String(10)),
|
||||
test_needs_fk=True)
|
||||
|
||||
Table('user', meta,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('name', String(50), nullable=False),
|
||||
Column('tid', Integer),
|
||||
sa.ForeignKeyConstraint(
|
||||
['tid'], ['table.id'],
|
||||
name='myfk',
|
||||
onupdate="SET NULL", ondelete="CASCADE"),
|
||||
test_needs_fk=True)
|
||||
|
||||
meta.create_all()
|
||||
|
||||
insp = inspect(meta.bind)
|
||||
|
||||
# test 'options' is always present for a backend
|
||||
# that can reflect these, since alembic looks for this
|
||||
opts = insp.get_foreign_keys('table')[0]['options']
|
||||
|
||||
eq_(
|
||||
dict(
|
||||
(k, opts[k])
|
||||
for k in opts if opts[k]
|
||||
),
|
||||
{}
|
||||
)
|
||||
|
||||
opts = insp.get_foreign_keys('user')[0]['options']
|
||||
eq_(
|
||||
dict(
|
||||
(k, opts[k])
|
||||
for k in opts if opts[k]
|
||||
),
|
||||
{'onupdate': 'SET NULL', 'ondelete': 'CASCADE'}
|
||||
)
|
||||
|
||||
@testing.provide_metadata
|
||||
def _test_get_indexes(self, schema=None):
|
||||
meta = self.metadata
|
||||
users, addresses, dingalings = self.tables.users, \
|
||||
self.tables.email_addresses, self.tables.dingalings
|
||||
# The database may decide to create indexes for foreign keys, etc.
|
||||
# so there may be more indexes than expected.
|
||||
insp = inspect(meta.bind)
|
||||
indexes = insp.get_indexes('users', schema=schema)
|
||||
expected_indexes = [
|
||||
{'unique': False,
|
||||
'column_names': ['test1', 'test2'],
|
||||
'name': 'users_t_idx'},
|
||||
{'unique': False,
|
||||
'column_names': ['user_id', 'test2', 'test1'],
|
||||
'name': 'users_all_idx'}
|
||||
]
|
||||
index_names = [d['name'] for d in indexes]
|
||||
for e_index in expected_indexes:
|
||||
assert e_index['name'] in index_names
|
||||
index = indexes[index_names.index(e_index['name'])]
|
||||
for key in e_index:
|
||||
eq_(e_index[key], index[key])
|
||||
|
||||
@testing.requires.index_reflection
|
||||
def test_get_indexes(self):
|
||||
self._test_get_indexes()
|
||||
|
||||
@testing.requires.index_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_indexes_with_schema(self):
|
||||
self._test_get_indexes(schema=testing.config.test_schema)
|
||||
|
||||
@testing.requires.unique_constraint_reflection
|
||||
def test_get_unique_constraints(self):
|
||||
self._test_get_unique_constraints()
|
||||
|
||||
@testing.requires.temp_table_reflection
|
||||
@testing.requires.unique_constraint_reflection
|
||||
def test_get_temp_table_unique_constraints(self):
|
||||
insp = inspect(self.bind)
|
||||
reflected = insp.get_unique_constraints('user_tmp')
|
||||
for refl in reflected:
|
||||
# Different dialects handle duplicate index and constraints
|
||||
# differently, so ignore this flag
|
||||
refl.pop('duplicates_index', None)
|
||||
eq_(reflected, [{'column_names': ['name'], 'name': 'user_tmp_uq'}])
|
||||
|
||||
@testing.requires.temp_table_reflection
|
||||
def test_get_temp_table_indexes(self):
|
||||
insp = inspect(self.bind)
|
||||
indexes = insp.get_indexes('user_tmp')
|
||||
for ind in indexes:
|
||||
ind.pop('dialect_options', None)
|
||||
eq_(
|
||||
# TODO: we need to add better filtering for indexes/uq constraints
|
||||
# that are doubled up
|
||||
[idx for idx in indexes if idx['name'] == 'user_tmp_ix'],
|
||||
[{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}]
|
||||
)
|
||||
|
||||
@testing.requires.unique_constraint_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_unique_constraints_with_schema(self):
|
||||
self._test_get_unique_constraints(schema=testing.config.test_schema)
|
||||
|
||||
@testing.provide_metadata
|
||||
def _test_get_unique_constraints(self, schema=None):
|
||||
# SQLite dialect needs to parse the names of the constraints
|
||||
# separately from what it gets from PRAGMA index_list(), and
|
||||
# then matches them up. so same set of column_names in two
|
||||
# constraints will confuse it. Perhaps we should no longer
|
||||
# bother with index_list() here since we have the whole
|
||||
# CREATE TABLE?
|
||||
uniques = sorted(
|
||||
[
|
||||
{'name': 'unique_a', 'column_names': ['a']},
|
||||
{'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']},
|
||||
{'name': 'unique_c_a_b', 'column_names': ['c', 'a', 'b']},
|
||||
{'name': 'unique_asc_key', 'column_names': ['asc', 'key']},
|
||||
{'name': 'i.have.dots', 'column_names': ['b']},
|
||||
{'name': 'i have spaces', 'column_names': ['c']},
|
||||
],
|
||||
key=operator.itemgetter('name')
|
||||
)
|
||||
orig_meta = self.metadata
|
||||
table = Table(
|
||||
'testtbl', orig_meta,
|
||||
Column('a', sa.String(20)),
|
||||
Column('b', sa.String(30)),
|
||||
Column('c', sa.Integer),
|
||||
# reserved identifiers
|
||||
Column('asc', sa.String(30)),
|
||||
Column('key', sa.String(30)),
|
||||
schema=schema
|
||||
)
|
||||
for uc in uniques:
|
||||
table.append_constraint(
|
||||
sa.UniqueConstraint(*uc['column_names'], name=uc['name'])
|
||||
)
|
||||
orig_meta.create_all()
|
||||
|
||||
inspector = inspect(orig_meta.bind)
|
||||
reflected = sorted(
|
||||
inspector.get_unique_constraints('testtbl', schema=schema),
|
||||
key=operator.itemgetter('name')
|
||||
)
|
||||
|
||||
for orig, refl in zip(uniques, reflected):
|
||||
# Different dialects handle duplicate index and constraints
|
||||
# differently, so ignore this flag
|
||||
refl.pop('duplicates_index', None)
|
||||
eq_(orig, refl)
|
||||
|
||||
@testing.provide_metadata
|
||||
def _test_get_view_definition(self, schema=None):
|
||||
meta = self.metadata
|
||||
users, addresses, dingalings = self.tables.users, \
|
||||
self.tables.email_addresses, self.tables.dingalings
|
||||
view_name1 = 'users_v'
|
||||
view_name2 = 'email_addresses_v'
|
||||
insp = inspect(meta.bind)
|
||||
v1 = insp.get_view_definition(view_name1, schema=schema)
|
||||
self.assert_(v1)
|
||||
v2 = insp.get_view_definition(view_name2, schema=schema)
|
||||
self.assert_(v2)
|
||||
|
||||
@testing.requires.view_reflection
|
||||
def test_get_view_definition(self):
|
||||
self._test_get_view_definition()
|
||||
|
||||
@testing.requires.view_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_view_definition_with_schema(self):
|
||||
self._test_get_view_definition(schema=testing.config.test_schema)
|
||||
|
||||
@testing.only_on("postgresql", "PG specific feature")
|
||||
@testing.provide_metadata
|
||||
def _test_get_table_oid(self, table_name, schema=None):
|
||||
meta = self.metadata
|
||||
users, addresses, dingalings = self.tables.users, \
|
||||
self.tables.email_addresses, self.tables.dingalings
|
||||
insp = inspect(meta.bind)
|
||||
oid = insp.get_table_oid(table_name, schema)
|
||||
self.assert_(isinstance(oid, int))
|
||||
|
||||
def test_get_table_oid(self):
|
||||
self._test_get_table_oid('users')
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_get_table_oid_with_schema(self):
|
||||
self._test_get_table_oid('users', schema=testing.config.test_schema)
|
||||
|
||||
@testing.requires.table_reflection
|
||||
@testing.provide_metadata
|
||||
def test_autoincrement_col(self):
|
||||
"""test that 'autoincrement' is reflected according to sqla's policy.
|
||||
|
||||
Don't mark this test as unsupported for any backend !
|
||||
|
||||
(technically it fails with MySQL InnoDB since "id" comes before "id2")
|
||||
|
||||
A backend is better off not returning "autoincrement" at all,
|
||||
instead of potentially returning "False" for an auto-incrementing
|
||||
primary key column.
|
||||
|
||||
"""
|
||||
|
||||
meta = self.metadata
|
||||
insp = inspect(meta.bind)
|
||||
|
||||
for tname, cname in [
|
||||
('users', 'user_id'),
|
||||
('email_addresses', 'address_id'),
|
||||
('dingalings', 'dingaling_id'),
|
||||
]:
|
||||
cols = insp.get_columns(tname)
|
||||
id_ = dict((c['name'], c) for c in cols)[cname]
|
||||
assert id_.get('autoincrement', True)
|
||||
|
||||
|
||||
class NormalizedNameTest(fixtures.TablesTest):
|
||||
__requires__ = 'denormalized_names',
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
quoted_name('t1', quote=True), metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
)
|
||||
Table(
|
||||
quoted_name('t2', quote=True), metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('t1id', ForeignKey('t1.id'))
|
||||
)
|
||||
|
||||
def test_reflect_lowercase_forced_tables(self):
|
||||
|
||||
m2 = MetaData(testing.db)
|
||||
t2_ref = Table(quoted_name('t2', quote=True), m2, autoload=True)
|
||||
t1_ref = m2.tables['t1']
|
||||
assert t2_ref.c.t1id.references(t1_ref.c.id)
|
||||
|
||||
m3 = MetaData(testing.db)
|
||||
m3.reflect(only=lambda name, m: name.lower() in ('t1', 't2'))
|
||||
assert m3.tables['t2'].c.t1id.references(m3.tables['t1'].c.id)
|
||||
|
||||
def test_get_table_names(self):
|
||||
tablenames = [
|
||||
t for t in inspect(testing.db).get_table_names()
|
||||
if t.lower() in ("t1", "t2")]
|
||||
|
||||
eq_(tablenames[0].upper(), tablenames[0].lower())
|
||||
eq_(tablenames[1].upper(), tablenames[1].lower())
|
||||
|
||||
|
||||
__all__ = ('ComponentReflectionTest', 'HasTableTest', 'NormalizedNameTest')
|
367
sqlalchemy/testing/suite/test_results.py
Normal file
367
sqlalchemy/testing/suite/test_results.py
Normal file
@@ -0,0 +1,367 @@
|
||||
from .. import fixtures, config
|
||||
from ..config import requirements
|
||||
from .. import exclusions
|
||||
from ..assertions import eq_
|
||||
from .. import engines
|
||||
from ... import testing
|
||||
|
||||
from sqlalchemy import Integer, String, select, util, sql, DateTime, text, func
|
||||
import datetime
|
||||
from ..schema import Table, Column
|
||||
|
||||
|
||||
class RowFetchTest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('plain_pk', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', String(50))
|
||||
)
|
||||
Table('has_dates', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('today', DateTime)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls):
|
||||
config.db.execute(
|
||||
cls.tables.plain_pk.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
]
|
||||
)
|
||||
|
||||
config.db.execute(
|
||||
cls.tables.has_dates.insert(),
|
||||
[
|
||||
{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}
|
||||
]
|
||||
)
|
||||
|
||||
def test_via_string(self):
|
||||
row = config.db.execute(
|
||||
self.tables.plain_pk.select().
|
||||
order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(
|
||||
row['id'], 1
|
||||
)
|
||||
eq_(
|
||||
row['data'], "d1"
|
||||
)
|
||||
|
||||
def test_via_int(self):
|
||||
row = config.db.execute(
|
||||
self.tables.plain_pk.select().
|
||||
order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(
|
||||
row[0], 1
|
||||
)
|
||||
eq_(
|
||||
row[1], "d1"
|
||||
)
|
||||
|
||||
def test_via_col_object(self):
|
||||
row = config.db.execute(
|
||||
self.tables.plain_pk.select().
|
||||
order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(
|
||||
row[self.tables.plain_pk.c.id], 1
|
||||
)
|
||||
eq_(
|
||||
row[self.tables.plain_pk.c.data], "d1"
|
||||
)
|
||||
|
||||
@requirements.duplicate_names_in_cursor_description
|
||||
def test_row_with_dupe_names(self):
|
||||
result = config.db.execute(
|
||||
select([self.tables.plain_pk.c.data,
|
||||
self.tables.plain_pk.c.data.label('data')]).
|
||||
order_by(self.tables.plain_pk.c.id)
|
||||
)
|
||||
row = result.first()
|
||||
eq_(result.keys(), ['data', 'data'])
|
||||
eq_(row, ('d1', 'd1'))
|
||||
|
||||
def test_row_w_scalar_select(self):
|
||||
"""test that a scalar select as a column is returned as such
|
||||
and that type conversion works OK.
|
||||
|
||||
(this is half a SQLAlchemy Core test and half to catch database
|
||||
backends that may have unusual behavior with scalar selects.)
|
||||
|
||||
"""
|
||||
datetable = self.tables.has_dates
|
||||
s = select([datetable.alias('x').c.today]).as_scalar()
|
||||
s2 = select([datetable.c.id, s.label('somelabel')])
|
||||
row = config.db.execute(s2).first()
|
||||
|
||||
eq_(row['somelabel'], datetime.datetime(2006, 5, 12, 12, 0, 0))
|
||||
|
||||
|
||||
class PercentSchemaNamesTest(fixtures.TablesTest):
|
||||
"""tests using percent signs, spaces in table and column names.
|
||||
|
||||
This is a very fringe use case, doesn't work for MySQL
|
||||
or PostgreSQL. the requirement, "percent_schema_names",
|
||||
is marked "skip" by default.
|
||||
|
||||
"""
|
||||
|
||||
__requires__ = ('percent_schema_names', )
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
cls.tables.percent_table = Table('percent%table', metadata,
|
||||
Column("percent%", Integer),
|
||||
Column(
|
||||
"spaces % more spaces", Integer),
|
||||
)
|
||||
cls.tables.lightweight_percent_table = sql.table(
|
||||
'percent%table', sql.column("percent%"),
|
||||
sql.column("spaces % more spaces")
|
||||
)
|
||||
|
||||
def test_single_roundtrip(self):
|
||||
percent_table = self.tables.percent_table
|
||||
for params in [
|
||||
{'percent%': 5, 'spaces % more spaces': 12},
|
||||
{'percent%': 7, 'spaces % more spaces': 11},
|
||||
{'percent%': 9, 'spaces % more spaces': 10},
|
||||
{'percent%': 11, 'spaces % more spaces': 9}
|
||||
]:
|
||||
config.db.execute(percent_table.insert(), params)
|
||||
self._assert_table()
|
||||
|
||||
def test_executemany_roundtrip(self):
|
||||
percent_table = self.tables.percent_table
|
||||
config.db.execute(
|
||||
percent_table.insert(),
|
||||
{'percent%': 5, 'spaces % more spaces': 12}
|
||||
)
|
||||
config.db.execute(
|
||||
percent_table.insert(),
|
||||
[{'percent%': 7, 'spaces % more spaces': 11},
|
||||
{'percent%': 9, 'spaces % more spaces': 10},
|
||||
{'percent%': 11, 'spaces % more spaces': 9}]
|
||||
)
|
||||
self._assert_table()
|
||||
|
||||
def _assert_table(self):
|
||||
percent_table = self.tables.percent_table
|
||||
lightweight_percent_table = self.tables.lightweight_percent_table
|
||||
|
||||
for table in (
|
||||
percent_table,
|
||||
percent_table.alias(),
|
||||
lightweight_percent_table,
|
||||
lightweight_percent_table.alias()):
|
||||
eq_(
|
||||
list(
|
||||
config.db.execute(
|
||||
table.select().order_by(table.c['percent%'])
|
||||
)
|
||||
),
|
||||
[
|
||||
(5, 12),
|
||||
(7, 11),
|
||||
(9, 10),
|
||||
(11, 9)
|
||||
]
|
||||
)
|
||||
|
||||
eq_(
|
||||
list(
|
||||
config.db.execute(
|
||||
table.select().
|
||||
where(table.c['spaces % more spaces'].in_([9, 10])).
|
||||
order_by(table.c['percent%']),
|
||||
)
|
||||
),
|
||||
[
|
||||
(9, 10),
|
||||
(11, 9)
|
||||
]
|
||||
)
|
||||
|
||||
row = config.db.execute(table.select().
|
||||
order_by(table.c['percent%'])).first()
|
||||
eq_(row['percent%'], 5)
|
||||
eq_(row['spaces % more spaces'], 12)
|
||||
|
||||
eq_(row[table.c['percent%']], 5)
|
||||
eq_(row[table.c['spaces % more spaces']], 12)
|
||||
|
||||
config.db.execute(
|
||||
percent_table.update().values(
|
||||
{percent_table.c['spaces % more spaces']: 15}
|
||||
)
|
||||
)
|
||||
|
||||
eq_(
|
||||
list(
|
||||
config.db.execute(
|
||||
percent_table.
|
||||
select().
|
||||
order_by(percent_table.c['percent%'])
|
||||
)
|
||||
),
|
||||
[(5, 15), (7, 15), (9, 15), (11, 15)]
|
||||
)
|
||||
|
||||
|
||||
class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
|
||||
|
||||
__requires__ = ('server_side_cursors', )
|
||||
|
||||
__backend__ = True
|
||||
|
||||
def _is_server_side(self, cursor):
|
||||
if self.engine.url.drivername == 'postgresql':
|
||||
return cursor.name
|
||||
elif self.engine.url.drivername == 'mysql':
|
||||
sscursor = __import__('MySQLdb.cursors').cursors.SSCursor
|
||||
return isinstance(cursor, sscursor)
|
||||
elif self.engine.url.drivername == 'mysql+pymysql':
|
||||
sscursor = __import__('pymysql.cursors').cursors.SSCursor
|
||||
return isinstance(cursor, sscursor)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _fixture(self, server_side_cursors):
|
||||
self.engine = engines.testing_engine(
|
||||
options={'server_side_cursors': server_side_cursors}
|
||||
)
|
||||
return self.engine
|
||||
|
||||
def tearDown(self):
|
||||
engines.testing_reaper.close_all()
|
||||
self.engine.dispose()
|
||||
|
||||
def test_global_string(self):
|
||||
engine = self._fixture(True)
|
||||
result = engine.execute('select 1')
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
def test_global_text(self):
|
||||
engine = self._fixture(True)
|
||||
result = engine.execute(text('select 1'))
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
def test_global_expr(self):
|
||||
engine = self._fixture(True)
|
||||
result = engine.execute(select([1]))
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
def test_global_off_explicit(self):
|
||||
engine = self._fixture(False)
|
||||
result = engine.execute(text('select 1'))
|
||||
|
||||
# It should be off globally ...
|
||||
|
||||
assert not self._is_server_side(result.cursor)
|
||||
|
||||
def test_stmt_option(self):
|
||||
engine = self._fixture(False)
|
||||
|
||||
s = select([1]).execution_options(stream_results=True)
|
||||
result = engine.execute(s)
|
||||
|
||||
# ... but enabled for this one.
|
||||
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
def test_conn_option(self):
|
||||
engine = self._fixture(False)
|
||||
|
||||
# and this one
|
||||
result = \
|
||||
engine.connect().execution_options(stream_results=True).\
|
||||
execute('select 1'
|
||||
)
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
def test_stmt_enabled_conn_option_disabled(self):
|
||||
engine = self._fixture(False)
|
||||
|
||||
s = select([1]).execution_options(stream_results=True)
|
||||
|
||||
# not this one
|
||||
result = \
|
||||
engine.connect().execution_options(stream_results=False).\
|
||||
execute(s)
|
||||
assert not self._is_server_side(result.cursor)
|
||||
|
||||
def test_stmt_option_disabled(self):
|
||||
engine = self._fixture(True)
|
||||
s = select([1]).execution_options(stream_results=False)
|
||||
result = engine.execute(s)
|
||||
assert not self._is_server_side(result.cursor)
|
||||
|
||||
def test_aliases_and_ss(self):
|
||||
engine = self._fixture(False)
|
||||
s1 = select([1]).execution_options(stream_results=True).alias()
|
||||
result = engine.execute(s1)
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
# s1's options shouldn't affect s2 when s2 is used as a
|
||||
# from_obj.
|
||||
s2 = select([1], from_obj=s1)
|
||||
result = engine.execute(s2)
|
||||
assert not self._is_server_side(result.cursor)
|
||||
|
||||
def test_for_update_expr(self):
|
||||
engine = self._fixture(True)
|
||||
s1 = select([1], for_update=True)
|
||||
result = engine.execute(s1)
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
def test_for_update_string(self):
|
||||
engine = self._fixture(True)
|
||||
result = engine.execute('SELECT 1 FOR UPDATE')
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
def test_text_no_ss(self):
|
||||
engine = self._fixture(False)
|
||||
s = text('select 42')
|
||||
result = engine.execute(s)
|
||||
assert not self._is_server_side(result.cursor)
|
||||
|
||||
def test_text_ss_option(self):
|
||||
engine = self._fixture(False)
|
||||
s = text('select 42').execution_options(stream_results=True)
|
||||
result = engine.execute(s)
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
@testing.provide_metadata
|
||||
def test_roundtrip(self):
|
||||
md = self.metadata
|
||||
|
||||
engine = self._fixture(True)
|
||||
test_table = Table('test_table', md,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', String(50)))
|
||||
test_table.create(checkfirst=True)
|
||||
test_table.insert().execute(data='data1')
|
||||
test_table.insert().execute(data='data2')
|
||||
eq_(test_table.select().execute().fetchall(), [(1, 'data1'
|
||||
), (2, 'data2')])
|
||||
test_table.update().where(
|
||||
test_table.c.id == 2).values(
|
||||
data=test_table.c.data +
|
||||
' updated').execute()
|
||||
eq_(test_table.select().execute().fetchall(),
|
||||
[(1, 'data1'), (2, 'data2 updated')])
|
||||
test_table.delete().execute()
|
||||
eq_(select([func.count('*')]).select_from(test_table).scalar(), 0)
|
312
sqlalchemy/testing/suite/test_select.py
Normal file
312
sqlalchemy/testing/suite/test_select.py
Normal file
@@ -0,0 +1,312 @@
|
||||
from .. import fixtures, config
|
||||
from ..assertions import eq_
|
||||
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy import Integer, String, select, func, bindparam, union
|
||||
from sqlalchemy import testing
|
||||
|
||||
from ..schema import Table, Column
|
||||
|
||||
|
||||
class OrderByLabelTest(fixtures.TablesTest):
|
||||
"""Test the dialect sends appropriate ORDER BY expressions when
|
||||
labels are used.
|
||||
|
||||
This essentially exercises the "supports_simple_order_by_label"
|
||||
setting.
|
||||
|
||||
"""
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table("some_table", metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('x', Integer),
|
||||
Column('y', Integer),
|
||||
Column('q', String(50)),
|
||||
Column('p', String(50))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls):
|
||||
config.db.execute(
|
||||
cls.tables.some_table.insert(),
|
||||
[
|
||||
{"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"},
|
||||
{"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"},
|
||||
{"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"},
|
||||
]
|
||||
)
|
||||
|
||||
def _assert_result(self, select, result):
|
||||
eq_(
|
||||
config.db.execute(select).fetchall(),
|
||||
result
|
||||
)
|
||||
|
||||
def test_plain(self):
|
||||
table = self.tables.some_table
|
||||
lx = table.c.x.label('lx')
|
||||
self._assert_result(
|
||||
select([lx]).order_by(lx),
|
||||
[(1, ), (2, ), (3, )]
|
||||
)
|
||||
|
||||
def test_composed_int(self):
|
||||
table = self.tables.some_table
|
||||
lx = (table.c.x + table.c.y).label('lx')
|
||||
self._assert_result(
|
||||
select([lx]).order_by(lx),
|
||||
[(3, ), (5, ), (7, )]
|
||||
)
|
||||
|
||||
def test_composed_multiple(self):
|
||||
table = self.tables.some_table
|
||||
lx = (table.c.x + table.c.y).label('lx')
|
||||
ly = (func.lower(table.c.q) + table.c.p).label('ly')
|
||||
self._assert_result(
|
||||
select([lx, ly]).order_by(lx, ly.desc()),
|
||||
[(3, util.u('q1p3')), (5, util.u('q2p2')), (7, util.u('q3p1'))]
|
||||
)
|
||||
|
||||
def test_plain_desc(self):
|
||||
table = self.tables.some_table
|
||||
lx = table.c.x.label('lx')
|
||||
self._assert_result(
|
||||
select([lx]).order_by(lx.desc()),
|
||||
[(3, ), (2, ), (1, )]
|
||||
)
|
||||
|
||||
def test_composed_int_desc(self):
|
||||
table = self.tables.some_table
|
||||
lx = (table.c.x + table.c.y).label('lx')
|
||||
self._assert_result(
|
||||
select([lx]).order_by(lx.desc()),
|
||||
[(7, ), (5, ), (3, )]
|
||||
)
|
||||
|
||||
def test_group_by_composed(self):
|
||||
table = self.tables.some_table
|
||||
expr = (table.c.x + table.c.y).label('lx')
|
||||
stmt = select([func.count(table.c.id), expr]).group_by(expr).order_by(expr)
|
||||
self._assert_result(
|
||||
stmt,
|
||||
[(1, 3), (1, 5), (1, 7)]
|
||||
)
|
||||
|
||||
|
||||
class LimitOffsetTest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table("some_table", metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('x', Integer),
|
||||
Column('y', Integer))
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls):
|
||||
config.db.execute(
|
||||
cls.tables.some_table.insert(),
|
||||
[
|
||||
{"id": 1, "x": 1, "y": 2},
|
||||
{"id": 2, "x": 2, "y": 3},
|
||||
{"id": 3, "x": 3, "y": 4},
|
||||
{"id": 4, "x": 4, "y": 5},
|
||||
]
|
||||
)
|
||||
|
||||
def _assert_result(self, select, result, params=()):
|
||||
eq_(
|
||||
config.db.execute(select, params).fetchall(),
|
||||
result
|
||||
)
|
||||
|
||||
def test_simple_limit(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).limit(2),
|
||||
[(1, 1, 2), (2, 2, 3)]
|
||||
)
|
||||
|
||||
@testing.requires.offset
|
||||
def test_simple_offset(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).offset(2),
|
||||
[(3, 3, 4), (4, 4, 5)]
|
||||
)
|
||||
|
||||
@testing.requires.offset
|
||||
def test_simple_limit_offset(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).limit(2).offset(1),
|
||||
[(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.offset
|
||||
def test_limit_offset_nobinds(self):
|
||||
"""test that 'literal binds' mode works - no bound params."""
|
||||
|
||||
table = self.tables.some_table
|
||||
stmt = select([table]).order_by(table.c.id).limit(2).offset(1)
|
||||
sql = stmt.compile(
|
||||
dialect=config.db.dialect,
|
||||
compile_kwargs={"literal_binds": True})
|
||||
sql = str(sql)
|
||||
|
||||
self._assert_result(
|
||||
sql,
|
||||
[(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.bound_limit_offset
|
||||
def test_bound_limit(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).limit(bindparam('l')),
|
||||
[(1, 1, 2), (2, 2, 3)],
|
||||
params={"l": 2}
|
||||
)
|
||||
|
||||
@testing.requires.bound_limit_offset
|
||||
def test_bound_offset(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).offset(bindparam('o')),
|
||||
[(3, 3, 4), (4, 4, 5)],
|
||||
params={"o": 2}
|
||||
)
|
||||
|
||||
@testing.requires.bound_limit_offset
|
||||
def test_bound_limit_offset(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).
|
||||
limit(bindparam("l")).offset(bindparam("o")),
|
||||
[(2, 2, 3), (3, 3, 4)],
|
||||
params={"l": 2, "o": 1}
|
||||
)
|
||||
|
||||
|
||||
class CompoundSelectTest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table("some_table", metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('x', Integer),
|
||||
Column('y', Integer))
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls):
|
||||
config.db.execute(
|
||||
cls.tables.some_table.insert(),
|
||||
[
|
||||
{"id": 1, "x": 1, "y": 2},
|
||||
{"id": 2, "x": 2, "y": 3},
|
||||
{"id": 3, "x": 3, "y": 4},
|
||||
{"id": 4, "x": 4, "y": 5},
|
||||
]
|
||||
)
|
||||
|
||||
def _assert_result(self, select, result, params=()):
|
||||
eq_(
|
||||
config.db.execute(select, params).fetchall(),
|
||||
result
|
||||
)
|
||||
|
||||
def test_plain_union(self):
|
||||
table = self.tables.some_table
|
||||
s1 = select([table]).where(table.c.id == 2)
|
||||
s2 = select([table]).where(table.c.id == 3)
|
||||
|
||||
u1 = union(s1, s2)
|
||||
self._assert_result(
|
||||
u1.order_by(u1.c.id),
|
||||
[(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
def test_select_from_plain_union(self):
|
||||
table = self.tables.some_table
|
||||
s1 = select([table]).where(table.c.id == 2)
|
||||
s2 = select([table]).where(table.c.id == 3)
|
||||
|
||||
u1 = union(s1, s2).alias().select()
|
||||
self._assert_result(
|
||||
u1.order_by(u1.c.id),
|
||||
[(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.parens_in_union_contained_select_w_limit_offset
|
||||
def test_limit_offset_selectable_in_unions(self):
|
||||
table = self.tables.some_table
|
||||
s1 = select([table]).where(table.c.id == 2).\
|
||||
limit(1).order_by(table.c.id)
|
||||
s2 = select([table]).where(table.c.id == 3).\
|
||||
limit(1).order_by(table.c.id)
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
self._assert_result(
|
||||
u1.order_by(u1.c.id),
|
||||
[(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.parens_in_union_contained_select_wo_limit_offset
|
||||
def test_order_by_selectable_in_unions(self):
|
||||
table = self.tables.some_table
|
||||
s1 = select([table]).where(table.c.id == 2).\
|
||||
order_by(table.c.id)
|
||||
s2 = select([table]).where(table.c.id == 3).\
|
||||
order_by(table.c.id)
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
self._assert_result(
|
||||
u1.order_by(u1.c.id),
|
||||
[(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
def test_distinct_selectable_in_unions(self):
|
||||
table = self.tables.some_table
|
||||
s1 = select([table]).where(table.c.id == 2).\
|
||||
distinct()
|
||||
s2 = select([table]).where(table.c.id == 3).\
|
||||
distinct()
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
self._assert_result(
|
||||
u1.order_by(u1.c.id),
|
||||
[(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.parens_in_union_contained_select_w_limit_offset
|
||||
def test_limit_offset_in_unions_from_alias(self):
|
||||
table = self.tables.some_table
|
||||
s1 = select([table]).where(table.c.id == 2).\
|
||||
limit(1).order_by(table.c.id)
|
||||
s2 = select([table]).where(table.c.id == 3).\
|
||||
limit(1).order_by(table.c.id)
|
||||
|
||||
# this necessarily has double parens
|
||||
u1 = union(s1, s2).alias()
|
||||
self._assert_result(
|
||||
u1.select().limit(2).order_by(u1.c.id),
|
||||
[(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
def test_limit_offset_aliased_selectable_in_unions(self):
|
||||
table = self.tables.some_table
|
||||
s1 = select([table]).where(table.c.id == 2).\
|
||||
limit(1).order_by(table.c.id).alias().select()
|
||||
s2 = select([table]).where(table.c.id == 3).\
|
||||
limit(1).order_by(table.c.id).alias().select()
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
self._assert_result(
|
||||
u1.order_by(u1.c.id),
|
||||
[(2, 2, 3), (3, 3, 4)]
|
||||
)
|
126
sqlalchemy/testing/suite/test_sequence.py
Normal file
126
sqlalchemy/testing/suite/test_sequence.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from .. import fixtures, config
|
||||
from ..config import requirements
|
||||
from ..assertions import eq_
|
||||
from ... import testing
|
||||
|
||||
from ... import Integer, String, Sequence, schema
|
||||
|
||||
from ..schema import Table, Column
|
||||
|
||||
|
||||
class SequenceTest(fixtures.TablesTest):
|
||||
__requires__ = ('sequences',)
|
||||
__backend__ = True
|
||||
|
||||
run_create_tables = 'each'
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('seq_pk', metadata,
|
||||
Column('id', Integer, Sequence('tab_id_seq'), primary_key=True),
|
||||
Column('data', String(50))
|
||||
)
|
||||
|
||||
Table('seq_opt_pk', metadata,
|
||||
Column('id', Integer, Sequence('tab_id_seq', optional=True),
|
||||
primary_key=True),
|
||||
Column('data', String(50))
|
||||
)
|
||||
|
||||
def test_insert_roundtrip(self):
|
||||
config.db.execute(
|
||||
self.tables.seq_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
self._assert_round_trip(self.tables.seq_pk, config.db)
|
||||
|
||||
def test_insert_lastrowid(self):
|
||||
r = config.db.execute(
|
||||
self.tables.seq_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
eq_(
|
||||
r.inserted_primary_key,
|
||||
[1]
|
||||
)
|
||||
|
||||
def test_nextval_direct(self):
|
||||
r = config.db.execute(
|
||||
self.tables.seq_pk.c.id.default
|
||||
)
|
||||
eq_(
|
||||
r, 1
|
||||
)
|
||||
|
||||
@requirements.sequences_optional
|
||||
def test_optional_seq(self):
|
||||
r = config.db.execute(
|
||||
self.tables.seq_opt_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
eq_(
|
||||
r.inserted_primary_key,
|
||||
[1]
|
||||
)
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(
|
||||
row,
|
||||
(1, "some data")
|
||||
)
|
||||
|
||||
|
||||
class HasSequenceTest(fixtures.TestBase):
|
||||
__requires__ = 'sequences',
|
||||
__backend__ = True
|
||||
|
||||
def test_has_sequence(self):
|
||||
s1 = Sequence('user_id_seq')
|
||||
testing.db.execute(schema.CreateSequence(s1))
|
||||
try:
|
||||
eq_(testing.db.dialect.has_sequence(testing.db,
|
||||
'user_id_seq'), True)
|
||||
finally:
|
||||
testing.db.execute(schema.DropSequence(s1))
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_schema(self):
|
||||
s1 = Sequence('user_id_seq', schema=config.test_schema)
|
||||
testing.db.execute(schema.CreateSequence(s1))
|
||||
try:
|
||||
eq_(testing.db.dialect.has_sequence(
|
||||
testing.db, 'user_id_seq', schema=config.test_schema), True)
|
||||
finally:
|
||||
testing.db.execute(schema.DropSequence(s1))
|
||||
|
||||
def test_has_sequence_neg(self):
|
||||
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),
|
||||
False)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_schemas_neg(self):
|
||||
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
|
||||
schema=config.test_schema),
|
||||
False)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_default_not_in_remote(self):
|
||||
s1 = Sequence('user_id_seq')
|
||||
testing.db.execute(schema.CreateSequence(s1))
|
||||
try:
|
||||
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
|
||||
schema=config.test_schema),
|
||||
False)
|
||||
finally:
|
||||
testing.db.execute(schema.DropSequence(s1))
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_remote_not_in_default(self):
|
||||
s1 = Sequence('user_id_seq', schema=config.test_schema)
|
||||
testing.db.execute(schema.CreateSequence(s1))
|
||||
try:
|
||||
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),
|
||||
False)
|
||||
finally:
|
||||
testing.db.execute(schema.DropSequence(s1))
|
898
sqlalchemy/testing/suite/test_types.py
Normal file
898
sqlalchemy/testing/suite/test_types.py
Normal file
@@ -0,0 +1,898 @@
|
||||
# coding: utf-8
|
||||
|
||||
from .. import fixtures, config
|
||||
from ..assertions import eq_
|
||||
from ..config import requirements
|
||||
from sqlalchemy import Integer, Unicode, UnicodeText, select
|
||||
from sqlalchemy import Date, DateTime, Time, MetaData, String, \
|
||||
Text, Numeric, Float, literal, Boolean, cast, null, JSON, and_, type_coerce
|
||||
from ..schema import Table, Column
|
||||
from ... import testing
|
||||
import decimal
|
||||
import datetime
|
||||
from ...util import u
|
||||
from ... import util
|
||||
|
||||
|
||||
class _LiteralRoundTripFixture(object):
|
||||
@testing.provide_metadata
|
||||
def _literal_round_trip(self, type_, input_, output, filter_=None):
|
||||
"""test literal rendering """
|
||||
|
||||
# for literal, we test the literal render in an INSERT
|
||||
# into a typed column. we can then SELECT it back as its
|
||||
# official type; ideally we'd be able to use CAST here
|
||||
# but MySQL in particular can't CAST fully
|
||||
t = Table('t', self.metadata, Column('x', type_))
|
||||
t.create()
|
||||
|
||||
for value in input_:
|
||||
ins = t.insert().values(x=literal(value)).compile(
|
||||
dialect=testing.db.dialect,
|
||||
compile_kwargs=dict(literal_binds=True)
|
||||
)
|
||||
testing.db.execute(ins)
|
||||
|
||||
for row in t.select().execute():
|
||||
value = row[0]
|
||||
if filter_ is not None:
|
||||
value = filter_(value)
|
||||
assert value in output
|
||||
|
||||
|
||||
class _UnicodeFixture(_LiteralRoundTripFixture):
|
||||
__requires__ = 'unicode_data',
|
||||
|
||||
data = u("Alors vous imaginez ma surprise, au lever du jour, "
|
||||
"quand une drôle de petite voix m’a réveillé. Elle "
|
||||
"disait: « S’il vous plaît… dessine-moi un mouton! »")
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('unicode_table', metadata,
|
||||
Column('id', Integer, primary_key=True,
|
||||
test_needs_autoincrement=True),
|
||||
Column('unicode_data', cls.datatype),
|
||||
)
|
||||
|
||||
def test_round_trip(self):
|
||||
unicode_table = self.tables.unicode_table
|
||||
|
||||
config.db.execute(
|
||||
unicode_table.insert(),
|
||||
{
|
||||
'unicode_data': self.data,
|
||||
}
|
||||
)
|
||||
|
||||
row = config.db.execute(
|
||||
select([
|
||||
unicode_table.c.unicode_data,
|
||||
])
|
||||
).first()
|
||||
|
||||
eq_(
|
||||
row,
|
||||
(self.data, )
|
||||
)
|
||||
assert isinstance(row[0], util.text_type)
|
||||
|
||||
def test_round_trip_executemany(self):
|
||||
unicode_table = self.tables.unicode_table
|
||||
|
||||
config.db.execute(
|
||||
unicode_table.insert(),
|
||||
[
|
||||
{
|
||||
'unicode_data': self.data,
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
)
|
||||
|
||||
rows = config.db.execute(
|
||||
select([
|
||||
unicode_table.c.unicode_data,
|
||||
])
|
||||
).fetchall()
|
||||
eq_(
|
||||
rows,
|
||||
[(self.data, ) for i in range(3)]
|
||||
)
|
||||
for row in rows:
|
||||
assert isinstance(row[0], util.text_type)
|
||||
|
||||
def _test_empty_strings(self):
|
||||
unicode_table = self.tables.unicode_table
|
||||
|
||||
config.db.execute(
|
||||
unicode_table.insert(),
|
||||
{"unicode_data": u('')}
|
||||
)
|
||||
row = config.db.execute(
|
||||
select([unicode_table.c.unicode_data])
|
||||
).first()
|
||||
eq_(row, (u(''),))
|
||||
|
||||
def test_literal(self):
|
||||
self._literal_round_trip(self.datatype, [self.data], [self.data])
|
||||
|
||||
|
||||
class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
|
||||
__requires__ = 'unicode_data',
|
||||
__backend__ = True
|
||||
|
||||
datatype = Unicode(255)
|
||||
|
||||
@requirements.empty_strings_varchar
|
||||
def test_empty_strings_varchar(self):
|
||||
self._test_empty_strings()
|
||||
|
||||
|
||||
class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
|
||||
__requires__ = 'unicode_data', 'text_type'
|
||||
__backend__ = True
|
||||
|
||||
datatype = UnicodeText()
|
||||
|
||||
@requirements.empty_strings_text
|
||||
def test_empty_strings_text(self):
|
||||
self._test_empty_strings()
|
||||
|
||||
|
||||
class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest):
|
||||
__requires__ = 'text_type',
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('text_table', metadata,
|
||||
Column('id', Integer, primary_key=True,
|
||||
test_needs_autoincrement=True),
|
||||
Column('text_data', Text),
|
||||
)
|
||||
|
||||
def test_text_roundtrip(self):
|
||||
text_table = self.tables.text_table
|
||||
|
||||
config.db.execute(
|
||||
text_table.insert(),
|
||||
{"text_data": 'some text'}
|
||||
)
|
||||
row = config.db.execute(
|
||||
select([text_table.c.text_data])
|
||||
).first()
|
||||
eq_(row, ('some text',))
|
||||
|
||||
def test_text_empty_strings(self):
|
||||
text_table = self.tables.text_table
|
||||
|
||||
config.db.execute(
|
||||
text_table.insert(),
|
||||
{"text_data": ''}
|
||||
)
|
||||
row = config.db.execute(
|
||||
select([text_table.c.text_data])
|
||||
).first()
|
||||
eq_(row, ('',))
|
||||
|
||||
def test_literal(self):
|
||||
self._literal_round_trip(Text, ["some text"], ["some text"])
|
||||
|
||||
def test_literal_quoting(self):
|
||||
data = '''some 'text' hey "hi there" that's text'''
|
||||
self._literal_round_trip(Text, [data], [data])
|
||||
|
||||
def test_literal_backslashes(self):
|
||||
data = r'backslash one \ backslash two \\ end'
|
||||
self._literal_round_trip(Text, [data], [data])
|
||||
|
||||
|
||||
class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
@requirements.unbounded_varchar
|
||||
def test_nolength_string(self):
|
||||
metadata = MetaData()
|
||||
foo = Table('foo', metadata,
|
||||
Column('one', String)
|
||||
)
|
||||
|
||||
foo.create(config.db)
|
||||
foo.drop(config.db)
|
||||
|
||||
def test_literal(self):
|
||||
self._literal_round_trip(String(40), ["some text"], ["some text"])
|
||||
|
||||
def test_literal_quoting(self):
|
||||
data = '''some 'text' hey "hi there" that's text'''
|
||||
self._literal_round_trip(String(40), [data], [data])
|
||||
|
||||
def test_literal_backslashes(self):
|
||||
data = r'backslash one \ backslash two \\ end'
|
||||
self._literal_round_trip(String(40), [data], [data])
|
||||
|
||||
|
||||
class _DateFixture(_LiteralRoundTripFixture):
|
||||
compare = None
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('date_table', metadata,
|
||||
Column('id', Integer, primary_key=True,
|
||||
test_needs_autoincrement=True),
|
||||
Column('date_data', cls.datatype),
|
||||
)
|
||||
|
||||
def test_round_trip(self):
|
||||
date_table = self.tables.date_table
|
||||
|
||||
config.db.execute(
|
||||
date_table.insert(),
|
||||
{'date_data': self.data}
|
||||
)
|
||||
|
||||
row = config.db.execute(
|
||||
select([
|
||||
date_table.c.date_data,
|
||||
])
|
||||
).first()
|
||||
|
||||
compare = self.compare or self.data
|
||||
eq_(row,
|
||||
(compare, ))
|
||||
assert isinstance(row[0], type(compare))
|
||||
|
||||
def test_null(self):
|
||||
date_table = self.tables.date_table
|
||||
|
||||
config.db.execute(
|
||||
date_table.insert(),
|
||||
{'date_data': None}
|
||||
)
|
||||
|
||||
row = config.db.execute(
|
||||
select([
|
||||
date_table.c.date_data,
|
||||
])
|
||||
).first()
|
||||
eq_(row, (None,))
|
||||
|
||||
@testing.requires.datetime_literals
|
||||
def test_literal(self):
|
||||
compare = self.compare or self.data
|
||||
self._literal_round_trip(self.datatype, [self.data], [compare])
|
||||
|
||||
|
||||
class DateTimeTest(_DateFixture, fixtures.TablesTest):
|
||||
__requires__ = 'datetime',
|
||||
__backend__ = True
|
||||
datatype = DateTime
|
||||
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
|
||||
|
||||
|
||||
class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
|
||||
__requires__ = 'datetime_microseconds',
|
||||
__backend__ = True
|
||||
datatype = DateTime
|
||||
data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
|
||||
|
||||
|
||||
class TimeTest(_DateFixture, fixtures.TablesTest):
|
||||
__requires__ = 'time',
|
||||
__backend__ = True
|
||||
datatype = Time
|
||||
data = datetime.time(12, 57, 18)
|
||||
|
||||
|
||||
class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
|
||||
__requires__ = 'time_microseconds',
|
||||
__backend__ = True
|
||||
datatype = Time
|
||||
data = datetime.time(12, 57, 18, 396)
|
||||
|
||||
|
||||
class DateTest(_DateFixture, fixtures.TablesTest):
|
||||
__requires__ = 'date',
|
||||
__backend__ = True
|
||||
datatype = Date
|
||||
data = datetime.date(2012, 10, 15)
|
||||
|
||||
|
||||
class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
|
||||
__requires__ = 'date', 'date_coerces_from_datetime'
|
||||
__backend__ = True
|
||||
datatype = Date
|
||||
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
|
||||
compare = datetime.date(2012, 10, 15)
|
||||
|
||||
|
||||
class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest):
|
||||
__requires__ = 'datetime_historic',
|
||||
__backend__ = True
|
||||
datatype = DateTime
|
||||
data = datetime.datetime(1850, 11, 10, 11, 52, 35)
|
||||
|
||||
|
||||
class DateHistoricTest(_DateFixture, fixtures.TablesTest):
|
||||
__requires__ = 'date_historic',
|
||||
__backend__ = True
|
||||
datatype = Date
|
||||
data = datetime.date(1727, 4, 1)
|
||||
|
||||
|
||||
class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
def test_literal(self):
|
||||
self._literal_round_trip(Integer, [5], [5])
|
||||
|
||||
|
||||
class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
|
||||
@testing.provide_metadata
|
||||
def _do_test(self, type_, input_, output,
|
||||
filter_=None, check_scale=False):
|
||||
metadata = self.metadata
|
||||
t = Table('t', metadata, Column('x', type_))
|
||||
t.create()
|
||||
t.insert().execute([{'x': x} for x in input_])
|
||||
|
||||
result = set([row[0] for row in t.select().execute()])
|
||||
output = set(output)
|
||||
if filter_:
|
||||
result = set(filter_(x) for x in result)
|
||||
output = set(filter_(x) for x in output)
|
||||
eq_(result, output)
|
||||
if check_scale:
|
||||
eq_(
|
||||
[str(x) for x in result],
|
||||
[str(x) for x in output],
|
||||
)
|
||||
|
||||
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
|
||||
def test_render_literal_numeric(self):
|
||||
self._literal_round_trip(
|
||||
Numeric(precision=8, scale=4),
|
||||
[15.7563, decimal.Decimal("15.7563")],
|
||||
[decimal.Decimal("15.7563")],
|
||||
)
|
||||
|
||||
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
|
||||
def test_render_literal_numeric_asfloat(self):
|
||||
self._literal_round_trip(
|
||||
Numeric(precision=8, scale=4, asdecimal=False),
|
||||
[15.7563, decimal.Decimal("15.7563")],
|
||||
[15.7563],
|
||||
)
|
||||
|
||||
def test_render_literal_float(self):
|
||||
self._literal_round_trip(
|
||||
Float(4),
|
||||
[15.7563, decimal.Decimal("15.7563")],
|
||||
[15.7563, ],
|
||||
filter_=lambda n: n is not None and round(n, 5) or None
|
||||
)
|
||||
|
||||
@testing.requires.precision_generic_float_type
|
||||
def test_float_custom_scale(self):
|
||||
self._do_test(
|
||||
Float(None, decimal_return_scale=7, asdecimal=True),
|
||||
[15.7563827, decimal.Decimal("15.7563827")],
|
||||
[decimal.Decimal("15.7563827"), ],
|
||||
check_scale=True
|
||||
)
|
||||
|
||||
def test_numeric_as_decimal(self):
|
||||
self._do_test(
|
||||
Numeric(precision=8, scale=4),
|
||||
[15.7563, decimal.Decimal("15.7563")],
|
||||
[decimal.Decimal("15.7563")],
|
||||
)
|
||||
|
||||
def test_numeric_as_float(self):
|
||||
self._do_test(
|
||||
Numeric(precision=8, scale=4, asdecimal=False),
|
||||
[15.7563, decimal.Decimal("15.7563")],
|
||||
[15.7563],
|
||||
)
|
||||
|
||||
@testing.requires.fetch_null_from_numeric
|
||||
def test_numeric_null_as_decimal(self):
|
||||
self._do_test(
|
||||
Numeric(precision=8, scale=4),
|
||||
[None],
|
||||
[None],
|
||||
)
|
||||
|
||||
@testing.requires.fetch_null_from_numeric
|
||||
def test_numeric_null_as_float(self):
|
||||
self._do_test(
|
||||
Numeric(precision=8, scale=4, asdecimal=False),
|
||||
[None],
|
||||
[None],
|
||||
)
|
||||
|
||||
@testing.requires.floats_to_four_decimals
|
||||
def test_float_as_decimal(self):
|
||||
self._do_test(
|
||||
Float(precision=8, asdecimal=True),
|
||||
[15.7563, decimal.Decimal("15.7563"), None],
|
||||
[decimal.Decimal("15.7563"), None],
|
||||
)
|
||||
|
||||
def test_float_as_float(self):
|
||||
self._do_test(
|
||||
Float(precision=8),
|
||||
[15.7563, decimal.Decimal("15.7563")],
|
||||
[15.7563],
|
||||
filter_=lambda n: n is not None and round(n, 5) or None
|
||||
)
|
||||
|
||||
@testing.requires.precision_numerics_general
|
||||
def test_precision_decimal(self):
|
||||
numbers = set([
|
||||
decimal.Decimal("54.234246451650"),
|
||||
decimal.Decimal("0.004354"),
|
||||
decimal.Decimal("900.0"),
|
||||
])
|
||||
|
||||
self._do_test(
|
||||
Numeric(precision=18, scale=12),
|
||||
numbers,
|
||||
numbers,
|
||||
)
|
||||
|
||||
@testing.requires.precision_numerics_enotation_large
|
||||
def test_enotation_decimal(self):
|
||||
"""test exceedingly small decimals.
|
||||
|
||||
Decimal reports values with E notation when the exponent
|
||||
is greater than 6.
|
||||
|
||||
"""
|
||||
|
||||
numbers = set([
|
||||
decimal.Decimal('1E-2'),
|
||||
decimal.Decimal('1E-3'),
|
||||
decimal.Decimal('1E-4'),
|
||||
decimal.Decimal('1E-5'),
|
||||
decimal.Decimal('1E-6'),
|
||||
decimal.Decimal('1E-7'),
|
||||
decimal.Decimal('1E-8'),
|
||||
decimal.Decimal("0.01000005940696"),
|
||||
decimal.Decimal("0.00000005940696"),
|
||||
decimal.Decimal("0.00000000000696"),
|
||||
decimal.Decimal("0.70000000000696"),
|
||||
decimal.Decimal("696E-12"),
|
||||
])
|
||||
self._do_test(
|
||||
Numeric(precision=18, scale=14),
|
||||
numbers,
|
||||
numbers
|
||||
)
|
||||
|
||||
@testing.requires.precision_numerics_enotation_large
|
||||
def test_enotation_decimal_large(self):
|
||||
"""test exceedingly large decimals.
|
||||
|
||||
"""
|
||||
|
||||
numbers = set([
|
||||
decimal.Decimal('4E+8'),
|
||||
decimal.Decimal("5748E+15"),
|
||||
decimal.Decimal('1.521E+15'),
|
||||
decimal.Decimal('00000000000000.1E+12'),
|
||||
])
|
||||
self._do_test(
|
||||
Numeric(precision=25, scale=2),
|
||||
numbers,
|
||||
numbers
|
||||
)
|
||||
|
||||
@testing.requires.precision_numerics_many_significant_digits
|
||||
def test_many_significant_digits(self):
|
||||
numbers = set([
|
||||
decimal.Decimal("31943874831932418390.01"),
|
||||
decimal.Decimal("319438950232418390.273596"),
|
||||
decimal.Decimal("87673.594069654243"),
|
||||
])
|
||||
self._do_test(
|
||||
Numeric(precision=38, scale=12),
|
||||
numbers,
|
||||
numbers
|
||||
)
|
||||
|
||||
@testing.requires.precision_numerics_retains_significant_digits
|
||||
def test_numeric_no_decimal(self):
|
||||
numbers = set([
|
||||
decimal.Decimal("1.000")
|
||||
])
|
||||
self._do_test(
|
||||
Numeric(precision=5, scale=3),
|
||||
numbers,
|
||||
numbers,
|
||||
check_scale=True
|
||||
)
|
||||
|
||||
|
||||
class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('boolean_table', metadata,
|
||||
Column('id', Integer, primary_key=True, autoincrement=False),
|
||||
Column('value', Boolean),
|
||||
Column('unconstrained_value', Boolean(create_constraint=False)),
|
||||
)
|
||||
|
||||
def test_render_literal_bool(self):
|
||||
self._literal_round_trip(
|
||||
Boolean(),
|
||||
[True, False],
|
||||
[True, False]
|
||||
)
|
||||
|
||||
def test_round_trip(self):
|
||||
boolean_table = self.tables.boolean_table
|
||||
|
||||
config.db.execute(
|
||||
boolean_table.insert(),
|
||||
{
|
||||
'id': 1,
|
||||
'value': True,
|
||||
'unconstrained_value': False
|
||||
}
|
||||
)
|
||||
|
||||
row = config.db.execute(
|
||||
select([
|
||||
boolean_table.c.value,
|
||||
boolean_table.c.unconstrained_value
|
||||
])
|
||||
).first()
|
||||
|
||||
eq_(
|
||||
row,
|
||||
(True, False)
|
||||
)
|
||||
assert isinstance(row[0], bool)
|
||||
|
||||
def test_null(self):
|
||||
boolean_table = self.tables.boolean_table
|
||||
|
||||
config.db.execute(
|
||||
boolean_table.insert(),
|
||||
{
|
||||
'id': 1,
|
||||
'value': None,
|
||||
'unconstrained_value': None
|
||||
}
|
||||
)
|
||||
|
||||
row = config.db.execute(
|
||||
select([
|
||||
boolean_table.c.value,
|
||||
boolean_table.c.unconstrained_value
|
||||
])
|
||||
).first()
|
||||
|
||||
eq_(
|
||||
row,
|
||||
(None, None)
|
||||
)
|
||||
|
||||
|
||||
class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
|
||||
__requires__ = 'json_type',
|
||||
__backend__ = True
|
||||
|
||||
datatype = JSON
|
||||
|
||||
data1 = {
|
||||
"key1": "value1",
|
||||
"key2": "value2"
|
||||
}
|
||||
|
||||
data2 = {
|
||||
"Key 'One'": "value1",
|
||||
"key two": "value2",
|
||||
"key three": "value ' three '"
|
||||
}
|
||||
|
||||
data3 = {
|
||||
"key1": [1, 2, 3],
|
||||
"key2": ["one", "two", "three"],
|
||||
"key3": [{"four": "five"}, {"six": "seven"}]
|
||||
}
|
||||
|
||||
data4 = ["one", "two", "three"]
|
||||
|
||||
data5 = {
|
||||
"nested": {
|
||||
"elem1": [
|
||||
{"a": "b", "c": "d"},
|
||||
{"e": "f", "g": "h"}
|
||||
],
|
||||
"elem2": {
|
||||
"elem3": {"elem4": "elem5"}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data6 = {
|
||||
"a": 5,
|
||||
"b": "some value",
|
||||
"c": {"foo": "bar"}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('data_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('name', String(30), nullable=False),
|
||||
Column('data', cls.datatype),
|
||||
Column('nulldata', cls.datatype(none_as_null=True))
|
||||
)
|
||||
|
||||
def test_round_trip_data1(self):
|
||||
self._test_round_trip(self.data1)
|
||||
|
||||
def _test_round_trip(self, data_element):
|
||||
data_table = self.tables.data_table
|
||||
|
||||
config.db.execute(
|
||||
data_table.insert(),
|
||||
{'name': 'row1', 'data': data_element}
|
||||
)
|
||||
|
||||
row = config.db.execute(
|
||||
select([
|
||||
data_table.c.data,
|
||||
])
|
||||
).first()
|
||||
|
||||
eq_(row, (data_element, ))
|
||||
|
||||
def test_round_trip_none_as_sql_null(self):
|
||||
col = self.tables.data_table.c['nulldata']
|
||||
|
||||
with config.db.connect() as conn:
|
||||
conn.execute(
|
||||
self.tables.data_table.insert(),
|
||||
{"name": "r1", "data": None}
|
||||
)
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select([self.tables.data_table.c.name]).
|
||||
where(col.is_(null()))
|
||||
),
|
||||
"r1"
|
||||
)
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select([col])
|
||||
),
|
||||
None
|
||||
)
|
||||
|
||||
def test_round_trip_json_null_as_json_null(self):
|
||||
col = self.tables.data_table.c['data']
|
||||
|
||||
with config.db.connect() as conn:
|
||||
conn.execute(
|
||||
self.tables.data_table.insert(),
|
||||
{"name": "r1", "data": JSON.NULL}
|
||||
)
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select([self.tables.data_table.c.name]).
|
||||
where(cast(col, String) == 'null')
|
||||
),
|
||||
"r1"
|
||||
)
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select([col])
|
||||
),
|
||||
None
|
||||
)
|
||||
|
||||
def test_round_trip_none_as_json_null(self):
|
||||
col = self.tables.data_table.c['data']
|
||||
|
||||
with config.db.connect() as conn:
|
||||
conn.execute(
|
||||
self.tables.data_table.insert(),
|
||||
{"name": "r1", "data": None}
|
||||
)
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select([self.tables.data_table.c.name]).
|
||||
where(cast(col, String) == 'null')
|
||||
),
|
||||
"r1"
|
||||
)
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select([col])
|
||||
),
|
||||
None
|
||||
)
|
||||
|
||||
def _criteria_fixture(self):
|
||||
config.db.execute(
|
||||
self.tables.data_table.insert(),
|
||||
[{"name": "r1", "data": self.data1},
|
||||
{"name": "r2", "data": self.data2},
|
||||
{"name": "r3", "data": self.data3},
|
||||
{"name": "r4", "data": self.data4},
|
||||
{"name": "r5", "data": self.data5},
|
||||
{"name": "r6", "data": self.data6}]
|
||||
)
|
||||
|
||||
def _test_index_criteria(self, crit, expected, test_literal=True):
|
||||
self._criteria_fixture()
|
||||
with config.db.connect() as conn:
|
||||
stmt = select([self.tables.data_table.c.name]).where(crit)
|
||||
|
||||
eq_(
|
||||
conn.scalar(stmt),
|
||||
expected
|
||||
)
|
||||
|
||||
if test_literal:
|
||||
literal_sql = str(stmt.compile(
|
||||
config.db, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
eq_(conn.scalar(literal_sql), expected)
|
||||
|
||||
def test_crit_spaces_in_key(self):
|
||||
name = self.tables.data_table.c.name
|
||||
col = self.tables.data_table.c['data']
|
||||
|
||||
# limit the rows here to avoid PG error
|
||||
# "cannot extract field from a non-object", which is
|
||||
# fixed in 9.4 but may exist in 9.3
|
||||
self._test_index_criteria(
|
||||
and_(
|
||||
name.in_(["r1", "r2", "r3"]),
|
||||
cast(col["key two"], String) == '"value2"'
|
||||
),
|
||||
"r2"
|
||||
)
|
||||
|
||||
@config.requirements.json_array_indexes
|
||||
def test_crit_simple_int(self):
|
||||
name = self.tables.data_table.c.name
|
||||
col = self.tables.data_table.c['data']
|
||||
|
||||
# limit the rows here to avoid PG error
|
||||
# "cannot extract array element from a non-array", which is
|
||||
# fixed in 9.4 but may exist in 9.3
|
||||
self._test_index_criteria(
|
||||
and_(name == 'r4', cast(col[1], String) == '"two"'),
|
||||
"r4"
|
||||
)
|
||||
|
||||
def test_crit_mixed_path(self):
|
||||
col = self.tables.data_table.c['data']
|
||||
self._test_index_criteria(
|
||||
cast(col[("key3", 1, "six")], String) == '"seven"',
|
||||
"r3"
|
||||
)
|
||||
|
||||
def test_crit_string_path(self):
|
||||
col = self.tables.data_table.c['data']
|
||||
self._test_index_criteria(
|
||||
cast(col[("nested", "elem2", "elem3", "elem4")], String)
|
||||
== '"elem5"',
|
||||
"r5"
|
||||
)
|
||||
|
||||
def test_crit_against_string_basic(self):
|
||||
name = self.tables.data_table.c.name
|
||||
col = self.tables.data_table.c['data']
|
||||
|
||||
self._test_index_criteria(
|
||||
and_(name == 'r6', cast(col["b"], String) == '"some value"'),
|
||||
"r6"
|
||||
)
|
||||
|
||||
def test_crit_against_string_coerce_type(self):
|
||||
name = self.tables.data_table.c.name
|
||||
col = self.tables.data_table.c['data']
|
||||
|
||||
self._test_index_criteria(
|
||||
and_(name == 'r6',
|
||||
cast(col["b"], String) == type_coerce("some value", JSON)),
|
||||
"r6",
|
||||
test_literal=False
|
||||
)
|
||||
|
||||
def test_crit_against_int_basic(self):
|
||||
name = self.tables.data_table.c.name
|
||||
col = self.tables.data_table.c['data']
|
||||
|
||||
self._test_index_criteria(
|
||||
and_(name == 'r6', cast(col["a"], String) == '5'),
|
||||
"r6"
|
||||
)
|
||||
|
||||
def test_crit_against_int_coerce_type(self):
|
||||
name = self.tables.data_table.c.name
|
||||
col = self.tables.data_table.c['data']
|
||||
|
||||
self._test_index_criteria(
|
||||
and_(name == 'r6', cast(col["a"], String) == type_coerce(5, JSON)),
|
||||
"r6",
|
||||
test_literal=False
|
||||
)
|
||||
|
||||
def test_unicode_round_trip(self):
|
||||
s = select([
|
||||
cast(
|
||||
{
|
||||
util.u('réveillé'): util.u('réveillé'),
|
||||
"data": {"k1": util.u('drôle')}
|
||||
},
|
||||
self.datatype
|
||||
)
|
||||
])
|
||||
eq_(
|
||||
config.db.scalar(s),
|
||||
{
|
||||
util.u('réveillé'): util.u('réveillé'),
|
||||
"data": {"k1": util.u('drôle')}
|
||||
},
|
||||
)
|
||||
|
||||
def test_eval_none_flag_orm(self):
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Data(Base):
|
||||
__table__ = self.tables.data_table
|
||||
|
||||
s = Session(testing.db)
|
||||
|
||||
d1 = Data(name='d1', data=None, nulldata=None)
|
||||
s.add(d1)
|
||||
s.commit()
|
||||
|
||||
s.bulk_insert_mappings(
|
||||
Data, [{"name": "d2", "data": None, "nulldata": None}]
|
||||
)
|
||||
eq_(
|
||||
s.query(
|
||||
cast(self.tables.data_table.c.data, String(convert_unicode="force")),
|
||||
cast(self.tables.data_table.c.nulldata, String)
|
||||
).filter(self.tables.data_table.c.name == 'd1').first(),
|
||||
("null", None)
|
||||
)
|
||||
eq_(
|
||||
s.query(
|
||||
cast(self.tables.data_table.c.data, String(convert_unicode="force")),
|
||||
cast(self.tables.data_table.c.nulldata, String)
|
||||
).filter(self.tables.data_table.c.name == 'd2').first(),
|
||||
("null", None)
|
||||
)
|
||||
|
||||
|
||||
__all__ = ('UnicodeVarcharTest', 'UnicodeTextTest', 'JSONTest',
|
||||
'DateTest', 'DateTimeTest', 'TextTest',
|
||||
'NumericTest', 'IntegerTest',
|
||||
'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest',
|
||||
'TimeMicrosecondsTest', 'TimeTest', 'DateTimeMicrosecondsTest',
|
||||
'DateHistoricTest', 'StringTest', 'BooleanTest')
|
63
sqlalchemy/testing/suite/test_update_delete.py
Normal file
63
sqlalchemy/testing/suite/test_update_delete.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from .. import fixtures, config
|
||||
from ..assertions import eq_
|
||||
|
||||
from sqlalchemy import Integer, String
|
||||
from ..schema import Table, Column
|
||||
|
||||
|
||||
class SimpleUpdateDeleteTest(fixtures.TablesTest):
|
||||
run_deletes = 'each'
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('plain_pk', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', String(50))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls):
|
||||
config.db.execute(
|
||||
cls.tables.plain_pk.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
]
|
||||
)
|
||||
|
||||
def test_update(self):
|
||||
t = self.tables.plain_pk
|
||||
r = config.db.execute(
|
||||
t.update().where(t.c.id == 2),
|
||||
data="d2_new"
|
||||
)
|
||||
assert not r.is_insert
|
||||
assert not r.returns_rows
|
||||
|
||||
eq_(
|
||||
config.db.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
[
|
||||
(1, "d1"),
|
||||
(2, "d2_new"),
|
||||
(3, "d3")
|
||||
]
|
||||
)
|
||||
|
||||
def test_delete(self):
|
||||
t = self.tables.plain_pk
|
||||
r = config.db.execute(
|
||||
t.delete().where(t.c.id == 2)
|
||||
)
|
||||
assert not r.is_insert
|
||||
assert not r.returns_rows
|
||||
eq_(
|
||||
config.db.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
[
|
||||
(1, "d1"),
|
||||
(3, "d3")
|
||||
]
|
||||
)
|
||||
|
||||
__all__ = ('SimpleUpdateDeleteTest', )
|
280
sqlalchemy/testing/util.py
Normal file
280
sqlalchemy/testing/util.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# testing/util.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from ..util import jython, pypy, defaultdict, decorator, py2k
|
||||
import decimal
|
||||
import gc
|
||||
import time
|
||||
import random
|
||||
import sys
|
||||
import types
|
||||
|
||||
if jython:
|
||||
def jython_gc_collect(*args):
|
||||
"""aggressive gc.collect for tests."""
|
||||
gc.collect()
|
||||
time.sleep(0.1)
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
return 0
|
||||
|
||||
# "lazy" gc, for VM's that don't GC on refcount == 0
|
||||
gc_collect = lazy_gc = jython_gc_collect
|
||||
elif pypy:
|
||||
def pypy_gc_collect(*args):
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
gc_collect = lazy_gc = pypy_gc_collect
|
||||
else:
|
||||
# assume CPython - straight gc.collect, lazy_gc() is a pass
|
||||
gc_collect = gc.collect
|
||||
|
||||
def lazy_gc():
|
||||
pass
|
||||
|
||||
|
||||
def picklers():
|
||||
picklers = set()
|
||||
if py2k:
|
||||
try:
|
||||
import cPickle
|
||||
picklers.add(cPickle)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import pickle
|
||||
picklers.add(pickle)
|
||||
|
||||
# yes, this thing needs this much testing
|
||||
for pickle_ in picklers:
|
||||
for protocol in -1, 0, 1, 2:
|
||||
yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
|
||||
|
||||
|
||||
def round_decimal(value, prec):
|
||||
if isinstance(value, float):
|
||||
return round(value, prec)
|
||||
|
||||
# can also use shift() here but that is 2.6 only
|
||||
return (value * decimal.Decimal("1" + "0" * prec)
|
||||
).to_integral(decimal.ROUND_FLOOR) / \
|
||||
pow(10, prec)
|
||||
|
||||
|
||||
class RandomSet(set):
|
||||
def __iter__(self):
|
||||
l = list(set.__iter__(self))
|
||||
random.shuffle(l)
|
||||
return iter(l)
|
||||
|
||||
def pop(self):
|
||||
index = random.randint(0, len(self) - 1)
|
||||
item = list(set.__iter__(self))[index]
|
||||
self.remove(item)
|
||||
return item
|
||||
|
||||
def union(self, other):
|
||||
return RandomSet(set.union(self, other))
|
||||
|
||||
def difference(self, other):
|
||||
return RandomSet(set.difference(self, other))
|
||||
|
||||
def intersection(self, other):
|
||||
return RandomSet(set.intersection(self, other))
|
||||
|
||||
def copy(self):
|
||||
return RandomSet(self)
|
||||
|
||||
|
||||
def conforms_partial_ordering(tuples, sorted_elements):
|
||||
"""True if the given sorting conforms to the given partial ordering."""
|
||||
|
||||
deps = defaultdict(set)
|
||||
for parent, child in tuples:
|
||||
deps[parent].add(child)
|
||||
for i, node in enumerate(sorted_elements):
|
||||
for n in sorted_elements[i:]:
|
||||
if node in deps[n]:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def all_partial_orderings(tuples, elements):
|
||||
edges = defaultdict(set)
|
||||
for parent, child in tuples:
|
||||
edges[child].add(parent)
|
||||
|
||||
def _all_orderings(elements):
|
||||
|
||||
if len(elements) == 1:
|
||||
yield list(elements)
|
||||
else:
|
||||
for elem in elements:
|
||||
subset = set(elements).difference([elem])
|
||||
if not subset.intersection(edges[elem]):
|
||||
for sub_ordering in _all_orderings(subset):
|
||||
yield [elem] + sub_ordering
|
||||
|
||||
return iter(_all_orderings(elements))
|
||||
|
||||
|
||||
def function_named(fn, name):
|
||||
"""Return a function with a given __name__.
|
||||
|
||||
Will assign to __name__ and return the original function if possible on
|
||||
the Python implementation, otherwise a new function will be constructed.
|
||||
|
||||
This function should be phased out as much as possible
|
||||
in favor of @decorator. Tests that "generate" many named tests
|
||||
should be modernized.
|
||||
|
||||
"""
|
||||
try:
|
||||
fn.__name__ = name
|
||||
except TypeError:
|
||||
fn = types.FunctionType(fn.__code__, fn.__globals__, name,
|
||||
fn.__defaults__, fn.__closure__)
|
||||
return fn
|
||||
|
||||
|
||||
def run_as_contextmanager(ctx, fn, *arg, **kw):
|
||||
"""Run the given function under the given contextmanager,
|
||||
simulating the behavior of 'with' to support older
|
||||
Python versions.
|
||||
|
||||
This is not necessary anymore as we have placed 2.6
|
||||
as minimum Python version, however some tests are still using
|
||||
this structure.
|
||||
|
||||
"""
|
||||
|
||||
obj = ctx.__enter__()
|
||||
try:
|
||||
result = fn(obj, *arg, **kw)
|
||||
ctx.__exit__(None, None, None)
|
||||
return result
|
||||
except:
|
||||
exc_info = sys.exc_info()
|
||||
raise_ = ctx.__exit__(*exc_info)
|
||||
if raise_ is None:
|
||||
raise
|
||||
else:
|
||||
return raise_
|
||||
|
||||
|
||||
def rowset(results):
|
||||
"""Converts the results of sql execution into a plain set of column tuples.
|
||||
|
||||
Useful for asserting the results of an unordered query.
|
||||
"""
|
||||
|
||||
return set([tuple(row) for row in results])
|
||||
|
||||
|
||||
def fail(msg):
|
||||
assert False, msg
|
||||
|
||||
|
||||
@decorator
|
||||
def provide_metadata(fn, *args, **kw):
|
||||
"""Provide bound MetaData for a single test, dropping afterwards."""
|
||||
|
||||
from . import config
|
||||
from . import engines
|
||||
from sqlalchemy import schema
|
||||
|
||||
metadata = schema.MetaData(config.db)
|
||||
self = args[0]
|
||||
prev_meta = getattr(self, 'metadata', None)
|
||||
self.metadata = metadata
|
||||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
engines.drop_all_tables(metadata, config.db)
|
||||
self.metadata = prev_meta
|
||||
|
||||
|
||||
def force_drop_names(*names):
|
||||
"""Force the given table names to be dropped after test complete,
|
||||
isolating for foreign key cycles
|
||||
|
||||
"""
|
||||
from . import config
|
||||
from sqlalchemy import inspect
|
||||
|
||||
@decorator
|
||||
def go(fn, *args, **kw):
|
||||
|
||||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
drop_all_tables(
|
||||
config.db, inspect(config.db), include_names=names)
|
||||
return go
|
||||
|
||||
|
||||
class adict(dict):
|
||||
"""Dict keys available as attributes. Shadows."""
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return dict.__getattribute__(self, key)
|
||||
|
||||
def __call__(self, *keys):
|
||||
return tuple([self[key] for key in keys])
|
||||
|
||||
get_all = __call__
|
||||
|
||||
|
||||
def drop_all_tables(engine, inspector, schema=None, include_names=None):
|
||||
from sqlalchemy import Column, Table, Integer, MetaData, \
|
||||
ForeignKeyConstraint
|
||||
from sqlalchemy.schema import DropTable, DropConstraint
|
||||
|
||||
if include_names is not None:
|
||||
include_names = set(include_names)
|
||||
|
||||
with engine.connect() as conn:
|
||||
for tname, fkcs in reversed(
|
||||
inspector.get_sorted_table_and_fkc_names(schema=schema)):
|
||||
if tname:
|
||||
if include_names is not None and tname not in include_names:
|
||||
continue
|
||||
conn.execute(DropTable(
|
||||
Table(tname, MetaData(), schema=schema)
|
||||
))
|
||||
elif fkcs:
|
||||
if not engine.dialect.supports_alter:
|
||||
continue
|
||||
for tname, fkc in fkcs:
|
||||
if include_names is not None and \
|
||||
tname not in include_names:
|
||||
continue
|
||||
tb = Table(
|
||||
tname, MetaData(),
|
||||
Column('x', Integer),
|
||||
Column('y', Integer),
|
||||
schema=schema
|
||||
)
|
||||
conn.execute(DropConstraint(
|
||||
ForeignKeyConstraint(
|
||||
[tb.c.x], [tb.c.y], name=fkc)
|
||||
))
|
||||
|
||||
|
||||
def teardown_events(event_cls):
|
||||
@decorator
|
||||
def decorate(fn, *arg, **kw):
|
||||
try:
|
||||
return fn(*arg, **kw)
|
||||
finally:
|
||||
event_cls._clear()
|
||||
return decorate
|
||||
|
41
sqlalchemy/testing/warnings.py
Normal file
41
sqlalchemy/testing/warnings.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# testing/warnings.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import warnings
|
||||
from .. import exc as sa_exc
|
||||
from . import assertions
|
||||
|
||||
|
||||
def setup_filters():
|
||||
"""Set global warning behavior for the test suite."""
|
||||
|
||||
warnings.filterwarnings('ignore',
|
||||
category=sa_exc.SAPendingDeprecationWarning)
|
||||
warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
|
||||
warnings.filterwarnings('error', category=sa_exc.SAWarning)
|
||||
|
||||
# some selected deprecations...
|
||||
warnings.filterwarnings('error', category=DeprecationWarning)
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=DeprecationWarning, message=".*StopIteration")
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=DeprecationWarning, message=".*inspect.getargspec")
|
||||
|
||||
|
||||
def assert_warnings(fn, warning_msgs, regex=False):
|
||||
"""Assert that each of the given warnings are emitted by fn.
|
||||
|
||||
Deprecated. Please use assertions.expect_warnings().
|
||||
|
||||
"""
|
||||
|
||||
with assertions._expect_warnings(
|
||||
sa_exc.SAWarning, warning_msgs, regex=regex):
|
||||
return fn()
|
||||
|
Reference in New Issue
Block a user