import sys, types, weakref from collections import deque import config from sqlalchemy.util import function_named, callable import re import warnings class ConnectionKiller(object): def __init__(self): self.proxy_refs = weakref.WeakKeyDictionary() def checkout(self, dbapi_con, con_record, con_proxy): self.proxy_refs[con_proxy] = True def _apply_all(self, methods): # must copy keys atomically for rec in self.proxy_refs.keys(): if rec is not None and rec.is_valid: try: for name in methods: if callable(name): name(rec) else: getattr(rec, name)() except (SystemExit, KeyboardInterrupt): raise except Exception, e: warnings.warn("testing_reaper couldn't close connection: %s" % e) def rollback_all(self): self._apply_all(('rollback',)) def close_all(self): self._apply_all(('rollback', 'close')) def assert_all_closed(self): for rec in self.proxy_refs: if rec.is_valid: assert False testing_reaper = ConnectionKiller() def drop_all_tables(metadata): testing_reaper.close_all() metadata.drop_all() def assert_conns_closed(fn): def decorated(*args, **kw): try: fn(*args, **kw) finally: testing_reaper.assert_all_closed() return function_named(decorated, fn.__name__) def rollback_open_connections(fn): """Decorator that rolls back all open connections after fn execution.""" def decorated(*args, **kw): try: fn(*args, **kw) finally: testing_reaper.rollback_all() return function_named(decorated, fn.__name__) def close_first(fn): """Decorator that closes all connections before fn execution.""" def decorated(*args, **kw): testing_reaper.close_all() fn(*args, **kw) return function_named(decorated, fn.__name__) def close_open_connections(fn): """Decorator that closes all connections after fn execution.""" def decorated(*args, **kw): try: fn(*args, **kw) finally: testing_reaper.close_all() return function_named(decorated, fn.__name__) def all_dialects(exclude=None): import sqlalchemy.databases as d for name in d.__all__: # TEMPORARY if exclude and name in exclude: continue mod = getattr(d, name, None) if not mod: mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) yield mod.dialect() class ReconnectFixture(object): def __init__(self, dbapi): self.dbapi = dbapi self.connections = [] def __getattr__(self, key): return getattr(self.dbapi, key) def connect(self, *args, **kwargs): conn = self.dbapi.connect(*args, **kwargs) self.connections.append(conn) return conn def shutdown(self): for c in list(self.connections): c.close() self.connections = [] def reconnecting_engine(url=None, options=None): url = url or config.db_url dbapi = config.db.dialect.dbapi if not options: options = {} options['module'] = ReconnectFixture(dbapi) engine = testing_engine(url, options) engine.test_shutdown = engine.dialect.dbapi.shutdown return engine def testing_engine(url=None, options=None): """Produce an engine configured by --options with optional overrides.""" from sqlalchemy import create_engine from sqlalchemy.test.assertsql import asserter url = url or config.db_url options = options or config.db_opts options.setdefault('proxy', asserter) listeners = options.setdefault('listeners', []) listeners.append(testing_reaper) engine = create_engine(url, **options) # may want to call this, results # in first-connect initializers #engine.connect() return engine def utf8_engine(url=None, options=None): """Hook for dialects or drivers that don't handle utf8 by default.""" from sqlalchemy.engine import url as engine_url if config.db.driver == 'mysqldb': dbapi_ver = config.db.dialect.dbapi.version_info if (dbapi_ver < (1, 2, 1) or dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2), (1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))): raise RuntimeError('Character set support unavailable with this ' 'driver version: %s' % repr(dbapi_ver)) else: url = url or config.db_url url = engine_url.make_url(url) url.query['charset'] = 'utf8' url.query['use_unicode'] = '0' url = str(url) return testing_engine(url, options) def mock_engine(dialect_name=None): """Provides a mocking engine based on the current testing.db. This is normally used to test DDL generation flow as emitted by an Engine. It should not be used in other cases, as assert_compile() and assert_sql_execution() are much better choices with fewer moving parts. """ from sqlalchemy import create_engine if not dialect_name: dialect_name = config.db.name buffer = [] def executor(sql, *a, **kw): buffer.append(sql) def assert_sql(stmts): recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] assert recv == stmts, recv engine = create_engine(dialect_name + '://', strategy='mock', executor=executor) assert not hasattr(engine, 'mock') engine.mock = buffer engine.assert_sql = assert_sql return engine class ReplayableSession(object): """A simple record/playback tool. This is *not* a mock testing class. It only records a session for later playback and makes no assertions on call consistency whatsoever. It's unlikely to be suitable for anything other than DB-API recording. """ Callable = object() NoAttribute = object() Natives = set([getattr(types, t) for t in dir(types) if not t.startswith('_')]). \ difference([getattr(types, t) # Py3K #for t in ('FunctionType', 'BuiltinFunctionType', # 'MethodType', 'BuiltinMethodType', # 'LambdaType', )]) # Py2K for t in ('FunctionType', 'BuiltinFunctionType', 'MethodType', 'BuiltinMethodType', 'LambdaType', 'UnboundMethodType',)]) # end Py2K def __init__(self): self.buffer = deque() def recorder(self, base): return self.Recorder(self.buffer, base) def player(self): return self.Player(self.buffer) class Recorder(object): def __init__(self, buffer, subject): self._buffer = buffer self._subject = subject def __call__(self, *args, **kw): subject, buffer = [object.__getattribute__(self, x) for x in ('_subject', '_buffer')] result = subject(*args, **kw) if type(result) not in ReplayableSession.Natives: buffer.append(ReplayableSession.Callable) return type(self)(buffer, result) else: buffer.append(result) return result @property def _sqla_unwrap(self): return self._subject def __getattribute__(self, key): try: return object.__getattribute__(self, key) except AttributeError: pass subject, buffer = [object.__getattribute__(self, x) for x in ('_subject', '_buffer')] try: result = type(subject).__getattribute__(subject, key) except AttributeError: buffer.append(ReplayableSession.NoAttribute) raise else: if type(result) not in ReplayableSession.Natives: buffer.append(ReplayableSession.Callable) return type(self)(buffer, result) else: buffer.append(result) return result class Player(object): def __init__(self, buffer): self._buffer = buffer def __call__(self, *args, **kw): buffer = object.__getattribute__(self, '_buffer') result = buffer.popleft() if result is ReplayableSession.Callable: return self else: return result @property def _sqla_unwrap(self): return None def __getattribute__(self, key): try: return object.__getattribute__(self, key) except AttributeError: pass buffer = object.__getattribute__(self, '_buffer') result = buffer.popleft() if result is ReplayableSession.Callable: return self elif result is ReplayableSession.NoAttribute: raise AttributeError(key) else: return result