morro
This commit is contained in:
119
sqlalchemy/__init__.py
Normal file
119
sqlalchemy/__init__.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# __init__.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
|
||||
import sqlalchemy.exc as exceptions
|
||||
sys.modules['sqlalchemy.exceptions'] = exceptions
|
||||
|
||||
from sqlalchemy.sql import (
|
||||
alias,
|
||||
and_,
|
||||
asc,
|
||||
between,
|
||||
bindparam,
|
||||
case,
|
||||
cast,
|
||||
collate,
|
||||
delete,
|
||||
desc,
|
||||
distinct,
|
||||
except_,
|
||||
except_all,
|
||||
exists,
|
||||
extract,
|
||||
func,
|
||||
insert,
|
||||
intersect,
|
||||
intersect_all,
|
||||
join,
|
||||
literal,
|
||||
literal_column,
|
||||
modifier,
|
||||
not_,
|
||||
null,
|
||||
or_,
|
||||
outerjoin,
|
||||
outparam,
|
||||
select,
|
||||
subquery,
|
||||
text,
|
||||
tuple_,
|
||||
union,
|
||||
union_all,
|
||||
update,
|
||||
)
|
||||
|
||||
from sqlalchemy.types import (
|
||||
BLOB,
|
||||
BOOLEAN,
|
||||
BigInteger,
|
||||
Binary,
|
||||
Boolean,
|
||||
CHAR,
|
||||
CLOB,
|
||||
DATE,
|
||||
DATETIME,
|
||||
DECIMAL,
|
||||
Date,
|
||||
DateTime,
|
||||
Enum,
|
||||
FLOAT,
|
||||
Float,
|
||||
INT,
|
||||
INTEGER,
|
||||
Integer,
|
||||
Interval,
|
||||
LargeBinary,
|
||||
NCHAR,
|
||||
NVARCHAR,
|
||||
NUMERIC,
|
||||
Numeric,
|
||||
PickleType,
|
||||
SMALLINT,
|
||||
SmallInteger,
|
||||
String,
|
||||
TEXT,
|
||||
TIME,
|
||||
TIMESTAMP,
|
||||
Text,
|
||||
Time,
|
||||
Unicode,
|
||||
UnicodeText,
|
||||
VARCHAR,
|
||||
)
|
||||
|
||||
|
||||
from sqlalchemy.schema import (
|
||||
CheckConstraint,
|
||||
Column,
|
||||
ColumnDefault,
|
||||
Constraint,
|
||||
DDL,
|
||||
DefaultClause,
|
||||
FetchedValue,
|
||||
ForeignKey,
|
||||
ForeignKeyConstraint,
|
||||
Index,
|
||||
MetaData,
|
||||
PassiveDefault,
|
||||
PrimaryKeyConstraint,
|
||||
Sequence,
|
||||
Table,
|
||||
ThreadLocalMetaData,
|
||||
UniqueConstraint,
|
||||
)
|
||||
|
||||
from sqlalchemy.engine import create_engine, engine_from_config
|
||||
|
||||
|
||||
__all__ = sorted(name for name, obj in locals().items()
|
||||
if not (name.startswith('_') or inspect.ismodule(obj)))
|
||||
|
||||
__version__ = '0.6beta3'
|
||||
|
||||
del inspect, sys
|
||||
393
sqlalchemy/cextension/processors.c
Normal file
393
sqlalchemy/cextension/processors.c
Normal file
@@ -0,0 +1,393 @@
|
||||
/*
|
||||
processors.c
|
||||
Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
|
||||
|
||||
This module is part of SQLAlchemy and is released under
|
||||
the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
*/
|
||||
|
||||
#include <Python.h>
|
||||
#include <datetime.h>
|
||||
|
||||
static PyObject *
|
||||
int_to_boolean(PyObject *self, PyObject *arg)
|
||||
{
|
||||
long l = 0;
|
||||
PyObject *res;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
l = PyInt_AsLong(arg);
|
||||
if (l == 0) {
|
||||
res = Py_False;
|
||||
} else if (l == 1) {
|
||||
res = Py_True;
|
||||
} else if ((l == -1) && PyErr_Occurred()) {
|
||||
/* -1 can be either the actual value, or an error flag. */
|
||||
return NULL;
|
||||
} else {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"int_to_boolean only accepts None, 0 or 1");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Py_INCREF(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
to_str(PyObject *self, PyObject *arg)
|
||||
{
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
return PyObject_Str(arg);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
to_float(PyObject *self, PyObject *arg)
|
||||
{
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
return PyNumber_Float(arg);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
str_to_datetime(PyObject *self, PyObject *arg)
|
||||
{
|
||||
const char *str;
|
||||
unsigned int year, month, day, hour, minute, second, microsecond = 0;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
str = PyString_AsString(arg);
|
||||
if (str == NULL)
|
||||
return NULL;
|
||||
|
||||
/* microseconds are optional */
|
||||
/*
|
||||
TODO: this is slightly less picky than the Python version which would
|
||||
not accept "2000-01-01 00:00:00.". I don't know which is better, but they
|
||||
should be coherent.
|
||||
*/
|
||||
if (sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day,
|
||||
&hour, &minute, &second, µsecond) < 6) {
|
||||
PyErr_SetString(PyExc_ValueError, "Couldn't parse datetime string.");
|
||||
return NULL;
|
||||
}
|
||||
return PyDateTime_FromDateAndTime(year, month, day,
|
||||
hour, minute, second, microsecond);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
str_to_time(PyObject *self, PyObject *arg)
|
||||
{
|
||||
const char *str;
|
||||
unsigned int hour, minute, second, microsecond = 0;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
str = PyString_AsString(arg);
|
||||
if (str == NULL)
|
||||
return NULL;
|
||||
|
||||
/* microseconds are optional */
|
||||
/*
|
||||
TODO: this is slightly less picky than the Python version which would
|
||||
not accept "00:00:00.". I don't know which is better, but they should be
|
||||
coherent.
|
||||
*/
|
||||
if (sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second,
|
||||
µsecond) < 3) {
|
||||
PyErr_SetString(PyExc_ValueError, "Couldn't parse time string.");
|
||||
return NULL;
|
||||
}
|
||||
return PyTime_FromTime(hour, minute, second, microsecond);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
str_to_date(PyObject *self, PyObject *arg)
|
||||
{
|
||||
const char *str;
|
||||
unsigned int year, month, day;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
str = PyString_AsString(arg);
|
||||
if (str == NULL)
|
||||
return NULL;
|
||||
|
||||
if (sscanf(str, "%4u-%2u-%2u", &year, &month, &day) != 3) {
|
||||
PyErr_SetString(PyExc_ValueError, "Couldn't parse date string.");
|
||||
return NULL;
|
||||
}
|
||||
return PyDate_FromDate(year, month, day);
|
||||
}
|
||||
|
||||
|
||||
/***********
|
||||
* Structs *
|
||||
***********/
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
PyObject *encoding;
|
||||
PyObject *errors;
|
||||
} UnicodeResultProcessor;
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
PyObject *type;
|
||||
PyObject *format;
|
||||
} DecimalResultProcessor;
|
||||
|
||||
|
||||
|
||||
/**************************
|
||||
* UnicodeResultProcessor *
|
||||
**************************/
|
||||
|
||||
static int
|
||||
UnicodeResultProcessor_init(UnicodeResultProcessor *self, PyObject *args,
|
||||
PyObject *kwds)
|
||||
{
|
||||
PyObject *encoding, *errors = NULL;
|
||||
static char *kwlist[] = {"encoding", "errors", NULL};
|
||||
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:__init__", kwlist,
|
||||
&encoding, &errors))
|
||||
return -1;
|
||||
|
||||
Py_INCREF(encoding);
|
||||
self->encoding = encoding;
|
||||
|
||||
if (errors) {
|
||||
Py_INCREF(errors);
|
||||
} else {
|
||||
errors = PyString_FromString("strict");
|
||||
if (errors == NULL)
|
||||
return -1;
|
||||
}
|
||||
self->errors = errors;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
UnicodeResultProcessor_process(UnicodeResultProcessor *self, PyObject *value)
|
||||
{
|
||||
const char *encoding, *errors;
|
||||
char *str;
|
||||
Py_ssize_t len;
|
||||
|
||||
if (value == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
if (PyString_AsStringAndSize(value, &str, &len))
|
||||
return NULL;
|
||||
|
||||
encoding = PyString_AS_STRING(self->encoding);
|
||||
errors = PyString_AS_STRING(self->errors);
|
||||
|
||||
return PyUnicode_Decode(str, len, encoding, errors);
|
||||
}
|
||||
|
||||
static PyMethodDef UnicodeResultProcessor_methods[] = {
|
||||
{"process", (PyCFunction)UnicodeResultProcessor_process, METH_O,
|
||||
"The value processor itself."},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
static PyTypeObject UnicodeResultProcessorType = {
|
||||
PyObject_HEAD_INIT(NULL)
|
||||
0, /* ob_size */
|
||||
"sqlalchemy.cprocessors.UnicodeResultProcessor", /* tp_name */
|
||||
sizeof(UnicodeResultProcessor), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
0, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
0, /* tp_compare */
|
||||
0, /* tp_repr */
|
||||
0, /* tp_as_number */
|
||||
0, /* tp_as_sequence */
|
||||
0, /* tp_as_mapping */
|
||||
0, /* tp_hash */
|
||||
0, /* tp_call */
|
||||
0, /* tp_str */
|
||||
0, /* tp_getattro */
|
||||
0, /* tp_setattro */
|
||||
0, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
||||
"UnicodeResultProcessor objects", /* tp_doc */
|
||||
0, /* tp_traverse */
|
||||
0, /* tp_clear */
|
||||
0, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
0, /* tp_iter */
|
||||
0, /* tp_iternext */
|
||||
UnicodeResultProcessor_methods, /* tp_methods */
|
||||
0, /* tp_members */
|
||||
0, /* tp_getset */
|
||||
0, /* tp_base */
|
||||
0, /* tp_dict */
|
||||
0, /* tp_descr_get */
|
||||
0, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc)UnicodeResultProcessor_init, /* tp_init */
|
||||
0, /* tp_alloc */
|
||||
0, /* tp_new */
|
||||
};
|
||||
|
||||
/**************************
|
||||
* DecimalResultProcessor *
|
||||
**************************/
|
||||
|
||||
static int
|
||||
DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args,
|
||||
PyObject *kwds)
|
||||
{
|
||||
PyObject *type, *format;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "OS", &type, &format))
|
||||
return -1;
|
||||
|
||||
Py_INCREF(type);
|
||||
self->type = type;
|
||||
|
||||
Py_INCREF(format);
|
||||
self->format = format;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value)
|
||||
{
|
||||
PyObject *str, *result, *args;
|
||||
|
||||
if (value == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
if (PyFloat_CheckExact(value)) {
|
||||
/* Decimal does not accept float values directly */
|
||||
args = PyTuple_Pack(1, value);
|
||||
if (args == NULL)
|
||||
return NULL;
|
||||
|
||||
str = PyString_Format(self->format, args);
|
||||
if (str == NULL)
|
||||
return NULL;
|
||||
|
||||
result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
|
||||
Py_DECREF(str);
|
||||
return result;
|
||||
} else {
|
||||
return PyObject_CallFunctionObjArgs(self->type, value, NULL);
|
||||
}
|
||||
}
|
||||
|
||||
static PyMethodDef DecimalResultProcessor_methods[] = {
|
||||
{"process", (PyCFunction)DecimalResultProcessor_process, METH_O,
|
||||
"The value processor itself."},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
static PyTypeObject DecimalResultProcessorType = {
|
||||
PyObject_HEAD_INIT(NULL)
|
||||
0, /* ob_size */
|
||||
"sqlalchemy.DecimalResultProcessor", /* tp_name */
|
||||
sizeof(DecimalResultProcessor), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
0, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
0, /* tp_compare */
|
||||
0, /* tp_repr */
|
||||
0, /* tp_as_number */
|
||||
0, /* tp_as_sequence */
|
||||
0, /* tp_as_mapping */
|
||||
0, /* tp_hash */
|
||||
0, /* tp_call */
|
||||
0, /* tp_str */
|
||||
0, /* tp_getattro */
|
||||
0, /* tp_setattro */
|
||||
0, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
||||
"DecimalResultProcessor objects", /* tp_doc */
|
||||
0, /* tp_traverse */
|
||||
0, /* tp_clear */
|
||||
0, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
0, /* tp_iter */
|
||||
0, /* tp_iternext */
|
||||
DecimalResultProcessor_methods, /* tp_methods */
|
||||
0, /* tp_members */
|
||||
0, /* tp_getset */
|
||||
0, /* tp_base */
|
||||
0, /* tp_dict */
|
||||
0, /* tp_descr_get */
|
||||
0, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc)DecimalResultProcessor_init, /* tp_init */
|
||||
0, /* tp_alloc */
|
||||
0, /* tp_new */
|
||||
};
|
||||
|
||||
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
|
||||
#define PyMODINIT_FUNC void
|
||||
#endif
|
||||
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"int_to_boolean", int_to_boolean, METH_O,
|
||||
"Convert an integer to a boolean."},
|
||||
{"to_str", to_str, METH_O,
|
||||
"Convert any value to its string representation."},
|
||||
{"to_float", to_float, METH_O,
|
||||
"Convert any value to its floating point representation."},
|
||||
{"str_to_datetime", str_to_datetime, METH_O,
|
||||
"Convert an ISO string to a datetime.datetime object."},
|
||||
{"str_to_time", str_to_time, METH_O,
|
||||
"Convert an ISO string to a datetime.time object."},
|
||||
{"str_to_date", str_to_date, METH_O,
|
||||
"Convert an ISO string to a datetime.date object."},
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC
|
||||
initcprocessors(void)
|
||||
{
|
||||
PyObject *m;
|
||||
|
||||
UnicodeResultProcessorType.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&UnicodeResultProcessorType) < 0)
|
||||
return;
|
||||
|
||||
DecimalResultProcessorType.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&DecimalResultProcessorType) < 0)
|
||||
return;
|
||||
|
||||
m = Py_InitModule3("cprocessors", module_methods,
|
||||
"Module containing C versions of data processing functions.");
|
||||
if (m == NULL)
|
||||
return;
|
||||
|
||||
PyDateTime_IMPORT;
|
||||
|
||||
Py_INCREF(&UnicodeResultProcessorType);
|
||||
PyModule_AddObject(m, "UnicodeResultProcessor",
|
||||
(PyObject *)&UnicodeResultProcessorType);
|
||||
|
||||
Py_INCREF(&DecimalResultProcessorType);
|
||||
PyModule_AddObject(m, "DecimalResultProcessor",
|
||||
(PyObject *)&DecimalResultProcessorType);
|
||||
}
|
||||
|
||||
586
sqlalchemy/cextension/resultproxy.c
Normal file
586
sqlalchemy/cextension/resultproxy.c
Normal file
@@ -0,0 +1,586 @@
|
||||
/*
|
||||
resultproxy.c
|
||||
Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
|
||||
|
||||
This module is part of SQLAlchemy and is released under
|
||||
the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
*/
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
|
||||
/***********
|
||||
* Structs *
|
||||
***********/
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
PyObject *parent;
|
||||
PyObject *row;
|
||||
PyObject *processors;
|
||||
PyObject *keymap;
|
||||
} BaseRowProxy;
|
||||
|
||||
/****************
|
||||
* BaseRowProxy *
|
||||
****************/
|
||||
|
||||
static PyObject *
|
||||
safe_rowproxy_reconstructor(PyObject *self, PyObject *args)
|
||||
{
|
||||
PyObject *cls, *state, *tmp;
|
||||
BaseRowProxy *obj;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "OO", &cls, &state))
|
||||
return NULL;
|
||||
|
||||
obj = (BaseRowProxy *)PyObject_CallMethod(cls, "__new__", "O", cls);
|
||||
if (obj == NULL)
|
||||
return NULL;
|
||||
|
||||
tmp = PyObject_CallMethod((PyObject *)obj, "__setstate__", "O", state);
|
||||
if (tmp == NULL) {
|
||||
Py_DECREF(obj);
|
||||
return NULL;
|
||||
}
|
||||
Py_DECREF(tmp);
|
||||
|
||||
if (obj->parent == NULL || obj->row == NULL ||
|
||||
obj->processors == NULL || obj->keymap == NULL) {
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
"__setstate__ for BaseRowProxy subclasses must set values "
|
||||
"for parent, row, processors and keymap");
|
||||
Py_DECREF(obj);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return (PyObject *)obj;
|
||||
}
|
||||
|
||||
static int
|
||||
BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds)
|
||||
{
|
||||
PyObject *parent, *row, *processors, *keymap;
|
||||
|
||||
if (!PyArg_UnpackTuple(args, "BaseRowProxy", 4, 4,
|
||||
&parent, &row, &processors, &keymap))
|
||||
return -1;
|
||||
|
||||
Py_INCREF(parent);
|
||||
self->parent = parent;
|
||||
|
||||
if (!PyTuple_CheckExact(row)) {
|
||||
PyErr_SetString(PyExc_TypeError, "row must be a tuple");
|
||||
return -1;
|
||||
}
|
||||
Py_INCREF(row);
|
||||
self->row = row;
|
||||
|
||||
if (!PyList_CheckExact(processors)) {
|
||||
PyErr_SetString(PyExc_TypeError, "processors must be a list");
|
||||
return -1;
|
||||
}
|
||||
Py_INCREF(processors);
|
||||
self->processors = processors;
|
||||
|
||||
if (!PyDict_CheckExact(keymap)) {
|
||||
PyErr_SetString(PyExc_TypeError, "keymap must be a dict");
|
||||
return -1;
|
||||
}
|
||||
Py_INCREF(keymap);
|
||||
self->keymap = keymap;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* We need the reduce method because otherwise the default implementation
|
||||
* does very weird stuff for pickle protocol 0 and 1. It calls
|
||||
* BaseRowProxy.__new__(RowProxy_instance) upon *pickling*.
|
||||
*/
|
||||
static PyObject *
|
||||
BaseRowProxy_reduce(PyObject *self)
|
||||
{
|
||||
PyObject *method, *state;
|
||||
PyObject *module, *reconstructor, *cls;
|
||||
|
||||
method = PyObject_GetAttrString(self, "__getstate__");
|
||||
if (method == NULL)
|
||||
return NULL;
|
||||
|
||||
state = PyObject_CallObject(method, NULL);
|
||||
Py_DECREF(method);
|
||||
if (state == NULL)
|
||||
return NULL;
|
||||
|
||||
module = PyImport_ImportModule("sqlalchemy.engine.base");
|
||||
if (module == NULL)
|
||||
return NULL;
|
||||
|
||||
reconstructor = PyObject_GetAttrString(module, "rowproxy_reconstructor");
|
||||
Py_DECREF(module);
|
||||
if (reconstructor == NULL) {
|
||||
Py_DECREF(state);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
cls = PyObject_GetAttrString(self, "__class__");
|
||||
if (cls == NULL) {
|
||||
Py_DECREF(reconstructor);
|
||||
Py_DECREF(state);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return Py_BuildValue("(N(NN))", reconstructor, cls, state);
|
||||
}
|
||||
|
||||
static void
|
||||
BaseRowProxy_dealloc(BaseRowProxy *self)
|
||||
{
|
||||
Py_XDECREF(self->parent);
|
||||
Py_XDECREF(self->row);
|
||||
Py_XDECREF(self->processors);
|
||||
Py_XDECREF(self->keymap);
|
||||
self->ob_type->tp_free((PyObject *)self);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
|
||||
{
|
||||
Py_ssize_t num_values, num_processors;
|
||||
PyObject **valueptr, **funcptr, **resultptr;
|
||||
PyObject *func, *result, *processed_value;
|
||||
|
||||
num_values = Py_SIZE(values);
|
||||
num_processors = Py_SIZE(processors);
|
||||
if (num_values != num_processors) {
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
"number of values in row differ from number of column processors");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (astuple) {
|
||||
result = PyTuple_New(num_values);
|
||||
} else {
|
||||
result = PyList_New(num_values);
|
||||
}
|
||||
if (result == NULL)
|
||||
return NULL;
|
||||
|
||||
/* we don't need to use PySequence_Fast as long as values, processors and
|
||||
* result are simple tuple or lists. */
|
||||
valueptr = PySequence_Fast_ITEMS(values);
|
||||
funcptr = PySequence_Fast_ITEMS(processors);
|
||||
resultptr = PySequence_Fast_ITEMS(result);
|
||||
while (--num_values >= 0) {
|
||||
func = *funcptr;
|
||||
if (func != Py_None) {
|
||||
processed_value = PyObject_CallFunctionObjArgs(func, *valueptr,
|
||||
NULL);
|
||||
if (processed_value == NULL) {
|
||||
Py_DECREF(result);
|
||||
return NULL;
|
||||
}
|
||||
*resultptr = processed_value;
|
||||
} else {
|
||||
Py_INCREF(*valueptr);
|
||||
*resultptr = *valueptr;
|
||||
}
|
||||
valueptr++;
|
||||
funcptr++;
|
||||
resultptr++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyListObject *
|
||||
BaseRowProxy_values(BaseRowProxy *self)
|
||||
{
|
||||
return (PyListObject *)BaseRowProxy_processvalues(self->row,
|
||||
self->processors, 0);
|
||||
}
|
||||
|
||||
static PyTupleObject *
|
||||
BaseRowProxy_tuplevalues(BaseRowProxy *self)
|
||||
{
|
||||
return (PyTupleObject *)BaseRowProxy_processvalues(self->row,
|
||||
self->processors, 1);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_iter(BaseRowProxy *self)
|
||||
{
|
||||
PyObject *values, *result;
|
||||
|
||||
values = (PyObject *)BaseRowProxy_tuplevalues(self);
|
||||
if (values == NULL)
|
||||
return NULL;
|
||||
|
||||
result = PyObject_GetIter(values);
|
||||
Py_DECREF(values);
|
||||
if (result == NULL)
|
||||
return NULL;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static Py_ssize_t
|
||||
BaseRowProxy_length(BaseRowProxy *self)
|
||||
{
|
||||
return Py_SIZE(self->row);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
|
||||
{
|
||||
PyObject *processors, *values;
|
||||
PyObject *processor, *value;
|
||||
PyObject *record, *result, *indexobject;
|
||||
PyObject *exc_module, *exception;
|
||||
char *cstr_key;
|
||||
long index;
|
||||
|
||||
if (PyInt_CheckExact(key)) {
|
||||
index = PyInt_AS_LONG(key);
|
||||
} else if (PyLong_CheckExact(key)) {
|
||||
index = PyLong_AsLong(key);
|
||||
if ((index == -1) && PyErr_Occurred())
|
||||
/* -1 can be either the actual value, or an error flag. */
|
||||
return NULL;
|
||||
} else if (PySlice_Check(key)) {
|
||||
values = PyObject_GetItem(self->row, key);
|
||||
if (values == NULL)
|
||||
return NULL;
|
||||
|
||||
processors = PyObject_GetItem(self->processors, key);
|
||||
if (processors == NULL) {
|
||||
Py_DECREF(values);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
result = BaseRowProxy_processvalues(values, processors, 1);
|
||||
Py_DECREF(values);
|
||||
Py_DECREF(processors);
|
||||
return result;
|
||||
} else {
|
||||
record = PyDict_GetItem((PyObject *)self->keymap, key);
|
||||
if (record == NULL) {
|
||||
record = PyObject_CallMethod(self->parent, "_key_fallback",
|
||||
"O", key);
|
||||
if (record == NULL)
|
||||
return NULL;
|
||||
}
|
||||
|
||||
indexobject = PyTuple_GetItem(record, 1);
|
||||
if (indexobject == NULL)
|
||||
return NULL;
|
||||
|
||||
if (indexobject == Py_None) {
|
||||
exc_module = PyImport_ImportModule("sqlalchemy.exc");
|
||||
if (exc_module == NULL)
|
||||
return NULL;
|
||||
|
||||
exception = PyObject_GetAttrString(exc_module,
|
||||
"InvalidRequestError");
|
||||
Py_DECREF(exc_module);
|
||||
if (exception == NULL)
|
||||
return NULL;
|
||||
|
||||
cstr_key = PyString_AsString(key);
|
||||
if (cstr_key == NULL)
|
||||
return NULL;
|
||||
|
||||
PyErr_Format(exception,
|
||||
"Ambiguous column name '%s' in result set! "
|
||||
"try 'use_labels' option on select statement.", cstr_key);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
index = PyInt_AsLong(indexobject);
|
||||
if ((index == -1) && PyErr_Occurred())
|
||||
/* -1 can be either the actual value, or an error flag. */
|
||||
return NULL;
|
||||
}
|
||||
processor = PyList_GetItem(self->processors, index);
|
||||
if (processor == NULL)
|
||||
return NULL;
|
||||
|
||||
value = PyTuple_GetItem(self->row, index);
|
||||
if (value == NULL)
|
||||
return NULL;
|
||||
|
||||
if (processor != Py_None) {
|
||||
return PyObject_CallFunctionObjArgs(processor, value, NULL);
|
||||
} else {
|
||||
Py_INCREF(value);
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
|
||||
{
|
||||
PyObject *tmp;
|
||||
|
||||
if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) {
|
||||
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
|
||||
return NULL;
|
||||
PyErr_Clear();
|
||||
}
|
||||
else
|
||||
return tmp;
|
||||
|
||||
return BaseRowProxy_subscript(self, name);
|
||||
}
|
||||
|
||||
/***********************
|
||||
* getters and setters *
|
||||
***********************/
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getparent(BaseRowProxy *self, void *closure)
|
||||
{
|
||||
Py_INCREF(self->parent);
|
||||
return self->parent;
|
||||
}
|
||||
|
||||
static int
|
||||
BaseRowProxy_setparent(BaseRowProxy *self, PyObject *value, void *closure)
|
||||
{
|
||||
PyObject *module, *cls;
|
||||
|
||||
if (value == NULL) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Cannot delete the 'parent' attribute");
|
||||
return -1;
|
||||
}
|
||||
|
||||
module = PyImport_ImportModule("sqlalchemy.engine.base");
|
||||
if (module == NULL)
|
||||
return -1;
|
||||
|
||||
cls = PyObject_GetAttrString(module, "ResultMetaData");
|
||||
Py_DECREF(module);
|
||||
if (cls == NULL)
|
||||
return -1;
|
||||
|
||||
if (PyObject_IsInstance(value, cls) != 1) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"The 'parent' attribute value must be an instance of "
|
||||
"ResultMetaData");
|
||||
return -1;
|
||||
}
|
||||
Py_DECREF(cls);
|
||||
Py_XDECREF(self->parent);
|
||||
Py_INCREF(value);
|
||||
self->parent = value;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getrow(BaseRowProxy *self, void *closure)
|
||||
{
|
||||
Py_INCREF(self->row);
|
||||
return self->row;
|
||||
}
|
||||
|
||||
static int
|
||||
BaseRowProxy_setrow(BaseRowProxy *self, PyObject *value, void *closure)
|
||||
{
|
||||
if (value == NULL) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Cannot delete the 'row' attribute");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!PyTuple_CheckExact(value)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"The 'row' attribute value must be a tuple");
|
||||
return -1;
|
||||
}
|
||||
|
||||
Py_XDECREF(self->row);
|
||||
Py_INCREF(value);
|
||||
self->row = value;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getprocessors(BaseRowProxy *self, void *closure)
|
||||
{
|
||||
Py_INCREF(self->processors);
|
||||
return self->processors;
|
||||
}
|
||||
|
||||
static int
|
||||
BaseRowProxy_setprocessors(BaseRowProxy *self, PyObject *value, void *closure)
|
||||
{
|
||||
if (value == NULL) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Cannot delete the 'processors' attribute");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!PyList_CheckExact(value)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"The 'processors' attribute value must be a list");
|
||||
return -1;
|
||||
}
|
||||
|
||||
Py_XDECREF(self->processors);
|
||||
Py_INCREF(value);
|
||||
self->processors = value;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getkeymap(BaseRowProxy *self, void *closure)
|
||||
{
|
||||
Py_INCREF(self->keymap);
|
||||
return self->keymap;
|
||||
}
|
||||
|
||||
static int
|
||||
BaseRowProxy_setkeymap(BaseRowProxy *self, PyObject *value, void *closure)
|
||||
{
|
||||
if (value == NULL) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Cannot delete the 'keymap' attribute");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!PyDict_CheckExact(value)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"The 'keymap' attribute value must be a dict");
|
||||
return -1;
|
||||
}
|
||||
|
||||
Py_XDECREF(self->keymap);
|
||||
Py_INCREF(value);
|
||||
self->keymap = value;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyGetSetDef BaseRowProxy_getseters[] = {
|
||||
{"_parent",
|
||||
(getter)BaseRowProxy_getparent, (setter)BaseRowProxy_setparent,
|
||||
"ResultMetaData",
|
||||
NULL},
|
||||
{"_row",
|
||||
(getter)BaseRowProxy_getrow, (setter)BaseRowProxy_setrow,
|
||||
"Original row tuple",
|
||||
NULL},
|
||||
{"_processors",
|
||||
(getter)BaseRowProxy_getprocessors, (setter)BaseRowProxy_setprocessors,
|
||||
"list of type processors",
|
||||
NULL},
|
||||
{"_keymap",
|
||||
(getter)BaseRowProxy_getkeymap, (setter)BaseRowProxy_setkeymap,
|
||||
"Key to (processor, index) dict",
|
||||
NULL},
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyMethodDef BaseRowProxy_methods[] = {
|
||||
{"values", (PyCFunction)BaseRowProxy_values, METH_NOARGS,
|
||||
"Return the values represented by this BaseRowProxy as a list."},
|
||||
{"__reduce__", (PyCFunction)BaseRowProxy_reduce, METH_NOARGS,
|
||||
"Pickle support method."},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
static PySequenceMethods BaseRowProxy_as_sequence = {
|
||||
(lenfunc)BaseRowProxy_length, /* sq_length */
|
||||
0, /* sq_concat */
|
||||
0, /* sq_repeat */
|
||||
0, /* sq_item */
|
||||
0, /* sq_slice */
|
||||
0, /* sq_ass_item */
|
||||
0, /* sq_ass_slice */
|
||||
0, /* sq_contains */
|
||||
0, /* sq_inplace_concat */
|
||||
0, /* sq_inplace_repeat */
|
||||
};
|
||||
|
||||
static PyMappingMethods BaseRowProxy_as_mapping = {
|
||||
(lenfunc)BaseRowProxy_length, /* mp_length */
|
||||
(binaryfunc)BaseRowProxy_subscript, /* mp_subscript */
|
||||
0 /* mp_ass_subscript */
|
||||
};
|
||||
|
||||
static PyTypeObject BaseRowProxyType = {
|
||||
PyObject_HEAD_INIT(NULL)
|
||||
0, /* ob_size */
|
||||
"sqlalchemy.cresultproxy.BaseRowProxy", /* tp_name */
|
||||
sizeof(BaseRowProxy), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)BaseRowProxy_dealloc, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
0, /* tp_compare */
|
||||
0, /* tp_repr */
|
||||
0, /* tp_as_number */
|
||||
&BaseRowProxy_as_sequence, /* tp_as_sequence */
|
||||
&BaseRowProxy_as_mapping, /* tp_as_mapping */
|
||||
0, /* tp_hash */
|
||||
0, /* tp_call */
|
||||
0, /* tp_str */
|
||||
(getattrofunc)BaseRowProxy_getattro,/* tp_getattro */
|
||||
0, /* tp_setattro */
|
||||
0, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
||||
"BaseRowProxy is a abstract base class for RowProxy", /* tp_doc */
|
||||
0, /* tp_traverse */
|
||||
0, /* tp_clear */
|
||||
0, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
(getiterfunc)BaseRowProxy_iter, /* tp_iter */
|
||||
0, /* tp_iternext */
|
||||
BaseRowProxy_methods, /* tp_methods */
|
||||
0, /* tp_members */
|
||||
BaseRowProxy_getseters, /* tp_getset */
|
||||
0, /* tp_base */
|
||||
0, /* tp_dict */
|
||||
0, /* tp_descr_get */
|
||||
0, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc)BaseRowProxy_init, /* tp_init */
|
||||
0, /* tp_alloc */
|
||||
0 /* tp_new */
|
||||
};
|
||||
|
||||
|
||||
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
|
||||
#define PyMODINIT_FUNC void
|
||||
#endif
|
||||
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"safe_rowproxy_reconstructor", safe_rowproxy_reconstructor, METH_VARARGS,
|
||||
"reconstruct a RowProxy instance from its pickled form."},
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC
|
||||
initcresultproxy(void)
|
||||
{
|
||||
PyObject *m;
|
||||
|
||||
BaseRowProxyType.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&BaseRowProxyType) < 0)
|
||||
return;
|
||||
|
||||
m = Py_InitModule3("cresultproxy", module_methods,
|
||||
"Module containing C versions of core ResultProxy classes.");
|
||||
if (m == NULL)
|
||||
return;
|
||||
|
||||
Py_INCREF(&BaseRowProxyType);
|
||||
PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType);
|
||||
|
||||
}
|
||||
|
||||
6
sqlalchemy/connectors/__init__.py
Normal file
6
sqlalchemy/connectors/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
|
||||
|
||||
class Connector(object):
|
||||
pass
|
||||
|
||||
|
||||
146
sqlalchemy/connectors/mxodbc.py
Normal file
146
sqlalchemy/connectors/mxodbc.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Provide an SQLALchemy connector for the eGenix mxODBC commercial
|
||||
Python adapter for ODBC. This is not a free product, but eGenix
|
||||
provides SQLAlchemy with a license for use in continuous integration
|
||||
testing.
|
||||
|
||||
This has been tested for use with mxODBC 3.1.2 on SQL Server 2005
|
||||
and 2008, using the SQL Server Native driver. However, it is
|
||||
possible for this to be used on other database platforms.
|
||||
|
||||
For more info on mxODBC, see http://www.egenix.com/
|
||||
|
||||
"""
|
||||
|
||||
import sys
|
||||
import re
|
||||
import warnings
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy.connectors import Connector
|
||||
from sqlalchemy import types as sqltypes
|
||||
import sqlalchemy.processors as processors
|
||||
|
||||
class MxODBCConnector(Connector):
|
||||
driver='mxodbc'
|
||||
|
||||
supports_sane_multi_rowcount = False
|
||||
supports_unicode_statements = False
|
||||
supports_unicode_binds = False
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
# this classmethod will normally be replaced by an instance
|
||||
# attribute of the same name, so this is normally only called once.
|
||||
cls._load_mx_exceptions()
|
||||
platform = sys.platform
|
||||
if platform == 'win32':
|
||||
from mx.ODBC import Windows as module
|
||||
# this can be the string "linux2", and possibly others
|
||||
elif 'linux' in platform:
|
||||
from mx.ODBC import unixODBC as module
|
||||
elif platform == 'darwin':
|
||||
from mx.ODBC import iODBC as module
|
||||
else:
|
||||
raise ImportError, "Unrecognized platform for mxODBC import"
|
||||
return module
|
||||
|
||||
@classmethod
|
||||
def _load_mx_exceptions(cls):
|
||||
""" Import mxODBC exception classes into the module namespace,
|
||||
as if they had been imported normally. This is done here
|
||||
to avoid requiring all SQLAlchemy users to install mxODBC.
|
||||
"""
|
||||
global InterfaceError, ProgrammingError
|
||||
from mx.ODBC import InterfaceError
|
||||
from mx.ODBC import ProgrammingError
|
||||
|
||||
def on_connect(self):
|
||||
def connect(conn):
|
||||
conn.stringformat = self.dbapi.MIXED_STRINGFORMAT
|
||||
conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT
|
||||
conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT
|
||||
conn.errorhandler = self._error_handler()
|
||||
return connect
|
||||
|
||||
def _error_handler(self):
|
||||
""" Return a handler that adjusts mxODBC's raised Warnings to
|
||||
emit Python standard warnings.
|
||||
"""
|
||||
from mx.ODBC.Error import Warning as MxOdbcWarning
|
||||
def error_handler(connection, cursor, errorclass, errorvalue):
|
||||
|
||||
if issubclass(errorclass, MxOdbcWarning):
|
||||
errorclass.__bases__ = (Warning,)
|
||||
warnings.warn(message=str(errorvalue),
|
||||
category=errorclass,
|
||||
stacklevel=2)
|
||||
else:
|
||||
raise errorclass, errorvalue
|
||||
return error_handler
|
||||
|
||||
def create_connect_args(self, url):
|
||||
""" Return a tuple of *args,**kwargs for creating a connection.
|
||||
|
||||
The mxODBC 3.x connection constructor looks like this:
|
||||
|
||||
connect(dsn, user='', password='',
|
||||
clear_auto_commit=1, errorhandler=None)
|
||||
|
||||
This method translates the values in the provided uri
|
||||
into args and kwargs needed to instantiate an mxODBC Connection.
|
||||
|
||||
The arg 'errorhandler' is not used by SQLAlchemy and will
|
||||
not be populated.
|
||||
|
||||
"""
|
||||
opts = url.translate_connect_args(username='user')
|
||||
opts.update(url.query)
|
||||
args = opts.pop('host')
|
||||
opts.pop('port', None)
|
||||
opts.pop('database', None)
|
||||
return (args,), opts
|
||||
|
||||
def is_disconnect(self, e):
|
||||
# eGenix recommends checking connection.closed here,
|
||||
# but how can we get a handle on the current connection?
|
||||
if isinstance(e, self.dbapi.ProgrammingError):
|
||||
return "connection already closed" in str(e)
|
||||
elif isinstance(e, self.dbapi.Error):
|
||||
return '[08S01]' in str(e)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
# eGenix suggests using conn.dbms_version instead of what we're doing here
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
# 18 == pyodbc.SQL_DBMS_VER
|
||||
for n in r.split(dbapi_con.getinfo(18)[1]):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
|
||||
def do_execute(self, cursor, statement, parameters, context=None):
|
||||
if context:
|
||||
native_odbc_execute = context.execution_options.\
|
||||
get('native_odbc_execute', 'auto')
|
||||
if native_odbc_execute is True:
|
||||
# user specified native_odbc_execute=True
|
||||
cursor.execute(statement, parameters)
|
||||
elif native_odbc_execute is False:
|
||||
# user specified native_odbc_execute=False
|
||||
cursor.executedirect(statement, parameters)
|
||||
elif context.is_crud:
|
||||
# statement is UPDATE, DELETE, INSERT
|
||||
cursor.execute(statement, parameters)
|
||||
else:
|
||||
# all other statements
|
||||
cursor.executedirect(statement, parameters)
|
||||
else:
|
||||
cursor.executedirect(statement, parameters)
|
||||
113
sqlalchemy/connectors/pyodbc.py
Normal file
113
sqlalchemy/connectors/pyodbc.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from sqlalchemy.connectors import Connector
|
||||
from sqlalchemy.util import asbool
|
||||
|
||||
import sys
|
||||
import re
|
||||
import urllib
|
||||
import decimal
|
||||
|
||||
class PyODBCConnector(Connector):
|
||||
driver='pyodbc'
|
||||
|
||||
supports_sane_multi_rowcount = False
|
||||
# PyODBC unicode is broken on UCS-4 builds
|
||||
supports_unicode = sys.maxunicode == 65535
|
||||
supports_unicode_statements = supports_unicode
|
||||
supports_native_decimal = True
|
||||
default_paramstyle = 'named'
|
||||
|
||||
# for non-DSN connections, this should
|
||||
# hold the desired driver name
|
||||
pyodbc_driver_name = None
|
||||
|
||||
# will be set to True after initialize()
|
||||
# if the freetds.so is detected
|
||||
freetds = False
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('pyodbc')
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
opts.update(url.query)
|
||||
|
||||
keys = opts
|
||||
query = url.query
|
||||
|
||||
connect_args = {}
|
||||
for param in ('ansi', 'unicode_results', 'autocommit'):
|
||||
if param in keys:
|
||||
connect_args[param] = asbool(keys.pop(param))
|
||||
|
||||
if 'odbc_connect' in keys:
|
||||
connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
|
||||
else:
|
||||
dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
|
||||
if dsn_connection:
|
||||
connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
|
||||
else:
|
||||
port = ''
|
||||
if 'port' in keys and not 'port' in query:
|
||||
port = ',%d' % int(keys.pop('port'))
|
||||
|
||||
connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name),
|
||||
'Server=%s%s' % (keys.pop('host', ''), port),
|
||||
'Database=%s' % keys.pop('database', '') ]
|
||||
|
||||
user = keys.pop("user", None)
|
||||
if user:
|
||||
connectors.append("UID=%s" % user)
|
||||
connectors.append("PWD=%s" % keys.pop('password', ''))
|
||||
else:
|
||||
connectors.append("Trusted_Connection=Yes")
|
||||
|
||||
# if set to 'Yes', the ODBC layer will try to automagically convert
|
||||
# textual data from your database encoding to your client encoding
|
||||
# This should obviously be set to 'No' if you query a cp1253 encoded
|
||||
# database from a latin1 client...
|
||||
if 'odbc_autotranslate' in keys:
|
||||
connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
|
||||
|
||||
connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
|
||||
return [[";".join (connectors)], connect_args]
|
||||
|
||||
def is_disconnect(self, e):
|
||||
if isinstance(e, self.dbapi.ProgrammingError):
|
||||
return "The cursor's connection has been closed." in str(e) or \
|
||||
'Attempt to use a closed connection.' in str(e)
|
||||
elif isinstance(e, self.dbapi.Error):
|
||||
return '[08S01]' in str(e)
|
||||
else:
|
||||
return False
|
||||
|
||||
def initialize(self, connection):
|
||||
# determine FreeTDS first. can't issue SQL easily
|
||||
# without getting unicode_statements/binds set up.
|
||||
|
||||
pyodbc = self.dbapi
|
||||
|
||||
dbapi_con = connection.connection
|
||||
|
||||
self.freetds = bool(re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)))
|
||||
|
||||
# the "Py2K only" part here is theoretical.
|
||||
# have not tried pyodbc + python3.1 yet.
|
||||
# Py2K
|
||||
self.supports_unicode_statements = not self.freetds
|
||||
self.supports_unicode_binds = not self.freetds
|
||||
# end Py2K
|
||||
|
||||
# run other initialization which asks for user name, etc.
|
||||
super(PyODBCConnector, self).initialize(connection)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
48
sqlalchemy/connectors/zxJDBC.py
Normal file
48
sqlalchemy/connectors/zxJDBC.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import sys
|
||||
from sqlalchemy.connectors import Connector
|
||||
|
||||
class ZxJDBCConnector(Connector):
|
||||
driver = 'zxjdbc'
|
||||
|
||||
supports_sane_rowcount = False
|
||||
supports_sane_multi_rowcount = False
|
||||
|
||||
supports_unicode_binds = True
|
||||
supports_unicode_statements = sys.version > '2.5.0+'
|
||||
description_encoding = None
|
||||
default_paramstyle = 'qmark'
|
||||
|
||||
jdbc_db_name = None
|
||||
jdbc_driver_name = None
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
from com.ziclix.python.sql import zxJDBC
|
||||
return zxJDBC
|
||||
|
||||
def _driver_kwargs(self):
|
||||
"""Return kw arg dict to be sent to connect()."""
|
||||
return {}
|
||||
|
||||
def _create_jdbc_url(self, url):
|
||||
"""Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`"""
|
||||
return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host,
|
||||
url.port is not None and ':%s' % url.port or '',
|
||||
url.database)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = self._driver_kwargs()
|
||||
opts.update(url.query)
|
||||
return [[self._create_jdbc_url(url), url.username, url.password, self.jdbc_driver_name],
|
||||
opts]
|
||||
|
||||
def is_disconnect(self, e):
|
||||
if not isinstance(e, self.dbapi.ProgrammingError):
|
||||
return False
|
||||
e = str(e)
|
||||
return 'connection is closed' in e or 'cursor is closed' in e
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
# use connection.connection.dbversion, and parse appropriately
|
||||
# to get a tuple
|
||||
raise NotImplementedError()
|
||||
31
sqlalchemy/databases/__init__.py
Normal file
31
sqlalchemy/databases/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# __init__.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from sqlalchemy.dialects.sqlite import base as sqlite
|
||||
from sqlalchemy.dialects.postgresql import base as postgresql
|
||||
postgres = postgresql
|
||||
from sqlalchemy.dialects.mysql import base as mysql
|
||||
from sqlalchemy.dialects.oracle import base as oracle
|
||||
from sqlalchemy.dialects.firebird import base as firebird
|
||||
from sqlalchemy.dialects.maxdb import base as maxdb
|
||||
from sqlalchemy.dialects.informix import base as informix
|
||||
from sqlalchemy.dialects.mssql import base as mssql
|
||||
from sqlalchemy.dialects.access import base as access
|
||||
from sqlalchemy.dialects.sybase import base as sybase
|
||||
|
||||
|
||||
__all__ = (
|
||||
'access',
|
||||
'firebird',
|
||||
'informix',
|
||||
'maxdb',
|
||||
'mssql',
|
||||
'mysql',
|
||||
'postgresql',
|
||||
'sqlite',
|
||||
'oracle',
|
||||
'sybase',
|
||||
)
|
||||
12
sqlalchemy/dialects/__init__.py
Normal file
12
sqlalchemy/dialects/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
__all__ = (
|
||||
# 'access',
|
||||
# 'firebird',
|
||||
# 'informix',
|
||||
# 'maxdb',
|
||||
# 'mssql',
|
||||
'mysql',
|
||||
'oracle',
|
||||
'postgresql',
|
||||
'sqlite',
|
||||
# 'sybase',
|
||||
)
|
||||
0
sqlalchemy/dialects/access/__init__.py
Normal file
0
sqlalchemy/dialects/access/__init__.py
Normal file
418
sqlalchemy/dialects/access/base.py
Normal file
418
sqlalchemy/dialects/access/base.py
Normal file
@@ -0,0 +1,418 @@
|
||||
# access.py
|
||||
# Copyright (C) 2007 Paul Johnston, paj@pajhome.org.uk
|
||||
# Portions derived from jet2sql.py by Matt Keranen, mksql@yahoo.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
Support for the Microsoft Access database.
|
||||
|
||||
This dialect is *not* ported to SQLAlchemy 0.6.
|
||||
|
||||
This dialect is *not* tested on SQLAlchemy 0.6.
|
||||
|
||||
|
||||
"""
|
||||
from sqlalchemy import sql, schema, types, exc, pool
|
||||
from sqlalchemy.sql import compiler, expression
|
||||
from sqlalchemy.engine import default, base, reflection
|
||||
from sqlalchemy import processors
|
||||
|
||||
class AcNumeric(types.Numeric):
|
||||
def get_col_spec(self):
|
||||
return "NUMERIC"
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return processors.to_str
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
class AcFloat(types.Float):
|
||||
def get_col_spec(self):
|
||||
return "FLOAT"
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
"""By converting to string, we can use Decimal types round-trip."""
|
||||
return processors.to_str
|
||||
|
||||
class AcInteger(types.Integer):
|
||||
def get_col_spec(self):
|
||||
return "INTEGER"
|
||||
|
||||
class AcTinyInteger(types.Integer):
|
||||
def get_col_spec(self):
|
||||
return "TINYINT"
|
||||
|
||||
class AcSmallInteger(types.SmallInteger):
|
||||
def get_col_spec(self):
|
||||
return "SMALLINT"
|
||||
|
||||
class AcDateTime(types.DateTime):
|
||||
def __init__(self, *a, **kw):
|
||||
super(AcDateTime, self).__init__(False)
|
||||
|
||||
def get_col_spec(self):
|
||||
return "DATETIME"
|
||||
|
||||
class AcDate(types.Date):
|
||||
def __init__(self, *a, **kw):
|
||||
super(AcDate, self).__init__(False)
|
||||
|
||||
def get_col_spec(self):
|
||||
return "DATETIME"
|
||||
|
||||
class AcText(types.Text):
|
||||
def get_col_spec(self):
|
||||
return "MEMO"
|
||||
|
||||
class AcString(types.String):
|
||||
def get_col_spec(self):
|
||||
return "TEXT" + (self.length and ("(%d)" % self.length) or "")
|
||||
|
||||
class AcUnicode(types.Unicode):
|
||||
def get_col_spec(self):
|
||||
return "TEXT" + (self.length and ("(%d)" % self.length) or "")
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
class AcChar(types.CHAR):
|
||||
def get_col_spec(self):
|
||||
return "TEXT" + (self.length and ("(%d)" % self.length) or "")
|
||||
|
||||
class AcBinary(types.LargeBinary):
|
||||
def get_col_spec(self):
|
||||
return "BINARY"
|
||||
|
||||
class AcBoolean(types.Boolean):
|
||||
def get_col_spec(self):
|
||||
return "YESNO"
|
||||
|
||||
class AcTimeStamp(types.TIMESTAMP):
|
||||
def get_col_spec(self):
|
||||
return "TIMESTAMP"
|
||||
|
||||
class AccessExecutionContext(default.DefaultExecutionContext):
|
||||
def _has_implicit_sequence(self, column):
|
||||
if column.primary_key and column.autoincrement:
|
||||
if isinstance(column.type, types.Integer) and not column.foreign_keys:
|
||||
if column.default is None or (isinstance(column.default, schema.Sequence) and \
|
||||
column.default.optional):
|
||||
return True
|
||||
return False
|
||||
|
||||
def post_exec(self):
|
||||
"""If we inserted into a row with a COUNTER column, fetch the ID"""
|
||||
|
||||
if self.compiled.isinsert:
|
||||
tbl = self.compiled.statement.table
|
||||
if not hasattr(tbl, 'has_sequence'):
|
||||
tbl.has_sequence = None
|
||||
for column in tbl.c:
|
||||
if getattr(column, 'sequence', False) or self._has_implicit_sequence(column):
|
||||
tbl.has_sequence = column
|
||||
break
|
||||
|
||||
if bool(tbl.has_sequence):
|
||||
# TBD: for some reason _last_inserted_ids doesn't exist here
|
||||
# (but it does at corresponding point in mssql???)
|
||||
#if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
|
||||
self.cursor.execute("SELECT @@identity AS lastrowid")
|
||||
row = self.cursor.fetchone()
|
||||
self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:]
|
||||
# print "LAST ROW ID", self._last_inserted_ids
|
||||
|
||||
super(AccessExecutionContext, self).post_exec()
|
||||
|
||||
|
||||
const, daoEngine = None, None
|
||||
class AccessDialect(default.DefaultDialect):
|
||||
colspecs = {
|
||||
types.Unicode : AcUnicode,
|
||||
types.Integer : AcInteger,
|
||||
types.SmallInteger: AcSmallInteger,
|
||||
types.Numeric : AcNumeric,
|
||||
types.Float : AcFloat,
|
||||
types.DateTime : AcDateTime,
|
||||
types.Date : AcDate,
|
||||
types.String : AcString,
|
||||
types.LargeBinary : AcBinary,
|
||||
types.Boolean : AcBoolean,
|
||||
types.Text : AcText,
|
||||
types.CHAR: AcChar,
|
||||
types.TIMESTAMP: AcTimeStamp,
|
||||
}
|
||||
name = 'access'
|
||||
supports_sane_rowcount = False
|
||||
supports_sane_multi_rowcount = False
|
||||
|
||||
ported_sqla_06 = False
|
||||
|
||||
def type_descriptor(self, typeobj):
|
||||
newobj = types.adapt_type(typeobj, self.colspecs)
|
||||
return newobj
|
||||
|
||||
def __init__(self, **params):
|
||||
super(AccessDialect, self).__init__(**params)
|
||||
self.text_as_varchar = False
|
||||
self._dtbs = None
|
||||
|
||||
def dbapi(cls):
|
||||
import win32com.client, pythoncom
|
||||
|
||||
global const, daoEngine
|
||||
if const is None:
|
||||
const = win32com.client.constants
|
||||
for suffix in (".36", ".35", ".30"):
|
||||
try:
|
||||
daoEngine = win32com.client.gencache.EnsureDispatch("DAO.DBEngine" + suffix)
|
||||
break
|
||||
except pythoncom.com_error:
|
||||
pass
|
||||
else:
|
||||
raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.")
|
||||
|
||||
import pyodbc as module
|
||||
return module
|
||||
dbapi = classmethod(dbapi)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args()
|
||||
connectors = ["Driver={Microsoft Access Driver (*.mdb)}"]
|
||||
connectors.append("Dbq=%s" % opts["database"])
|
||||
user = opts.get("username", None)
|
||||
if user:
|
||||
connectors.append("UID=%s" % user)
|
||||
connectors.append("PWD=%s" % opts.get("password", ""))
|
||||
return [[";".join(connectors)], {}]
|
||||
|
||||
def last_inserted_ids(self):
|
||||
return self.context.last_inserted_ids
|
||||
|
||||
def do_execute(self, cursor, statement, params, **kwargs):
|
||||
if params == {}:
|
||||
params = ()
|
||||
super(AccessDialect, self).do_execute(cursor, statement, params, **kwargs)
|
||||
|
||||
def _execute(self, c, statement, parameters):
|
||||
try:
|
||||
if parameters == {}:
|
||||
parameters = ()
|
||||
c.execute(statement, parameters)
|
||||
self.context.rowcount = c.rowcount
|
||||
except Exception, e:
|
||||
raise exc.DBAPIError.instance(statement, parameters, e)
|
||||
|
||||
def has_table(self, connection, tablename, schema=None):
|
||||
# This approach seems to be more reliable that using DAO
|
||||
try:
|
||||
connection.execute('select top 1 * from [%s]' % tablename)
|
||||
return True
|
||||
except Exception, e:
|
||||
return False
|
||||
|
||||
def reflecttable(self, connection, table, include_columns):
|
||||
# This is defined in the function, as it relies on win32com constants,
|
||||
# that aren't imported until dbapi method is called
|
||||
if not hasattr(self, 'ischema_names'):
|
||||
self.ischema_names = {
|
||||
const.dbByte: AcBinary,
|
||||
const.dbInteger: AcInteger,
|
||||
const.dbLong: AcInteger,
|
||||
const.dbSingle: AcFloat,
|
||||
const.dbDouble: AcFloat,
|
||||
const.dbDate: AcDateTime,
|
||||
const.dbLongBinary: AcBinary,
|
||||
const.dbMemo: AcText,
|
||||
const.dbBoolean: AcBoolean,
|
||||
const.dbText: AcUnicode, # All Access strings are unicode
|
||||
const.dbCurrency: AcNumeric,
|
||||
}
|
||||
|
||||
# A fresh DAO connection is opened for each reflection
|
||||
# This is necessary, so we get the latest updates
|
||||
dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
|
||||
|
||||
try:
|
||||
for tbl in dtbs.TableDefs:
|
||||
if tbl.Name.lower() == table.name.lower():
|
||||
break
|
||||
else:
|
||||
raise exc.NoSuchTableError(table.name)
|
||||
|
||||
for col in tbl.Fields:
|
||||
coltype = self.ischema_names[col.Type]
|
||||
if col.Type == const.dbText:
|
||||
coltype = coltype(col.Size)
|
||||
|
||||
colargs = \
|
||||
{
|
||||
'nullable': not(col.Required or col.Attributes & const.dbAutoIncrField),
|
||||
}
|
||||
default = col.DefaultValue
|
||||
|
||||
if col.Attributes & const.dbAutoIncrField:
|
||||
colargs['default'] = schema.Sequence(col.Name + '_seq')
|
||||
elif default:
|
||||
if col.Type == const.dbBoolean:
|
||||
default = default == 'Yes' and '1' or '0'
|
||||
colargs['server_default'] = schema.DefaultClause(sql.text(default))
|
||||
|
||||
table.append_column(schema.Column(col.Name, coltype, **colargs))
|
||||
|
||||
# TBD: check constraints
|
||||
|
||||
# Find primary key columns first
|
||||
for idx in tbl.Indexes:
|
||||
if idx.Primary:
|
||||
for col in idx.Fields:
|
||||
thecol = table.c[col.Name]
|
||||
table.primary_key.add(thecol)
|
||||
if isinstance(thecol.type, AcInteger) and \
|
||||
not (thecol.default and isinstance(thecol.default.arg, schema.Sequence)):
|
||||
thecol.autoincrement = False
|
||||
|
||||
# Then add other indexes
|
||||
for idx in tbl.Indexes:
|
||||
if not idx.Primary:
|
||||
if len(idx.Fields) == 1:
|
||||
col = table.c[idx.Fields[0].Name]
|
||||
if not col.primary_key:
|
||||
col.index = True
|
||||
col.unique = idx.Unique
|
||||
else:
|
||||
pass # TBD: multi-column indexes
|
||||
|
||||
|
||||
for fk in dtbs.Relations:
|
||||
if fk.ForeignTable != table.name:
|
||||
continue
|
||||
scols = [c.ForeignName for c in fk.Fields]
|
||||
rcols = ['%s.%s' % (fk.Table, c.Name) for c in fk.Fields]
|
||||
table.append_constraint(schema.ForeignKeyConstraint(scols, rcols, link_to_name=True))
|
||||
|
||||
finally:
|
||||
dtbs.Close()
|
||||
|
||||
@reflection.cache
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
# A fresh DAO connection is opened for each reflection
|
||||
# This is necessary, so we get the latest updates
|
||||
dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
|
||||
|
||||
names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"]
|
||||
dtbs.Close()
|
||||
return names
|
||||
|
||||
|
||||
class AccessCompiler(compiler.SQLCompiler):
|
||||
extract_map = compiler.SQLCompiler.extract_map.copy()
|
||||
extract_map.update ({
|
||||
'month': 'm',
|
||||
'day': 'd',
|
||||
'year': 'yyyy',
|
||||
'second': 's',
|
||||
'hour': 'h',
|
||||
'doy': 'y',
|
||||
'minute': 'n',
|
||||
'quarter': 'q',
|
||||
'dow': 'w',
|
||||
'week': 'ww'
|
||||
})
|
||||
|
||||
def visit_select_precolumns(self, select):
|
||||
"""Access puts TOP, it's version of LIMIT here """
|
||||
s = select.distinct and "DISTINCT " or ""
|
||||
if select.limit:
|
||||
s += "TOP %s " % (select.limit)
|
||||
if select.offset:
|
||||
raise exc.InvalidRequestError('Access does not support LIMIT with an offset')
|
||||
return s
|
||||
|
||||
def limit_clause(self, select):
|
||||
"""Limit in access is after the select keyword"""
|
||||
return ""
|
||||
|
||||
def binary_operator_string(self, binary):
|
||||
"""Access uses "mod" instead of "%" """
|
||||
return binary.operator == '%' and 'mod' or binary.operator
|
||||
|
||||
def label_select_column(self, select, column, asfrom):
|
||||
if isinstance(column, expression.Function):
|
||||
return column.label()
|
||||
else:
|
||||
return super(AccessCompiler, self).label_select_column(select, column, asfrom)
|
||||
|
||||
function_rewrites = {'current_date': 'now',
|
||||
'current_timestamp': 'now',
|
||||
'length': 'len',
|
||||
}
|
||||
def visit_function(self, func):
|
||||
"""Access function names differ from the ANSI SQL names; rewrite common ones"""
|
||||
func.name = self.function_rewrites.get(func.name, func.name)
|
||||
return super(AccessCompiler, self).visit_function(func)
|
||||
|
||||
def for_update_clause(self, select):
|
||||
"""FOR UPDATE is not supported by Access; silently ignore"""
|
||||
return ''
|
||||
|
||||
# Strip schema
|
||||
def visit_table(self, table, asfrom=False, **kwargs):
|
||||
if asfrom:
|
||||
return self.preparer.quote(table.name, table.quote)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def visit_join(self, join, asfrom=False, **kwargs):
|
||||
return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \
|
||||
self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
|
||||
|
||||
def visit_extract(self, extract, **kw):
|
||||
field = self.extract_map.get(extract.field, extract.field)
|
||||
return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
|
||||
|
||||
class AccessDDLCompiler(compiler.DDLCompiler):
|
||||
def get_column_specification(self, column, **kwargs):
|
||||
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
|
||||
|
||||
# install a sequence if we have an implicit IDENTITY column
|
||||
if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
|
||||
column.autoincrement and isinstance(column.type, types.Integer) and not column.foreign_keys:
|
||||
if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
|
||||
column.sequence = schema.Sequence(column.name + '_seq')
|
||||
|
||||
if not column.nullable:
|
||||
colspec += " NOT NULL"
|
||||
|
||||
if hasattr(column, 'sequence'):
|
||||
column.table.has_sequence = column
|
||||
colspec = self.preparer.format_column(column) + " counter"
|
||||
else:
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
return colspec
|
||||
|
||||
def visit_drop_index(self, drop):
|
||||
index = drop.element
|
||||
self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False)))
|
||||
|
||||
class AccessIdentifierPreparer(compiler.IdentifierPreparer):
|
||||
reserved_words = compiler.RESERVED_WORDS.copy()
|
||||
reserved_words.update(['value', 'text'])
|
||||
def __init__(self, dialect):
|
||||
super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
|
||||
|
||||
|
||||
dialect = AccessDialect
|
||||
dialect.poolclass = pool.SingletonThreadPool
|
||||
dialect.statement_compiler = AccessCompiler
|
||||
dialect.ddlcompiler = AccessDDLCompiler
|
||||
dialect.preparer = AccessIdentifierPreparer
|
||||
dialect.execution_ctx_cls = AccessExecutionContext
|
||||
16
sqlalchemy/dialects/firebird/__init__.py
Normal file
16
sqlalchemy/dialects/firebird/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy.dialects.firebird import base, kinterbasdb
|
||||
|
||||
base.dialect = kinterbasdb.dialect
|
||||
|
||||
from sqlalchemy.dialects.firebird.base import \
|
||||
SMALLINT, BIGINT, FLOAT, FLOAT, DATE, TIME, \
|
||||
TEXT, NUMERIC, FLOAT, TIMESTAMP, VARCHAR, CHAR, BLOB,\
|
||||
dialect
|
||||
|
||||
__all__ = (
|
||||
'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME',
|
||||
'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB',
|
||||
'dialect'
|
||||
)
|
||||
|
||||
|
||||
619
sqlalchemy/dialects/firebird/base.py
Normal file
619
sqlalchemy/dialects/firebird/base.py
Normal file
@@ -0,0 +1,619 @@
|
||||
# firebird.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
Support for the Firebird database.
|
||||
|
||||
Connectivity is usually supplied via the kinterbasdb_ DBAPI module.
|
||||
|
||||
Dialects
|
||||
~~~~~~~~
|
||||
|
||||
Firebird offers two distinct dialects_ (not to be confused with a
|
||||
SQLAlchemy ``Dialect``):
|
||||
|
||||
dialect 1
|
||||
This is the old syntax and behaviour, inherited from Interbase pre-6.0.
|
||||
|
||||
dialect 3
|
||||
This is the newer and supported syntax, introduced in Interbase 6.0.
|
||||
|
||||
The SQLAlchemy Firebird dialect detects these versions and
|
||||
adjusts its representation of SQL accordingly. However,
|
||||
support for dialect 1 is not well tested and probably has
|
||||
incompatibilities.
|
||||
|
||||
Locking Behavior
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
Firebird locks tables aggressively. For this reason, a DROP TABLE may
|
||||
hang until other transactions are released. SQLAlchemy does its best
|
||||
to release transactions as quickly as possible. The most common cause
|
||||
of hanging transactions is a non-fully consumed result set, i.e.::
|
||||
|
||||
result = engine.execute("select * from table")
|
||||
row = result.fetchone()
|
||||
return
|
||||
|
||||
Where above, the ``ResultProxy`` has not been fully consumed. The
|
||||
connection will be returned to the pool and the transactional state
|
||||
rolled back once the Python garbage collector reclaims the objects
|
||||
which hold onto the connection, which often occurs asynchronously.
|
||||
The above use case can be alleviated by calling ``first()`` on the
|
||||
``ResultProxy`` which will fetch the first row and immediately close
|
||||
all remaining cursor/connection resources.
|
||||
|
||||
RETURNING support
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
Firebird 2.0 supports returning a result set from inserts, and 2.1
|
||||
extends that to deletes and updates. This is generically exposed by
|
||||
the SQLAlchemy ``returning()`` method, such as::
|
||||
|
||||
# INSERT..RETURNING
|
||||
result = table.insert().returning(table.c.col1, table.c.col2).\\
|
||||
values(name='foo')
|
||||
print result.fetchall()
|
||||
|
||||
# UPDATE..RETURNING
|
||||
raises = empl.update().returning(empl.c.id, empl.c.salary).\\
|
||||
where(empl.c.sales>100).\\
|
||||
values(dict(salary=empl.c.salary * 1.1))
|
||||
print raises.fetchall()
|
||||
|
||||
|
||||
.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html
|
||||
|
||||
"""
|
||||
|
||||
import datetime, re
|
||||
|
||||
from sqlalchemy import schema as sa_schema
|
||||
from sqlalchemy import exc, types as sqltypes, sql, util
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.engine import base, default, reflection
|
||||
from sqlalchemy.sql import compiler
|
||||
|
||||
|
||||
from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE,
|
||||
FLOAT, INTEGER, NUMERIC, SMALLINT,
|
||||
TEXT, TIME, TIMESTAMP, VARCHAR)
|
||||
|
||||
|
||||
RESERVED_WORDS = set([
|
||||
"active", "add", "admin", "after", "all", "alter", "and", "any", "as",
|
||||
"asc", "ascending", "at", "auto", "avg", "before", "begin", "between",
|
||||
"bigint", "bit_length", "blob", "both", "by", "case", "cast", "char",
|
||||
"character", "character_length", "char_length", "check", "close",
|
||||
"collate", "column", "commit", "committed", "computed", "conditional",
|
||||
"connect", "constraint", "containing", "count", "create", "cross",
|
||||
"cstring", "current", "current_connection", "current_date",
|
||||
"current_role", "current_time", "current_timestamp",
|
||||
"current_transaction", "current_user", "cursor", "database", "date",
|
||||
"day", "dec", "decimal", "declare", "default", "delete", "desc",
|
||||
"descending", "disconnect", "distinct", "do", "domain", "double",
|
||||
"drop", "else", "end", "entry_point", "escape", "exception",
|
||||
"execute", "exists", "exit", "external", "extract", "fetch", "file",
|
||||
"filter", "float", "for", "foreign", "from", "full", "function",
|
||||
"gdscode", "generator", "gen_id", "global", "grant", "group",
|
||||
"having", "hour", "if", "in", "inactive", "index", "inner",
|
||||
"input_type", "insensitive", "insert", "int", "integer", "into", "is",
|
||||
"isolation", "join", "key", "leading", "left", "length", "level",
|
||||
"like", "long", "lower", "manual", "max", "maximum_segment", "merge",
|
||||
"min", "minute", "module_name", "month", "names", "national",
|
||||
"natural", "nchar", "no", "not", "null", "numeric", "octet_length",
|
||||
"of", "on", "only", "open", "option", "or", "order", "outer",
|
||||
"output_type", "overflow", "page", "pages", "page_size", "parameter",
|
||||
"password", "plan", "position", "post_event", "precision", "primary",
|
||||
"privileges", "procedure", "protected", "rdb$db_key", "read", "real",
|
||||
"record_version", "recreate", "recursive", "references", "release",
|
||||
"reserv", "reserving", "retain", "returning_values", "returns",
|
||||
"revoke", "right", "rollback", "rows", "row_count", "savepoint",
|
||||
"schema", "second", "segment", "select", "sensitive", "set", "shadow",
|
||||
"shared", "singular", "size", "smallint", "snapshot", "some", "sort",
|
||||
"sqlcode", "stability", "start", "starting", "starts", "statistics",
|
||||
"sub_type", "sum", "suspend", "table", "then", "time", "timestamp",
|
||||
"to", "trailing", "transaction", "trigger", "trim", "uncommitted",
|
||||
"union", "unique", "update", "upper", "user", "using", "value",
|
||||
"values", "varchar", "variable", "varying", "view", "wait", "when",
|
||||
"where", "while", "with", "work", "write", "year",
|
||||
])
|
||||
|
||||
|
||||
colspecs = {
|
||||
}
|
||||
|
||||
ischema_names = {
|
||||
'SHORT': SMALLINT,
|
||||
'LONG': BIGINT,
|
||||
'QUAD': FLOAT,
|
||||
'FLOAT': FLOAT,
|
||||
'DATE': DATE,
|
||||
'TIME': TIME,
|
||||
'TEXT': TEXT,
|
||||
'INT64': NUMERIC,
|
||||
'DOUBLE': FLOAT,
|
||||
'TIMESTAMP': TIMESTAMP,
|
||||
'VARYING': VARCHAR,
|
||||
'CSTRING': CHAR,
|
||||
'BLOB': BLOB,
|
||||
}
|
||||
|
||||
|
||||
# TODO: date conversion types (should be implemented as _FBDateTime, _FBDate, etc.
|
||||
# as bind/result functionality is required)
|
||||
|
||||
class FBTypeCompiler(compiler.GenericTypeCompiler):
|
||||
def visit_boolean(self, type_):
|
||||
return self.visit_SMALLINT(type_)
|
||||
|
||||
def visit_datetime(self, type_):
|
||||
return self.visit_TIMESTAMP(type_)
|
||||
|
||||
def visit_TEXT(self, type_):
|
||||
return "BLOB SUB_TYPE 1"
|
||||
|
||||
def visit_BLOB(self, type_):
|
||||
return "BLOB SUB_TYPE 0"
|
||||
|
||||
|
||||
class FBCompiler(sql.compiler.SQLCompiler):
|
||||
"""Firebird specific idiosincrasies"""
|
||||
|
||||
def visit_mod(self, binary, **kw):
|
||||
# Firebird lacks a builtin modulo operator, but there is
|
||||
# an equivalent function in the ib_udf library.
|
||||
return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
|
||||
|
||||
def visit_alias(self, alias, asfrom=False, **kwargs):
|
||||
if self.dialect._version_two:
|
||||
return super(FBCompiler, self).visit_alias(alias, asfrom=asfrom, **kwargs)
|
||||
else:
|
||||
# Override to not use the AS keyword which FB 1.5 does not like
|
||||
if asfrom:
|
||||
alias_name = isinstance(alias.name, expression._generated_label) and \
|
||||
self._truncated_identifier("alias", alias.name) or alias.name
|
||||
|
||||
return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + \
|
||||
self.preparer.format_alias(alias, alias_name)
|
||||
else:
|
||||
return self.process(alias.original, **kwargs)
|
||||
|
||||
def visit_substring_func(self, func, **kw):
|
||||
s = self.process(func.clauses.clauses[0])
|
||||
start = self.process(func.clauses.clauses[1])
|
||||
if len(func.clauses.clauses) > 2:
|
||||
length = self.process(func.clauses.clauses[2])
|
||||
return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
|
||||
else:
|
||||
return "SUBSTRING(%s FROM %s)" % (s, start)
|
||||
|
||||
def visit_length_func(self, function, **kw):
|
||||
if self.dialect._version_two:
|
||||
return "char_length" + self.function_argspec(function)
|
||||
else:
|
||||
return "strlen" + self.function_argspec(function)
|
||||
|
||||
visit_char_length_func = visit_length_func
|
||||
|
||||
def function_argspec(self, func, **kw):
|
||||
if func.clauses is not None and len(func.clauses):
|
||||
return self.process(func.clause_expr)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def default_from(self):
|
||||
return " FROM rdb$database"
|
||||
|
||||
def visit_sequence(self, seq):
|
||||
return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
|
||||
|
||||
def get_select_precolumns(self, select):
|
||||
"""Called when building a ``SELECT`` statement, position is just
|
||||
before column list Firebird puts the limit and offset right
|
||||
after the ``SELECT``...
|
||||
"""
|
||||
|
||||
result = ""
|
||||
if select._limit:
|
||||
result += "FIRST %d " % select._limit
|
||||
if select._offset:
|
||||
result +="SKIP %d " % select._offset
|
||||
if select._distinct:
|
||||
result += "DISTINCT "
|
||||
return result
|
||||
|
||||
def limit_clause(self, select):
|
||||
"""Already taken care of in the `get_select_precolumns` method."""
|
||||
|
||||
return ""
|
||||
|
||||
def returning_clause(self, stmt, returning_cols):
|
||||
|
||||
columns = [
|
||||
self.process(
|
||||
self.label_select_column(None, c, asfrom=False),
|
||||
within_columns_clause=True,
|
||||
result_map=self.result_map
|
||||
)
|
||||
for c in expression._select_iterables(returning_cols)
|
||||
]
|
||||
return 'RETURNING ' + ', '.join(columns)
|
||||
|
||||
|
||||
class FBDDLCompiler(sql.compiler.DDLCompiler):
|
||||
"""Firebird syntactic idiosincrasies"""
|
||||
|
||||
def visit_create_sequence(self, create):
|
||||
"""Generate a ``CREATE GENERATOR`` statement for the sequence."""
|
||||
|
||||
# no syntax for these
|
||||
# http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html
|
||||
if create.element.start is not None:
|
||||
raise NotImplemented("Firebird SEQUENCE doesn't support START WITH")
|
||||
if create.element.increment is not None:
|
||||
raise NotImplemented("Firebird SEQUENCE doesn't support INCREMENT BY")
|
||||
|
||||
if self.dialect._version_two:
|
||||
return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
|
||||
else:
|
||||
return "CREATE GENERATOR %s" % self.preparer.format_sequence(create.element)
|
||||
|
||||
def visit_drop_sequence(self, drop):
|
||||
"""Generate a ``DROP GENERATOR`` statement for the sequence."""
|
||||
|
||||
if self.dialect._version_two:
|
||||
return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
|
||||
else:
|
||||
return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element)
|
||||
|
||||
|
||||
class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
|
||||
"""Install Firebird specific reserved words."""
|
||||
|
||||
reserved_words = RESERVED_WORDS
|
||||
|
||||
def __init__(self, dialect):
|
||||
super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
|
||||
|
||||
|
||||
class FBExecutionContext(default.DefaultExecutionContext):
|
||||
def fire_sequence(self, seq):
|
||||
"""Get the next value from the sequence using ``gen_id()``."""
|
||||
|
||||
return self._execute_scalar("SELECT gen_id(%s, 1) FROM rdb$database" % \
|
||||
self.dialect.identifier_preparer.format_sequence(seq))
|
||||
|
||||
|
||||
class FBDialect(default.DefaultDialect):
|
||||
"""Firebird dialect"""
|
||||
|
||||
name = 'firebird'
|
||||
|
||||
max_identifier_length = 31
|
||||
|
||||
supports_sequences = True
|
||||
sequences_optional = False
|
||||
supports_default_values = True
|
||||
postfetch_lastrowid = False
|
||||
|
||||
supports_native_boolean = False
|
||||
|
||||
requires_name_normalize = True
|
||||
supports_empty_insert = False
|
||||
|
||||
|
||||
statement_compiler = FBCompiler
|
||||
ddl_compiler = FBDDLCompiler
|
||||
preparer = FBIdentifierPreparer
|
||||
type_compiler = FBTypeCompiler
|
||||
execution_ctx_cls = FBExecutionContext
|
||||
|
||||
colspecs = colspecs
|
||||
ischema_names = ischema_names
|
||||
|
||||
# defaults to dialect ver. 3,
|
||||
# will be autodetected off upon
|
||||
# first connect
|
||||
_version_two = True
|
||||
|
||||
def initialize(self, connection):
|
||||
super(FBDialect, self).initialize(connection)
|
||||
self._version_two = self.server_version_info > (2, )
|
||||
if not self._version_two:
|
||||
# TODO: whatever other pre < 2.0 stuff goes here
|
||||
self.ischema_names = ischema_names.copy()
|
||||
self.ischema_names['TIMESTAMP'] = sqltypes.DATE
|
||||
self.colspecs = {
|
||||
sqltypes.DateTime: sqltypes.DATE
|
||||
}
|
||||
else:
|
||||
self.implicit_returning = True
|
||||
|
||||
def normalize_name(self, name):
|
||||
# Remove trailing spaces: FB uses a CHAR() type,
|
||||
# that is padded with spaces
|
||||
name = name and name.rstrip()
|
||||
if name is None:
|
||||
return None
|
||||
elif name.upper() == name and \
|
||||
not self.identifier_preparer._requires_quotes(name.lower()):
|
||||
return name.lower()
|
||||
else:
|
||||
return name
|
||||
|
||||
def denormalize_name(self, name):
|
||||
if name is None:
|
||||
return None
|
||||
elif name.lower() == name and \
|
||||
not self.identifier_preparer._requires_quotes(name.lower()):
|
||||
return name.upper()
|
||||
else:
|
||||
return name
|
||||
|
||||
def has_table(self, connection, table_name, schema=None):
|
||||
"""Return ``True`` if the given table exists, ignoring the `schema`."""
|
||||
|
||||
tblqry = """
|
||||
SELECT 1 FROM rdb$database
|
||||
WHERE EXISTS (SELECT rdb$relation_name
|
||||
FROM rdb$relations
|
||||
WHERE rdb$relation_name=?)
|
||||
"""
|
||||
c = connection.execute(tblqry, [self.denormalize_name(table_name)])
|
||||
return c.first() is not None
|
||||
|
||||
def has_sequence(self, connection, sequence_name, schema=None):
|
||||
"""Return ``True`` if the given sequence (generator) exists."""
|
||||
|
||||
genqry = """
|
||||
SELECT 1 FROM rdb$database
|
||||
WHERE EXISTS (SELECT rdb$generator_name
|
||||
FROM rdb$generators
|
||||
WHERE rdb$generator_name=?)
|
||||
"""
|
||||
c = connection.execute(genqry, [self.denormalize_name(sequence_name)])
|
||||
return c.first() is not None
|
||||
|
||||
@reflection.cache
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
s = """
|
||||
SELECT DISTINCT rdb$relation_name
|
||||
FROM rdb$relation_fields
|
||||
WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
|
||||
"""
|
||||
return [self.normalize_name(row[0]) for row in connection.execute(s)]
|
||||
|
||||
@reflection.cache
|
||||
def get_view_names(self, connection, schema=None, **kw):
|
||||
s = """
|
||||
SELECT distinct rdb$view_name
|
||||
FROM rdb$view_relations
|
||||
"""
|
||||
return [self.normalize_name(row[0]) for row in connection.execute(s)]
|
||||
|
||||
@reflection.cache
|
||||
def get_view_definition(self, connection, view_name, schema=None, **kw):
|
||||
qry = """
|
||||
SELECT rdb$view_source AS view_source
|
||||
FROM rdb$relations
|
||||
WHERE rdb$relation_name=?
|
||||
"""
|
||||
rp = connection.execute(qry, [self.denormalize_name(view_name)])
|
||||
row = rp.first()
|
||||
if row:
|
||||
return row['view_source']
|
||||
else:
|
||||
return None
|
||||
|
||||
@reflection.cache
|
||||
def get_primary_keys(self, connection, table_name, schema=None, **kw):
|
||||
# Query to extract the PK/FK constrained fields of the given table
|
||||
keyqry = """
|
||||
SELECT se.rdb$field_name AS fname
|
||||
FROM rdb$relation_constraints rc
|
||||
JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
|
||||
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
|
||||
"""
|
||||
tablename = self.denormalize_name(table_name)
|
||||
# get primary key fields
|
||||
c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
|
||||
pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()]
|
||||
return pkfields
|
||||
|
||||
@reflection.cache
|
||||
def get_column_sequence(self, connection, table_name, column_name, schema=None, **kw):
|
||||
tablename = self.denormalize_name(table_name)
|
||||
colname = self.denormalize_name(column_name)
|
||||
# Heuristic-query to determine the generator associated to a PK field
|
||||
genqry = """
|
||||
SELECT trigdep.rdb$depended_on_name AS fgenerator
|
||||
FROM rdb$dependencies tabdep
|
||||
JOIN rdb$dependencies trigdep
|
||||
ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
|
||||
AND trigdep.rdb$depended_on_type=14
|
||||
AND trigdep.rdb$dependent_type=2
|
||||
JOIN rdb$triggers trig ON trig.rdb$trigger_name=tabdep.rdb$dependent_name
|
||||
WHERE tabdep.rdb$depended_on_name=?
|
||||
AND tabdep.rdb$depended_on_type=0
|
||||
AND trig.rdb$trigger_type=1
|
||||
AND tabdep.rdb$field_name=?
|
||||
AND (SELECT count(*)
|
||||
FROM rdb$dependencies trigdep2
|
||||
WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
|
||||
"""
|
||||
genr = connection.execute(genqry, [tablename, colname]).first()
|
||||
if genr is not None:
|
||||
return dict(name=self.normalize_name(genr['fgenerator']))
|
||||
|
||||
@reflection.cache
|
||||
def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
# Query to extract the details of all the fields of the given table
|
||||
tblqry = """
|
||||
SELECT DISTINCT r.rdb$field_name AS fname,
|
||||
r.rdb$null_flag AS null_flag,
|
||||
t.rdb$type_name AS ftype,
|
||||
f.rdb$field_sub_type AS stype,
|
||||
f.rdb$field_length/COALESCE(cs.rdb$bytes_per_character,1) AS flen,
|
||||
f.rdb$field_precision AS fprec,
|
||||
f.rdb$field_scale AS fscale,
|
||||
COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault
|
||||
FROM rdb$relation_fields r
|
||||
JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
|
||||
JOIN rdb$types t
|
||||
ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE'
|
||||
LEFT JOIN rdb$character_sets cs ON f.rdb$character_set_id=cs.rdb$character_set_id
|
||||
WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
|
||||
ORDER BY r.rdb$field_position
|
||||
"""
|
||||
# get the PK, used to determine the eventual associated sequence
|
||||
pkey_cols = self.get_primary_keys(connection, table_name)
|
||||
|
||||
tablename = self.denormalize_name(table_name)
|
||||
# get all of the fields for this table
|
||||
c = connection.execute(tblqry, [tablename])
|
||||
cols = []
|
||||
while True:
|
||||
row = c.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
name = self.normalize_name(row['fname'])
|
||||
orig_colname = row['fname']
|
||||
|
||||
# get the data type
|
||||
colspec = row['ftype'].rstrip()
|
||||
coltype = self.ischema_names.get(colspec)
|
||||
if coltype is None:
|
||||
util.warn("Did not recognize type '%s' of column '%s'" %
|
||||
(colspec, name))
|
||||
coltype = sqltypes.NULLTYPE
|
||||
elif colspec == 'INT64':
|
||||
coltype = coltype(precision=row['fprec'], scale=row['fscale'] * -1)
|
||||
elif colspec in ('VARYING', 'CSTRING'):
|
||||
coltype = coltype(row['flen'])
|
||||
elif colspec == 'TEXT':
|
||||
coltype = TEXT(row['flen'])
|
||||
elif colspec == 'BLOB':
|
||||
if row['stype'] == 1:
|
||||
coltype = TEXT()
|
||||
else:
|
||||
coltype = BLOB()
|
||||
else:
|
||||
coltype = coltype(row)
|
||||
|
||||
# does it have a default value?
|
||||
defvalue = None
|
||||
if row['fdefault'] is not None:
|
||||
# the value comes down as "DEFAULT 'value'": there may be
|
||||
# more than one whitespace around the "DEFAULT" keyword
|
||||
# (see also http://tracker.firebirdsql.org/browse/CORE-356)
|
||||
defexpr = row['fdefault'].lstrip()
|
||||
assert defexpr[:8].rstrip()=='DEFAULT', "Unrecognized default value: %s" % defexpr
|
||||
defvalue = defexpr[8:].strip()
|
||||
if defvalue == 'NULL':
|
||||
# Redundant
|
||||
defvalue = None
|
||||
col_d = {
|
||||
'name' : name,
|
||||
'type' : coltype,
|
||||
'nullable' : not bool(row['null_flag']),
|
||||
'default' : defvalue
|
||||
}
|
||||
|
||||
if orig_colname.lower() == orig_colname:
|
||||
col_d['quote'] = True
|
||||
|
||||
# if the PK is a single field, try to see if its linked to
|
||||
# a sequence thru a trigger
|
||||
if len(pkey_cols)==1 and name==pkey_cols[0]:
|
||||
seq_d = self.get_column_sequence(connection, tablename, name)
|
||||
if seq_d is not None:
|
||||
col_d['sequence'] = seq_d
|
||||
|
||||
cols.append(col_d)
|
||||
return cols
|
||||
|
||||
@reflection.cache
|
||||
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
|
||||
# Query to extract the details of each UK/FK of the given table
|
||||
fkqry = """
|
||||
SELECT rc.rdb$constraint_name AS cname,
|
||||
cse.rdb$field_name AS fname,
|
||||
ix2.rdb$relation_name AS targetrname,
|
||||
se.rdb$field_name AS targetfname
|
||||
FROM rdb$relation_constraints rc
|
||||
JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
|
||||
JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
|
||||
JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name
|
||||
JOIN rdb$index_segments se
|
||||
ON se.rdb$index_name=ix2.rdb$index_name
|
||||
AND se.rdb$field_position=cse.rdb$field_position
|
||||
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
|
||||
ORDER BY se.rdb$index_name, se.rdb$field_position
|
||||
"""
|
||||
tablename = self.denormalize_name(table_name)
|
||||
|
||||
c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
|
||||
fks = util.defaultdict(lambda:{
|
||||
'name' : None,
|
||||
'constrained_columns' : [],
|
||||
'referred_schema' : None,
|
||||
'referred_table' : None,
|
||||
'referred_columns' : []
|
||||
})
|
||||
|
||||
for row in c:
|
||||
cname = self.normalize_name(row['cname'])
|
||||
fk = fks[cname]
|
||||
if not fk['name']:
|
||||
fk['name'] = cname
|
||||
fk['referred_table'] = self.normalize_name(row['targetrname'])
|
||||
fk['constrained_columns'].append(self.normalize_name(row['fname']))
|
||||
fk['referred_columns'].append(
|
||||
self.normalize_name(row['targetfname']))
|
||||
return fks.values()
|
||||
|
||||
@reflection.cache
|
||||
def get_indexes(self, connection, table_name, schema=None, **kw):
|
||||
qry = """
|
||||
SELECT ix.rdb$index_name AS index_name,
|
||||
ix.rdb$unique_flag AS unique_flag,
|
||||
ic.rdb$field_name AS field_name
|
||||
FROM rdb$indices ix
|
||||
JOIN rdb$index_segments ic
|
||||
ON ix.rdb$index_name=ic.rdb$index_name
|
||||
LEFT OUTER JOIN rdb$relation_constraints
|
||||
ON rdb$relation_constraints.rdb$index_name = ic.rdb$index_name
|
||||
WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
|
||||
AND rdb$relation_constraints.rdb$constraint_type IS NULL
|
||||
ORDER BY index_name, field_name
|
||||
"""
|
||||
c = connection.execute(qry, [self.denormalize_name(table_name)])
|
||||
|
||||
indexes = util.defaultdict(dict)
|
||||
for row in c:
|
||||
indexrec = indexes[row['index_name']]
|
||||
if 'name' not in indexrec:
|
||||
indexrec['name'] = self.normalize_name(row['index_name'])
|
||||
indexrec['column_names'] = []
|
||||
indexrec['unique'] = bool(row['unique_flag'])
|
||||
|
||||
indexrec['column_names'].append(self.normalize_name(row['field_name']))
|
||||
|
||||
return indexes.values()
|
||||
|
||||
def do_execute(self, cursor, statement, parameters, **kwargs):
|
||||
# kinterbase does not accept a None, but wants an empty list
|
||||
# when there are no arguments.
|
||||
cursor.execute(statement, parameters or [])
|
||||
|
||||
def do_rollback(self, connection):
|
||||
# Use the retaining feature, that keeps the transaction going
|
||||
connection.rollback(True)
|
||||
|
||||
def do_commit(self, connection):
|
||||
# Use the retaining feature, that keeps the transaction going
|
||||
connection.commit(True)
|
||||
120
sqlalchemy/dialects/firebird/kinterbasdb.py
Normal file
120
sqlalchemy/dialects/firebird/kinterbasdb.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# kinterbasdb.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
The most common way to connect to a Firebird engine is implemented by
|
||||
kinterbasdb__, currently maintained__ directly by the Firebird people.
|
||||
|
||||
The connection URL is of the form
|
||||
``firebird[+kinterbasdb]://user:password@host:port/path/to/db[?key=value&key=value...]``.
|
||||
|
||||
Kinterbasedb backend specific keyword arguments are:
|
||||
|
||||
type_conv
|
||||
select the kind of mapping done on the types: by default SQLAlchemy
|
||||
uses 200 with Unicode, datetime and decimal support (see details__).
|
||||
|
||||
concurrency_level
|
||||
set the backend policy with regards to threading issues: by default
|
||||
SQLAlchemy uses policy 1 (see details__).
|
||||
|
||||
__ http://sourceforge.net/projects/kinterbasdb
|
||||
__ http://firebirdsql.org/index.php?op=devel&sub=python
|
||||
__ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
|
||||
__ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler
|
||||
from sqlalchemy import util, types as sqltypes
|
||||
|
||||
class _FBNumeric_kinterbasdb(sqltypes.Numeric):
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
class FBDialect_kinterbasdb(FBDialect):
|
||||
driver = 'kinterbasdb'
|
||||
supports_sane_rowcount = False
|
||||
supports_sane_multi_rowcount = False
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
colspecs = util.update_copy(
|
||||
FBDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric:_FBNumeric_kinterbasdb
|
||||
}
|
||||
|
||||
)
|
||||
|
||||
def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
|
||||
super(FBDialect_kinterbasdb, self).__init__(**kwargs)
|
||||
|
||||
self.type_conv = type_conv
|
||||
self.concurrency_level = concurrency_level
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
k = __import__('kinterbasdb')
|
||||
return k
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
if opts.get('port'):
|
||||
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
|
||||
del opts['port']
|
||||
opts.update(url.query)
|
||||
|
||||
type_conv = opts.pop('type_conv', self.type_conv)
|
||||
concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
|
||||
|
||||
if self.dbapi is not None:
|
||||
initialized = getattr(self.dbapi, 'initialized', None)
|
||||
if initialized is None:
|
||||
# CVS rev 1.96 changed the name of the attribute:
|
||||
# http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
|
||||
initialized = getattr(self.dbapi, '_initialized', False)
|
||||
if not initialized:
|
||||
self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
|
||||
return ([], opts)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
"""Get the version of the Firebird server used by a connection.
|
||||
|
||||
Returns a tuple of (`major`, `minor`, `build`), three integers
|
||||
representing the version of the attached server.
|
||||
"""
|
||||
|
||||
# This is the simpler approach (the other uses the services api),
|
||||
# that for backward compatibility reasons returns a string like
|
||||
# LI-V6.3.3.12981 Firebird 2.0
|
||||
# where the first version is a fake one resembling the old
|
||||
# Interbase signature. This is more than enough for our purposes,
|
||||
# as this is mainly (only?) used by the testsuite.
|
||||
|
||||
from re import match
|
||||
|
||||
fbconn = connection.connection
|
||||
version = fbconn.server_version
|
||||
m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
|
||||
if not m:
|
||||
raise AssertionError("Could not determine version from string '%s'" % version)
|
||||
return tuple([int(x) for x in m.group(5, 6, 4)])
|
||||
|
||||
def is_disconnect(self, e):
|
||||
if isinstance(e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)):
|
||||
msg = str(e)
|
||||
return ('Unable to complete network request to host' in msg or
|
||||
'Invalid connection state' in msg or
|
||||
'Invalid cursor state' in msg)
|
||||
else:
|
||||
return False
|
||||
|
||||
dialect = FBDialect_kinterbasdb
|
||||
3
sqlalchemy/dialects/informix/__init__.py
Normal file
3
sqlalchemy/dialects/informix/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from sqlalchemy.dialects.informix import base, informixdb
|
||||
|
||||
base.dialect = informixdb.dialect
|
||||
306
sqlalchemy/dialects/informix/base.py
Normal file
306
sqlalchemy/dialects/informix/base.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# informix.py
|
||||
# Copyright (C) 2005,2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# coding: gbk
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""Support for the Informix database.
|
||||
|
||||
This dialect is *not* tested on SQLAlchemy 0.6.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import sql, schema, exc, pool, util
|
||||
from sqlalchemy.sql import compiler
|
||||
from sqlalchemy.engine import default, reflection
|
||||
from sqlalchemy import types as sqltypes
|
||||
|
||||
|
||||
class InfoDateTime(sqltypes.DateTime):
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
if value.microsecond:
|
||||
value = value.replace(microsecond=0)
|
||||
return value
|
||||
return process
|
||||
|
||||
class InfoTime(sqltypes.Time):
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
if value.microsecond:
|
||||
value = value.replace(microsecond=0)
|
||||
return value
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value.time()
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
colspecs = {
|
||||
sqltypes.DateTime : InfoDateTime,
|
||||
sqltypes.Time: InfoTime,
|
||||
}
|
||||
|
||||
|
||||
ischema_names = {
|
||||
0 : sqltypes.CHAR, # CHAR
|
||||
1 : sqltypes.SMALLINT, # SMALLINT
|
||||
2 : sqltypes.INTEGER, # INT
|
||||
3 : sqltypes.FLOAT, # Float
|
||||
3 : sqltypes.Float, # SmallFloat
|
||||
5 : sqltypes.DECIMAL, # DECIMAL
|
||||
6 : sqltypes.Integer, # Serial
|
||||
7 : sqltypes.DATE, # DATE
|
||||
8 : sqltypes.Numeric, # MONEY
|
||||
10 : sqltypes.DATETIME, # DATETIME
|
||||
11 : sqltypes.LargeBinary, # BYTE
|
||||
12 : sqltypes.TEXT, # TEXT
|
||||
13 : sqltypes.VARCHAR, # VARCHAR
|
||||
15 : sqltypes.NCHAR, # NCHAR
|
||||
16 : sqltypes.NVARCHAR, # NVARCHAR
|
||||
17 : sqltypes.Integer, # INT8
|
||||
18 : sqltypes.Integer, # Serial8
|
||||
43 : sqltypes.String, # LVARCHAR
|
||||
-1 : sqltypes.BLOB, # BLOB
|
||||
-1 : sqltypes.CLOB, # CLOB
|
||||
}
|
||||
|
||||
|
||||
class InfoTypeCompiler(compiler.GenericTypeCompiler):
|
||||
def visit_DATETIME(self, type_):
|
||||
return "DATETIME YEAR TO SECOND"
|
||||
|
||||
def visit_TIME(self, type_):
|
||||
return "DATETIME HOUR TO SECOND"
|
||||
|
||||
def visit_large_binary(self, type_):
|
||||
return "BYTE"
|
||||
|
||||
def visit_boolean(self, type_):
|
||||
return "SMALLINT"
|
||||
|
||||
class InfoSQLCompiler(compiler.SQLCompiler):
|
||||
|
||||
def default_from(self):
|
||||
return " from systables where tabname = 'systables' "
|
||||
|
||||
def get_select_precolumns(self, select):
|
||||
s = select._distinct and "DISTINCT " or ""
|
||||
# only has limit
|
||||
if select._limit:
|
||||
s += " FIRST %s " % select._limit
|
||||
else:
|
||||
s += ""
|
||||
return s
|
||||
|
||||
def visit_select(self, select):
|
||||
# the column in order by clause must in select too
|
||||
|
||||
def __label(c):
|
||||
try:
|
||||
return c._label.lower()
|
||||
except:
|
||||
return ''
|
||||
|
||||
# TODO: dont modify the original select, generate a new one
|
||||
a = [__label(c) for c in select._raw_columns]
|
||||
for c in select._order_by_clause.clauses:
|
||||
if __label(c) not in a:
|
||||
select.append_column(c)
|
||||
|
||||
return compiler.SQLCompiler.visit_select(self, select)
|
||||
|
||||
def limit_clause(self, select):
|
||||
if select._offset is not None and select._offset > 0:
|
||||
raise NotImplementedError("Informix does not support OFFSET")
|
||||
return ""
|
||||
|
||||
def visit_function(self, func):
|
||||
if func.name.lower() == 'current_date':
|
||||
return "today"
|
||||
elif func.name.lower() == 'current_time':
|
||||
return "CURRENT HOUR TO SECOND"
|
||||
elif func.name.lower() in ('current_timestamp', 'now'):
|
||||
return "CURRENT YEAR TO SECOND"
|
||||
else:
|
||||
return compiler.SQLCompiler.visit_function(self, func)
|
||||
|
||||
|
||||
class InfoDDLCompiler(compiler.DDLCompiler):
|
||||
def get_column_specification(self, column, first_pk=False):
|
||||
colspec = self.preparer.format_column(column)
|
||||
if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
|
||||
isinstance(column.type, sqltypes.Integer) and first_pk:
|
||||
colspec += " SERIAL"
|
||||
else:
|
||||
colspec += " " + self.dialect.type_compiler.process(column.type)
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
if not column.nullable:
|
||||
colspec += " NOT NULL"
|
||||
|
||||
return colspec
|
||||
|
||||
|
||||
class InfoIdentifierPreparer(compiler.IdentifierPreparer):
|
||||
def __init__(self, dialect):
|
||||
super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
|
||||
|
||||
def format_constraint(self, constraint):
|
||||
# informix doesnt support names for constraints
|
||||
return ''
|
||||
|
||||
def _requires_quotes(self, value):
|
||||
return False
|
||||
|
||||
class InformixDialect(default.DefaultDialect):
|
||||
name = 'informix'
|
||||
|
||||
max_identifier_length = 128 # adjusts at runtime based on server version
|
||||
|
||||
type_compiler = InfoTypeCompiler
|
||||
statement_compiler = InfoSQLCompiler
|
||||
ddl_compiler = InfoDDLCompiler
|
||||
preparer = InfoIdentifierPreparer
|
||||
colspecs = colspecs
|
||||
ischema_names = ischema_names
|
||||
|
||||
def initialize(self, connection):
|
||||
super(InformixDialect, self).initialize(connection)
|
||||
|
||||
# http://www.querix.com/support/knowledge-base/error_number_message/error_200
|
||||
if self.server_version_info < (9, 2):
|
||||
self.max_identifier_length = 18
|
||||
else:
|
||||
self.max_identifier_length = 128
|
||||
|
||||
def do_begin(self, connect):
|
||||
cu = connect.cursor()
|
||||
cu.execute('SET LOCK MODE TO WAIT')
|
||||
#cu.execute('SET ISOLATION TO REPEATABLE READ')
|
||||
|
||||
@reflection.cache
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
s = "select tabname from systables"
|
||||
return [row[0] for row in connection.execute(s)]
|
||||
|
||||
def has_table(self, connection, table_name, schema=None):
|
||||
cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower())
|
||||
return cursor.first() is not None
|
||||
|
||||
@reflection.cache
|
||||
def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from
|
||||
syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3
|
||||
where t1.tabid = t2.tabid and t2.tabname=?
|
||||
and t3.tabid = t2.tabid and t3.colno = t1.colno
|
||||
order by t1.colno""", table.name.lower())
|
||||
columns = []
|
||||
for name, colattr, collength, default, colno in rows:
|
||||
name = name.lower()
|
||||
if include_columns and name not in include_columns:
|
||||
continue
|
||||
|
||||
# in 7.31, coltype = 0x000
|
||||
# ^^-- column type
|
||||
# ^-- 1 not null, 0 null
|
||||
nullable, coltype = divmod(colattr, 256)
|
||||
if coltype not in (0, 13) and default:
|
||||
default = default.split()[-1]
|
||||
|
||||
if coltype == 0 or coltype == 13: # char, varchar
|
||||
coltype = ischema_names[coltype](collength)
|
||||
if default:
|
||||
default = "'%s'" % default
|
||||
elif coltype == 5: # decimal
|
||||
precision, scale = (collength & 0xFF00) >> 8, collength & 0xFF
|
||||
if scale == 255:
|
||||
scale = 0
|
||||
coltype = sqltypes.Numeric(precision, scale)
|
||||
else:
|
||||
try:
|
||||
coltype = ischema_names[coltype]
|
||||
except KeyError:
|
||||
util.warn("Did not recognize type '%s' of column '%s'" %
|
||||
(coltype, name))
|
||||
coltype = sqltypes.NULLTYPE
|
||||
|
||||
# TODO: nullability ??
|
||||
nullable = True
|
||||
|
||||
column_info = dict(name=name, type=coltype, nullable=nullable,
|
||||
default=default)
|
||||
columns.append(column_info)
|
||||
return columns
|
||||
|
||||
@reflection.cache
|
||||
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
|
||||
# FK
|
||||
c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type ,
|
||||
t4.colname as local_column , t7.tabname as remote_table ,
|
||||
t6.colname as remote_column
|
||||
from sysconstraints as t1 , systables as t2 ,
|
||||
sysindexes as t3 , syscolumns as t4 ,
|
||||
sysreferences as t5 , syscolumns as t6 , systables as t7 ,
|
||||
sysconstraints as t8 , sysindexes as t9
|
||||
where t1.tabid = t2.tabid and t2.tabname=? and t1.constrtype = 'R'
|
||||
and t3.tabid = t2.tabid and t3.idxname = t1.idxname
|
||||
and t4.tabid = t2.tabid and t4.colno = t3.part1
|
||||
and t5.constrid = t1.constrid and t8.constrid = t5.primary
|
||||
and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname
|
||||
and t7.tabid = t5.ptabid""", table.name.lower())
|
||||
|
||||
|
||||
def fkey_rec():
|
||||
return {
|
||||
'name' : None,
|
||||
'constrained_columns' : [],
|
||||
'referred_schema' : None,
|
||||
'referred_table' : None,
|
||||
'referred_columns' : []
|
||||
}
|
||||
|
||||
fkeys = util.defaultdict(fkey_rec)
|
||||
|
||||
for cons_name, cons_type, local_column, remote_table, remote_column in rows:
|
||||
|
||||
rec = fkeys[cons_name]
|
||||
rec['name'] = cons_name
|
||||
local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns']
|
||||
|
||||
if not rec['referred_table']:
|
||||
rec['referred_table'] = remote_table
|
||||
|
||||
local_cols.append(local_column)
|
||||
remote_cols.append(remote_column)
|
||||
|
||||
return fkeys.values()
|
||||
|
||||
@reflection.cache
|
||||
def get_primary_keys(self, connection, table_name, schema=None, **kw):
|
||||
c = connection.execute("""select t4.colname as local_column
|
||||
from sysconstraints as t1 , systables as t2 ,
|
||||
sysindexes as t3 , syscolumns as t4
|
||||
where t1.tabid = t2.tabid and t2.tabname=? and t1.constrtype = 'P'
|
||||
and t3.tabid = t2.tabid and t3.idxname = t1.idxname
|
||||
and t4.tabid = t2.tabid and t4.colno = t3.part1""", table.name.lower())
|
||||
return [r[0] for r in c.fetchall()]
|
||||
|
||||
@reflection.cache
|
||||
def get_indexes(self, connection, table_name, schema, **kw):
|
||||
# TODO
|
||||
return []
|
||||
46
sqlalchemy/dialects/informix/informixdb.py
Normal file
46
sqlalchemy/dialects/informix/informixdb.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from sqlalchemy.dialects.informix.base import InformixDialect
|
||||
from sqlalchemy.engine import default
|
||||
|
||||
class InformixExecutionContext_informixdb(default.DefaultExecutionContext):
|
||||
def post_exec(self):
|
||||
if self.isinsert:
|
||||
self._lastrowid = [self.cursor.sqlerrd[1]]
|
||||
|
||||
|
||||
class InformixDialect_informixdb(InformixDialect):
|
||||
driver = 'informixdb'
|
||||
default_paramstyle = 'qmark'
|
||||
execution_context_cls = InformixExecutionContext_informixdb
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('informixdb')
|
||||
|
||||
def create_connect_args(self, url):
|
||||
if url.host:
|
||||
dsn = '%s@%s' % (url.database, url.host)
|
||||
else:
|
||||
dsn = url.database
|
||||
|
||||
if url.username:
|
||||
opt = {'user': url.username, 'password': url.password}
|
||||
else:
|
||||
opt = {}
|
||||
|
||||
return ([dsn], opt)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
# http://informixdb.sourceforge.net/manual.html#inspecting-version-numbers
|
||||
vers = connection.dbms_version
|
||||
|
||||
# TODO: not tested
|
||||
return tuple([int(x) for x in vers.split('.')])
|
||||
|
||||
def is_disconnect(self, e):
|
||||
if isinstance(e, self.dbapi.OperationalError):
|
||||
return 'closed the connection' in str(e) or 'connection not open' in str(e)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
dialect = InformixDialect_informixdb
|
||||
3
sqlalchemy/dialects/maxdb/__init__.py
Normal file
3
sqlalchemy/dialects/maxdb/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from sqlalchemy.dialects.maxdb import base, sapdb
|
||||
|
||||
base.dialect = sapdb.dialect
|
||||
1058
sqlalchemy/dialects/maxdb/base.py
Normal file
1058
sqlalchemy/dialects/maxdb/base.py
Normal file
File diff suppressed because it is too large
Load Diff
17
sqlalchemy/dialects/maxdb/sapdb.py
Normal file
17
sqlalchemy/dialects/maxdb/sapdb.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from sqlalchemy.dialects.maxdb.base import MaxDBDialect
|
||||
|
||||
class MaxDBDialect_sapdb(MaxDBDialect):
|
||||
driver = 'sapdb'
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
from sapdb import dbapi as _dbapi
|
||||
return _dbapi
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
opts.update(url.query)
|
||||
return [], opts
|
||||
|
||||
|
||||
dialect = MaxDBDialect_sapdb
|
||||
19
sqlalchemy/dialects/mssql/__init__.py
Normal file
19
sqlalchemy/dialects/mssql/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, pymssql, zxjdbc, mxodbc
|
||||
|
||||
base.dialect = pyodbc.dialect
|
||||
|
||||
from sqlalchemy.dialects.mssql.base import \
|
||||
INTEGER, BIGINT, SMALLINT, TINYINT, VARCHAR, NVARCHAR, CHAR, \
|
||||
NCHAR, TEXT, NTEXT, DECIMAL, NUMERIC, FLOAT, DATETIME,\
|
||||
DATETIME2, DATETIMEOFFSET, DATE, TIME, SMALLDATETIME, \
|
||||
BINARY, VARBINARY, BIT, REAL, IMAGE, TIMESTAMP,\
|
||||
MONEY, SMALLMONEY, UNIQUEIDENTIFIER, SQL_VARIANT, dialect
|
||||
|
||||
|
||||
__all__ = (
|
||||
'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR',
|
||||
'NCHAR', 'TEXT', 'NTEXT', 'DECIMAL', 'NUMERIC', 'FLOAT', 'DATETIME',
|
||||
'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME',
|
||||
'BINARY', 'VARBINARY', 'BIT', 'REAL', 'IMAGE', 'TIMESTAMP',
|
||||
'MONEY', 'SMALLMONEY', 'UNIQUEIDENTIFIER', 'SQL_VARIANT', 'dialect'
|
||||
)
|
||||
59
sqlalchemy/dialects/mssql/adodbapi.py
Normal file
59
sqlalchemy/dialects/mssql/adodbapi.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
The adodbapi dialect is not implemented for 0.6 at this time.
|
||||
|
||||
"""
|
||||
from sqlalchemy import types as sqltypes, util
|
||||
from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
|
||||
import sys
|
||||
|
||||
class MSDateTime_adodbapi(MSDateTime):
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
# adodbapi will return datetimes with empty time values as datetime.date() objects.
|
||||
# Promote them back to full datetime.datetime()
|
||||
if type(value) is datetime.date:
|
||||
return datetime.datetime(value.year, value.month, value.day)
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class MSDialect_adodbapi(MSDialect):
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
supports_unicode = sys.maxunicode == 65535
|
||||
supports_unicode_statements = True
|
||||
driver = 'adodbapi'
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
import adodbapi as module
|
||||
return module
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MSDialect.colspecs,
|
||||
{
|
||||
sqltypes.DateTime:MSDateTime_adodbapi
|
||||
}
|
||||
)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
keys = url.query
|
||||
|
||||
connectors = ["Provider=SQLOLEDB"]
|
||||
if 'port' in keys:
|
||||
connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
|
||||
else:
|
||||
connectors.append ("Data Source=%s" % keys.get("host"))
|
||||
connectors.append ("Initial Catalog=%s" % keys.get("database"))
|
||||
user = keys.get("user")
|
||||
if user:
|
||||
connectors.append("User Id=%s" % user)
|
||||
connectors.append("Password=%s" % keys.get("password", ""))
|
||||
else:
|
||||
connectors.append("Integrated Security=SSPI")
|
||||
return [[";".join (connectors)], {}]
|
||||
|
||||
def is_disconnect(self, e):
|
||||
return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
|
||||
|
||||
dialect = MSDialect_adodbapi
|
||||
1297
sqlalchemy/dialects/mssql/base.py
Normal file
1297
sqlalchemy/dialects/mssql/base.py
Normal file
File diff suppressed because it is too large
Load Diff
83
sqlalchemy/dialects/mssql/information_schema.py
Normal file
83
sqlalchemy/dialects/mssql/information_schema.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from sqlalchemy import Table, MetaData, Column, ForeignKey
|
||||
from sqlalchemy.types import String, Unicode, Integer, TypeDecorator
|
||||
|
||||
ischema = MetaData()
|
||||
|
||||
class CoerceUnicode(TypeDecorator):
|
||||
impl = Unicode
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if isinstance(value, str):
|
||||
value = value.decode(dialect.encoding)
|
||||
return value
|
||||
|
||||
schemata = Table("SCHEMATA", ischema,
|
||||
Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
|
||||
Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
|
||||
Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
tables = Table("TABLES", ischema,
|
||||
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
columns = Table("COLUMNS", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("IS_NULLABLE", Integer, key="is_nullable"),
|
||||
Column("DATA_TYPE", String, key="data_type"),
|
||||
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
|
||||
Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"),
|
||||
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
|
||||
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
|
||||
Column("COLUMN_DEFAULT", Integer, key="column_default"),
|
||||
Column("COLLATION_NAME", String, key="collation_name"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
constraints = Table("TABLE_CONSTRAINTS", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
Column("CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
key_constraints = Table("KEY_COLUMN_USAGE", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
|
||||
Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
|
||||
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, key="unique_constraint_catalog"), # TODO: is CATLOG misspelled ?
|
||||
Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, key="unique_constraint_schema"),
|
||||
Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"),
|
||||
Column("MATCH_OPTION", String, key="match_option"),
|
||||
Column("UPDATE_RULE", String, key="update_rule"),
|
||||
Column("DELETE_RULE", String, key="delete_rule"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
views = Table("VIEWS", ischema,
|
||||
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
|
||||
Column("CHECK_OPTION", String, key="check_option"),
|
||||
Column("IS_UPDATABLE", String, key="is_updatable"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
83
sqlalchemy/dialects/mssql/mxodbc.py
Normal file
83
sqlalchemy/dialects/mssql/mxodbc.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Support for MS-SQL via mxODBC.
|
||||
|
||||
mxODBC is available at:
|
||||
|
||||
http://www.egenix.com/
|
||||
|
||||
This was tested with mxODBC 3.1.2 and the SQL Server Native
|
||||
Client connected to MSSQL 2005 and 2008 Express Editions.
|
||||
|
||||
Connecting
|
||||
~~~~~~~~~~
|
||||
|
||||
Connection is via DSN::
|
||||
|
||||
mssql+mxodbc://<username>:<password>@<dsnname>
|
||||
|
||||
Execution Modes
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
mxODBC features two styles of statement execution, using the ``cursor.execute()``
|
||||
and ``cursor.executedirect()`` methods (the second being an extension to the
|
||||
DBAPI specification). The former makes use of the native
|
||||
parameter binding services of the ODBC driver, while the latter uses string escaping.
|
||||
The primary advantage to native parameter binding is that the same statement, when
|
||||
executed many times, is only prepared once. Whereas the primary advantage to the
|
||||
latter is that the rules for bind parameter placement are relaxed. MS-SQL has very
|
||||
strict rules for native binds, including that they cannot be placed within the argument
|
||||
lists of function calls, anywhere outside the FROM, or even within subqueries within the
|
||||
FROM clause - making the usage of bind parameters within SELECT statements impossible for
|
||||
all but the most simplistic statements. For this reason, the mxODBC dialect uses the
|
||||
"native" mode by default only for INSERT, UPDATE, and DELETE statements, and uses the
|
||||
escaped string mode for all other statements. This behavior can be controlled completely
|
||||
via :meth:`~sqlalchemy.sql.expression.Executable.execution_options`
|
||||
using the ``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a value of
|
||||
``True`` will unconditionally use native bind parameters and a value of ``False`` will
|
||||
uncondtionally use string-escaped parameters.
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.connectors.mxodbc import MxODBCConnector
|
||||
from sqlalchemy.dialects.mssql.pyodbc import MSExecutionContext_pyodbc
|
||||
from sqlalchemy.dialects.mssql.base import (MSExecutionContext, MSDialect,
|
||||
MSSQLCompiler, MSSQLStrictCompiler,
|
||||
_MSDateTime, _MSDate, TIME)
|
||||
|
||||
|
||||
|
||||
class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
|
||||
"""
|
||||
The pyodbc execution context is useful for enabling
|
||||
SELECT SCOPE_IDENTITY in cases where OUTPUT clause
|
||||
does not work (tables with insert triggers).
|
||||
"""
|
||||
#todo - investigate whether the pyodbc execution context
|
||||
# is really only being used in cases where OUTPUT
|
||||
# won't work.
|
||||
|
||||
class MSDialect_mxodbc(MxODBCConnector, MSDialect):
|
||||
|
||||
# TODO: may want to use this only if FreeTDS is not in use,
|
||||
# since FreeTDS doesn't seem to use native binds.
|
||||
statement_compiler = MSSQLStrictCompiler
|
||||
execution_ctx_cls = MSExecutionContext_mxodbc
|
||||
colspecs = {
|
||||
#sqltypes.Numeric : _MSNumeric,
|
||||
sqltypes.DateTime : _MSDateTime,
|
||||
sqltypes.Date : _MSDate,
|
||||
sqltypes.Time : TIME,
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, description_encoding='latin-1', **params):
|
||||
super(MSDialect_mxodbc, self).__init__(**params)
|
||||
self.description_encoding = description_encoding
|
||||
|
||||
dialect = MSDialect_mxodbc
|
||||
|
||||
101
sqlalchemy/dialects/mssql/pymssql.py
Normal file
101
sqlalchemy/dialects/mssql/pymssql.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Support for the pymssql dialect.
|
||||
|
||||
This dialect supports pymssql 1.0 and greater.
|
||||
|
||||
pymssql is available at:
|
||||
|
||||
http://pymssql.sourceforge.net/
|
||||
|
||||
Connecting
|
||||
^^^^^^^^^^
|
||||
|
||||
Sample connect string::
|
||||
|
||||
mssql+pymssql://<username>:<password>@<freetds_name>
|
||||
|
||||
Adding "?charset=utf8" or similar will cause pymssql to return
|
||||
strings as Python unicode objects. This can potentially improve
|
||||
performance in some scenarios as decoding of strings is
|
||||
handled natively.
|
||||
|
||||
Limitations
|
||||
^^^^^^^^^^^
|
||||
|
||||
pymssql inherits a lot of limitations from FreeTDS, including:
|
||||
|
||||
* no support for multibyte schema identifiers
|
||||
* poor support for large decimals
|
||||
* poor support for binary fields
|
||||
* poor support for VARCHAR/CHAR fields over 255 characters
|
||||
|
||||
Please consult the pymssql documentation for further information.
|
||||
|
||||
"""
|
||||
from sqlalchemy.dialects.mssql.base import MSDialect
|
||||
from sqlalchemy import types as sqltypes, util, processors
|
||||
import re
|
||||
import decimal
|
||||
|
||||
class _MSNumeric_pymssql(sqltypes.Numeric):
|
||||
def result_processor(self, dialect, type_):
|
||||
if not self.asdecimal:
|
||||
return processors.to_float
|
||||
else:
|
||||
return sqltypes.Numeric.result_processor(self, dialect, type_)
|
||||
|
||||
class MSDialect_pymssql(MSDialect):
|
||||
supports_sane_rowcount = False
|
||||
max_identifier_length = 30
|
||||
driver = 'pymssql'
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MSDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric:_MSNumeric_pymssql,
|
||||
sqltypes.Float:sqltypes.Float,
|
||||
}
|
||||
)
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
module = __import__('pymssql')
|
||||
# pymmsql doesn't have a Binary method. we use string
|
||||
# TODO: monkeypatching here is less than ideal
|
||||
module.Binary = str
|
||||
|
||||
client_ver = tuple(int(x) for x in module.__version__.split("."))
|
||||
if client_ver < (1, ):
|
||||
util.warn("The pymssql dialect expects at least "
|
||||
"the 1.0 series of the pymssql DBAPI.")
|
||||
return module
|
||||
|
||||
def __init__(self, **params):
|
||||
super(MSDialect_pymssql, self).__init__(**params)
|
||||
self.use_scope_identity = True
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
vers = connection.scalar("select @@version")
|
||||
m = re.match(r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers)
|
||||
if m:
|
||||
return tuple(int(x) for x in m.group(1, 2, 3, 4))
|
||||
else:
|
||||
return None
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
opts.update(url.query)
|
||||
opts.pop('port', None)
|
||||
return [[], opts]
|
||||
|
||||
def is_disconnect(self, e):
|
||||
for msg in (
|
||||
"Error 10054",
|
||||
"Not connected to any MS SQL server",
|
||||
"Connection is closed"
|
||||
):
|
||||
if msg in str(e):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
dialect = MSDialect_pymssql
|
||||
197
sqlalchemy/dialects/mssql/pyodbc.py
Normal file
197
sqlalchemy/dialects/mssql/pyodbc.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Support for MS-SQL via pyodbc.
|
||||
|
||||
pyodbc is available at:
|
||||
|
||||
http://pypi.python.org/pypi/pyodbc/
|
||||
|
||||
Connecting
|
||||
^^^^^^^^^^
|
||||
|
||||
Examples of pyodbc connection string URLs:
|
||||
|
||||
* ``mssql+pyodbc://mydsn`` - connects using the specified DSN named ``mydsn``.
|
||||
The connection string that is created will appear like::
|
||||
|
||||
dsn=mydsn;Trusted_Connection=Yes
|
||||
|
||||
* ``mssql+pyodbc://user:pass@mydsn`` - connects using the DSN named
|
||||
``mydsn`` passing in the ``UID`` and ``PWD`` information. The
|
||||
connection string that is created will appear like::
|
||||
|
||||
dsn=mydsn;UID=user;PWD=pass
|
||||
|
||||
* ``mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english`` - connects
|
||||
using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
|
||||
information, plus the additional connection configuration option
|
||||
``LANGUAGE``. The connection string that is created will appear
|
||||
like::
|
||||
|
||||
dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
|
||||
|
||||
* ``mssql+pyodbc://user:pass@host/db`` - connects using a connection string
|
||||
dynamically created that would appear like::
|
||||
|
||||
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
|
||||
|
||||
* ``mssql+pyodbc://user:pass@host:123/db`` - connects using a connection
|
||||
string that is dynamically created, which also includes the port
|
||||
information using the comma syntax. If your connection string
|
||||
requires the port information to be passed as a ``port`` keyword
|
||||
see the next example. This will create the following connection
|
||||
string::
|
||||
|
||||
DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
|
||||
|
||||
* ``mssql+pyodbc://user:pass@host/db?port=123`` - connects using a connection
|
||||
string that is dynamically created that includes the port
|
||||
information as a separate ``port`` keyword. This will create the
|
||||
following connection string::
|
||||
|
||||
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
|
||||
|
||||
If you require a connection string that is outside the options
|
||||
presented above, use the ``odbc_connect`` keyword to pass in a
|
||||
urlencoded connection string. What gets passed in will be urldecoded
|
||||
and passed directly.
|
||||
|
||||
For example::
|
||||
|
||||
mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
|
||||
|
||||
would create the following connection string::
|
||||
|
||||
dsn=mydsn;Database=db
|
||||
|
||||
Encoding your connection string can be easily accomplished through
|
||||
the python shell. For example::
|
||||
|
||||
>>> import urllib
|
||||
>>> urllib.quote_plus('dsn=mydsn;Database=db')
|
||||
'dsn%3Dmydsn%3BDatabase%3Ddb'
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect
|
||||
from sqlalchemy.connectors.pyodbc import PyODBCConnector
|
||||
from sqlalchemy import types as sqltypes, util
|
||||
import decimal
|
||||
|
||||
class _MSNumeric_pyodbc(sqltypes.Numeric):
|
||||
"""Turns Decimals with adjusted() < 0 or > 7 into strings.
|
||||
|
||||
This is the only method that is proven to work with Pyodbc+MSSQL
|
||||
without crashing (floats can be used but seem to cause sporadic
|
||||
crashes).
|
||||
|
||||
"""
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
super_process = super(_MSNumeric_pyodbc, self).bind_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
if self.asdecimal and \
|
||||
isinstance(value, decimal.Decimal):
|
||||
|
||||
adjusted = value.adjusted()
|
||||
if adjusted < 0:
|
||||
return self._small_dec_to_string(value)
|
||||
elif adjusted > 7:
|
||||
return self._large_dec_to_string(value)
|
||||
|
||||
if super_process:
|
||||
return super_process(value)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
def _small_dec_to_string(self, value):
|
||||
return "%s0.%s%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
'0' * (abs(value.adjusted()) - 1),
|
||||
"".join([str(nint) for nint in value._int]))
|
||||
|
||||
def _large_dec_to_string(self, value):
|
||||
if 'E' in str(value):
|
||||
result = "%s%s%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
"".join([str(s) for s in value._int]),
|
||||
"0" * (value.adjusted() - (len(value._int)-1)))
|
||||
else:
|
||||
if (len(value._int) - 1) > value.adjusted():
|
||||
result = "%s%s.%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
"".join([str(s) for s in value._int][0:value.adjusted() + 1]),
|
||||
"".join([str(s) for s in value._int][value.adjusted() + 1:]))
|
||||
else:
|
||||
result = "%s%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
"".join([str(s) for s in value._int][0:value.adjusted() + 1]))
|
||||
return result
|
||||
|
||||
|
||||
class MSExecutionContext_pyodbc(MSExecutionContext):
|
||||
_embedded_scope_identity = False
|
||||
|
||||
def pre_exec(self):
|
||||
"""where appropriate, issue "select scope_identity()" in the same statement.
|
||||
|
||||
Background on why "scope_identity()" is preferable to "@@identity":
|
||||
http://msdn.microsoft.com/en-us/library/ms190315.aspx
|
||||
|
||||
Background on why we attempt to embed "scope_identity()" into the same
|
||||
statement as the INSERT:
|
||||
http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
|
||||
|
||||
"""
|
||||
|
||||
super(MSExecutionContext_pyodbc, self).pre_exec()
|
||||
|
||||
# don't embed the scope_identity select into an "INSERT .. DEFAULT VALUES"
|
||||
if self._select_lastrowid and \
|
||||
self.dialect.use_scope_identity and \
|
||||
len(self.parameters[0]):
|
||||
self._embedded_scope_identity = True
|
||||
|
||||
self.statement += "; select scope_identity()"
|
||||
|
||||
def post_exec(self):
|
||||
if self._embedded_scope_identity:
|
||||
# Fetch the last inserted id from the manipulated statement
|
||||
# We may have to skip over a number of result sets with no data (due to triggers, etc.)
|
||||
while True:
|
||||
try:
|
||||
# fetchall() ensures the cursor is consumed
|
||||
# without closing it (FreeTDS particularly)
|
||||
row = self.cursor.fetchall()[0]
|
||||
break
|
||||
except self.dialect.dbapi.Error, e:
|
||||
# no way around this - nextset() consumes the previous set
|
||||
# so we need to just keep flipping
|
||||
self.cursor.nextset()
|
||||
|
||||
self._lastrowid = int(row[0])
|
||||
else:
|
||||
super(MSExecutionContext_pyodbc, self).post_exec()
|
||||
|
||||
|
||||
class MSDialect_pyodbc(PyODBCConnector, MSDialect):
|
||||
|
||||
execution_ctx_cls = MSExecutionContext_pyodbc
|
||||
|
||||
pyodbc_driver_name = 'SQL Server'
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MSDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric:_MSNumeric_pyodbc
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, description_encoding='latin-1', **params):
|
||||
super(MSDialect_pyodbc, self).__init__(**params)
|
||||
self.description_encoding = description_encoding
|
||||
self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset')
|
||||
|
||||
dialect = MSDialect_pyodbc
|
||||
64
sqlalchemy/dialects/mssql/zxjdbc.py
Normal file
64
sqlalchemy/dialects/mssql/zxjdbc.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Support for the Microsoft SQL Server database via the zxjdbc JDBC
|
||||
connector.
|
||||
|
||||
JDBC Driver
|
||||
-----------
|
||||
|
||||
Requires the jTDS driver, available from: http://jtds.sourceforge.net/
|
||||
|
||||
Connecting
|
||||
----------
|
||||
|
||||
URLs are of the standard form of
|
||||
``mssql+zxjdbc://user:pass@host:port/dbname[?key=value&key=value...]``.
|
||||
|
||||
Additional arguments which may be specified either as query string
|
||||
arguments on the URL, or as keyword arguments to
|
||||
:func:`~sqlalchemy.create_engine()` will be passed as Connection
|
||||
properties to the underlying JDBC driver.
|
||||
|
||||
"""
|
||||
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
|
||||
from sqlalchemy.dialects.mssql.base import MSDialect, MSExecutionContext
|
||||
from sqlalchemy.engine import base
|
||||
|
||||
class MSExecutionContext_zxjdbc(MSExecutionContext):
|
||||
|
||||
_embedded_scope_identity = False
|
||||
|
||||
def pre_exec(self):
|
||||
super(MSExecutionContext_zxjdbc, self).pre_exec()
|
||||
# scope_identity after the fact returns null in jTDS so we must
|
||||
# embed it
|
||||
if self._select_lastrowid and self.dialect.use_scope_identity:
|
||||
self._embedded_scope_identity = True
|
||||
self.statement += "; SELECT scope_identity()"
|
||||
|
||||
def post_exec(self):
|
||||
if self._embedded_scope_identity:
|
||||
while True:
|
||||
try:
|
||||
row = self.cursor.fetchall()[0]
|
||||
break
|
||||
except self.dialect.dbapi.Error, e:
|
||||
self.cursor.nextset()
|
||||
self._lastrowid = int(row[0])
|
||||
|
||||
if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning:
|
||||
self._result_proxy = base.FullyBufferedResultProxy(self)
|
||||
|
||||
if self._enable_identity_insert:
|
||||
table = self.dialect.identifier_preparer.format_table(self.compiled.statement.table)
|
||||
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table)
|
||||
|
||||
|
||||
class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect):
|
||||
jdbc_db_name = 'jtds:sqlserver'
|
||||
jdbc_driver_name = 'net.sourceforge.jtds.jdbc.Driver'
|
||||
|
||||
execution_ctx_cls = MSExecutionContext_zxjdbc
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
return tuple(int(x) for x in connection.connection.dbversion.split('.'))
|
||||
|
||||
dialect = MSDialect_zxjdbc
|
||||
17
sqlalchemy/dialects/mysql/__init__.py
Normal file
17
sqlalchemy/dialects/mysql/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from sqlalchemy.dialects.mysql import base, mysqldb, oursql, pyodbc, zxjdbc, mysqlconnector
|
||||
|
||||
# default dialect
|
||||
base.dialect = mysqldb.dialect
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import \
|
||||
BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, DOUBLE, ENUM, DECIMAL,\
|
||||
FLOAT, INTEGER, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, MEDIUMTEXT, NCHAR, \
|
||||
NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT,\
|
||||
VARBINARY, VARCHAR, YEAR, dialect
|
||||
|
||||
__all__ = (
|
||||
'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE',
|
||||
'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT',
|
||||
'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP',
|
||||
'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect'
|
||||
)
|
||||
2528
sqlalchemy/dialects/mysql/base.py
Normal file
2528
sqlalchemy/dialects/mysql/base.py
Normal file
File diff suppressed because it is too large
Load Diff
132
sqlalchemy/dialects/mysql/mysqlconnector.py
Normal file
132
sqlalchemy/dialects/mysql/mysqlconnector.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Support for the MySQL database via the MySQL Connector/Python adapter.
|
||||
|
||||
MySQL Connector/Python is available at:
|
||||
|
||||
https://launchpad.net/myconnpy
|
||||
|
||||
Connecting
|
||||
-----------
|
||||
|
||||
Connect string format::
|
||||
|
||||
mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import (MySQLDialect,
|
||||
MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer,
|
||||
BIT)
|
||||
|
||||
from sqlalchemy.engine import base as engine_base, default
|
||||
from sqlalchemy.sql import operators as sql_operators
|
||||
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
|
||||
from sqlalchemy import processors
|
||||
|
||||
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
|
||||
|
||||
def get_lastrowid(self):
|
||||
return self.cursor.lastrowid
|
||||
|
||||
|
||||
class MySQLCompiler_mysqlconnector(MySQLCompiler):
|
||||
def visit_mod(self, binary, **kw):
|
||||
return self.process(binary.left) + " %% " + self.process(binary.right)
|
||||
|
||||
def post_process_text(self, text):
|
||||
return text.replace('%', '%%')
|
||||
|
||||
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
|
||||
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value.replace("%", "%%")
|
||||
|
||||
class _myconnpyBIT(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""MySQL-connector already converts mysql bits, so."""
|
||||
|
||||
return None
|
||||
|
||||
class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
driver = 'mysqlconnector'
|
||||
supports_unicode_statements = True
|
||||
supports_unicode_binds = True
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
default_paramstyle = 'format'
|
||||
execution_ctx_cls = MySQLExecutionContext_mysqlconnector
|
||||
statement_compiler = MySQLCompiler_mysqlconnector
|
||||
|
||||
preparer = MySQLIdentifierPreparer_mysqlconnector
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MySQLDialect.colspecs,
|
||||
{
|
||||
BIT: _myconnpyBIT,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
from mysql import connector
|
||||
return connector
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, 'buffered', bool)
|
||||
util.coerce_kw_type(opts, 'raise_on_warnings', bool)
|
||||
opts['buffered'] = True
|
||||
opts['raise_on_warnings'] = True
|
||||
|
||||
# FOUND_ROWS must be set in ClientFlag to enable
|
||||
# supports_sane_rowcount.
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
from mysql.connector.constants import ClientFlag
|
||||
client_flags = opts.get('client_flags', ClientFlag.get_default())
|
||||
client_flags |= ClientFlag.FOUND_ROWS
|
||||
opts['client_flags'] = client_flags
|
||||
except:
|
||||
pass
|
||||
return [[], opts]
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
|
||||
from mysql.connector.constants import ClientFlag
|
||||
dbapi_con.set_client_flag(ClientFlag.FOUND_ROWS)
|
||||
|
||||
version = dbapi_con.get_server_version()
|
||||
return tuple(version)
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
return connection.connection.get_characterset_info()
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
try:
|
||||
return exception.orig.errno
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def is_disconnect(self, e):
|
||||
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
|
||||
exceptions = (self.dbapi.OperationalError,self.dbapi.InterfaceError)
|
||||
if isinstance(e, exceptions):
|
||||
return e.errno in errnos
|
||||
else:
|
||||
return False
|
||||
|
||||
def _compat_fetchall(self, rp, charset=None):
|
||||
return rp.fetchall()
|
||||
|
||||
def _compat_fetchone(self, rp, charset=None):
|
||||
return rp.fetchone()
|
||||
|
||||
dialect = MySQLDialect_mysqlconnector
|
||||
202
sqlalchemy/dialects/mysql/mysqldb.py
Normal file
202
sqlalchemy/dialects/mysql/mysqldb.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Support for the MySQL database via the MySQL-python adapter.
|
||||
|
||||
MySQL-Python is available at:
|
||||
|
||||
http://sourceforge.net/projects/mysql-python
|
||||
|
||||
At least version 1.2.1 or 1.2.2 should be used.
|
||||
|
||||
Connecting
|
||||
-----------
|
||||
|
||||
Connect string format::
|
||||
|
||||
mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
|
||||
Character Sets
|
||||
--------------
|
||||
|
||||
Many MySQL server installations default to a ``latin1`` encoding for client
|
||||
connections. All data sent through the connection will be converted into
|
||||
``latin1``, even if you have ``utf8`` or another character set on your tables
|
||||
and columns. With versions 4.1 and higher, you can change the connection
|
||||
character set either through server configuration or by including the
|
||||
``charset`` parameter in the URL used for ``create_engine``. The ``charset``
|
||||
option is passed through to MySQL-Python and has the side-effect of also
|
||||
enabling ``use_unicode`` in the driver by default. For regular encoded
|
||||
strings, also pass ``use_unicode=0`` in the connection arguments::
|
||||
|
||||
# set client encoding to utf8; all strings come back as unicode
|
||||
create_engine('mysql+mysqldb:///mydb?charset=utf8')
|
||||
|
||||
# set client encoding to utf8; all strings come back as utf8 str
|
||||
create_engine('mysql+mysqldb:///mydb?charset=utf8&use_unicode=0')
|
||||
|
||||
Known Issues
|
||||
-------------
|
||||
|
||||
MySQL-python at least as of version 1.2.2 has a serious memory leak related
|
||||
to unicode conversion, a feature which is disabled via ``use_unicode=0``.
|
||||
The recommended connection form with SQLAlchemy is::
|
||||
|
||||
engine = create_engine('mysql://scott:tiger@localhost/test?charset=utf8&use_unicode=0', pool_recycle=3600)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import (MySQLDialect, MySQLExecutionContext,
|
||||
MySQLCompiler, MySQLIdentifierPreparer)
|
||||
from sqlalchemy.engine import base as engine_base, default
|
||||
from sqlalchemy.sql import operators as sql_operators
|
||||
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
|
||||
from sqlalchemy import processors
|
||||
|
||||
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
if hasattr(self, '_rowcount'):
|
||||
return self._rowcount
|
||||
else:
|
||||
return self.cursor.rowcount
|
||||
|
||||
|
||||
class MySQLCompiler_mysqldb(MySQLCompiler):
|
||||
def visit_mod(self, binary, **kw):
|
||||
return self.process(binary.left) + " %% " + self.process(binary.right)
|
||||
|
||||
def post_process_text(self, text):
|
||||
return text.replace('%', '%%')
|
||||
|
||||
|
||||
class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer):
|
||||
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value.replace("%", "%%")
|
||||
|
||||
class MySQLDialect_mysqldb(MySQLDialect):
|
||||
driver = 'mysqldb'
|
||||
supports_unicode_statements = False
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
default_paramstyle = 'format'
|
||||
execution_ctx_cls = MySQLExecutionContext_mysqldb
|
||||
statement_compiler = MySQLCompiler_mysqldb
|
||||
preparer = MySQLIdentifierPreparer_mysqldb
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MySQLDialect.colspecs,
|
||||
{
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('MySQLdb')
|
||||
|
||||
def do_executemany(self, cursor, statement, parameters, context=None):
|
||||
rowcount = cursor.executemany(statement, parameters)
|
||||
if context is not None:
|
||||
context._rowcount = rowcount
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(database='db', username='user',
|
||||
password='passwd')
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, 'compress', bool)
|
||||
util.coerce_kw_type(opts, 'connect_timeout', int)
|
||||
util.coerce_kw_type(opts, 'client_flag', int)
|
||||
util.coerce_kw_type(opts, 'local_infile', int)
|
||||
# Note: using either of the below will cause all strings to be returned
|
||||
# as Unicode, both in raw SQL operations and with column types like
|
||||
# String and MSString.
|
||||
util.coerce_kw_type(opts, 'use_unicode', bool)
|
||||
util.coerce_kw_type(opts, 'charset', str)
|
||||
|
||||
# Rich values 'cursorclass' and 'conv' are not supported via
|
||||
# query string.
|
||||
|
||||
ssl = {}
|
||||
for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']:
|
||||
if key in opts:
|
||||
ssl[key[4:]] = opts[key]
|
||||
util.coerce_kw_type(ssl, key[4:], str)
|
||||
del opts[key]
|
||||
if ssl:
|
||||
opts['ssl'] = ssl
|
||||
|
||||
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
|
||||
# supports_sane_rowcount.
|
||||
client_flag = opts.get('client_flag', 0)
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
from MySQLdb.constants import CLIENT as CLIENT_FLAGS
|
||||
client_flag |= CLIENT_FLAGS.FOUND_ROWS
|
||||
except:
|
||||
pass
|
||||
opts['client_flag'] = client_flag
|
||||
return [[], opts]
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
for n in r.split(dbapi_con.get_server_info()):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
try:
|
||||
return exception.orig.args[0]
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
# Note: MySQL-python 1.2.1c7 seems to ignore changes made
|
||||
# on a connection via set_character_set()
|
||||
if self.server_version_info < (4, 1, 0):
|
||||
try:
|
||||
return connection.connection.character_set_name()
|
||||
except AttributeError:
|
||||
# < 1.2.1 final MySQL-python drivers have no charset support.
|
||||
# a query is needed.
|
||||
pass
|
||||
|
||||
# Prefer 'character_set_results' for the current connection over the
|
||||
# value in the driver. SET NAMES or individual variable SETs will
|
||||
# change the charset without updating the driver's view of the world.
|
||||
#
|
||||
# If it's decided that issuing that sort of SQL leaves you SOL, then
|
||||
# this can prefer the driver value.
|
||||
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
|
||||
opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
|
||||
|
||||
if 'character_set_results' in opts:
|
||||
return opts['character_set_results']
|
||||
try:
|
||||
return connection.connection.character_set_name()
|
||||
except AttributeError:
|
||||
# Still no charset on < 1.2.1 final...
|
||||
if 'character_set' in opts:
|
||||
return opts['character_set']
|
||||
else:
|
||||
util.warn(
|
||||
"Could not detect the connection character set with this "
|
||||
"combination of MySQL server and MySQL-python. "
|
||||
"MySQL-python >= 1.2.2 is recommended. Assuming latin1.")
|
||||
return 'latin1'
|
||||
|
||||
|
||||
dialect = MySQLDialect_mysqldb
|
||||
255
sqlalchemy/dialects/mysql/oursql.py
Normal file
255
sqlalchemy/dialects/mysql/oursql.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Support for the MySQL database via the oursql adapter.
|
||||
|
||||
OurSQL is available at:
|
||||
|
||||
http://packages.python.org/oursql/
|
||||
|
||||
Connecting
|
||||
-----------
|
||||
|
||||
Connect string format::
|
||||
|
||||
mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
|
||||
Character Sets
|
||||
--------------
|
||||
|
||||
oursql defaults to using ``utf8`` as the connection charset, but other
|
||||
encodings may be used instead. Like the MySQL-Python driver, unicode support
|
||||
can be completely disabled::
|
||||
|
||||
# oursql sets the connection charset to utf8 automatically; all strings come
|
||||
# back as utf8 str
|
||||
create_engine('mysql+oursql:///mydb?use_unicode=0')
|
||||
|
||||
To not automatically use ``utf8`` and instead use whatever the connection
|
||||
defaults to, there is a separate parameter::
|
||||
|
||||
# use the default connection charset; all strings come back as unicode
|
||||
create_engine('mysql+oursql:///mydb?default_charset=1')
|
||||
|
||||
# use latin1 as the connection charset; all strings come back as unicode
|
||||
create_engine('mysql+oursql:///mydb?charset=latin1')
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import (BIT, MySQLDialect, MySQLExecutionContext,
|
||||
MySQLCompiler, MySQLIdentifierPreparer)
|
||||
from sqlalchemy.engine import base as engine_base, default
|
||||
from sqlalchemy.sql import operators as sql_operators
|
||||
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
|
||||
from sqlalchemy import processors
|
||||
|
||||
|
||||
|
||||
class _oursqlBIT(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""oursql already converts mysql bits, so."""
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MySQLExecutionContext_oursql(MySQLExecutionContext):
|
||||
|
||||
@property
|
||||
def plain_query(self):
|
||||
return self.execution_options.get('_oursql_plain_query', False)
|
||||
|
||||
class MySQLDialect_oursql(MySQLDialect):
|
||||
driver = 'oursql'
|
||||
# Py3K
|
||||
# description_encoding = None
|
||||
# Py2K
|
||||
supports_unicode_binds = True
|
||||
supports_unicode_statements = True
|
||||
# end Py2K
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
execution_ctx_cls = MySQLExecutionContext_oursql
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MySQLDialect.colspecs,
|
||||
{
|
||||
sqltypes.Time: sqltypes.Time,
|
||||
BIT: _oursqlBIT,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('oursql')
|
||||
|
||||
def do_execute(self, cursor, statement, parameters, context=None):
|
||||
"""Provide an implementation of *cursor.execute(statement, parameters)*."""
|
||||
|
||||
if context and context.plain_query:
|
||||
cursor.execute(statement, plain_query=True)
|
||||
else:
|
||||
cursor.execute(statement, parameters)
|
||||
|
||||
def do_begin(self, connection):
|
||||
connection.cursor().execute('BEGIN', plain_query=True)
|
||||
|
||||
def _xa_query(self, connection, query, xid):
|
||||
# Py2K
|
||||
arg = connection.connection._escape_string(xid)
|
||||
# end Py2K
|
||||
# Py3K
|
||||
# charset = self._connection_charset
|
||||
# arg = connection.connection._escape_string(xid.encode(charset)).decode(charset)
|
||||
connection.execution_options(_oursql_plain_query=True).execute(query % arg)
|
||||
|
||||
# Because mysql is bad, these methods have to be
|
||||
# reimplemented to use _PlainQuery. Basically, some queries
|
||||
# refuse to return any data if they're run through
|
||||
# the parameterized query API, or refuse to be parameterized
|
||||
# in the first place.
|
||||
def do_begin_twophase(self, connection, xid):
|
||||
self._xa_query(connection, 'XA BEGIN "%s"', xid)
|
||||
|
||||
def do_prepare_twophase(self, connection, xid):
|
||||
self._xa_query(connection, 'XA END "%s"', xid)
|
||||
self._xa_query(connection, 'XA PREPARE "%s"', xid)
|
||||
|
||||
def do_rollback_twophase(self, connection, xid, is_prepared=True,
|
||||
recover=False):
|
||||
if not is_prepared:
|
||||
self._xa_query(connection, 'XA END "%s"', xid)
|
||||
self._xa_query(connection, 'XA ROLLBACK "%s"', xid)
|
||||
|
||||
def do_commit_twophase(self, connection, xid, is_prepared=True,
|
||||
recover=False):
|
||||
if not is_prepared:
|
||||
self.do_prepare_twophase(connection, xid)
|
||||
self._xa_query(connection, 'XA COMMIT "%s"', xid)
|
||||
|
||||
# Q: why didn't we need all these "plain_query" overrides earlier ?
|
||||
# am i on a newer/older version of OurSQL ?
|
||||
def has_table(self, connection, table_name, schema=None):
|
||||
return MySQLDialect.has_table(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
table_name, schema)
|
||||
|
||||
def get_table_options(self, connection, table_name, schema=None, **kw):
|
||||
return MySQLDialect.get_table_options(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
table_name,
|
||||
schema = schema,
|
||||
**kw
|
||||
)
|
||||
|
||||
|
||||
def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
return MySQLDialect.get_columns(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
table_name,
|
||||
schema=schema,
|
||||
**kw
|
||||
)
|
||||
|
||||
def get_view_names(self, connection, schema=None, **kw):
|
||||
return MySQLDialect.get_view_names(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
schema=schema,
|
||||
**kw
|
||||
)
|
||||
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
return MySQLDialect.get_table_names(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
schema
|
||||
)
|
||||
|
||||
def get_schema_names(self, connection, **kw):
|
||||
return MySQLDialect.get_schema_names(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
**kw
|
||||
)
|
||||
|
||||
def initialize(self, connection):
|
||||
return MySQLDialect.initialize(
|
||||
self,
|
||||
connection.execution_options(_oursql_plain_query=True)
|
||||
)
|
||||
|
||||
def _show_create_table(self, connection, table, charset=None,
|
||||
full_name=None):
|
||||
return MySQLDialect._show_create_table(self,
|
||||
connection.contextual_connect(close_with_result=True).
|
||||
execution_options(_oursql_plain_query=True),
|
||||
table, charset, full_name)
|
||||
|
||||
def is_disconnect(self, e):
|
||||
if isinstance(e, self.dbapi.ProgrammingError):
|
||||
return e.errno is None and 'cursor' not in e.args[1] and e.args[1].endswith('closed')
|
||||
else:
|
||||
return e.errno in (2006, 2013, 2014, 2045, 2055)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(database='db', username='user',
|
||||
password='passwd')
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, 'port', int)
|
||||
util.coerce_kw_type(opts, 'compress', bool)
|
||||
util.coerce_kw_type(opts, 'autoping', bool)
|
||||
|
||||
util.coerce_kw_type(opts, 'default_charset', bool)
|
||||
if opts.pop('default_charset', False):
|
||||
opts['charset'] = None
|
||||
else:
|
||||
util.coerce_kw_type(opts, 'charset', str)
|
||||
opts['use_unicode'] = opts.get('use_unicode', True)
|
||||
util.coerce_kw_type(opts, 'use_unicode', bool)
|
||||
|
||||
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
|
||||
# supports_sane_rowcount.
|
||||
opts.setdefault('found_rows', True)
|
||||
|
||||
return [[], opts]
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
for n in r.split(dbapi_con.server_info):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
try:
|
||||
return exception.orig.errno
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
return connection.connection.charset
|
||||
|
||||
def _compat_fetchall(self, rp, charset=None):
|
||||
"""oursql isn't super-broken like MySQLdb, yaaay."""
|
||||
return rp.fetchall()
|
||||
|
||||
def _compat_fetchone(self, rp, charset=None):
|
||||
"""oursql isn't super-broken like MySQLdb, yaaay."""
|
||||
return rp.fetchone()
|
||||
|
||||
def _compat_first(self, rp, charset=None):
|
||||
return rp.first()
|
||||
|
||||
|
||||
dialect = MySQLDialect_oursql
|
||||
76
sqlalchemy/dialects/mysql/pyodbc.py
Normal file
76
sqlalchemy/dialects/mysql/pyodbc.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Support for the MySQL database via the pyodbc adapter.
|
||||
|
||||
pyodbc is available at:
|
||||
|
||||
http://pypi.python.org/pypi/pyodbc/
|
||||
|
||||
Connecting
|
||||
----------
|
||||
|
||||
Connect string::
|
||||
|
||||
mysql+pyodbc://<username>:<password>@<dsnname>
|
||||
|
||||
Limitations
|
||||
-----------
|
||||
|
||||
The mysql-pyodbc dialect is subject to unresolved character encoding issues
|
||||
which exist within the current ODBC drivers available.
|
||||
(see http://code.google.com/p/pyodbc/issues/detail?id=25). Consider usage
|
||||
of OurSQL, MySQLdb, or MySQL-connector/Python.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
|
||||
from sqlalchemy.connectors.pyodbc import PyODBCConnector
|
||||
from sqlalchemy.engine import base as engine_base
|
||||
from sqlalchemy import util
|
||||
import re
|
||||
|
||||
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
|
||||
|
||||
def get_lastrowid(self):
|
||||
cursor = self.create_cursor()
|
||||
cursor.execute("SELECT LAST_INSERT_ID()")
|
||||
lastrowid = cursor.fetchone()[0]
|
||||
cursor.close()
|
||||
return lastrowid
|
||||
|
||||
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
|
||||
supports_unicode_statements = False
|
||||
execution_ctx_cls = MySQLExecutionContext_pyodbc
|
||||
|
||||
pyodbc_driver_name = "MySQL"
|
||||
|
||||
def __init__(self, **kw):
|
||||
# deal with http://code.google.com/p/pyodbc/issues/detail?id=25
|
||||
kw.setdefault('convert_unicode', True)
|
||||
super(MySQLDialect_pyodbc, self).__init__(**kw)
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
# Prefer 'character_set_results' for the current connection over the
|
||||
# value in the driver. SET NAMES or individual variable SETs will
|
||||
# change the charset without updating the driver's view of the world.
|
||||
#
|
||||
# If it's decided that issuing that sort of SQL leaves you SOL, then
|
||||
# this can prefer the driver value.
|
||||
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
|
||||
opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
|
||||
for key in ('character_set_connection', 'character_set'):
|
||||
if opts.get(key, None):
|
||||
return opts[key]
|
||||
|
||||
util.warn("Could not detect the connection character set. Assuming latin1.")
|
||||
return 'latin1'
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
m = re.compile(r"\((\d+)\)").search(str(exception.orig.args))
|
||||
c = m.group(1)
|
||||
if c:
|
||||
return int(c)
|
||||
else:
|
||||
return None
|
||||
|
||||
dialect = MySQLDialect_pyodbc
|
||||
111
sqlalchemy/dialects/mysql/zxjdbc.py
Normal file
111
sqlalchemy/dialects/mysql/zxjdbc.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Support for the MySQL database via Jython's zxjdbc JDBC connector.
|
||||
|
||||
JDBC Driver
|
||||
-----------
|
||||
|
||||
The official MySQL JDBC driver is at
|
||||
http://dev.mysql.com/downloads/connector/j/.
|
||||
|
||||
Connecting
|
||||
----------
|
||||
|
||||
Connect string format:
|
||||
|
||||
mysql+zxjdbc://<user>:<password>@<hostname>[:<port>]/<database>
|
||||
|
||||
Character Sets
|
||||
--------------
|
||||
|
||||
SQLAlchemy zxjdbc dialects pass unicode straight through to the
|
||||
zxjdbc/JDBC layer. To allow multiple character sets to be sent from the
|
||||
MySQL Connector/J JDBC driver, by default SQLAlchemy sets its
|
||||
``characterEncoding`` connection property to ``UTF-8``. It may be
|
||||
overriden via a ``create_engine`` URL parameter.
|
||||
|
||||
"""
|
||||
import re
|
||||
|
||||
from sqlalchemy import types as sqltypes, util
|
||||
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
|
||||
from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext
|
||||
|
||||
class _ZxJDBCBit(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Converts boolean or byte arrays from MySQL Connector/J to longs."""
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
v = 0L
|
||||
for i in value:
|
||||
v = v << 8 | (i & 0xff)
|
||||
value = v
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class MySQLExecutionContext_zxjdbc(MySQLExecutionContext):
|
||||
def get_lastrowid(self):
|
||||
cursor = self.create_cursor()
|
||||
cursor.execute("SELECT LAST_INSERT_ID()")
|
||||
lastrowid = cursor.fetchone()[0]
|
||||
cursor.close()
|
||||
return lastrowid
|
||||
|
||||
|
||||
class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
|
||||
jdbc_db_name = 'mysql'
|
||||
jdbc_driver_name = 'com.mysql.jdbc.Driver'
|
||||
|
||||
execution_ctx_cls = MySQLExecutionContext_zxjdbc
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MySQLDialect.colspecs,
|
||||
{
|
||||
sqltypes.Time: sqltypes.Time,
|
||||
BIT: _ZxJDBCBit
|
||||
}
|
||||
)
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
# Prefer 'character_set_results' for the current connection over the
|
||||
# value in the driver. SET NAMES or individual variable SETs will
|
||||
# change the charset without updating the driver's view of the world.
|
||||
#
|
||||
# If it's decided that issuing that sort of SQL leaves you SOL, then
|
||||
# this can prefer the driver value.
|
||||
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
|
||||
opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs))
|
||||
for key in ('character_set_connection', 'character_set'):
|
||||
if opts.get(key, None):
|
||||
return opts[key]
|
||||
|
||||
util.warn("Could not detect the connection character set. Assuming latin1.")
|
||||
return 'latin1'
|
||||
|
||||
def _driver_kwargs(self):
|
||||
"""return kw arg dict to be sent to connect()."""
|
||||
return dict(characterEncoding='UTF-8', yearIsDateType='false')
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
# e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist
|
||||
# [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' ()
|
||||
m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.orig.args))
|
||||
c = m.group(1)
|
||||
if c:
|
||||
return int(c)
|
||||
|
||||
def _get_server_version_info(self,connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
for n in r.split(dbapi_con.dbversion):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
|
||||
dialect = MySQLDialect_zxjdbc
|
||||
17
sqlalchemy/dialects/oracle/__init__.py
Normal file
17
sqlalchemy/dialects/oracle/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc
|
||||
|
||||
base.dialect = cx_oracle.dialect
|
||||
|
||||
from sqlalchemy.dialects.oracle.base import \
|
||||
VARCHAR, NVARCHAR, CHAR, DATE, DATETIME, NUMBER,\
|
||||
BLOB, BFILE, CLOB, NCLOB, TIMESTAMP, RAW,\
|
||||
FLOAT, DOUBLE_PRECISION, LONG, dialect, INTERVAL,\
|
||||
VARCHAR2, NVARCHAR2
|
||||
|
||||
|
||||
__all__ = (
|
||||
'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'DATETIME', 'NUMBER',
|
||||
'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW',
|
||||
'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL',
|
||||
'VARCHAR2', 'NVARCHAR2'
|
||||
)
|
||||
1030
sqlalchemy/dialects/oracle/base.py
Normal file
1030
sqlalchemy/dialects/oracle/base.py
Normal file
File diff suppressed because it is too large
Load Diff
529
sqlalchemy/dialects/oracle/cx_oracle.py
Normal file
529
sqlalchemy/dialects/oracle/cx_oracle.py
Normal file
@@ -0,0 +1,529 @@
|
||||
"""Support for the Oracle database via the cx_oracle driver.
|
||||
|
||||
Driver
|
||||
------
|
||||
|
||||
The Oracle dialect uses the cx_oracle driver, available at
|
||||
http://cx-oracle.sourceforge.net/ . The dialect has several behaviors
|
||||
which are specifically tailored towards compatibility with this module.
|
||||
|
||||
Connecting
|
||||
----------
|
||||
|
||||
Connecting with create_engine() uses the standard URL approach of
|
||||
``oracle://user:pass@host:port/dbname[?key=value&key=value...]``. If dbname is present, the
|
||||
host, port, and dbname tokens are converted to a TNS name using the cx_oracle
|
||||
:func:`makedsn()` function. Otherwise, the host token is taken directly as a TNS name.
|
||||
|
||||
Additional arguments which may be specified either as query string arguments on the
|
||||
URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are:
|
||||
|
||||
* *allow_twophase* - enable two-phase transactions. Defaults to ``True``.
|
||||
|
||||
* *arraysize* - set the cx_oracle.arraysize value on cursors, in SQLAlchemy
|
||||
it defaults to 50. See the section on "LOB Objects" below.
|
||||
|
||||
* *auto_convert_lobs* - defaults to True, see the section on LOB objects.
|
||||
|
||||
* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters.
|
||||
This is required for LOB datatypes but can be disabled to reduce overhead. Defaults
|
||||
to ``True``.
|
||||
|
||||
* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an
|
||||
integer value. This value is only available as a URL query string argument.
|
||||
|
||||
* *threaded* - enable multithreaded access to cx_oracle connections. Defaults
|
||||
to ``True``. Note that this is the opposite default of cx_oracle itself.
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
As of cx_oracle 5, Python unicode objects can be bound directly to statements,
|
||||
and it appears that cx_oracle can handle these even without NLS_LANG being set.
|
||||
SQLAlchemy tests for version 5 and will pass unicode objects straight to cx_oracle
|
||||
if this is the case. For older versions of cx_oracle, SQLAlchemy will encode bind
|
||||
parameters normally using dialect.encoding as the encoding.
|
||||
|
||||
LOB Objects
|
||||
-----------
|
||||
|
||||
cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set
|
||||
is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default,
|
||||
SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First,
|
||||
the LOB object requires an active cursor association, meaning if you were to fetch many rows
|
||||
at once such that cx_oracle had to go back to the database and fetch a new batch of rows,
|
||||
the LOB objects in the already-fetched rows are now unreadable and will raise an error.
|
||||
SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read.
|
||||
The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy
|
||||
defaults to 50 (cx_oracle normally defaults this to one).
|
||||
|
||||
Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to
|
||||
"normalize" the results to look more like that of other DBAPIs.
|
||||
|
||||
The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place
|
||||
for all statement executions, even plain string-based statements for which SQLA has no awareness
|
||||
of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases
|
||||
without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch"
|
||||
of LOB objects, can be disabled using auto_convert_lobs=False.
|
||||
|
||||
Two Phase Transaction Support
|
||||
-----------------------------
|
||||
|
||||
Two Phase transactions are implemented using XA transactions. Success has been reported
|
||||
with this feature but it should be regarded as experimental.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, \
|
||||
RESERVED_WORDS, OracleExecutionContext
|
||||
from sqlalchemy.dialects.oracle import base as oracle
|
||||
from sqlalchemy.engine import base
|
||||
from sqlalchemy import types as sqltypes, util, exc
|
||||
from datetime import datetime
|
||||
import random
|
||||
|
||||
class _OracleNumeric(sqltypes.Numeric):
|
||||
# cx_oracle accepts Decimal objects, but returns
|
||||
# floats
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
class _OracleDate(sqltypes.Date):
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return value.date()
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
class _LOBMixin(object):
|
||||
def result_processor(self, dialect, coltype):
|
||||
if not dialect.auto_convert_lobs:
|
||||
# return the cx_oracle.LOB directly.
|
||||
return None
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return value.read()
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
class _NativeUnicodeMixin(object):
|
||||
# Py2K
|
||||
def bind_processor(self, dialect):
|
||||
if dialect._cx_oracle_with_unicode:
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
return unicode(value)
|
||||
return process
|
||||
else:
|
||||
return super(_NativeUnicodeMixin, self).bind_processor(dialect)
|
||||
# end Py2K
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
# if we know cx_Oracle will return unicode,
|
||||
# don't process results
|
||||
if dialect._cx_oracle_with_unicode:
|
||||
return None
|
||||
elif self.convert_unicode != 'force' and \
|
||||
dialect._cx_oracle_native_nvarchar and \
|
||||
coltype in dialect._cx_oracle_unicode_types:
|
||||
return None
|
||||
else:
|
||||
return super(_NativeUnicodeMixin, self).result_processor(dialect, coltype)
|
||||
|
||||
class _OracleChar(_NativeUnicodeMixin, sqltypes.CHAR):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.FIXED_CHAR
|
||||
|
||||
class _OracleNVarChar(_NativeUnicodeMixin, sqltypes.NVARCHAR):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return getattr(dbapi, 'UNICODE', dbapi.STRING)
|
||||
|
||||
class _OracleText(_LOBMixin, sqltypes.Text):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.CLOB
|
||||
|
||||
class _OracleString(_NativeUnicodeMixin, sqltypes.String):
|
||||
pass
|
||||
|
||||
class _OracleUnicodeText(_LOBMixin, _NativeUnicodeMixin, sqltypes.UnicodeText):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.NCLOB
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
lob_processor = _LOBMixin.result_processor(self, dialect, coltype)
|
||||
if lob_processor is None:
|
||||
return None
|
||||
|
||||
string_processor = _NativeUnicodeMixin.result_processor(self, dialect, coltype)
|
||||
|
||||
if string_processor is None:
|
||||
return lob_processor
|
||||
else:
|
||||
def process(value):
|
||||
return string_processor(lob_processor(value))
|
||||
return process
|
||||
|
||||
class _OracleInteger(sqltypes.Integer):
|
||||
def result_processor(self, dialect, coltype):
|
||||
def to_int(val):
|
||||
if val is not None:
|
||||
val = int(val)
|
||||
return val
|
||||
return to_int
|
||||
|
||||
class _OracleBinary(_LOBMixin, sqltypes.LargeBinary):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.BLOB
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
class _OracleInterval(oracle.INTERVAL):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.INTERVAL
|
||||
|
||||
class _OracleRaw(oracle.RAW):
|
||||
pass
|
||||
|
||||
class OracleCompiler_cx_oracle(OracleCompiler):
|
||||
def bindparam_string(self, name):
|
||||
if self.preparer._bindparam_requires_quotes(name):
|
||||
quoted_name = '"%s"' % name
|
||||
self._quoted_bind_names[name] = quoted_name
|
||||
return OracleCompiler.bindparam_string(self, quoted_name)
|
||||
else:
|
||||
return OracleCompiler.bindparam_string(self, name)
|
||||
|
||||
|
||||
class OracleExecutionContext_cx_oracle(OracleExecutionContext):
|
||||
|
||||
def pre_exec(self):
|
||||
quoted_bind_names = \
|
||||
getattr(self.compiled, '_quoted_bind_names', None)
|
||||
if quoted_bind_names:
|
||||
if not self.dialect.supports_unicode_binds:
|
||||
quoted_bind_names = \
|
||||
dict(
|
||||
(fromname, toname.encode(self.dialect.encoding))
|
||||
for fromname, toname in
|
||||
quoted_bind_names.items()
|
||||
)
|
||||
for param in self.parameters:
|
||||
for fromname, toname in quoted_bind_names.items():
|
||||
param[toname] = param[fromname]
|
||||
del param[fromname]
|
||||
|
||||
if self.dialect.auto_setinputsizes:
|
||||
# cx_oracle really has issues when you setinputsizes
|
||||
# on String, including that outparams/RETURNING
|
||||
# breaks for varchars
|
||||
self.set_input_sizes(quoted_bind_names,
|
||||
exclude_types=self.dialect._cx_oracle_string_types
|
||||
)
|
||||
|
||||
# if a single execute, check for outparams
|
||||
if len(self.compiled_parameters) == 1:
|
||||
for bindparam in self.compiled.binds.values():
|
||||
if bindparam.isoutparam:
|
||||
dbtype = bindparam.type.dialect_impl(self.dialect).\
|
||||
get_dbapi_type(self.dialect.dbapi)
|
||||
if not hasattr(self, 'out_parameters'):
|
||||
self.out_parameters = {}
|
||||
if dbtype is None:
|
||||
raise exc.InvalidRequestError("Cannot create out parameter for parameter "
|
||||
"%r - it's type %r is not supported by"
|
||||
" cx_oracle" %
|
||||
(name, bindparam.type)
|
||||
)
|
||||
name = self.compiled.bind_names[bindparam]
|
||||
self.out_parameters[name] = self.cursor.var(dbtype)
|
||||
self.parameters[0][quoted_bind_names.get(name, name)] = \
|
||||
self.out_parameters[name]
|
||||
|
||||
def create_cursor(self):
|
||||
c = self._connection.connection.cursor()
|
||||
if self.dialect.arraysize:
|
||||
c.arraysize = self.dialect.arraysize
|
||||
return c
|
||||
|
||||
def get_result_proxy(self):
|
||||
if hasattr(self, 'out_parameters') and self.compiled.returning:
|
||||
returning_params = dict(
|
||||
(k, v.getvalue())
|
||||
for k, v in self.out_parameters.items()
|
||||
)
|
||||
return ReturningResultProxy(self, returning_params)
|
||||
|
||||
result = None
|
||||
if self.cursor.description is not None:
|
||||
for column in self.cursor.description:
|
||||
type_code = column[1]
|
||||
if type_code in self.dialect._cx_oracle_binary_types:
|
||||
result = base.BufferedColumnResultProxy(self)
|
||||
|
||||
if result is None:
|
||||
result = base.ResultProxy(self)
|
||||
|
||||
if hasattr(self, 'out_parameters'):
|
||||
if self.compiled_parameters is not None and \
|
||||
len(self.compiled_parameters) == 1:
|
||||
result.out_parameters = out_parameters = {}
|
||||
|
||||
for bind, name in self.compiled.bind_names.items():
|
||||
if name in self.out_parameters:
|
||||
type = bind.type
|
||||
impl_type = type.dialect_impl(self.dialect)
|
||||
dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi)
|
||||
result_processor = impl_type.\
|
||||
result_processor(self.dialect,
|
||||
dbapi_type)
|
||||
if result_processor is not None:
|
||||
out_parameters[name] = \
|
||||
result_processor(self.out_parameters[name].getvalue())
|
||||
else:
|
||||
out_parameters[name] = self.out_parameters[name].getvalue()
|
||||
else:
|
||||
result.out_parameters = dict(
|
||||
(k, v.getvalue())
|
||||
for k, v in self.out_parameters.items()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
class OracleExecutionContext_cx_oracle_with_unicode(OracleExecutionContext_cx_oracle):
|
||||
"""Support WITH_UNICODE in Python 2.xx.
|
||||
|
||||
WITH_UNICODE allows cx_Oracle's Python 3 unicode handling
|
||||
behavior under Python 2.x. This mode in some cases disallows
|
||||
and in other cases silently passes corrupted data when
|
||||
non-Python-unicode strings (a.k.a. plain old Python strings)
|
||||
are passed as arguments to connect(), the statement sent to execute(),
|
||||
or any of the bind parameter keys or values sent to execute().
|
||||
This optional context therefore ensures that all statements are
|
||||
passed as Python unicode objects.
|
||||
|
||||
"""
|
||||
def __init__(self, *arg, **kw):
|
||||
OracleExecutionContext_cx_oracle.__init__(self, *arg, **kw)
|
||||
self.statement = unicode(self.statement)
|
||||
|
||||
def _execute_scalar(self, stmt):
|
||||
return super(OracleExecutionContext_cx_oracle_with_unicode, self).\
|
||||
_execute_scalar(unicode(stmt))
|
||||
|
||||
class ReturningResultProxy(base.FullyBufferedResultProxy):
|
||||
"""Result proxy which stuffs the _returning clause + outparams into the fetch."""
|
||||
|
||||
def __init__(self, context, returning_params):
|
||||
self._returning_params = returning_params
|
||||
super(ReturningResultProxy, self).__init__(context)
|
||||
|
||||
def _cursor_description(self):
|
||||
returning = self.context.compiled.returning
|
||||
|
||||
ret = []
|
||||
for c in returning:
|
||||
if hasattr(c, 'name'):
|
||||
ret.append((c.name, c.type))
|
||||
else:
|
||||
ret.append((c.anon_label, c.type))
|
||||
return ret
|
||||
|
||||
def _buffer_rows(self):
|
||||
return [tuple(self._returning_params["ret_%d" % i]
|
||||
for i, c in enumerate(self._returning_params))]
|
||||
|
||||
class OracleDialect_cx_oracle(OracleDialect):
|
||||
execution_ctx_cls = OracleExecutionContext_cx_oracle
|
||||
statement_compiler = OracleCompiler_cx_oracle
|
||||
driver = "cx_oracle"
|
||||
|
||||
colspecs = colspecs = {
|
||||
sqltypes.Numeric: _OracleNumeric,
|
||||
sqltypes.Date : _OracleDate, # generic type, assume datetime.date is desired
|
||||
oracle.DATE: oracle.DATE, # non generic type - passthru
|
||||
sqltypes.LargeBinary : _OracleBinary,
|
||||
sqltypes.Boolean : oracle._OracleBoolean,
|
||||
sqltypes.Interval : _OracleInterval,
|
||||
oracle.INTERVAL : _OracleInterval,
|
||||
sqltypes.Text : _OracleText,
|
||||
sqltypes.String : _OracleString,
|
||||
sqltypes.UnicodeText : _OracleUnicodeText,
|
||||
sqltypes.CHAR : _OracleChar,
|
||||
sqltypes.Integer : _OracleInteger, # this is only needed for OUT parameters.
|
||||
# it would be nice if we could not use it otherwise.
|
||||
oracle.NUMBER : oracle.NUMBER, # don't let this get converted
|
||||
oracle.RAW: _OracleRaw,
|
||||
sqltypes.Unicode: _OracleNVarChar,
|
||||
sqltypes.NVARCHAR : _OracleNVarChar,
|
||||
}
|
||||
|
||||
|
||||
execute_sequence_format = list
|
||||
|
||||
def __init__(self,
|
||||
auto_setinputsizes=True,
|
||||
auto_convert_lobs=True,
|
||||
threaded=True,
|
||||
allow_twophase=True,
|
||||
arraysize=50, **kwargs):
|
||||
OracleDialect.__init__(self, **kwargs)
|
||||
self.threaded = threaded
|
||||
self.arraysize = arraysize
|
||||
self.allow_twophase = allow_twophase
|
||||
self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
|
||||
self.auto_setinputsizes = auto_setinputsizes
|
||||
self.auto_convert_lobs = auto_convert_lobs
|
||||
|
||||
if hasattr(self.dbapi, 'version'):
|
||||
cx_oracle_ver = tuple([int(x) for x in self.dbapi.version.split('.')])
|
||||
else:
|
||||
cx_oracle_ver = (0, 0, 0)
|
||||
|
||||
def types(*names):
|
||||
return set([
|
||||
getattr(self.dbapi, name, None) for name in names
|
||||
]).difference([None])
|
||||
|
||||
self._cx_oracle_string_types = types("STRING", "UNICODE", "NCLOB", "CLOB")
|
||||
self._cx_oracle_unicode_types = types("UNICODE", "NCLOB")
|
||||
self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB")
|
||||
self.supports_unicode_binds = cx_oracle_ver >= (5, 0)
|
||||
self._cx_oracle_native_nvarchar = cx_oracle_ver >= (5, 0)
|
||||
|
||||
if cx_oracle_ver is None:
|
||||
# this occurs in tests with mock DBAPIs
|
||||
self._cx_oracle_string_types = set()
|
||||
self._cx_oracle_with_unicode = False
|
||||
elif cx_oracle_ver >= (5,) and not hasattr(self.dbapi, 'UNICODE'):
|
||||
# cx_Oracle WITH_UNICODE mode. *only* python
|
||||
# unicode objects accepted for anything
|
||||
self.supports_unicode_statements = True
|
||||
self.supports_unicode_binds = True
|
||||
self._cx_oracle_with_unicode = True
|
||||
# Py2K
|
||||
# There's really no reason to run with WITH_UNICODE under Python 2.x.
|
||||
# Give the user a hint.
|
||||
util.warn("cx_Oracle is compiled under Python 2.xx using the "
|
||||
"WITH_UNICODE flag. Consider recompiling cx_Oracle without "
|
||||
"this flag, which is in no way necessary for full support of Unicode. "
|
||||
"Otherwise, all string-holding bind parameters must "
|
||||
"be explicitly typed using SQLAlchemy's String type or one of its subtypes,"
|
||||
"or otherwise be passed as Python unicode. Plain Python strings "
|
||||
"passed as bind parameters will be silently corrupted by cx_Oracle."
|
||||
)
|
||||
self.execution_ctx_cls = OracleExecutionContext_cx_oracle_with_unicode
|
||||
# end Py2K
|
||||
else:
|
||||
self._cx_oracle_with_unicode = False
|
||||
|
||||
if cx_oracle_ver is None or \
|
||||
not self.auto_convert_lobs or \
|
||||
not hasattr(self.dbapi, 'CLOB'):
|
||||
self.dbapi_type_map = {}
|
||||
else:
|
||||
# only use this for LOB objects. using it for strings, dates
|
||||
# etc. leads to a little too much magic, reflection doesn't know if it should
|
||||
# expect encoded strings or unicodes, etc.
|
||||
self.dbapi_type_map = {
|
||||
self.dbapi.CLOB: oracle.CLOB(),
|
||||
self.dbapi.NCLOB:oracle.NCLOB(),
|
||||
self.dbapi.BLOB: oracle.BLOB(),
|
||||
self.dbapi.BINARY: oracle.RAW(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
import cx_Oracle
|
||||
return cx_Oracle
|
||||
|
||||
def create_connect_args(self, url):
|
||||
dialect_opts = dict(url.query)
|
||||
for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
|
||||
'threaded', 'allow_twophase'):
|
||||
if opt in dialect_opts:
|
||||
util.coerce_kw_type(dialect_opts, opt, bool)
|
||||
setattr(self, opt, dialect_opts[opt])
|
||||
|
||||
if url.database:
|
||||
# if we have a database, then we have a remote host
|
||||
port = url.port
|
||||
if port:
|
||||
port = int(port)
|
||||
else:
|
||||
port = 1521
|
||||
dsn = self.dbapi.makedsn(url.host, port, url.database)
|
||||
else:
|
||||
# we have a local tnsname
|
||||
dsn = url.host
|
||||
|
||||
opts = dict(
|
||||
user=url.username,
|
||||
password=url.password,
|
||||
dsn=dsn,
|
||||
threaded=self.threaded,
|
||||
twophase=self.allow_twophase,
|
||||
)
|
||||
|
||||
# Py2K
|
||||
if self._cx_oracle_with_unicode:
|
||||
for k, v in opts.items():
|
||||
if isinstance(v, str):
|
||||
opts[k] = unicode(v)
|
||||
# end Py2K
|
||||
|
||||
if 'mode' in url.query:
|
||||
opts['mode'] = url.query['mode']
|
||||
if isinstance(opts['mode'], basestring):
|
||||
mode = opts['mode'].upper()
|
||||
if mode == 'SYSDBA':
|
||||
opts['mode'] = self.dbapi.SYSDBA
|
||||
elif mode == 'SYSOPER':
|
||||
opts['mode'] = self.dbapi.SYSOPER
|
||||
else:
|
||||
util.coerce_kw_type(opts, 'mode', int)
|
||||
return ([], opts)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
return tuple(int(x) for x in connection.connection.version.split('.'))
|
||||
|
||||
def is_disconnect(self, e):
|
||||
if isinstance(e, self.dbapi.InterfaceError):
|
||||
return "not connected" in str(e)
|
||||
else:
|
||||
return "ORA-03114" in str(e) or "ORA-03113" in str(e)
|
||||
|
||||
def create_xid(self):
|
||||
"""create a two-phase transaction ID.
|
||||
|
||||
this id will be passed to do_begin_twophase(), do_rollback_twophase(),
|
||||
do_commit_twophase(). its format is unspecified."""
|
||||
|
||||
id = random.randint(0, 2 ** 128)
|
||||
return (0x1234, "%032x" % id, "%032x" % 9)
|
||||
|
||||
def do_begin_twophase(self, connection, xid):
|
||||
connection.connection.begin(*xid)
|
||||
|
||||
def do_prepare_twophase(self, connection, xid):
|
||||
connection.connection.prepare()
|
||||
|
||||
def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
|
||||
self.do_rollback(connection.connection)
|
||||
|
||||
def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
|
||||
self.do_commit(connection.connection)
|
||||
|
||||
def do_recover_twophase(self, connection):
|
||||
pass
|
||||
|
||||
dialect = OracleDialect_cx_oracle
|
||||
209
sqlalchemy/dialects/oracle/zxjdbc.py
Normal file
209
sqlalchemy/dialects/oracle/zxjdbc.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Support for the Oracle database via the zxjdbc JDBC connector.
|
||||
|
||||
JDBC Driver
|
||||
-----------
|
||||
|
||||
The official Oracle JDBC driver is at
|
||||
http://www.oracle.com/technology/software/tech/java/sqlj_jdbc/index.html.
|
||||
|
||||
"""
|
||||
import decimal
|
||||
import re
|
||||
|
||||
from sqlalchemy import sql, types as sqltypes, util
|
||||
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
|
||||
from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, OracleExecutionContext
|
||||
from sqlalchemy.engine import base, default
|
||||
from sqlalchemy.sql import expression
|
||||
|
||||
SQLException = zxJDBC = None
|
||||
|
||||
class _ZxJDBCDate(sqltypes.Date):
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return value.date()
|
||||
return process
|
||||
|
||||
|
||||
class _ZxJDBCNumeric(sqltypes.Numeric):
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
#XXX: does the dialect return Decimal or not???
|
||||
# if it does (in all cases), we could use a None processor as well as
|
||||
# the to_float generic processor
|
||||
if self.asdecimal:
|
||||
def process(value):
|
||||
if isinstance(value, decimal.Decimal):
|
||||
return value
|
||||
else:
|
||||
return decimal.Decimal(str(value))
|
||||
else:
|
||||
def process(value):
|
||||
if isinstance(value, decimal.Decimal):
|
||||
return float(value)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class OracleCompiler_zxjdbc(OracleCompiler):
|
||||
|
||||
def returning_clause(self, stmt, returning_cols):
|
||||
self.returning_cols = list(expression._select_iterables(returning_cols))
|
||||
|
||||
# within_columns_clause=False so that labels (foo AS bar) don't render
|
||||
columns = [self.process(c, within_columns_clause=False, result_map=self.result_map)
|
||||
for c in self.returning_cols]
|
||||
|
||||
if not hasattr(self, 'returning_parameters'):
|
||||
self.returning_parameters = []
|
||||
|
||||
binds = []
|
||||
for i, col in enumerate(self.returning_cols):
|
||||
dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
|
||||
self.returning_parameters.append((i + 1, dbtype))
|
||||
|
||||
bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype))
|
||||
self.binds[bindparam.key] = bindparam
|
||||
binds.append(self.bindparam_string(self._truncate_bindparam(bindparam)))
|
||||
|
||||
return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
|
||||
|
||||
|
||||
class OracleExecutionContext_zxjdbc(OracleExecutionContext):
|
||||
|
||||
def pre_exec(self):
|
||||
if hasattr(self.compiled, 'returning_parameters'):
|
||||
# prepare a zxJDBC statement so we can grab its underlying
|
||||
# OraclePreparedStatement's getReturnResultSet later
|
||||
self.statement = self.cursor.prepare(self.statement)
|
||||
|
||||
def get_result_proxy(self):
|
||||
if hasattr(self.compiled, 'returning_parameters'):
|
||||
rrs = None
|
||||
try:
|
||||
try:
|
||||
rrs = self.statement.__statement__.getReturnResultSet()
|
||||
rrs.next()
|
||||
except SQLException, sqle:
|
||||
msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode())
|
||||
if sqle.getSQLState() is not None:
|
||||
msg += ' [SQLState: %s]' % sqle.getSQLState()
|
||||
raise zxJDBC.Error(msg)
|
||||
else:
|
||||
row = tuple(self.cursor.datahandler.getPyObject(rrs, index, dbtype)
|
||||
for index, dbtype in self.compiled.returning_parameters)
|
||||
return ReturningResultProxy(self, row)
|
||||
finally:
|
||||
if rrs is not None:
|
||||
try:
|
||||
rrs.close()
|
||||
except SQLException:
|
||||
pass
|
||||
self.statement.close()
|
||||
|
||||
return base.ResultProxy(self)
|
||||
|
||||
def create_cursor(self):
|
||||
cursor = self._connection.connection.cursor()
|
||||
cursor.datahandler = self.dialect.DataHandler(cursor.datahandler)
|
||||
return cursor
|
||||
|
||||
|
||||
class ReturningResultProxy(base.FullyBufferedResultProxy):
|
||||
|
||||
"""ResultProxy backed by the RETURNING ResultSet results."""
|
||||
|
||||
def __init__(self, context, returning_row):
|
||||
self._returning_row = returning_row
|
||||
super(ReturningResultProxy, self).__init__(context)
|
||||
|
||||
def _cursor_description(self):
|
||||
ret = []
|
||||
for c in self.context.compiled.returning_cols:
|
||||
if hasattr(c, 'name'):
|
||||
ret.append((c.name, c.type))
|
||||
else:
|
||||
ret.append((c.anon_label, c.type))
|
||||
return ret
|
||||
|
||||
def _buffer_rows(self):
|
||||
return [self._returning_row]
|
||||
|
||||
|
||||
class ReturningParam(object):
|
||||
|
||||
"""A bindparam value representing a RETURNING parameter.
|
||||
|
||||
Specially handled by OracleReturningDataHandler.
|
||||
"""
|
||||
|
||||
def __init__(self, type):
|
||||
self.type = type
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, ReturningParam):
|
||||
return self.type == other.type
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
if isinstance(other, ReturningParam):
|
||||
return self.type != other.type
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self):
|
||||
kls = self.__class__
|
||||
return '<%s.%s object at 0x%x type=%s>' % (kls.__module__, kls.__name__, id(self),
|
||||
self.type)
|
||||
|
||||
|
||||
class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
|
||||
jdbc_db_name = 'oracle'
|
||||
jdbc_driver_name = 'oracle.jdbc.OracleDriver'
|
||||
|
||||
statement_compiler = OracleCompiler_zxjdbc
|
||||
execution_ctx_cls = OracleExecutionContext_zxjdbc
|
||||
|
||||
colspecs = util.update_copy(
|
||||
OracleDialect.colspecs,
|
||||
{
|
||||
sqltypes.Date : _ZxJDBCDate,
|
||||
sqltypes.Numeric: _ZxJDBCNumeric
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(OracleDialect_zxjdbc, self).__init__(*args, **kwargs)
|
||||
global SQLException, zxJDBC
|
||||
from java.sql import SQLException
|
||||
from com.ziclix.python.sql import zxJDBC
|
||||
from com.ziclix.python.sql.handler import OracleDataHandler
|
||||
class OracleReturningDataHandler(OracleDataHandler):
|
||||
|
||||
"""zxJDBC DataHandler that specially handles ReturningParam."""
|
||||
|
||||
def setJDBCObject(self, statement, index, object, dbtype=None):
|
||||
if type(object) is ReturningParam:
|
||||
statement.registerReturnParameter(index, object.type)
|
||||
elif dbtype is None:
|
||||
OracleDataHandler.setJDBCObject(self, statement, index, object)
|
||||
else:
|
||||
OracleDataHandler.setJDBCObject(self, statement, index, object, dbtype)
|
||||
self.DataHandler = OracleReturningDataHandler
|
||||
|
||||
def initialize(self, connection):
|
||||
super(OracleDialect_zxjdbc, self).initialize(connection)
|
||||
self.implicit_returning = connection.connection.driverversion >= '10.2'
|
||||
|
||||
def _create_jdbc_url(self, url):
|
||||
return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
|
||||
return tuple(int(x) for x in version.split('.'))
|
||||
|
||||
dialect = OracleDialect_zxjdbc
|
||||
10
sqlalchemy/dialects/postgres.py
Normal file
10
sqlalchemy/dialects/postgres.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# backwards compat with the old name
|
||||
from sqlalchemy.util import warn_deprecated
|
||||
|
||||
warn_deprecated(
|
||||
"The SQLAlchemy PostgreSQL dialect has been renamed from 'postgres' to 'postgresql'. "
|
||||
"The new URL format is postgresql[+driver]://<user>:<pass>@<host>/<dbname>"
|
||||
)
|
||||
|
||||
from sqlalchemy.dialects.postgresql import *
|
||||
from sqlalchemy.dialects.postgresql import base
|
||||
14
sqlalchemy/dialects/postgresql/__init__.py
Normal file
14
sqlalchemy/dialects/postgresql/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from sqlalchemy.dialects.postgresql import base, psycopg2, pg8000, pypostgresql, zxjdbc
|
||||
|
||||
base.dialect = psycopg2.dialect
|
||||
|
||||
from sqlalchemy.dialects.postgresql.base import \
|
||||
INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, INET, \
|
||||
CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME,\
|
||||
DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect
|
||||
|
||||
__all__ = (
|
||||
'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', 'FLOAT', 'REAL', 'INET',
|
||||
'CIDR', 'UUID', 'BIT', 'MACADDR', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME',
|
||||
'DATE', 'BYTEA', 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect'
|
||||
)
|
||||
1161
sqlalchemy/dialects/postgresql/base.py
Normal file
1161
sqlalchemy/dialects/postgresql/base.py
Normal file
File diff suppressed because it is too large
Load Diff
105
sqlalchemy/dialects/postgresql/pg8000.py
Normal file
105
sqlalchemy/dialects/postgresql/pg8000.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Support for the PostgreSQL database via the pg8000 driver.
|
||||
|
||||
Connecting
|
||||
----------
|
||||
|
||||
URLs are of the form
|
||||
`postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]`.
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
pg8000 requires that the postgresql client encoding be configured in the postgresql.conf file
|
||||
in order to use encodings other than ascii. Set this value to the same value as
|
||||
the "encoding" parameter on create_engine(), usually "utf-8".
|
||||
|
||||
Interval
|
||||
--------
|
||||
|
||||
Passing data from/to the Interval type is not supported as of yet.
|
||||
|
||||
"""
|
||||
import decimal
|
||||
|
||||
from sqlalchemy.engine import default
|
||||
from sqlalchemy import util, exc
|
||||
from sqlalchemy import processors
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.dialects.postgresql.base import PGDialect, \
|
||||
PGCompiler, PGIdentifierPreparer, PGExecutionContext
|
||||
|
||||
class _PGNumeric(sqltypes.Numeric):
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.asdecimal:
|
||||
if coltype in (700, 701):
|
||||
return processors.to_decimal_processor_factory(decimal.Decimal)
|
||||
elif coltype == 1700:
|
||||
# pg8000 returns Decimal natively for 1700
|
||||
return None
|
||||
else:
|
||||
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
|
||||
else:
|
||||
if coltype in (700, 701):
|
||||
# pg8000 returns float natively for 701
|
||||
return None
|
||||
elif coltype == 1700:
|
||||
return processors.to_float
|
||||
else:
|
||||
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
|
||||
|
||||
class PGExecutionContext_pg8000(PGExecutionContext):
|
||||
pass
|
||||
|
||||
|
||||
class PGCompiler_pg8000(PGCompiler):
|
||||
def visit_mod(self, binary, **kw):
|
||||
return self.process(binary.left) + " %% " + self.process(binary.right)
|
||||
|
||||
def post_process_text(self, text):
|
||||
if '%%' in text:
|
||||
util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() "
|
||||
"expressions to '%%'.")
|
||||
return text.replace('%', '%%')
|
||||
|
||||
|
||||
class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value.replace('%', '%%')
|
||||
|
||||
|
||||
class PGDialect_pg8000(PGDialect):
|
||||
driver = 'pg8000'
|
||||
|
||||
supports_unicode_statements = True
|
||||
|
||||
supports_unicode_binds = True
|
||||
|
||||
default_paramstyle = 'format'
|
||||
supports_sane_multi_rowcount = False
|
||||
execution_ctx_cls = PGExecutionContext_pg8000
|
||||
statement_compiler = PGCompiler_pg8000
|
||||
preparer = PGIdentifierPreparer_pg8000
|
||||
|
||||
colspecs = util.update_copy(
|
||||
PGDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric : _PGNumeric,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('pg8000').dbapi
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
if 'port' in opts:
|
||||
opts['port'] = int(opts['port'])
|
||||
opts.update(url.query)
|
||||
return ([], opts)
|
||||
|
||||
def is_disconnect(self, e):
|
||||
return "connection is closed" in str(e)
|
||||
|
||||
dialect = PGDialect_pg8000
|
||||
239
sqlalchemy/dialects/postgresql/psycopg2.py
Normal file
239
sqlalchemy/dialects/postgresql/psycopg2.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Support for the PostgreSQL database via the psycopg2 driver.
|
||||
|
||||
Driver
|
||||
------
|
||||
|
||||
The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
|
||||
The dialect has several behaviors which are specifically tailored towards compatibility
|
||||
with this module.
|
||||
|
||||
Note that psycopg1 is **not** supported.
|
||||
|
||||
Connecting
|
||||
----------
|
||||
|
||||
URLs are of the form `postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]`.
|
||||
|
||||
psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
|
||||
|
||||
* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
|
||||
this feature. What this essentially means from a psycopg2 point of view is that the cursor is
|
||||
created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
|
||||
are not immediately pre-fetched and buffered after statement execution, but are instead left
|
||||
on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
|
||||
uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows
|
||||
at a time are fetched over the wire to reduce conversational overhead.
|
||||
* *use_native_unicode* - Enable the usage of Psycopg2 "native unicode" mode per connection. True
|
||||
by default.
|
||||
* *isolation_level* - Sets the transaction isolation level for each transaction
|
||||
within the engine. Valid isolation levels are `READ_COMMITTED`,
|
||||
`READ_UNCOMMITTED`, `REPEATABLE_READ`, and `SERIALIZABLE`.
|
||||
|
||||
Transactions
|
||||
------------
|
||||
|
||||
The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
|
||||
|
||||
NOTICE logging
|
||||
---------------
|
||||
|
||||
The psycopg2 dialect will log Postgresql NOTICE messages via the
|
||||
``sqlalchemy.dialects.postgresql`` logger::
|
||||
|
||||
import logging
|
||||
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
|
||||
|
||||
|
||||
Per-Statement Execution Options
|
||||
-------------------------------
|
||||
|
||||
The following per-statement execution options are respected:
|
||||
|
||||
* *stream_results* - Enable or disable usage of server side cursors for the SELECT-statement.
|
||||
If *None* or not set, the *server_side_cursors* option of the connection is used. If
|
||||
auto-commit is enabled, the option is ignored.
|
||||
|
||||
"""
|
||||
|
||||
import random
|
||||
import re
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy import processors
|
||||
from sqlalchemy.engine import base, default
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.sql import operators as sql_operators
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, \
|
||||
PGIdentifierPreparer, PGExecutionContext, \
|
||||
ENUM, ARRAY
|
||||
|
||||
|
||||
logger = logging.getLogger('sqlalchemy.dialects.postgresql')
|
||||
|
||||
|
||||
class _PGNumeric(sqltypes.Numeric):
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.asdecimal:
|
||||
if coltype in (700, 701):
|
||||
return processors.to_decimal_processor_factory(decimal.Decimal)
|
||||
elif coltype == 1700:
|
||||
# pg8000 returns Decimal natively for 1700
|
||||
return None
|
||||
else:
|
||||
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
|
||||
else:
|
||||
if coltype in (700, 701):
|
||||
# pg8000 returns float natively for 701
|
||||
return None
|
||||
elif coltype == 1700:
|
||||
return processors.to_float
|
||||
else:
|
||||
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
|
||||
|
||||
class _PGEnum(ENUM):
|
||||
def __init__(self, *arg, **kw):
|
||||
super(_PGEnum, self).__init__(*arg, **kw)
|
||||
if self.convert_unicode:
|
||||
self.convert_unicode = "force"
|
||||
|
||||
class _PGArray(ARRAY):
|
||||
def __init__(self, *arg, **kw):
|
||||
super(_PGArray, self).__init__(*arg, **kw)
|
||||
# FIXME: this check won't work for setups that
|
||||
# have convert_unicode only on their create_engine().
|
||||
if isinstance(self.item_type, sqltypes.String) and \
|
||||
self.item_type.convert_unicode:
|
||||
self.item_type.convert_unicode = "force"
|
||||
|
||||
# When we're handed literal SQL, ensure it's a SELECT-query. Since
|
||||
# 8.3, combining cursors and "FOR UPDATE" has been fine.
|
||||
SERVER_SIDE_CURSOR_RE = re.compile(
|
||||
r'\s*SELECT',
|
||||
re.I | re.UNICODE)
|
||||
|
||||
class PGExecutionContext_psycopg2(PGExecutionContext):
|
||||
def create_cursor(self):
|
||||
# TODO: coverage for server side cursors + select.for_update()
|
||||
|
||||
if self.dialect.server_side_cursors:
|
||||
is_server_side = \
|
||||
self.execution_options.get('stream_results', True) and (
|
||||
(self.compiled and isinstance(self.compiled.statement, expression.Selectable) \
|
||||
or \
|
||||
(
|
||||
(not self.compiled or
|
||||
isinstance(self.compiled.statement, expression._TextClause))
|
||||
and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
|
||||
)
|
||||
)
|
||||
else:
|
||||
is_server_side = self.execution_options.get('stream_results', False)
|
||||
|
||||
self.__is_server_side = is_server_side
|
||||
if is_server_side:
|
||||
# use server-side cursors:
|
||||
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
|
||||
ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
|
||||
return self._connection.connection.cursor(ident)
|
||||
else:
|
||||
return self._connection.connection.cursor()
|
||||
|
||||
def get_result_proxy(self):
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
self._log_notices(self.cursor)
|
||||
|
||||
if self.__is_server_side:
|
||||
return base.BufferedRowResultProxy(self)
|
||||
else:
|
||||
return base.ResultProxy(self)
|
||||
|
||||
def _log_notices(self, cursor):
|
||||
for notice in cursor.connection.notices:
|
||||
# NOTICE messages have a
|
||||
# newline character at the end
|
||||
logger.info(notice.rstrip())
|
||||
|
||||
cursor.connection.notices[:] = []
|
||||
|
||||
|
||||
class PGCompiler_psycopg2(PGCompiler):
|
||||
def visit_mod(self, binary, **kw):
|
||||
return self.process(binary.left) + " %% " + self.process(binary.right)
|
||||
|
||||
def post_process_text(self, text):
|
||||
return text.replace('%', '%%')
|
||||
|
||||
|
||||
class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value.replace('%', '%%')
|
||||
|
||||
class PGDialect_psycopg2(PGDialect):
|
||||
driver = 'psycopg2'
|
||||
supports_unicode_statements = False
|
||||
default_paramstyle = 'pyformat'
|
||||
supports_sane_multi_rowcount = False
|
||||
execution_ctx_cls = PGExecutionContext_psycopg2
|
||||
statement_compiler = PGCompiler_psycopg2
|
||||
preparer = PGIdentifierPreparer_psycopg2
|
||||
|
||||
colspecs = util.update_copy(
|
||||
PGDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric : _PGNumeric,
|
||||
ENUM : _PGEnum, # needs force_unicode
|
||||
sqltypes.Enum : _PGEnum, # needs force_unicode
|
||||
ARRAY : _PGArray, # needs force_unicode
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, server_side_cursors=False, use_native_unicode=True, **kwargs):
|
||||
PGDialect.__init__(self, **kwargs)
|
||||
self.server_side_cursors = server_side_cursors
|
||||
self.use_native_unicode = use_native_unicode
|
||||
self.supports_unicode_binds = use_native_unicode
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
psycopg = __import__('psycopg2')
|
||||
return psycopg
|
||||
|
||||
def on_connect(self):
|
||||
base_on_connect = super(PGDialect_psycopg2, self).on_connect()
|
||||
if self.dbapi and self.use_native_unicode:
|
||||
extensions = __import__('psycopg2.extensions').extensions
|
||||
def connect(conn):
|
||||
extensions.register_type(extensions.UNICODE, conn)
|
||||
if base_on_connect:
|
||||
base_on_connect(conn)
|
||||
return connect
|
||||
else:
|
||||
return base_on_connect
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
if 'port' in opts:
|
||||
opts['port'] = int(opts['port'])
|
||||
opts.update(url.query)
|
||||
return ([], opts)
|
||||
|
||||
def is_disconnect(self, e):
|
||||
if isinstance(e, self.dbapi.OperationalError):
|
||||
return 'closed the connection' in str(e) or 'connection not open' in str(e)
|
||||
elif isinstance(e, self.dbapi.InterfaceError):
|
||||
return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
|
||||
elif isinstance(e, self.dbapi.ProgrammingError):
|
||||
# yes, it really says "losed", not "closed"
|
||||
return "losed the connection unexpectedly" in str(e)
|
||||
else:
|
||||
return False
|
||||
|
||||
dialect = PGDialect_psycopg2
|
||||
|
||||
69
sqlalchemy/dialects/postgresql/pypostgresql.py
Normal file
69
sqlalchemy/dialects/postgresql/pypostgresql.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Support for the PostgreSQL database via py-postgresql.
|
||||
|
||||
Connecting
|
||||
----------
|
||||
|
||||
URLs are of the form `postgresql+pypostgresql://user@password@host:port/dbname[?key=value&key=value...]`.
|
||||
|
||||
|
||||
"""
|
||||
from sqlalchemy.engine import default
|
||||
import decimal
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext
|
||||
from sqlalchemy import processors
|
||||
|
||||
class PGNumeric(sqltypes.Numeric):
|
||||
def bind_processor(self, dialect):
|
||||
return processors.to_str
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.asdecimal:
|
||||
return None
|
||||
else:
|
||||
return processors.to_float
|
||||
|
||||
class PGExecutionContext_pypostgresql(PGExecutionContext):
|
||||
pass
|
||||
|
||||
class PGDialect_pypostgresql(PGDialect):
|
||||
driver = 'pypostgresql'
|
||||
|
||||
supports_unicode_statements = True
|
||||
supports_unicode_binds = True
|
||||
description_encoding = None
|
||||
default_paramstyle = 'pyformat'
|
||||
|
||||
# requires trunk version to support sane rowcounts
|
||||
# TODO: use dbapi version information to set this flag appropariately
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = False
|
||||
|
||||
execution_ctx_cls = PGExecutionContext_pypostgresql
|
||||
colspecs = util.update_copy(
|
||||
PGDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric : PGNumeric,
|
||||
sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
from postgresql.driver import dbapi20
|
||||
return dbapi20
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
if 'port' in opts:
|
||||
opts['port'] = int(opts['port'])
|
||||
else:
|
||||
opts['port'] = 5432
|
||||
opts.update(url.query)
|
||||
return ([], opts)
|
||||
|
||||
def is_disconnect(self, e):
|
||||
return "connection is closed" in str(e)
|
||||
|
||||
dialect = PGDialect_pypostgresql
|
||||
19
sqlalchemy/dialects/postgresql/zxjdbc.py
Normal file
19
sqlalchemy/dialects/postgresql/zxjdbc.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Support for the PostgreSQL database via the zxjdbc JDBC connector.
|
||||
|
||||
JDBC Driver
|
||||
-----------
|
||||
|
||||
The official Postgresql JDBC driver is at http://jdbc.postgresql.org/.
|
||||
|
||||
"""
|
||||
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
|
||||
from sqlalchemy.dialects.postgresql.base import PGDialect
|
||||
|
||||
class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect):
|
||||
jdbc_db_name = 'postgresql'
|
||||
jdbc_driver_name = 'org.postgresql.Driver'
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
return tuple(int(x) for x in connection.connection.dbversion.split('.'))
|
||||
|
||||
dialect = PGDialect_zxjdbc
|
||||
14
sqlalchemy/dialects/sqlite/__init__.py
Normal file
14
sqlalchemy/dialects/sqlite/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from sqlalchemy.dialects.sqlite import base, pysqlite
|
||||
|
||||
# default dialect
|
||||
base.dialect = pysqlite.dialect
|
||||
|
||||
|
||||
from sqlalchemy.dialects.sqlite.base import \
|
||||
BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, FLOAT, INTEGER,\
|
||||
NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, dialect
|
||||
|
||||
__all__ = (
|
||||
'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'FLOAT', 'INTEGER',
|
||||
'NUMERIC', 'SMALLINT', 'TEXT', 'TIME', 'TIMESTAMP', 'VARCHAR', 'dialect'
|
||||
)
|
||||
596
sqlalchemy/dialects/sqlite/base.py
Normal file
596
sqlalchemy/dialects/sqlite/base.py
Normal file
@@ -0,0 +1,596 @@
|
||||
# sqlite.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""Support for the SQLite database.
|
||||
|
||||
For information on connecting using a specific driver, see the documentation
|
||||
section regarding that driver.
|
||||
|
||||
Date and Time Types
|
||||
-------------------
|
||||
|
||||
SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide
|
||||
out of the box functionality for translating values between Python `datetime` objects
|
||||
and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime`
|
||||
and related types provide date formatting and parsing functionality when SQlite is used.
|
||||
The implementation classes are :class:`DATETIME`, :class:`DATE` and :class:`TIME`.
|
||||
These types represent dates and times as ISO formatted strings, which also nicely
|
||||
support ordering. There's no reliance on typical "libc" internals for these functions
|
||||
so historical dates are fully supported.
|
||||
|
||||
Auto Incrementing Behavior
|
||||
--------------------------
|
||||
|
||||
Background on SQLite's autoincrement is at: http://sqlite.org/autoinc.html
|
||||
|
||||
Two things to note:
|
||||
|
||||
* The AUTOINCREMENT keyword is **not** required for SQLite tables to
|
||||
generate primary key values automatically. AUTOINCREMENT only means that
|
||||
the algorithm used to generate ROWID values should be slightly different.
|
||||
* SQLite does **not** generate primary key (i.e. ROWID) values, even for
|
||||
one column, if the table has a composite (i.e. multi-column) primary key.
|
||||
This is regardless of the AUTOINCREMENT keyword being present or not.
|
||||
|
||||
To specifically render the AUTOINCREMENT keyword on the primary key
|
||||
column when rendering DDL, add the flag ``sqlite_autoincrement=True``
|
||||
to the Table construct::
|
||||
|
||||
Table('sometable', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
sqlite_autoincrement=True)
|
||||
|
||||
"""
|
||||
|
||||
import datetime, re, time
|
||||
|
||||
from sqlalchemy import schema as sa_schema
|
||||
from sqlalchemy import sql, exc, pool, DefaultClause
|
||||
from sqlalchemy.engine import default
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.sql import compiler, functions as sql_functions
|
||||
from sqlalchemy.util import NoneType
|
||||
from sqlalchemy import processors
|
||||
|
||||
from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\
|
||||
FLOAT, INTEGER, NUMERIC, SMALLINT, TEXT, TIME,\
|
||||
TIMESTAMP, VARCHAR
|
||||
|
||||
|
||||
class _DateTimeMixin(object):
|
||||
_reg = None
|
||||
_storage_format = None
|
||||
|
||||
def __init__(self, storage_format=None, regexp=None, **kwargs):
|
||||
if regexp is not None:
|
||||
self._reg = re.compile(regexp)
|
||||
if storage_format is not None:
|
||||
self._storage_format = storage_format
|
||||
|
||||
class DATETIME(_DateTimeMixin, sqltypes.DateTime):
|
||||
_storage_format = "%04d-%02d-%02d %02d:%02d:%02d.%06d"
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
datetime_datetime = datetime.datetime
|
||||
datetime_date = datetime.date
|
||||
format = self._storage_format
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, datetime_datetime):
|
||||
return format % (value.year, value.month, value.day,
|
||||
value.hour, value.minute, value.second,
|
||||
value.microsecond)
|
||||
elif isinstance(value, datetime_date):
|
||||
return format % (value.year, value.month, value.day,
|
||||
0, 0, 0, 0)
|
||||
else:
|
||||
raise TypeError("SQLite DateTime type only accepts Python "
|
||||
"datetime and date objects as input.")
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self._reg:
|
||||
return processors.str_to_datetime_processor_factory(
|
||||
self._reg, datetime.datetime)
|
||||
else:
|
||||
return processors.str_to_datetime
|
||||
|
||||
class DATE(_DateTimeMixin, sqltypes.Date):
|
||||
_storage_format = "%04d-%02d-%02d"
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
datetime_date = datetime.date
|
||||
format = self._storage_format
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, datetime_date):
|
||||
return format % (value.year, value.month, value.day)
|
||||
else:
|
||||
raise TypeError("SQLite Date type only accepts Python "
|
||||
"date objects as input.")
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self._reg:
|
||||
return processors.str_to_datetime_processor_factory(
|
||||
self._reg, datetime.date)
|
||||
else:
|
||||
return processors.str_to_date
|
||||
|
||||
class TIME(_DateTimeMixin, sqltypes.Time):
|
||||
_storage_format = "%02d:%02d:%02d.%06d"
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
datetime_time = datetime.time
|
||||
format = self._storage_format
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, datetime_time):
|
||||
return format % (value.hour, value.minute, value.second,
|
||||
value.microsecond)
|
||||
else:
|
||||
raise TypeError("SQLite Time type only accepts Python "
|
||||
"time objects as input.")
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self._reg:
|
||||
return processors.str_to_datetime_processor_factory(
|
||||
self._reg, datetime.time)
|
||||
else:
|
||||
return processors.str_to_time
|
||||
|
||||
colspecs = {
|
||||
sqltypes.Date: DATE,
|
||||
sqltypes.DateTime: DATETIME,
|
||||
sqltypes.Time: TIME,
|
||||
}
|
||||
|
||||
ischema_names = {
|
||||
'BLOB': sqltypes.BLOB,
|
||||
'BOOL': sqltypes.BOOLEAN,
|
||||
'BOOLEAN': sqltypes.BOOLEAN,
|
||||
'CHAR': sqltypes.CHAR,
|
||||
'DATE': sqltypes.DATE,
|
||||
'DATETIME': sqltypes.DATETIME,
|
||||
'DECIMAL': sqltypes.DECIMAL,
|
||||
'FLOAT': sqltypes.FLOAT,
|
||||
'INT': sqltypes.INTEGER,
|
||||
'INTEGER': sqltypes.INTEGER,
|
||||
'NUMERIC': sqltypes.NUMERIC,
|
||||
'REAL': sqltypes.Numeric,
|
||||
'SMALLINT': sqltypes.SMALLINT,
|
||||
'TEXT': sqltypes.TEXT,
|
||||
'TIME': sqltypes.TIME,
|
||||
'TIMESTAMP': sqltypes.TIMESTAMP,
|
||||
'VARCHAR': sqltypes.VARCHAR,
|
||||
}
|
||||
|
||||
|
||||
|
||||
class SQLiteCompiler(compiler.SQLCompiler):
|
||||
extract_map = util.update_copy(
|
||||
compiler.SQLCompiler.extract_map,
|
||||
{
|
||||
'month': '%m',
|
||||
'day': '%d',
|
||||
'year': '%Y',
|
||||
'second': '%S',
|
||||
'hour': '%H',
|
||||
'doy': '%j',
|
||||
'minute': '%M',
|
||||
'epoch': '%s',
|
||||
'dow': '%w',
|
||||
'week': '%W'
|
||||
})
|
||||
|
||||
def visit_now_func(self, fn, **kw):
|
||||
return "CURRENT_TIMESTAMP"
|
||||
|
||||
def visit_char_length_func(self, fn, **kw):
|
||||
return "length%s" % self.function_argspec(fn)
|
||||
|
||||
def visit_cast(self, cast, **kwargs):
|
||||
if self.dialect.supports_cast:
|
||||
return super(SQLiteCompiler, self).visit_cast(cast)
|
||||
else:
|
||||
return self.process(cast.clause)
|
||||
|
||||
def visit_extract(self, extract, **kw):
|
||||
try:
|
||||
return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
|
||||
self.extract_map[extract.field], self.process(extract.expr, **kw))
|
||||
except KeyError:
|
||||
raise exc.ArgumentError(
|
||||
"%s is not a valid extract argument." % extract.field)
|
||||
|
||||
def limit_clause(self, select):
|
||||
text = ""
|
||||
if select._limit is not None:
|
||||
text += " \n LIMIT " + str(select._limit)
|
||||
if select._offset is not None:
|
||||
if select._limit is None:
|
||||
text += " \n LIMIT -1"
|
||||
text += " OFFSET " + str(select._offset)
|
||||
else:
|
||||
text += " OFFSET 0"
|
||||
return text
|
||||
|
||||
def for_update_clause(self, select):
|
||||
# sqlite has no "FOR UPDATE" AFAICT
|
||||
return ''
|
||||
|
||||
|
||||
class SQLiteDDLCompiler(compiler.DDLCompiler):
|
||||
|
||||
def get_column_specification(self, column, **kwargs):
|
||||
colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
if not column.nullable:
|
||||
colspec += " NOT NULL"
|
||||
|
||||
if column.primary_key and \
|
||||
column.table.kwargs.get('sqlite_autoincrement', False) and \
|
||||
len(column.table.primary_key.columns) == 1 and \
|
||||
isinstance(column.type, sqltypes.Integer) and \
|
||||
not column.foreign_keys:
|
||||
colspec += " PRIMARY KEY AUTOINCREMENT"
|
||||
|
||||
return colspec
|
||||
|
||||
def visit_primary_key_constraint(self, constraint):
|
||||
# for columns with sqlite_autoincrement=True,
|
||||
# the PRIMARY KEY constraint can only be inline
|
||||
# with the column itself.
|
||||
if len(constraint.columns) == 1:
|
||||
c = list(constraint)[0]
|
||||
if c.primary_key and \
|
||||
c.table.kwargs.get('sqlite_autoincrement', False) and \
|
||||
isinstance(c.type, sqltypes.Integer) and \
|
||||
not c.foreign_keys:
|
||||
return ''
|
||||
|
||||
return super(SQLiteDDLCompiler, self).\
|
||||
visit_primary_key_constraint(constraint)
|
||||
|
||||
|
||||
def visit_create_index(self, create):
|
||||
index = create.element
|
||||
preparer = self.preparer
|
||||
text = "CREATE "
|
||||
if index.unique:
|
||||
text += "UNIQUE "
|
||||
text += "INDEX %s ON %s (%s)" \
|
||||
% (preparer.format_index(index,
|
||||
name=self._validate_identifier(index.name, True)),
|
||||
preparer.format_table(index.table, use_schema=False),
|
||||
', '.join(preparer.quote(c.name, c.quote)
|
||||
for c in index.columns))
|
||||
return text
|
||||
|
||||
class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
|
||||
def visit_large_binary(self, type_):
|
||||
return self.visit_BLOB(type_)
|
||||
|
||||
class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
|
||||
reserved_words = set([
|
||||
'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
|
||||
'attach', 'autoincrement', 'before', 'begin', 'between', 'by',
|
||||
'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit',
|
||||
'conflict', 'constraint', 'create', 'cross', 'current_date',
|
||||
'current_time', 'current_timestamp', 'database', 'default',
|
||||
'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct',
|
||||
'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive',
|
||||
'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob',
|
||||
'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index',
|
||||
'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is',
|
||||
'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural',
|
||||
'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer',
|
||||
'plan', 'pragma', 'primary', 'query', 'raise', 'references',
|
||||
'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback',
|
||||
'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to',
|
||||
'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using',
|
||||
'vacuum', 'values', 'view', 'virtual', 'when', 'where',
|
||||
])
|
||||
|
||||
def format_index(self, index, use_schema=True, name=None):
|
||||
"""Prepare a quoted index and schema name."""
|
||||
|
||||
if name is None:
|
||||
name = index.name
|
||||
result = self.quote(name, index.quote)
|
||||
if not self.omit_schema and use_schema and getattr(index.table, "schema", None):
|
||||
result = self.quote_schema(index.table.schema, index.table.quote_schema) + "." + result
|
||||
return result
|
||||
|
||||
class SQLiteDialect(default.DefaultDialect):
|
||||
name = 'sqlite'
|
||||
supports_alter = False
|
||||
supports_unicode_statements = True
|
||||
supports_unicode_binds = True
|
||||
supports_default_values = True
|
||||
supports_empty_insert = False
|
||||
supports_cast = True
|
||||
|
||||
default_paramstyle = 'qmark'
|
||||
statement_compiler = SQLiteCompiler
|
||||
ddl_compiler = SQLiteDDLCompiler
|
||||
type_compiler = SQLiteTypeCompiler
|
||||
preparer = SQLiteIdentifierPreparer
|
||||
ischema_names = ischema_names
|
||||
colspecs = colspecs
|
||||
isolation_level = None
|
||||
|
||||
supports_cast = True
|
||||
supports_default_values = True
|
||||
|
||||
def __init__(self, isolation_level=None, native_datetime=False, **kwargs):
|
||||
default.DefaultDialect.__init__(self, **kwargs)
|
||||
if isolation_level and isolation_level not in ('SERIALIZABLE',
|
||||
'READ UNCOMMITTED'):
|
||||
raise exc.ArgumentError("Invalid value for isolation_level. "
|
||||
"Valid isolation levels for sqlite are 'SERIALIZABLE' and "
|
||||
"'READ UNCOMMITTED'.")
|
||||
self.isolation_level = isolation_level
|
||||
|
||||
# this flag used by pysqlite dialect, and perhaps others in the
|
||||
# future, to indicate the driver is handling date/timestamp
|
||||
# conversions (and perhaps datetime/time as well on some
|
||||
# hypothetical driver ?)
|
||||
self.native_datetime = native_datetime
|
||||
|
||||
if self.dbapi is not None:
|
||||
self.supports_default_values = \
|
||||
self.dbapi.sqlite_version_info >= (3, 3, 8)
|
||||
self.supports_cast = \
|
||||
self.dbapi.sqlite_version_info >= (3, 2, 3)
|
||||
|
||||
|
||||
def on_connect(self):
|
||||
if self.isolation_level is not None:
|
||||
if self.isolation_level == 'READ UNCOMMITTED':
|
||||
isolation_level = 1
|
||||
else:
|
||||
isolation_level = 0
|
||||
|
||||
def connect(conn):
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level)
|
||||
cursor.close()
|
||||
return connect
|
||||
else:
|
||||
return None
|
||||
|
||||
@reflection.cache
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
if schema is not None:
|
||||
qschema = self.identifier_preparer.quote_identifier(schema)
|
||||
master = '%s.sqlite_master' % qschema
|
||||
s = ("SELECT name FROM %s "
|
||||
"WHERE type='table' ORDER BY name") % (master,)
|
||||
rs = connection.execute(s)
|
||||
else:
|
||||
try:
|
||||
s = ("SELECT name FROM "
|
||||
" (SELECT * FROM sqlite_master UNION ALL "
|
||||
" SELECT * FROM sqlite_temp_master) "
|
||||
"WHERE type='table' ORDER BY name")
|
||||
rs = connection.execute(s)
|
||||
except exc.DBAPIError:
|
||||
raise
|
||||
s = ("SELECT name FROM sqlite_master "
|
||||
"WHERE type='table' ORDER BY name")
|
||||
rs = connection.execute(s)
|
||||
|
||||
return [row[0] for row in rs]
|
||||
|
||||
def has_table(self, connection, table_name, schema=None):
|
||||
quote = self.identifier_preparer.quote_identifier
|
||||
if schema is not None:
|
||||
pragma = "PRAGMA %s." % quote(schema)
|
||||
else:
|
||||
pragma = "PRAGMA "
|
||||
qtable = quote(table_name)
|
||||
cursor = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
|
||||
row = cursor.fetchone()
|
||||
|
||||
# consume remaining rows, to work around
|
||||
# http://www.sqlite.org/cvstrac/tktview?tn=1884
|
||||
while cursor.fetchone() is not None:
|
||||
pass
|
||||
|
||||
return (row is not None)
|
||||
|
||||
@reflection.cache
|
||||
def get_view_names(self, connection, schema=None, **kw):
|
||||
if schema is not None:
|
||||
qschema = self.identifier_preparer.quote_identifier(schema)
|
||||
master = '%s.sqlite_master' % qschema
|
||||
s = ("SELECT name FROM %s "
|
||||
"WHERE type='view' ORDER BY name") % (master,)
|
||||
rs = connection.execute(s)
|
||||
else:
|
||||
try:
|
||||
s = ("SELECT name FROM "
|
||||
" (SELECT * FROM sqlite_master UNION ALL "
|
||||
" SELECT * FROM sqlite_temp_master) "
|
||||
"WHERE type='view' ORDER BY name")
|
||||
rs = connection.execute(s)
|
||||
except exc.DBAPIError:
|
||||
raise
|
||||
s = ("SELECT name FROM sqlite_master "
|
||||
"WHERE type='view' ORDER BY name")
|
||||
rs = connection.execute(s)
|
||||
|
||||
return [row[0] for row in rs]
|
||||
|
||||
@reflection.cache
|
||||
def get_view_definition(self, connection, view_name, schema=None, **kw):
|
||||
quote = self.identifier_preparer.quote_identifier
|
||||
if schema is not None:
|
||||
qschema = self.identifier_preparer.quote_identifier(schema)
|
||||
master = '%s.sqlite_master' % qschema
|
||||
s = ("SELECT sql FROM %s WHERE name = '%s'"
|
||||
"AND type='view'") % (master, view_name)
|
||||
rs = connection.execute(s)
|
||||
else:
|
||||
try:
|
||||
s = ("SELECT sql FROM "
|
||||
" (SELECT * FROM sqlite_master UNION ALL "
|
||||
" SELECT * FROM sqlite_temp_master) "
|
||||
"WHERE name = '%s' "
|
||||
"AND type='view'") % view_name
|
||||
rs = connection.execute(s)
|
||||
except exc.DBAPIError:
|
||||
raise
|
||||
s = ("SELECT sql FROM sqlite_master WHERE name = '%s' "
|
||||
"AND type='view'") % view_name
|
||||
rs = connection.execute(s)
|
||||
|
||||
result = rs.fetchall()
|
||||
if result:
|
||||
return result[0].sql
|
||||
|
||||
@reflection.cache
|
||||
def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
quote = self.identifier_preparer.quote_identifier
|
||||
if schema is not None:
|
||||
pragma = "PRAGMA %s." % quote(schema)
|
||||
else:
|
||||
pragma = "PRAGMA "
|
||||
qtable = quote(table_name)
|
||||
c = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
|
||||
found_table = False
|
||||
columns = []
|
||||
while True:
|
||||
row = c.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
(name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5])
|
||||
name = re.sub(r'^\"|\"$', '', name)
|
||||
if default:
|
||||
default = re.sub(r"^\'|\'$", '', default)
|
||||
match = re.match(r'(\w+)(\(.*?\))?', type_)
|
||||
if match:
|
||||
coltype = match.group(1)
|
||||
args = match.group(2)
|
||||
else:
|
||||
coltype = "VARCHAR"
|
||||
args = ''
|
||||
try:
|
||||
coltype = self.ischema_names[coltype]
|
||||
except KeyError:
|
||||
util.warn("Did not recognize type '%s' of column '%s'" %
|
||||
(coltype, name))
|
||||
coltype = sqltypes.NullType
|
||||
if args is not None:
|
||||
args = re.findall(r'(\d+)', args)
|
||||
coltype = coltype(*[int(a) for a in args])
|
||||
|
||||
columns.append({
|
||||
'name' : name,
|
||||
'type' : coltype,
|
||||
'nullable' : nullable,
|
||||
'default' : default,
|
||||
'primary_key': primary_key
|
||||
})
|
||||
return columns
|
||||
|
||||
@reflection.cache
|
||||
def get_primary_keys(self, connection, table_name, schema=None, **kw):
|
||||
cols = self.get_columns(connection, table_name, schema, **kw)
|
||||
pkeys = []
|
||||
for col in cols:
|
||||
if col['primary_key']:
|
||||
pkeys.append(col['name'])
|
||||
return pkeys
|
||||
|
||||
@reflection.cache
|
||||
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
|
||||
quote = self.identifier_preparer.quote_identifier
|
||||
if schema is not None:
|
||||
pragma = "PRAGMA %s." % quote(schema)
|
||||
else:
|
||||
pragma = "PRAGMA "
|
||||
qtable = quote(table_name)
|
||||
c = _pragma_cursor(connection.execute("%sforeign_key_list(%s)" % (pragma, qtable)))
|
||||
fkeys = []
|
||||
fks = {}
|
||||
while True:
|
||||
row = c.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
(constraint_name, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4])
|
||||
rtbl = re.sub(r'^\"|\"$', '', rtbl)
|
||||
lcol = re.sub(r'^\"|\"$', '', lcol)
|
||||
rcol = re.sub(r'^\"|\"$', '', rcol)
|
||||
try:
|
||||
fk = fks[constraint_name]
|
||||
except KeyError:
|
||||
fk = {
|
||||
'name' : constraint_name,
|
||||
'constrained_columns' : [],
|
||||
'referred_schema' : None,
|
||||
'referred_table' : rtbl,
|
||||
'referred_columns' : []
|
||||
}
|
||||
fkeys.append(fk)
|
||||
fks[constraint_name] = fk
|
||||
|
||||
# look up the table based on the given table's engine, not 'self',
|
||||
# since it could be a ProxyEngine
|
||||
if lcol not in fk['constrained_columns']:
|
||||
fk['constrained_columns'].append(lcol)
|
||||
if rcol not in fk['referred_columns']:
|
||||
fk['referred_columns'].append(rcol)
|
||||
return fkeys
|
||||
|
||||
@reflection.cache
|
||||
def get_indexes(self, connection, table_name, schema=None, **kw):
|
||||
quote = self.identifier_preparer.quote_identifier
|
||||
if schema is not None:
|
||||
pragma = "PRAGMA %s." % quote(schema)
|
||||
else:
|
||||
pragma = "PRAGMA "
|
||||
include_auto_indexes = kw.pop('include_auto_indexes', False)
|
||||
qtable = quote(table_name)
|
||||
c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable)))
|
||||
indexes = []
|
||||
while True:
|
||||
row = c.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
# ignore implicit primary key index.
|
||||
# http://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html
|
||||
elif not include_auto_indexes and row[1].startswith('sqlite_autoindex'):
|
||||
continue
|
||||
|
||||
indexes.append(dict(name=row[1], column_names=[], unique=row[2]))
|
||||
# loop thru unique indexes to get the column names.
|
||||
for idx in indexes:
|
||||
c = connection.execute("%sindex_info(%s)" % (pragma, quote(idx['name'])))
|
||||
cols = idx['column_names']
|
||||
while True:
|
||||
row = c.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
cols.append(row[2])
|
||||
return indexes
|
||||
|
||||
|
||||
def _pragma_cursor(cursor):
|
||||
"""work around SQLite issue whereby cursor.description is blank when PRAGMA returns no rows."""
|
||||
|
||||
if cursor.closed:
|
||||
cursor._fetchone_impl = lambda: None
|
||||
return cursor
|
||||
236
sqlalchemy/dialects/sqlite/pysqlite.py
Normal file
236
sqlalchemy/dialects/sqlite/pysqlite.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Support for the SQLite database via pysqlite.
|
||||
|
||||
Note that pysqlite is the same driver as the ``sqlite3``
|
||||
module included with the Python distribution.
|
||||
|
||||
Driver
|
||||
------
|
||||
|
||||
When using Python 2.5 and above, the built in ``sqlite3`` driver is
|
||||
already installed and no additional installation is needed. Otherwise,
|
||||
the ``pysqlite2`` driver needs to be present. This is the same driver as
|
||||
``sqlite3``, just with a different name.
|
||||
|
||||
The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3``
|
||||
is loaded. This allows an explicitly installed pysqlite driver to take
|
||||
precedence over the built in one. As with all dialects, a specific
|
||||
DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control
|
||||
this explicitly::
|
||||
|
||||
from sqlite3 import dbapi2 as sqlite
|
||||
e = create_engine('sqlite+pysqlite:///file.db', module=sqlite)
|
||||
|
||||
Full documentation on pysqlite is available at:
|
||||
`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
|
||||
|
||||
Connect Strings
|
||||
---------------
|
||||
|
||||
The file specification for the SQLite database is taken as the "database" portion of
|
||||
the URL. Note that the format of a url is::
|
||||
|
||||
driver://user:pass@host/database
|
||||
|
||||
This means that the actual filename to be used starts with the characters to the
|
||||
**right** of the third slash. So connecting to a relative filepath looks like::
|
||||
|
||||
# relative path
|
||||
e = create_engine('sqlite:///path/to/database.db')
|
||||
|
||||
An absolute path, which is denoted by starting with a slash, means you need **four**
|
||||
slashes::
|
||||
|
||||
# absolute path
|
||||
e = create_engine('sqlite:////path/to/database.db')
|
||||
|
||||
To use a Windows path, regular drive specifications and backslashes can be used.
|
||||
Double backslashes are probably needed::
|
||||
|
||||
# absolute path on Windows
|
||||
e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db')
|
||||
|
||||
The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify
|
||||
``sqlite://`` and nothing else::
|
||||
|
||||
# in-memory database
|
||||
e = create_engine('sqlite://')
|
||||
|
||||
Compatibility with sqlite3 "native" date and datetime types
|
||||
-----------------------------------------------------------
|
||||
|
||||
The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and
|
||||
sqlite3.PARSE_COLNAMES options, which have the effect of any column
|
||||
or expression explicitly cast as "date" or "timestamp" will be converted
|
||||
to a Python date or datetime object. The date and datetime types provided
|
||||
with the pysqlite dialect are not currently compatible with these options,
|
||||
since they render the ISO date/datetime including microseconds, which
|
||||
pysqlite's driver does not. Additionally, SQLAlchemy does not at
|
||||
this time automatically render the "cast" syntax required for the
|
||||
freestanding functions "current_timestamp" and "current_date" to return
|
||||
datetime/date types natively. Unfortunately, pysqlite
|
||||
does not provide the standard DBAPI types in `cursor.description`,
|
||||
leaving SQLAlchemy with no way to detect these types on the fly
|
||||
without expensive per-row type checks.
|
||||
|
||||
Usage of PARSE_DECLTYPES can be forced if one configures
|
||||
"native_datetime=True" on create_engine()::
|
||||
|
||||
engine = create_engine('sqlite://',
|
||||
connect_args={'detect_types': sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
|
||||
native_datetime=True
|
||||
)
|
||||
|
||||
With this flag enabled, the DATE and TIMESTAMP types (but note - not the DATETIME
|
||||
or TIME types...confused yet ?) will not perform any bind parameter or result
|
||||
processing. Execution of "func.current_date()" will return a string.
|
||||
"func.current_timestamp()" is registered as returning a DATETIME type in
|
||||
SQLAlchemy, so this function still receives SQLAlchemy-level result processing.
|
||||
|
||||
Threading Behavior
|
||||
------------------
|
||||
|
||||
Pysqlite connections do not support being moved between threads, unless
|
||||
the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition,
|
||||
when using an in-memory SQLite database, the full database exists only within
|
||||
the scope of a single connection. It is reported that an in-memory
|
||||
database does not support being shared between threads regardless of the
|
||||
``check_same_thread`` flag - which means that a multithreaded
|
||||
application **cannot** share data from a ``:memory:`` database across threads
|
||||
unless access to the connection is limited to a single worker thread which communicates
|
||||
through a queueing mechanism to concurrent threads.
|
||||
|
||||
To provide a default which accomodates SQLite's default threading capabilities
|
||||
somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool`
|
||||
be used by default. This pool maintains a single SQLite connection per thread
|
||||
that is held open up to a count of five concurrent threads. When more than five threads
|
||||
are used, a cleanup mechanism will dispose of excess unused connections.
|
||||
|
||||
Two optional pool implementations that may be appropriate for particular SQLite usage scenarios:
|
||||
|
||||
* the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded
|
||||
application using an in-memory database, assuming the threading issues inherent in
|
||||
pysqlite are somehow accomodated for. This pool holds persistently onto a single connection
|
||||
which is never closed, and is returned for all requests.
|
||||
|
||||
* the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that
|
||||
makes use of a file-based sqlite database. This pool disables any actual "pooling"
|
||||
behavior, and simply opens and closes real connections corresonding to the :func:`connect()`
|
||||
and :func:`close()` methods. SQLite can "connect" to a particular file with very high
|
||||
efficiency, so this option may actually perform better without the extra overhead
|
||||
of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection
|
||||
useless since the database would be lost as soon as the connection is "returned" to the pool.
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's
|
||||
default behavior regarding Unicode is that all strings are returned as Python unicode objects
|
||||
in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is
|
||||
*not* used, you will still always receive unicode data back from a result set. It is
|
||||
**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type
|
||||
to represent strings, since it will raise a warning if a non-unicode Python string is
|
||||
passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can
|
||||
quickly create confusion, particularly when using the ORM as internal data is not
|
||||
always represented by an actual database result string.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.sqlite.base import SQLiteDialect, DATETIME, DATE
|
||||
from sqlalchemy import schema, exc, pool
|
||||
from sqlalchemy.engine import default
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy import util
|
||||
|
||||
|
||||
class _SQLite_pysqliteTimeStamp(DATETIME):
|
||||
def bind_processor(self, dialect):
|
||||
if dialect.native_datetime:
|
||||
return None
|
||||
else:
|
||||
return DATETIME.bind_processor(self, dialect)
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if dialect.native_datetime:
|
||||
return None
|
||||
else:
|
||||
return DATETIME.result_processor(self, dialect, coltype)
|
||||
|
||||
class _SQLite_pysqliteDate(DATE):
|
||||
def bind_processor(self, dialect):
|
||||
if dialect.native_datetime:
|
||||
return None
|
||||
else:
|
||||
return DATE.bind_processor(self, dialect)
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if dialect.native_datetime:
|
||||
return None
|
||||
else:
|
||||
return DATE.result_processor(self, dialect, coltype)
|
||||
|
||||
class SQLiteDialect_pysqlite(SQLiteDialect):
|
||||
default_paramstyle = 'qmark'
|
||||
poolclass = pool.SingletonThreadPool
|
||||
|
||||
colspecs = util.update_copy(
|
||||
SQLiteDialect.colspecs,
|
||||
{
|
||||
sqltypes.Date:_SQLite_pysqliteDate,
|
||||
sqltypes.TIMESTAMP:_SQLite_pysqliteTimeStamp,
|
||||
}
|
||||
)
|
||||
|
||||
# Py3K
|
||||
#description_encoding = None
|
||||
|
||||
driver = 'pysqlite'
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
SQLiteDialect.__init__(self, **kwargs)
|
||||
|
||||
if self.dbapi is not None:
|
||||
sqlite_ver = self.dbapi.version_info
|
||||
if sqlite_ver < (2, 1, 3):
|
||||
util.warn(
|
||||
("The installed version of pysqlite2 (%s) is out-dated "
|
||||
"and will cause errors in some cases. Version 2.1.3 "
|
||||
"or greater is recommended.") %
|
||||
'.'.join([str(subver) for subver in sqlite_ver]))
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
try:
|
||||
from pysqlite2 import dbapi2 as sqlite
|
||||
except ImportError, e:
|
||||
try:
|
||||
from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
|
||||
except ImportError:
|
||||
raise e
|
||||
return sqlite
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
return self.dbapi.sqlite_version_info
|
||||
|
||||
def create_connect_args(self, url):
|
||||
if url.username or url.password or url.host or url.port:
|
||||
raise exc.ArgumentError(
|
||||
"Invalid SQLite URL: %s\n"
|
||||
"Valid SQLite URL forms are:\n"
|
||||
" sqlite:///:memory: (or, sqlite://)\n"
|
||||
" sqlite:///relative/path/to/file.db\n"
|
||||
" sqlite:////absolute/path/to/file.db" % (url,))
|
||||
filename = url.database or ':memory:'
|
||||
|
||||
opts = url.query.copy()
|
||||
util.coerce_kw_type(opts, 'timeout', float)
|
||||
util.coerce_kw_type(opts, 'isolation_level', str)
|
||||
util.coerce_kw_type(opts, 'detect_types', int)
|
||||
util.coerce_kw_type(opts, 'check_same_thread', bool)
|
||||
util.coerce_kw_type(opts, 'cached_statements', int)
|
||||
|
||||
return ([filename], opts)
|
||||
|
||||
def is_disconnect(self, e):
|
||||
return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
|
||||
|
||||
dialect = SQLiteDialect_pysqlite
|
||||
20
sqlalchemy/dialects/sybase/__init__.py
Normal file
20
sqlalchemy/dialects/sybase/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from sqlalchemy.dialects.sybase import base, pysybase, pyodbc
|
||||
|
||||
|
||||
from base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
|
||||
TEXT,DATE,DATETIME, FLOAT, NUMERIC,\
|
||||
BIGINT,INT, INTEGER, SMALLINT, BINARY,\
|
||||
VARBINARY,UNITEXT,UNICHAR,UNIVARCHAR,\
|
||||
IMAGE,BIT,MONEY,SMALLMONEY,TINYINT
|
||||
|
||||
# default dialect
|
||||
base.dialect = pyodbc.dialect
|
||||
|
||||
__all__ = (
|
||||
'CHAR', 'VARCHAR', 'TIME', 'NCHAR', 'NVARCHAR',
|
||||
'TEXT','DATE','DATETIME', 'FLOAT', 'NUMERIC',
|
||||
'BIGINT','INT', 'INTEGER', 'SMALLINT', 'BINARY',
|
||||
'VARBINARY','UNITEXT','UNICHAR','UNIVARCHAR',
|
||||
'IMAGE','BIT','MONEY','SMALLMONEY','TINYINT',
|
||||
'dialect'
|
||||
)
|
||||
420
sqlalchemy/dialects/sybase/base.py
Normal file
420
sqlalchemy/dialects/sybase/base.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# sybase/base.py
|
||||
# Copyright (C) 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
# get_select_precolumns(), limit_clause() implementation
|
||||
# copyright (C) 2007 Fisch Asset Management
|
||||
# AG http://www.fam.ch, with coding by Alexander Houben
|
||||
# alexander.houben@thor-solutions.ch
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Support for Sybase Adaptive Server Enterprise (ASE).
|
||||
|
||||
Note that this dialect is no longer specific to Sybase iAnywhere.
|
||||
ASE is the primary support platform.
|
||||
|
||||
"""
|
||||
|
||||
import operator
|
||||
from sqlalchemy.sql import compiler, expression, text, bindparam
|
||||
from sqlalchemy.engine import default, base, reflection
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.sql import operators as sql_operators
|
||||
from sqlalchemy import schema as sa_schema
|
||||
from sqlalchemy import util, sql, exc
|
||||
|
||||
from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
|
||||
TEXT,DATE,DATETIME, FLOAT, NUMERIC,\
|
||||
BIGINT,INT, INTEGER, SMALLINT, BINARY,\
|
||||
VARBINARY, DECIMAL, TIMESTAMP, Unicode,\
|
||||
UnicodeText
|
||||
|
||||
RESERVED_WORDS = set([
|
||||
"add", "all", "alter", "and",
|
||||
"any", "as", "asc", "backup",
|
||||
"begin", "between", "bigint", "binary",
|
||||
"bit", "bottom", "break", "by",
|
||||
"call", "capability", "cascade", "case",
|
||||
"cast", "char", "char_convert", "character",
|
||||
"check", "checkpoint", "close", "comment",
|
||||
"commit", "connect", "constraint", "contains",
|
||||
"continue", "convert", "create", "cross",
|
||||
"cube", "current", "current_timestamp", "current_user",
|
||||
"cursor", "date", "dbspace", "deallocate",
|
||||
"dec", "decimal", "declare", "default",
|
||||
"delete", "deleting", "desc", "distinct",
|
||||
"do", "double", "drop", "dynamic",
|
||||
"else", "elseif", "encrypted", "end",
|
||||
"endif", "escape", "except", "exception",
|
||||
"exec", "execute", "existing", "exists",
|
||||
"externlogin", "fetch", "first", "float",
|
||||
"for", "force", "foreign", "forward",
|
||||
"from", "full", "goto", "grant",
|
||||
"group", "having", "holdlock", "identified",
|
||||
"if", "in", "index", "index_lparen",
|
||||
"inner", "inout", "insensitive", "insert",
|
||||
"inserting", "install", "instead", "int",
|
||||
"integer", "integrated", "intersect", "into",
|
||||
"iq", "is", "isolation", "join",
|
||||
"key", "lateral", "left", "like",
|
||||
"lock", "login", "long", "match",
|
||||
"membership", "message", "mode", "modify",
|
||||
"natural", "new", "no", "noholdlock",
|
||||
"not", "notify", "null", "numeric",
|
||||
"of", "off", "on", "open",
|
||||
"option", "options", "or", "order",
|
||||
"others", "out", "outer", "over",
|
||||
"passthrough", "precision", "prepare", "primary",
|
||||
"print", "privileges", "proc", "procedure",
|
||||
"publication", "raiserror", "readtext", "real",
|
||||
"reference", "references", "release", "remote",
|
||||
"remove", "rename", "reorganize", "resource",
|
||||
"restore", "restrict", "return", "revoke",
|
||||
"right", "rollback", "rollup", "save",
|
||||
"savepoint", "scroll", "select", "sensitive",
|
||||
"session", "set", "setuser", "share",
|
||||
"smallint", "some", "sqlcode", "sqlstate",
|
||||
"start", "stop", "subtrans", "subtransaction",
|
||||
"synchronize", "syntax_error", "table", "temporary",
|
||||
"then", "time", "timestamp", "tinyint",
|
||||
"to", "top", "tran", "trigger",
|
||||
"truncate", "tsequal", "unbounded", "union",
|
||||
"unique", "unknown", "unsigned", "update",
|
||||
"updating", "user", "using", "validate",
|
||||
"values", "varbinary", "varchar", "variable",
|
||||
"varying", "view", "wait", "waitfor",
|
||||
"when", "where", "while", "window",
|
||||
"with", "with_cube", "with_lparen", "with_rollup",
|
||||
"within", "work", "writetext",
|
||||
])
|
||||
|
||||
|
||||
class _SybaseUnitypeMixin(object):
|
||||
"""these types appear to return a buffer object."""
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return str(value) #.decode("ucs-2")
|
||||
else:
|
||||
return None
|
||||
return process
|
||||
|
||||
class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
|
||||
__visit_name__ = 'UNICHAR'
|
||||
|
||||
class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
|
||||
__visit_name__ = 'UNIVARCHAR'
|
||||
|
||||
class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText):
|
||||
__visit_name__ = 'UNITEXT'
|
||||
|
||||
class TINYINT(sqltypes.Integer):
|
||||
__visit_name__ = 'TINYINT'
|
||||
|
||||
class BIT(sqltypes.TypeEngine):
|
||||
__visit_name__ = 'BIT'
|
||||
|
||||
class MONEY(sqltypes.TypeEngine):
|
||||
__visit_name__ = "MONEY"
|
||||
|
||||
class SMALLMONEY(sqltypes.TypeEngine):
|
||||
__visit_name__ = "SMALLMONEY"
|
||||
|
||||
class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
|
||||
__visit_name__ = "UNIQUEIDENTIFIER"
|
||||
|
||||
class IMAGE(sqltypes.LargeBinary):
|
||||
__visit_name__ = 'IMAGE'
|
||||
|
||||
|
||||
class SybaseTypeCompiler(compiler.GenericTypeCompiler):
|
||||
def visit_large_binary(self, type_):
|
||||
return self.visit_IMAGE(type_)
|
||||
|
||||
def visit_boolean(self, type_):
|
||||
return self.visit_BIT(type_)
|
||||
|
||||
def visit_unicode(self, type_):
|
||||
return self.visit_NVARCHAR(type_)
|
||||
|
||||
def visit_UNICHAR(self, type_):
|
||||
return "UNICHAR(%d)" % type_.length
|
||||
|
||||
def visit_UNIVARCHAR(self, type_):
|
||||
return "UNIVARCHAR(%d)" % type_.length
|
||||
|
||||
def visit_UNITEXT(self, type_):
|
||||
return "UNITEXT"
|
||||
|
||||
def visit_TINYINT(self, type_):
|
||||
return "TINYINT"
|
||||
|
||||
def visit_IMAGE(self, type_):
|
||||
return "IMAGE"
|
||||
|
||||
def visit_BIT(self, type_):
|
||||
return "BIT"
|
||||
|
||||
def visit_MONEY(self, type_):
|
||||
return "MONEY"
|
||||
|
||||
def visit_SMALLMONEY(self, type_):
|
||||
return "SMALLMONEY"
|
||||
|
||||
def visit_UNIQUEIDENTIFIER(self, type_):
|
||||
return "UNIQUEIDENTIFIER"
|
||||
|
||||
ischema_names = {
|
||||
'integer' : INTEGER,
|
||||
'unsigned int' : INTEGER, # TODO: unsigned flags
|
||||
'unsigned smallint' : SMALLINT, # TODO: unsigned flags
|
||||
'unsigned bigint' : BIGINT, # TODO: unsigned flags
|
||||
'bigint': BIGINT,
|
||||
'smallint' : SMALLINT,
|
||||
'tinyint' : TINYINT,
|
||||
'varchar' : VARCHAR,
|
||||
'long varchar' : TEXT, # TODO
|
||||
'char' : CHAR,
|
||||
'decimal' : DECIMAL,
|
||||
'numeric' : NUMERIC,
|
||||
'float' : FLOAT,
|
||||
'double' : NUMERIC, # TODO
|
||||
'binary' : BINARY,
|
||||
'varbinary' : VARBINARY,
|
||||
'bit': BIT,
|
||||
'image' : IMAGE,
|
||||
'timestamp': TIMESTAMP,
|
||||
'money': MONEY,
|
||||
'smallmoney': MONEY,
|
||||
'uniqueidentifier': UNIQUEIDENTIFIER,
|
||||
|
||||
}
|
||||
|
||||
|
||||
class SybaseExecutionContext(default.DefaultExecutionContext):
|
||||
_enable_identity_insert = False
|
||||
|
||||
def set_ddl_autocommit(self, connection, value):
|
||||
"""Must be implemented by subclasses to accommodate DDL executions.
|
||||
|
||||
"connection" is the raw unwrapped DBAPI connection. "value"
|
||||
is True or False. when True, the connection should be configured
|
||||
such that a DDL can take place subsequently. when False,
|
||||
a DDL has taken place and the connection should be resumed
|
||||
into non-autocommit mode.
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def pre_exec(self):
|
||||
if self.isinsert:
|
||||
tbl = self.compiled.statement.table
|
||||
seq_column = tbl._autoincrement_column
|
||||
insert_has_sequence = seq_column is not None
|
||||
|
||||
if insert_has_sequence:
|
||||
self._enable_identity_insert = seq_column.key in self.compiled_parameters[0]
|
||||
else:
|
||||
self._enable_identity_insert = False
|
||||
|
||||
if self._enable_identity_insert:
|
||||
self.cursor.execute("SET IDENTITY_INSERT %s ON" %
|
||||
self.dialect.identifier_preparer.format_table(tbl))
|
||||
|
||||
if self.isddl:
|
||||
# TODO: to enhance this, we can detect "ddl in tran" on the
|
||||
# database settings. this error message should be improved to
|
||||
# include a note about that.
|
||||
if not self.should_autocommit:
|
||||
raise exc.InvalidRequestError("The Sybase dialect only supports "
|
||||
"DDL in 'autocommit' mode at this time.")
|
||||
|
||||
self.root_connection.engine.logger.info("AUTOCOMMIT (Assuming no Sybase 'ddl in tran')")
|
||||
|
||||
self.set_ddl_autocommit(self.root_connection.connection.connection, True)
|
||||
|
||||
|
||||
def post_exec(self):
|
||||
if self.isddl:
|
||||
self.set_ddl_autocommit(self.root_connection, False)
|
||||
|
||||
if self._enable_identity_insert:
|
||||
self.cursor.execute(
|
||||
"SET IDENTITY_INSERT %s OFF" %
|
||||
self.dialect.identifier_preparer.
|
||||
format_table(self.compiled.statement.table)
|
||||
)
|
||||
|
||||
def get_lastrowid(self):
|
||||
cursor = self.create_cursor()
|
||||
cursor.execute("SELECT @@identity AS lastrowid")
|
||||
lastrowid = cursor.fetchone()[0]
|
||||
cursor.close()
|
||||
return lastrowid
|
||||
|
||||
class SybaseSQLCompiler(compiler.SQLCompiler):
|
||||
ansi_bind_rules = True
|
||||
|
||||
extract_map = util.update_copy(
|
||||
compiler.SQLCompiler.extract_map,
|
||||
{
|
||||
'doy': 'dayofyear',
|
||||
'dow': 'weekday',
|
||||
'milliseconds': 'millisecond'
|
||||
})
|
||||
|
||||
def get_select_precolumns(self, select):
|
||||
s = select._distinct and "DISTINCT " or ""
|
||||
if select._limit:
|
||||
#if select._limit == 1:
|
||||
#s += "FIRST "
|
||||
#else:
|
||||
#s += "TOP %s " % (select._limit,)
|
||||
s += "TOP %s " % (select._limit,)
|
||||
if select._offset:
|
||||
if not select._limit:
|
||||
# FIXME: sybase doesn't allow an offset without a limit
|
||||
# so use a huge value for TOP here
|
||||
s += "TOP 1000000 "
|
||||
s += "START AT %s " % (select._offset+1,)
|
||||
return s
|
||||
|
||||
def get_from_hint_text(self, table, text):
|
||||
return text
|
||||
|
||||
def limit_clause(self, select):
|
||||
# Limit in sybase is after the select keyword
|
||||
return ""
|
||||
|
||||
def visit_extract(self, extract, **kw):
|
||||
field = self.extract_map.get(extract.field, extract.field)
|
||||
return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
|
||||
|
||||
def for_update_clause(self, select):
|
||||
# "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
|
||||
return ''
|
||||
|
||||
def order_by_clause(self, select, **kw):
|
||||
kw['literal_binds'] = True
|
||||
order_by = self.process(select._order_by_clause, **kw)
|
||||
|
||||
# SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
|
||||
if order_by and (not self.is_subquery() or select._limit):
|
||||
return " ORDER BY " + order_by
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
class SybaseDDLCompiler(compiler.DDLCompiler):
|
||||
def get_column_specification(self, column, **kwargs):
|
||||
colspec = self.preparer.format_column(column) + " " + \
|
||||
self.dialect.type_compiler.process(column.type)
|
||||
|
||||
if column.table is None:
|
||||
raise exc.InvalidRequestError("The Sybase dialect requires Table-bound "\
|
||||
"columns in order to generate DDL")
|
||||
seq_col = column.table._autoincrement_column
|
||||
|
||||
# install a IDENTITY Sequence if we have an implicit IDENTITY column
|
||||
if seq_col is column:
|
||||
sequence = isinstance(column.default, sa_schema.Sequence) and column.default
|
||||
if sequence:
|
||||
start, increment = sequence.start or 1, sequence.increment or 1
|
||||
else:
|
||||
start, increment = 1, 1
|
||||
if (start, increment) == (1, 1):
|
||||
colspec += " IDENTITY"
|
||||
else:
|
||||
# TODO: need correct syntax for this
|
||||
colspec += " IDENTITY(%s,%s)" % (start, increment)
|
||||
else:
|
||||
if column.nullable is not None:
|
||||
if not column.nullable or column.primary_key:
|
||||
colspec += " NOT NULL"
|
||||
else:
|
||||
colspec += " NULL"
|
||||
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
return colspec
|
||||
|
||||
def visit_drop_index(self, drop):
|
||||
index = drop.element
|
||||
return "\nDROP INDEX %s.%s" % (
|
||||
self.preparer.quote_identifier(index.table.name),
|
||||
self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
|
||||
)
|
||||
|
||||
class SybaseIdentifierPreparer(compiler.IdentifierPreparer):
|
||||
reserved_words = RESERVED_WORDS
|
||||
|
||||
class SybaseDialect(default.DefaultDialect):
|
||||
name = 'sybase'
|
||||
supports_unicode_statements = False
|
||||
supports_sane_rowcount = False
|
||||
supports_sane_multi_rowcount = False
|
||||
|
||||
supports_native_boolean = False
|
||||
supports_unicode_binds = False
|
||||
postfetch_lastrowid = True
|
||||
|
||||
colspecs = {}
|
||||
ischema_names = ischema_names
|
||||
|
||||
type_compiler = SybaseTypeCompiler
|
||||
statement_compiler = SybaseSQLCompiler
|
||||
ddl_compiler = SybaseDDLCompiler
|
||||
preparer = SybaseIdentifierPreparer
|
||||
|
||||
def _get_default_schema_name(self, connection):
|
||||
return connection.scalar(
|
||||
text("SELECT user_name() as user_name", typemap={'user_name':Unicode})
|
||||
)
|
||||
|
||||
def initialize(self, connection):
|
||||
super(SybaseDialect, self).initialize(connection)
|
||||
if self.server_version_info is not None and\
|
||||
self.server_version_info < (15, ):
|
||||
self.max_identifier_length = 30
|
||||
else:
|
||||
self.max_identifier_length = 255
|
||||
|
||||
@reflection.cache
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
if schema is None:
|
||||
schema = self.default_schema_name
|
||||
|
||||
result = connection.execute(
|
||||
text("select sysobjects.name from sysobjects, sysusers "
|
||||
"where sysobjects.uid=sysusers.uid and "
|
||||
"sysusers.name=:schemaname and "
|
||||
"sysobjects.type='U'",
|
||||
bindparams=[
|
||||
bindparam('schemaname', schema)
|
||||
])
|
||||
)
|
||||
return [r[0] for r in result]
|
||||
|
||||
def has_table(self, connection, tablename, schema=None):
|
||||
if schema is None:
|
||||
schema = self.default_schema_name
|
||||
|
||||
result = connection.execute(
|
||||
text("select sysobjects.name from sysobjects, sysusers "
|
||||
"where sysobjects.uid=sysusers.uid and "
|
||||
"sysobjects.name=:tablename and "
|
||||
"sysusers.name=:schemaname and "
|
||||
"sysobjects.type='U'",
|
||||
bindparams=[
|
||||
bindparam('tablename', tablename),
|
||||
bindparam('schemaname', schema)
|
||||
])
|
||||
)
|
||||
return result.scalar() is not None
|
||||
|
||||
def reflecttable(self, connection, table, include_columns):
|
||||
raise NotImplementedError()
|
||||
|
||||
17
sqlalchemy/dialects/sybase/mxodbc.py
Normal file
17
sqlalchemy/dialects/sybase/mxodbc.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Support for Sybase via mxodbc.
|
||||
|
||||
This dialect is a stub only and is likely non functional at this time.
|
||||
|
||||
|
||||
"""
|
||||
from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
|
||||
from sqlalchemy.connectors.mxodbc import MxODBCConnector
|
||||
|
||||
class SybaseExecutionContext_mxodbc(SybaseExecutionContext):
|
||||
pass
|
||||
|
||||
class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect):
|
||||
execution_ctx_cls = SybaseExecutionContext_mxodbc
|
||||
|
||||
dialect = SybaseDialect_mxodbc
|
||||
75
sqlalchemy/dialects/sybase/pyodbc.py
Normal file
75
sqlalchemy/dialects/sybase/pyodbc.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Support for Sybase via pyodbc.
|
||||
|
||||
http://pypi.python.org/pypi/pyodbc/
|
||||
|
||||
Connect strings are of the form::
|
||||
|
||||
sybase+pyodbc://<username>:<password>@<dsn>/
|
||||
sybase+pyodbc://<username>:<password>@<host>/<database>
|
||||
|
||||
Unicode Support
|
||||
---------------
|
||||
|
||||
The pyodbc driver currently supports usage of these Sybase types with
|
||||
Unicode or multibyte strings::
|
||||
|
||||
CHAR
|
||||
NCHAR
|
||||
NVARCHAR
|
||||
TEXT
|
||||
VARCHAR
|
||||
|
||||
Currently *not* supported are::
|
||||
|
||||
UNICHAR
|
||||
UNITEXT
|
||||
UNIVARCHAR
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
|
||||
from sqlalchemy.connectors.pyodbc import PyODBCConnector
|
||||
import decimal
|
||||
from sqlalchemy import types as sqltypes, util, processors
|
||||
|
||||
class _SybNumeric_pyodbc(sqltypes.Numeric):
|
||||
"""Turns Decimals with adjusted() < -6 into floats.
|
||||
|
||||
It's not yet known how to get decimals with many
|
||||
significant digits or very large adjusted() into Sybase
|
||||
via pyodbc.
|
||||
|
||||
"""
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
if self.asdecimal and \
|
||||
isinstance(value, decimal.Decimal):
|
||||
|
||||
if value.adjusted() < -6:
|
||||
return processors.to_float(value)
|
||||
|
||||
if super_process:
|
||||
return super_process(value)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
|
||||
def set_ddl_autocommit(self, connection, value):
|
||||
if value:
|
||||
connection.autocommit = True
|
||||
else:
|
||||
connection.autocommit = False
|
||||
|
||||
class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect):
|
||||
execution_ctx_cls = SybaseExecutionContext_pyodbc
|
||||
|
||||
colspecs = {
|
||||
sqltypes.Numeric:_SybNumeric_pyodbc,
|
||||
}
|
||||
|
||||
dialect = SybaseDialect_pyodbc
|
||||
98
sqlalchemy/dialects/sybase/pysybase.py
Normal file
98
sqlalchemy/dialects/sybase/pysybase.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# pysybase.py
|
||||
# Copyright (C) 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
Support for Sybase via the python-sybase driver.
|
||||
|
||||
http://python-sybase.sourceforge.net/
|
||||
|
||||
Connect strings are of the form::
|
||||
|
||||
sybase+pysybase://<username>:<password>@<dsn>/[database name]
|
||||
|
||||
Unicode Support
|
||||
---------------
|
||||
|
||||
The python-sybase driver does not appear to support non-ASCII strings of any
|
||||
kind at this time.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import types as sqltypes, processors
|
||||
from sqlalchemy.dialects.sybase.base import SybaseDialect, \
|
||||
SybaseExecutionContext, SybaseSQLCompiler
|
||||
|
||||
|
||||
class _SybNumeric(sqltypes.Numeric):
|
||||
def result_processor(self, dialect, type_):
|
||||
if not self.asdecimal:
|
||||
return processors.to_float
|
||||
else:
|
||||
return sqltypes.Numeric.result_processor(self, dialect, type_)
|
||||
|
||||
class SybaseExecutionContext_pysybase(SybaseExecutionContext):
|
||||
|
||||
def set_ddl_autocommit(self, dbapi_connection, value):
|
||||
if value:
|
||||
# call commit() on the Sybase connection directly,
|
||||
# to avoid any side effects of calling a Connection
|
||||
# transactional method inside of pre_exec()
|
||||
dbapi_connection.commit()
|
||||
|
||||
def pre_exec(self):
|
||||
SybaseExecutionContext.pre_exec(self)
|
||||
|
||||
for param in self.parameters:
|
||||
for key in list(param):
|
||||
param["@" + key] = param[key]
|
||||
del param[key]
|
||||
|
||||
|
||||
class SybaseSQLCompiler_pysybase(SybaseSQLCompiler):
|
||||
def bindparam_string(self, name):
|
||||
return "@" + name
|
||||
|
||||
class SybaseDialect_pysybase(SybaseDialect):
|
||||
driver = 'pysybase'
|
||||
execution_ctx_cls = SybaseExecutionContext_pysybase
|
||||
statement_compiler = SybaseSQLCompiler_pysybase
|
||||
|
||||
colspecs={
|
||||
sqltypes.Numeric:_SybNumeric,
|
||||
sqltypes.Float:sqltypes.Float
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
import Sybase
|
||||
return Sybase
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user', password='passwd')
|
||||
|
||||
return ([opts.pop('host')], opts)
|
||||
|
||||
def do_executemany(self, cursor, statement, parameters, context=None):
|
||||
# calling python-sybase executemany yields:
|
||||
# TypeError: string too long for buffer
|
||||
for param in parameters:
|
||||
cursor.execute(statement, param)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
vers = connection.scalar("select @@version_number")
|
||||
# i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0), (12, 5, 0, 0)
|
||||
return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10)
|
||||
|
||||
def is_disconnect(self, e):
|
||||
if isinstance(e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)):
|
||||
msg = str(e)
|
||||
return ('Unable to complete network request to host' in msg or
|
||||
'Invalid connection state' in msg or
|
||||
'Invalid cursor state' in msg)
|
||||
else:
|
||||
return False
|
||||
|
||||
dialect = SybaseDialect_pysybase
|
||||
145
sqlalchemy/dialects/type_migration_guidelines.txt
Normal file
145
sqlalchemy/dialects/type_migration_guidelines.txt
Normal file
@@ -0,0 +1,145 @@
|
||||
Rules for Migrating TypeEngine classes to 0.6
|
||||
---------------------------------------------
|
||||
|
||||
1. the TypeEngine classes are used for:
|
||||
|
||||
a. Specifying behavior which needs to occur for bind parameters
|
||||
or result row columns.
|
||||
|
||||
b. Specifying types that are entirely specific to the database
|
||||
in use and have no analogue in the sqlalchemy.types package.
|
||||
|
||||
c. Specifying types where there is an analogue in sqlalchemy.types,
|
||||
but the database in use takes vendor-specific flags for those
|
||||
types.
|
||||
|
||||
d. If a TypeEngine class doesn't provide any of this, it should be
|
||||
*removed* from the dialect.
|
||||
|
||||
2. the TypeEngine classes are *no longer* used for generating DDL. Dialects
|
||||
now have a TypeCompiler subclass which uses the same visit_XXX model as
|
||||
other compilers.
|
||||
|
||||
3. the "ischema_names" and "colspecs" dictionaries are now required members on
|
||||
the Dialect class.
|
||||
|
||||
4. The names of types within dialects are now important. If a dialect-specific type
|
||||
is a subclass of an existing generic type and is only provided for bind/result behavior,
|
||||
the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case,
|
||||
end users would never need to use _PGNumeric directly. However, if a dialect-specific
|
||||
type is specifying a type *or* arguments that are not present generically, it should
|
||||
match the real name of the type on that backend, in uppercase. E.g. postgresql.INET,
|
||||
mysql.ENUM, postgresql.ARRAY.
|
||||
|
||||
Or follow this handy flowchart:
|
||||
|
||||
is the type meant to provide bind/result is the type the same name as an
|
||||
behavior to a generic type (i.e. MixedCase) ---- no ---> UPPERCASE type in types.py ?
|
||||
type in types.py ? | |
|
||||
| no yes
|
||||
yes | |
|
||||
| | does your type need special
|
||||
| +<--- yes --- behavior or arguments ?
|
||||
| | |
|
||||
| | no
|
||||
name the type using | |
|
||||
_MixedCase, i.e. v V
|
||||
_OracleBoolean. it name the type don't make a
|
||||
stays private to the dialect identically as that type, make sure the dialect's
|
||||
and is invoked *only* via within the DB, base.py imports the types.py
|
||||
the colspecs dict. using UPPERCASE UPPERCASE name into its namespace
|
||||
| (i.e. BIT, NCHAR, INTERVAL).
|
||||
| Users can import it.
|
||||
| |
|
||||
v v
|
||||
subclass the closest is the name of this type
|
||||
MixedCase type types.py, identical to an UPPERCASE
|
||||
i.e. <--- no ------- name in types.py ?
|
||||
class _DateTime(types.DateTime),
|
||||
class DATETIME2(types.DateTime), |
|
||||
class BIT(types.TypeEngine). yes
|
||||
|
|
||||
v
|
||||
the type should
|
||||
subclass the
|
||||
UPPERCASE
|
||||
type in types.py
|
||||
(i.e. class BLOB(types.BLOB))
|
||||
|
||||
|
||||
Example 1. pysqlite needs bind/result processing for the DateTime type in types.py,
|
||||
which applies to all DateTimes and subclasses. It's named _SLDateTime and
|
||||
subclasses types.DateTime.
|
||||
|
||||
Example 2. MS-SQL has a TIME type which takes a non-standard "precision" argument
|
||||
that is rendered within DDL. So it's named TIME in the MS-SQL dialect's base.py,
|
||||
and subclasses types.TIME. Users can then say mssql.TIME(precision=10).
|
||||
|
||||
Example 3. MS-SQL dialects also need special bind/result processing for date
|
||||
But its DATE type doesn't render DDL differently than that of a plain
|
||||
DATE, i.e. it takes no special arguments. Therefore we are just adding behavior
|
||||
to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses
|
||||
types.Date.
|
||||
|
||||
Example 4. MySQL has a SET type, there's no analogue for this in types.py. So
|
||||
MySQL names it SET in the dialect's base.py, and it subclasses types.String, since
|
||||
it ultimately deals with strings.
|
||||
|
||||
Example 5. Postgresql has a DATETIME type. The DBAPIs handle dates correctly,
|
||||
and no special arguments are used in PG's DDL beyond what types.py provides.
|
||||
Postgresql dialect therefore imports types.DATETIME into its base.py.
|
||||
|
||||
Ideally one should be able to specify a schema using names imported completely from a
|
||||
dialect, all matching the real name on that backend:
|
||||
|
||||
from sqlalchemy.dialects.postgresql import base as pg
|
||||
|
||||
t = Table('mytable', metadata,
|
||||
Column('id', pg.INTEGER, primary_key=True),
|
||||
Column('name', pg.VARCHAR(300)),
|
||||
Column('inetaddr', pg.INET)
|
||||
)
|
||||
|
||||
where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types,
|
||||
but the PG dialect makes them available in its own namespace.
|
||||
|
||||
5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types
|
||||
linked to types specified in the dialect. Again, if a type in the dialect does not
|
||||
specify any special behavior for bind_processor() or result_processor() and does not
|
||||
indicate a special type only available in this database, it must be *removed* from the
|
||||
module and from this dictionary.
|
||||
|
||||
6. "ischema_names" indicates string descriptions of types as returned from the database
|
||||
linked to TypeEngine classes.
|
||||
|
||||
a. The string name should be matched to the most specific type possible within
|
||||
sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which
|
||||
case it points to a dialect type. *It doesn't matter* if the dialect has it's
|
||||
own subclass of that type with special bind/result behavior - reflect to the types.py
|
||||
UPPERCASE type as much as possible. With very few exceptions, all types
|
||||
should reflect to an UPPERCASE type.
|
||||
|
||||
b. If the dialect contains a matching dialect-specific type that takes extra arguments
|
||||
which the generic one does not, then point to the dialect-specific type. E.g.
|
||||
mssql.VARCHAR takes a "collation" parameter which should be preserved.
|
||||
|
||||
5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by
|
||||
a subclass of compiler.GenericTypeCompiler.
|
||||
|
||||
a. your TypeCompiler class will receive generic and uppercase types from
|
||||
sqlalchemy.types. Do not assume the presence of dialect-specific attributes on
|
||||
these types.
|
||||
|
||||
b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with
|
||||
methods that produce a different DDL name. Uppercase types don't do any kind of
|
||||
"guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in
|
||||
all cases, regardless of whether or not that type is legal on the backend database.
|
||||
|
||||
c. the visit_UPPERCASE methods *should* be overridden with methods that add additional
|
||||
arguments and flags to those types.
|
||||
|
||||
d. the visit_lowercase methods are overridden to provide an interpretation of a generic
|
||||
type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)".
|
||||
|
||||
e. visit_lowercase methods should *never* render strings directly - it should always
|
||||
be via calling a visit_UPPERCASE() method.
|
||||
274
sqlalchemy/engine/__init__.py
Normal file
274
sqlalchemy/engine/__init__.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# engine/__init__.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""SQL connections, SQL execution and high-level DB-API interface.
|
||||
|
||||
The engine package defines the basic components used to interface
|
||||
DB-API modules with higher-level statement construction,
|
||||
connection-management, execution and result contexts. The primary
|
||||
"entry point" class into this package is the Engine and it's public
|
||||
constructor ``create_engine()``.
|
||||
|
||||
This package includes:
|
||||
|
||||
base.py
|
||||
Defines interface classes and some implementation classes which
|
||||
comprise the basic components used to interface between a DB-API,
|
||||
constructed and plain-text statements, connections, transactions,
|
||||
and results.
|
||||
|
||||
default.py
|
||||
Contains default implementations of some of the components defined
|
||||
in base.py. All current database dialects use the classes in
|
||||
default.py as base classes for their own database-specific
|
||||
implementations.
|
||||
|
||||
strategies.py
|
||||
The mechanics of constructing ``Engine`` objects are represented
|
||||
here. Defines the ``EngineStrategy`` class which represents how
|
||||
to go from arguments specified to the ``create_engine()``
|
||||
function, to a fully constructed ``Engine``, including
|
||||
initialization of connection pooling, dialects, and specific
|
||||
subclasses of ``Engine``.
|
||||
|
||||
threadlocal.py
|
||||
The ``TLEngine`` class is defined here, which is a subclass of
|
||||
the generic ``Engine`` and tracks ``Connection`` and
|
||||
``Transaction`` objects against the identity of the current
|
||||
thread. This allows certain programming patterns based around
|
||||
the concept of a "thread-local connection" to be possible.
|
||||
The ``TLEngine`` is created by using the "threadlocal" engine
|
||||
strategy in conjunction with the ``create_engine()`` function.
|
||||
|
||||
url.py
|
||||
Defines the ``URL`` class which represents the individual
|
||||
components of a string URL passed to ``create_engine()``. Also
|
||||
defines a basic module-loading strategy for the dialect specifier
|
||||
within a URL.
|
||||
"""
|
||||
|
||||
# not sure what this was used for
|
||||
#import sqlalchemy.databases
|
||||
|
||||
from sqlalchemy.engine.base import (
|
||||
BufferedColumnResultProxy,
|
||||
BufferedColumnRow,
|
||||
BufferedRowResultProxy,
|
||||
Compiled,
|
||||
Connectable,
|
||||
Connection,
|
||||
Dialect,
|
||||
Engine,
|
||||
ExecutionContext,
|
||||
NestedTransaction,
|
||||
ResultProxy,
|
||||
RootTransaction,
|
||||
RowProxy,
|
||||
Transaction,
|
||||
TwoPhaseTransaction,
|
||||
TypeCompiler
|
||||
)
|
||||
from sqlalchemy.engine import strategies
|
||||
from sqlalchemy import util
|
||||
|
||||
|
||||
__all__ = (
|
||||
'BufferedColumnResultProxy',
|
||||
'BufferedColumnRow',
|
||||
'BufferedRowResultProxy',
|
||||
'Compiled',
|
||||
'Connectable',
|
||||
'Connection',
|
||||
'Dialect',
|
||||
'Engine',
|
||||
'ExecutionContext',
|
||||
'NestedTransaction',
|
||||
'ResultProxy',
|
||||
'RootTransaction',
|
||||
'RowProxy',
|
||||
'Transaction',
|
||||
'TwoPhaseTransaction',
|
||||
'TypeCompiler',
|
||||
'create_engine',
|
||||
'engine_from_config',
|
||||
)
|
||||
|
||||
|
||||
default_strategy = 'plain'
|
||||
def create_engine(*args, **kwargs):
|
||||
"""Create a new Engine instance.
|
||||
|
||||
The standard method of specifying the engine is via URL as the
|
||||
first positional argument, to indicate the appropriate database
|
||||
dialect and connection arguments, with additional keyword
|
||||
arguments sent as options to the dialect and resulting Engine.
|
||||
|
||||
The URL is a string in the form
|
||||
``dialect+driver://user:password@host/dbname[?key=value..]``, where
|
||||
``dialect`` is a database name such as ``mysql``, ``oracle``,
|
||||
``postgresql``, etc., and ``driver`` the name of a DBAPI, such as
|
||||
``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively,
|
||||
the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`.
|
||||
|
||||
`**kwargs` takes a wide variety of options which are routed
|
||||
towards their appropriate components. Arguments may be
|
||||
specific to the Engine, the underlying Dialect, as well as the
|
||||
Pool. Specific dialects also accept keyword arguments that
|
||||
are unique to that dialect. Here, we describe the parameters
|
||||
that are common to most ``create_engine()`` usage.
|
||||
|
||||
:param assert_unicode: Deprecated. A warning is raised in all cases when a non-Unicode
|
||||
object is passed when SQLAlchemy would coerce into an encoding
|
||||
(note: but **not** when the DBAPI handles unicode objects natively).
|
||||
To suppress or raise this warning to an
|
||||
error, use the Python warnings filter documented at:
|
||||
http://docs.python.org/library/warnings.html
|
||||
|
||||
:param connect_args: a dictionary of options which will be
|
||||
passed directly to the DBAPI's ``connect()`` method as
|
||||
additional keyword arguments.
|
||||
|
||||
:param convert_unicode=False: if set to True, all
|
||||
String/character based types will convert Unicode values to raw
|
||||
byte values going into the database, and all raw byte values to
|
||||
Python Unicode coming out in result sets. This is an
|
||||
engine-wide method to provide unicode conversion across the
|
||||
board. For unicode conversion on a column-by-column level, use
|
||||
the ``Unicode`` column type instead, described in `types`.
|
||||
|
||||
:param creator: a callable which returns a DBAPI connection.
|
||||
This creation function will be passed to the underlying
|
||||
connection pool and will be used to create all new database
|
||||
connections. Usage of this function causes connection
|
||||
parameters specified in the URL argument to be bypassed.
|
||||
|
||||
:param echo=False: if True, the Engine will log all statements
|
||||
as well as a repr() of their parameter lists to the engines
|
||||
logger, which defaults to sys.stdout. The ``echo`` attribute of
|
||||
``Engine`` can be modified at any time to turn logging on and
|
||||
off. If set to the string ``"debug"``, result rows will be
|
||||
printed to the standard output as well. This flag ultimately
|
||||
controls a Python logger; see :ref:`dbengine_logging` for
|
||||
information on how to configure logging directly.
|
||||
|
||||
:param echo_pool=False: if True, the connection pool will log
|
||||
all checkouts/checkins to the logging stream, which defaults to
|
||||
sys.stdout. This flag ultimately controls a Python logger; see
|
||||
:ref:`dbengine_logging` for information on how to configure logging
|
||||
directly.
|
||||
|
||||
:param encoding='utf-8': the encoding to use for all Unicode
|
||||
translations, both by engine-wide unicode conversion as well as
|
||||
the ``Unicode`` type object.
|
||||
|
||||
:param label_length=None: optional integer value which limits
|
||||
the size of dynamically generated column labels to that many
|
||||
characters. If less than 6, labels are generated as
|
||||
"_(counter)". If ``None``, the value of
|
||||
``dialect.max_identifier_length`` is used instead.
|
||||
|
||||
:param listeners: A list of one or more
|
||||
:class:`~sqlalchemy.interfaces.PoolListener` objects which will
|
||||
receive connection pool events.
|
||||
|
||||
:param logging_name: String identifier which will be used within
|
||||
the "name" field of logging records generated within the
|
||||
"sqlalchemy.engine" logger. Defaults to a hexstring of the
|
||||
object's id.
|
||||
|
||||
:param max_overflow=10: the number of connections to allow in
|
||||
connection pool "overflow", that is connections that can be
|
||||
opened above and beyond the pool_size setting, which defaults
|
||||
to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`.
|
||||
|
||||
:param module=None: used by database implementations which
|
||||
support multiple DBAPI modules, this is a reference to a DBAPI2
|
||||
module to be used instead of the engine's default module. For
|
||||
PostgreSQL, the default is psycopg2. For Oracle, it's cx_Oracle.
|
||||
|
||||
:param pool=None: an already-constructed instance of
|
||||
:class:`~sqlalchemy.pool.Pool`, such as a
|
||||
:class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this
|
||||
pool will be used directly as the underlying connection pool
|
||||
for the engine, bypassing whatever connection parameters are
|
||||
present in the URL argument. For information on constructing
|
||||
connection pools manually, see `pooling`.
|
||||
|
||||
:param poolclass=None: a :class:`~sqlalchemy.pool.Pool`
|
||||
subclass, which will be used to create a connection pool
|
||||
instance using the connection parameters given in the URL. Note
|
||||
this differs from ``pool`` in that you don't actually
|
||||
instantiate the pool in this case, you just indicate what type
|
||||
of pool to be used.
|
||||
|
||||
:param pool_logging_name: String identifier which will be used within
|
||||
the "name" field of logging records generated within the
|
||||
"sqlalchemy.pool" logger. Defaults to a hexstring of the object's
|
||||
id.
|
||||
|
||||
:param pool_size=5: the number of connections to keep open
|
||||
inside the connection pool. This used with :class:`~sqlalchemy.pool.QueuePool` as
|
||||
well as :class:`~sqlalchemy.pool.SingletonThreadPool`.
|
||||
|
||||
:param pool_recycle=-1: this setting causes the pool to recycle
|
||||
connections after the given number of seconds has passed. It
|
||||
defaults to -1, or no timeout. For example, setting to 3600
|
||||
means connections will be recycled after one hour. Note that
|
||||
MySQL in particular will ``disconnect automatically`` if no
|
||||
activity is detected on a connection for eight hours (although
|
||||
this is configurable with the MySQLDB connection itself and the
|
||||
server configuration as well).
|
||||
|
||||
:param pool_timeout=30: number of seconds to wait before giving
|
||||
up on getting a connection from the pool. This is only used
|
||||
with :class:`~sqlalchemy.pool.QueuePool`.
|
||||
|
||||
:param strategy='plain': used to invoke alternate :class:`~sqlalchemy.engine.base.Engine.`
|
||||
implementations. Currently available is the ``threadlocal``
|
||||
strategy, which is described in :ref:`threadlocal_strategy`.
|
||||
|
||||
"""
|
||||
|
||||
strategy = kwargs.pop('strategy', default_strategy)
|
||||
strategy = strategies.strategies[strategy]
|
||||
return strategy.create(*args, **kwargs)
|
||||
|
||||
def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
|
||||
"""Create a new Engine instance using a configuration dictionary.
|
||||
|
||||
The dictionary is typically produced from a config file where keys
|
||||
are prefixed, such as sqlalchemy.url, sqlalchemy.echo, etc. The
|
||||
'prefix' argument indicates the prefix to be searched for.
|
||||
|
||||
A select set of keyword arguments will be "coerced" to their
|
||||
expected type based on string values. In a future release, this
|
||||
functionality will be expanded and include dialect-specific
|
||||
arguments.
|
||||
"""
|
||||
|
||||
opts = _coerce_config(configuration, prefix)
|
||||
opts.update(kwargs)
|
||||
url = opts.pop('url')
|
||||
return create_engine(url, **opts)
|
||||
|
||||
def _coerce_config(configuration, prefix):
|
||||
"""Convert configuration values to expected types."""
|
||||
|
||||
options = dict((key[len(prefix):], configuration[key])
|
||||
for key in configuration
|
||||
if key.startswith(prefix))
|
||||
for option, type_ in (
|
||||
('convert_unicode', bool),
|
||||
('pool_timeout', int),
|
||||
('echo', bool),
|
||||
('echo_pool', bool),
|
||||
('pool_recycle', int),
|
||||
('pool_size', int),
|
||||
('max_overflow', int),
|
||||
('pool_threadlocal', bool),
|
||||
):
|
||||
util.coerce_kw_type(options, option, type_)
|
||||
return options
|
||||
2422
sqlalchemy/engine/base.py
Normal file
2422
sqlalchemy/engine/base.py
Normal file
File diff suppressed because it is too large
Load Diff
128
sqlalchemy/engine/ddl.py
Normal file
128
sqlalchemy/engine/ddl.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# engine/ddl.py
|
||||
# Copyright (C) 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Routines to handle CREATE/DROP workflow."""
|
||||
|
||||
from sqlalchemy import engine, schema
|
||||
from sqlalchemy.sql import util as sql_util
|
||||
|
||||
|
||||
class DDLBase(schema.SchemaVisitor):
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
class SchemaGenerator(DDLBase):
|
||||
def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
|
||||
super(SchemaGenerator, self).__init__(connection, **kwargs)
|
||||
self.checkfirst = checkfirst
|
||||
self.tables = tables and set(tables) or None
|
||||
self.preparer = dialect.identifier_preparer
|
||||
self.dialect = dialect
|
||||
|
||||
def _can_create(self, table):
|
||||
self.dialect.validate_identifier(table.name)
|
||||
if table.schema:
|
||||
self.dialect.validate_identifier(table.schema)
|
||||
return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
|
||||
|
||||
def visit_metadata(self, metadata):
|
||||
if self.tables:
|
||||
tables = self.tables
|
||||
else:
|
||||
tables = metadata.tables.values()
|
||||
collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
|
||||
|
||||
for listener in metadata.ddl_listeners['before-create']:
|
||||
listener('before-create', metadata, self.connection, tables=collection)
|
||||
|
||||
for table in collection:
|
||||
self.traverse_single(table)
|
||||
|
||||
for listener in metadata.ddl_listeners['after-create']:
|
||||
listener('after-create', metadata, self.connection, tables=collection)
|
||||
|
||||
def visit_table(self, table):
|
||||
for listener in table.ddl_listeners['before-create']:
|
||||
listener('before-create', table, self.connection)
|
||||
|
||||
for column in table.columns:
|
||||
if column.default is not None:
|
||||
self.traverse_single(column.default)
|
||||
|
||||
self.connection.execute(schema.CreateTable(table))
|
||||
|
||||
if hasattr(table, 'indexes'):
|
||||
for index in table.indexes:
|
||||
self.traverse_single(index)
|
||||
|
||||
for listener in table.ddl_listeners['after-create']:
|
||||
listener('after-create', table, self.connection)
|
||||
|
||||
def visit_sequence(self, sequence):
|
||||
if self.dialect.supports_sequences:
|
||||
if ((not self.dialect.sequences_optional or
|
||||
not sequence.optional) and
|
||||
(not self.checkfirst or
|
||||
not self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema))):
|
||||
self.connection.execute(schema.CreateSequence(sequence))
|
||||
|
||||
def visit_index(self, index):
|
||||
self.connection.execute(schema.CreateIndex(index))
|
||||
|
||||
|
||||
class SchemaDropper(DDLBase):
|
||||
def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
|
||||
super(SchemaDropper, self).__init__(connection, **kwargs)
|
||||
self.checkfirst = checkfirst
|
||||
self.tables = tables
|
||||
self.preparer = dialect.identifier_preparer
|
||||
self.dialect = dialect
|
||||
|
||||
def visit_metadata(self, metadata):
|
||||
if self.tables:
|
||||
tables = self.tables
|
||||
else:
|
||||
tables = metadata.tables.values()
|
||||
collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
|
||||
|
||||
for listener in metadata.ddl_listeners['before-drop']:
|
||||
listener('before-drop', metadata, self.connection, tables=collection)
|
||||
|
||||
for table in collection:
|
||||
self.traverse_single(table)
|
||||
|
||||
for listener in metadata.ddl_listeners['after-drop']:
|
||||
listener('after-drop', metadata, self.connection, tables=collection)
|
||||
|
||||
def _can_drop(self, table):
|
||||
self.dialect.validate_identifier(table.name)
|
||||
if table.schema:
|
||||
self.dialect.validate_identifier(table.schema)
|
||||
return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
|
||||
|
||||
def visit_index(self, index):
|
||||
self.connection.execute(schema.DropIndex(index))
|
||||
|
||||
def visit_table(self, table):
|
||||
for listener in table.ddl_listeners['before-drop']:
|
||||
listener('before-drop', table, self.connection)
|
||||
|
||||
for column in table.columns:
|
||||
if column.default is not None:
|
||||
self.traverse_single(column.default)
|
||||
|
||||
self.connection.execute(schema.DropTable(table))
|
||||
|
||||
for listener in table.ddl_listeners['after-drop']:
|
||||
listener('after-drop', table, self.connection)
|
||||
|
||||
def visit_sequence(self, sequence):
|
||||
if self.dialect.supports_sequences:
|
||||
if ((not self.dialect.sequences_optional or
|
||||
not sequence.optional) and
|
||||
(not self.checkfirst or
|
||||
self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema))):
|
||||
self.connection.execute(schema.DropSequence(sequence))
|
||||
700
sqlalchemy/engine/default.py
Normal file
700
sqlalchemy/engine/default.py
Normal file
@@ -0,0 +1,700 @@
|
||||
# engine/default.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Default implementations of per-dialect sqlalchemy.engine classes.
|
||||
|
||||
These are semi-private implementation classes which are only of importance
|
||||
to database dialect authors; dialects will usually use the classes here
|
||||
as the base class for their own corresponding classes.
|
||||
|
||||
"""
|
||||
|
||||
import re, random
|
||||
from sqlalchemy.engine import base, reflection
|
||||
from sqlalchemy.sql import compiler, expression
|
||||
from sqlalchemy import exc, types as sqltypes, util
|
||||
|
||||
AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
|
||||
re.I | re.UNICODE)
|
||||
|
||||
|
||||
class DefaultDialect(base.Dialect):
|
||||
"""Default implementation of Dialect"""
|
||||
|
||||
statement_compiler = compiler.SQLCompiler
|
||||
ddl_compiler = compiler.DDLCompiler
|
||||
type_compiler = compiler.GenericTypeCompiler
|
||||
preparer = compiler.IdentifierPreparer
|
||||
supports_alter = True
|
||||
|
||||
# most DBAPIs happy with this for execute().
|
||||
# not cx_oracle.
|
||||
execute_sequence_format = tuple
|
||||
|
||||
supports_sequences = False
|
||||
sequences_optional = False
|
||||
preexecute_autoincrement_sequences = False
|
||||
postfetch_lastrowid = True
|
||||
implicit_returning = False
|
||||
|
||||
supports_native_enum = False
|
||||
supports_native_boolean = False
|
||||
|
||||
# if the NUMERIC type
|
||||
# returns decimal.Decimal.
|
||||
# *not* the FLOAT type however.
|
||||
supports_native_decimal = False
|
||||
|
||||
# Py3K
|
||||
#supports_unicode_statements = True
|
||||
#supports_unicode_binds = True
|
||||
# Py2K
|
||||
supports_unicode_statements = False
|
||||
supports_unicode_binds = False
|
||||
returns_unicode_strings = False
|
||||
# end Py2K
|
||||
|
||||
name = 'default'
|
||||
max_identifier_length = 9999
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
dbapi_type_map = {}
|
||||
colspecs = {}
|
||||
default_paramstyle = 'named'
|
||||
supports_default_values = False
|
||||
supports_empty_insert = True
|
||||
|
||||
server_version_info = None
|
||||
|
||||
# indicates symbol names are
|
||||
# UPPERCASEd if they are case insensitive
|
||||
# within the database.
|
||||
# if this is True, the methods normalize_name()
|
||||
# and denormalize_name() must be provided.
|
||||
requires_name_normalize = False
|
||||
|
||||
reflection_options = ()
|
||||
|
||||
def __init__(self, convert_unicode=False, assert_unicode=False,
|
||||
encoding='utf-8', paramstyle=None, dbapi=None,
|
||||
implicit_returning=None,
|
||||
label_length=None, **kwargs):
|
||||
|
||||
if not getattr(self, 'ported_sqla_06', True):
|
||||
util.warn(
|
||||
"The %s dialect is not yet ported to SQLAlchemy 0.6" % self.name)
|
||||
|
||||
self.convert_unicode = convert_unicode
|
||||
if assert_unicode:
|
||||
util.warn_deprecated("assert_unicode is deprecated. "
|
||||
"SQLAlchemy emits a warning in all cases where it "
|
||||
"would otherwise like to encode a Python unicode object "
|
||||
"into a specific encoding but a plain bytestring is received. "
|
||||
"This does *not* apply to DBAPIs that coerce Unicode natively."
|
||||
)
|
||||
|
||||
self.encoding = encoding
|
||||
self.positional = False
|
||||
self._ischema = None
|
||||
self.dbapi = dbapi
|
||||
if paramstyle is not None:
|
||||
self.paramstyle = paramstyle
|
||||
elif self.dbapi is not None:
|
||||
self.paramstyle = self.dbapi.paramstyle
|
||||
else:
|
||||
self.paramstyle = self.default_paramstyle
|
||||
if implicit_returning is not None:
|
||||
self.implicit_returning = implicit_returning
|
||||
self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
|
||||
self.identifier_preparer = self.preparer(self)
|
||||
self.type_compiler = self.type_compiler(self)
|
||||
|
||||
if label_length and label_length > self.max_identifier_length:
|
||||
raise exc.ArgumentError("Label length of %d is greater than this dialect's"
|
||||
" maximum identifier length of %d" %
|
||||
(label_length, self.max_identifier_length))
|
||||
self.label_length = label_length
|
||||
|
||||
if not hasattr(self, 'description_encoding'):
|
||||
self.description_encoding = getattr(self, 'description_encoding', encoding)
|
||||
|
||||
@property
|
||||
def dialect_description(self):
|
||||
return self.name + "+" + self.driver
|
||||
|
||||
def initialize(self, connection):
|
||||
try:
|
||||
self.server_version_info = self._get_server_version_info(connection)
|
||||
except NotImplementedError:
|
||||
self.server_version_info = None
|
||||
try:
|
||||
self.default_schema_name = self._get_default_schema_name(connection)
|
||||
except NotImplementedError:
|
||||
self.default_schema_name = None
|
||||
|
||||
self.returns_unicode_strings = self._check_unicode_returns(connection)
|
||||
|
||||
self.do_rollback(connection.connection)
|
||||
|
||||
def on_connect(self):
|
||||
"""return a callable which sets up a newly created DBAPI connection.
|
||||
|
||||
This is used to set dialect-wide per-connection options such as isolation
|
||||
modes, unicode modes, etc.
|
||||
|
||||
If a callable is returned, it will be assembled into a pool listener
|
||||
that receives the direct DBAPI connection, with all wrappers removed.
|
||||
|
||||
If None is returned, no listener will be generated.
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
def _check_unicode_returns(self, connection):
|
||||
# Py2K
|
||||
if self.supports_unicode_statements:
|
||||
cast_to = unicode
|
||||
else:
|
||||
cast_to = str
|
||||
# end Py2K
|
||||
# Py3K
|
||||
#cast_to = str
|
||||
def check_unicode(type_):
|
||||
cursor = connection.connection.cursor()
|
||||
try:
|
||||
cursor.execute(
|
||||
cast_to(
|
||||
expression.select(
|
||||
[expression.cast(
|
||||
expression.literal_column("'test unicode returns'"), type_)
|
||||
]).compile(dialect=self)
|
||||
)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
return isinstance(row[0], unicode)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# detect plain VARCHAR
|
||||
unicode_for_varchar = check_unicode(sqltypes.VARCHAR(60))
|
||||
|
||||
# detect if there's an NVARCHAR type with different behavior available
|
||||
unicode_for_unicode = check_unicode(sqltypes.Unicode(60))
|
||||
|
||||
if unicode_for_unicode and not unicode_for_varchar:
|
||||
return "conditional"
|
||||
else:
|
||||
return unicode_for_varchar
|
||||
|
||||
def type_descriptor(self, typeobj):
|
||||
"""Provide a database-specific ``TypeEngine`` object, given
|
||||
the generic object which comes from the types module.
|
||||
|
||||
This method looks for a dictionary called
|
||||
``colspecs`` as a class or instance-level variable,
|
||||
and passes on to ``types.adapt_type()``.
|
||||
|
||||
"""
|
||||
return sqltypes.adapt_type(typeobj, self.colspecs)
|
||||
|
||||
def reflecttable(self, connection, table, include_columns):
|
||||
insp = reflection.Inspector.from_engine(connection)
|
||||
return insp.reflecttable(table, include_columns)
|
||||
|
||||
def validate_identifier(self, ident):
|
||||
if len(ident) > self.max_identifier_length:
|
||||
raise exc.IdentifierError(
|
||||
"Identifier '%s' exceeds maximum length of %d characters" %
|
||||
(ident, self.max_identifier_length)
|
||||
)
|
||||
|
||||
def connect(self, *cargs, **cparams):
|
||||
return self.dbapi.connect(*cargs, **cparams)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args()
|
||||
opts.update(url.query)
|
||||
return [[], opts]
|
||||
|
||||
def do_begin(self, connection):
|
||||
"""Implementations might want to put logic here for turning
|
||||
autocommit on/off, etc.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def do_rollback(self, connection):
|
||||
"""Implementations might want to put logic here for turning
|
||||
autocommit on/off, etc.
|
||||
"""
|
||||
|
||||
connection.rollback()
|
||||
|
||||
def do_commit(self, connection):
|
||||
"""Implementations might want to put logic here for turning
|
||||
autocommit on/off, etc.
|
||||
"""
|
||||
|
||||
connection.commit()
|
||||
|
||||
def create_xid(self):
|
||||
"""Create a random two-phase transaction ID.
|
||||
|
||||
This id will be passed to do_begin_twophase(), do_rollback_twophase(),
|
||||
do_commit_twophase(). Its format is unspecified.
|
||||
"""
|
||||
|
||||
return "_sa_%032x" % random.randint(0, 2 ** 128)
|
||||
|
||||
def do_savepoint(self, connection, name):
|
||||
connection.execute(expression.SavepointClause(name))
|
||||
|
||||
def do_rollback_to_savepoint(self, connection, name):
|
||||
connection.execute(expression.RollbackToSavepointClause(name))
|
||||
|
||||
def do_release_savepoint(self, connection, name):
|
||||
connection.execute(expression.ReleaseSavepointClause(name))
|
||||
|
||||
def do_executemany(self, cursor, statement, parameters, context=None):
|
||||
cursor.executemany(statement, parameters)
|
||||
|
||||
def do_execute(self, cursor, statement, parameters, context=None):
|
||||
cursor.execute(statement, parameters)
|
||||
|
||||
def is_disconnect(self, e):
|
||||
return False
|
||||
|
||||
|
||||
class DefaultExecutionContext(base.ExecutionContext):
|
||||
execution_options = util.frozendict()
|
||||
isinsert = False
|
||||
isupdate = False
|
||||
isdelete = False
|
||||
isddl = False
|
||||
executemany = False
|
||||
result_map = None
|
||||
compiled = None
|
||||
statement = None
|
||||
|
||||
def __init__(self,
|
||||
dialect,
|
||||
connection,
|
||||
compiled_sql=None,
|
||||
compiled_ddl=None,
|
||||
statement=None,
|
||||
parameters=None):
|
||||
|
||||
self.dialect = dialect
|
||||
self._connection = self.root_connection = connection
|
||||
self.engine = connection.engine
|
||||
|
||||
if compiled_ddl is not None:
|
||||
self.compiled = compiled = compiled_ddl
|
||||
self.isddl = True
|
||||
|
||||
if compiled.statement._execution_options:
|
||||
self.execution_options = compiled.statement._execution_options
|
||||
if connection._execution_options:
|
||||
self.execution_options = self.execution_options.union(
|
||||
connection._execution_options
|
||||
)
|
||||
|
||||
if not dialect.supports_unicode_statements:
|
||||
self.unicode_statement = unicode(compiled)
|
||||
self.statement = self.unicode_statement.encode(self.dialect.encoding)
|
||||
else:
|
||||
self.statement = self.unicode_statement = unicode(compiled)
|
||||
|
||||
self.cursor = self.create_cursor()
|
||||
self.compiled_parameters = []
|
||||
self.parameters = [self._default_params]
|
||||
|
||||
elif compiled_sql is not None:
|
||||
self.compiled = compiled = compiled_sql
|
||||
|
||||
if not compiled.can_execute:
|
||||
raise exc.ArgumentError("Not an executable clause: %s" % compiled)
|
||||
|
||||
if compiled.statement._execution_options:
|
||||
self.execution_options = compiled.statement._execution_options
|
||||
if connection._execution_options:
|
||||
self.execution_options = self.execution_options.union(
|
||||
connection._execution_options
|
||||
)
|
||||
|
||||
# compiled clauseelement. process bind params, process table defaults,
|
||||
# track collections used by ResultProxy to target and process results
|
||||
|
||||
self.processors = dict(
|
||||
(key, value) for key, value in
|
||||
( (compiled.bind_names[bindparam],
|
||||
bindparam.bind_processor(self.dialect))
|
||||
for bindparam in compiled.bind_names )
|
||||
if value is not None)
|
||||
|
||||
self.result_map = compiled.result_map
|
||||
|
||||
if not dialect.supports_unicode_statements:
|
||||
self.unicode_statement = unicode(compiled)
|
||||
self.statement = self.unicode_statement.encode(self.dialect.encoding)
|
||||
else:
|
||||
self.statement = self.unicode_statement = unicode(compiled)
|
||||
|
||||
self.isinsert = compiled.isinsert
|
||||
self.isupdate = compiled.isupdate
|
||||
self.isdelete = compiled.isdelete
|
||||
|
||||
if not parameters:
|
||||
self.compiled_parameters = [compiled.construct_params()]
|
||||
else:
|
||||
self.compiled_parameters = [compiled.construct_params(m, _group_number=grp) for
|
||||
grp,m in enumerate(parameters)]
|
||||
|
||||
self.executemany = len(parameters) > 1
|
||||
|
||||
self.cursor = self.create_cursor()
|
||||
if self.isinsert or self.isupdate:
|
||||
self.__process_defaults()
|
||||
self.parameters = self.__convert_compiled_params(self.compiled_parameters)
|
||||
|
||||
elif statement is not None:
|
||||
# plain text statement
|
||||
if connection._execution_options:
|
||||
self.execution_options = self.execution_options.union(connection._execution_options)
|
||||
self.parameters = self.__encode_param_keys(parameters)
|
||||
self.executemany = len(parameters) > 1
|
||||
|
||||
if isinstance(statement, unicode) and not dialect.supports_unicode_statements:
|
||||
self.unicode_statement = statement
|
||||
self.statement = statement.encode(self.dialect.encoding)
|
||||
else:
|
||||
self.statement = self.unicode_statement = statement
|
||||
|
||||
self.cursor = self.create_cursor()
|
||||
else:
|
||||
# no statement. used for standalone ColumnDefault execution.
|
||||
if connection._execution_options:
|
||||
self.execution_options = self.execution_options.union(connection._execution_options)
|
||||
self.cursor = self.create_cursor()
|
||||
|
||||
@util.memoized_property
|
||||
def is_crud(self):
|
||||
return self.isinsert or self.isupdate or self.isdelete
|
||||
|
||||
@util.memoized_property
|
||||
def should_autocommit(self):
|
||||
autocommit = self.execution_options.get('autocommit',
|
||||
not self.compiled and
|
||||
self.statement and
|
||||
expression.PARSE_AUTOCOMMIT
|
||||
or False)
|
||||
|
||||
if autocommit is expression.PARSE_AUTOCOMMIT:
|
||||
return self.should_autocommit_text(self.unicode_statement)
|
||||
else:
|
||||
return autocommit
|
||||
|
||||
@util.memoized_property
|
||||
def _is_explicit_returning(self):
|
||||
return self.compiled and \
|
||||
getattr(self.compiled.statement, '_returning', False)
|
||||
|
||||
@util.memoized_property
|
||||
def _is_implicit_returning(self):
|
||||
return self.compiled and \
|
||||
bool(self.compiled.returning) and \
|
||||
not self.compiled.statement._returning
|
||||
|
||||
@util.memoized_property
|
||||
def _default_params(self):
|
||||
if self.dialect.positional:
|
||||
return self.dialect.execute_sequence_format()
|
||||
else:
|
||||
return {}
|
||||
|
||||
def _execute_scalar(self, stmt):
|
||||
"""Execute a string statement on the current cursor, returning a scalar result.
|
||||
|
||||
Used to fire off sequences, default phrases, and "select lastrowid"
|
||||
types of statements individually
|
||||
or in the context of a parent INSERT or UPDATE statement.
|
||||
|
||||
"""
|
||||
|
||||
conn = self._connection
|
||||
if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
|
||||
stmt = stmt.encode(self.dialect.encoding)
|
||||
conn._cursor_execute(self.cursor, stmt, self._default_params)
|
||||
return self.cursor.fetchone()[0]
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
return self._connection._branch()
|
||||
|
||||
def __encode_param_keys(self, params):
|
||||
"""Apply string encoding to the keys of dictionary-based bind parameters.
|
||||
|
||||
This is only used executing textual, non-compiled SQL expressions.
|
||||
|
||||
"""
|
||||
|
||||
if not params:
|
||||
return [self._default_params]
|
||||
elif isinstance(params[0], self.dialect.execute_sequence_format):
|
||||
return params
|
||||
elif isinstance(params[0], dict):
|
||||
if self.dialect.supports_unicode_statements:
|
||||
return params
|
||||
else:
|
||||
def proc(d):
|
||||
return dict((k.encode(self.dialect.encoding), d[k]) for k in d)
|
||||
return [proc(d) for d in params] or [{}]
|
||||
else:
|
||||
return [self.dialect.execute_sequence_format(p) for p in params]
|
||||
|
||||
|
||||
def __convert_compiled_params(self, compiled_parameters):
|
||||
"""Convert the dictionary of bind parameter values into a dict or list
|
||||
to be sent to the DBAPI's execute() or executemany() method.
|
||||
"""
|
||||
|
||||
processors = self.processors
|
||||
parameters = []
|
||||
if self.dialect.positional:
|
||||
for compiled_params in compiled_parameters:
|
||||
param = []
|
||||
for key in self.compiled.positiontup:
|
||||
if key in processors:
|
||||
param.append(processors[key](compiled_params[key]))
|
||||
else:
|
||||
param.append(compiled_params[key])
|
||||
parameters.append(self.dialect.execute_sequence_format(param))
|
||||
else:
|
||||
encode = not self.dialect.supports_unicode_statements
|
||||
for compiled_params in compiled_parameters:
|
||||
param = {}
|
||||
if encode:
|
||||
encoding = self.dialect.encoding
|
||||
for key in compiled_params:
|
||||
if key in processors:
|
||||
param[key.encode(encoding)] = processors[key](compiled_params[key])
|
||||
else:
|
||||
param[key.encode(encoding)] = compiled_params[key]
|
||||
else:
|
||||
for key in compiled_params:
|
||||
if key in processors:
|
||||
param[key] = processors[key](compiled_params[key])
|
||||
else:
|
||||
param[key] = compiled_params[key]
|
||||
parameters.append(param)
|
||||
return self.dialect.execute_sequence_format(parameters)
|
||||
|
||||
def should_autocommit_text(self, statement):
|
||||
return AUTOCOMMIT_REGEXP.match(statement)
|
||||
|
||||
def create_cursor(self):
|
||||
return self._connection.connection.cursor()
|
||||
|
||||
def pre_exec(self):
|
||||
pass
|
||||
|
||||
def post_exec(self):
|
||||
pass
|
||||
|
||||
def get_lastrowid(self):
|
||||
"""return self.cursor.lastrowid, or equivalent, after an INSERT.
|
||||
|
||||
This may involve calling special cursor functions,
|
||||
issuing a new SELECT on the cursor (or a new one),
|
||||
or returning a stored value that was
|
||||
calculated within post_exec().
|
||||
|
||||
This function will only be called for dialects
|
||||
which support "implicit" primary key generation,
|
||||
keep preexecute_autoincrement_sequences set to False,
|
||||
and when no explicit id value was bound to the
|
||||
statement.
|
||||
|
||||
The function is called once, directly after
|
||||
post_exec() and before the transaction is committed
|
||||
or ResultProxy is generated. If the post_exec()
|
||||
method assigns a value to `self._lastrowid`, the
|
||||
value is used in place of calling get_lastrowid().
|
||||
|
||||
Note that this method is *not* equivalent to the
|
||||
``lastrowid`` method on ``ResultProxy``, which is a
|
||||
direct proxy to the DBAPI ``lastrowid`` accessor
|
||||
in all cases.
|
||||
|
||||
"""
|
||||
|
||||
return self.cursor.lastrowid
|
||||
|
||||
def handle_dbapi_exception(self, e):
|
||||
pass
|
||||
|
||||
def get_result_proxy(self):
|
||||
return base.ResultProxy(self)
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
return self.cursor.rowcount
|
||||
|
||||
def supports_sane_rowcount(self):
|
||||
return self.dialect.supports_sane_rowcount
|
||||
|
||||
def supports_sane_multi_rowcount(self):
|
||||
return self.dialect.supports_sane_multi_rowcount
|
||||
|
||||
def post_insert(self):
|
||||
if self.dialect.postfetch_lastrowid and \
|
||||
(not len(self._inserted_primary_key) or \
|
||||
None in self._inserted_primary_key):
|
||||
|
||||
table = self.compiled.statement.table
|
||||
lastrowid = self.get_lastrowid()
|
||||
self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v
|
||||
for c, v in zip(table.primary_key, self._inserted_primary_key)
|
||||
]
|
||||
|
||||
def _fetch_implicit_returning(self, resultproxy):
|
||||
table = self.compiled.statement.table
|
||||
row = resultproxy.fetchone()
|
||||
|
||||
self._inserted_primary_key = [v is not None and v or row[c]
|
||||
for c, v in zip(table.primary_key, self._inserted_primary_key)
|
||||
]
|
||||
|
||||
def last_inserted_params(self):
|
||||
return self._last_inserted_params
|
||||
|
||||
def last_updated_params(self):
|
||||
return self._last_updated_params
|
||||
|
||||
def lastrow_has_defaults(self):
|
||||
return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
|
||||
|
||||
def set_input_sizes(self, translate=None, exclude_types=None):
|
||||
"""Given a cursor and ClauseParameters, call the appropriate
|
||||
style of ``setinputsizes()`` on the cursor, using DB-API types
|
||||
from the bind parameter's ``TypeEngine`` objects.
|
||||
"""
|
||||
|
||||
if not hasattr(self.compiled, 'bind_names'):
|
||||
return
|
||||
|
||||
types = dict(
|
||||
(self.compiled.bind_names[bindparam], bindparam.type)
|
||||
for bindparam in self.compiled.bind_names)
|
||||
|
||||
if self.dialect.positional:
|
||||
inputsizes = []
|
||||
for key in self.compiled.positiontup:
|
||||
typeengine = types[key]
|
||||
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
|
||||
if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
|
||||
inputsizes.append(dbtype)
|
||||
try:
|
||||
self.cursor.setinputsizes(*inputsizes)
|
||||
except Exception, e:
|
||||
self._connection._handle_dbapi_exception(e, None, None, None, self)
|
||||
raise
|
||||
else:
|
||||
inputsizes = {}
|
||||
for key in self.compiled.bind_names.values():
|
||||
typeengine = types[key]
|
||||
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
|
||||
if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
|
||||
if translate:
|
||||
key = translate.get(key, key)
|
||||
inputsizes[key.encode(self.dialect.encoding)] = dbtype
|
||||
try:
|
||||
self.cursor.setinputsizes(**inputsizes)
|
||||
except Exception, e:
|
||||
self._connection._handle_dbapi_exception(e, None, None, None, self)
|
||||
raise
|
||||
|
||||
def _exec_default(self, default):
|
||||
if default.is_sequence:
|
||||
return self.fire_sequence(default)
|
||||
elif default.is_callable:
|
||||
return default.arg(self)
|
||||
elif default.is_clause_element:
|
||||
# TODO: expensive branching here should be
|
||||
# pulled into _exec_scalar()
|
||||
conn = self.connection
|
||||
c = expression.select([default.arg]).compile(bind=conn)
|
||||
return conn._execute_compiled(c, (), {}).scalar()
|
||||
else:
|
||||
return default.arg
|
||||
|
||||
def get_insert_default(self, column):
|
||||
if column.default is None:
|
||||
return None
|
||||
else:
|
||||
return self._exec_default(column.default)
|
||||
|
||||
def get_update_default(self, column):
|
||||
if column.onupdate is None:
|
||||
return None
|
||||
else:
|
||||
return self._exec_default(column.onupdate)
|
||||
|
||||
def __process_defaults(self):
|
||||
"""Generate default values for compiled insert/update statements,
|
||||
and generate inserted_primary_key collection.
|
||||
"""
|
||||
|
||||
if self.executemany:
|
||||
if len(self.compiled.prefetch):
|
||||
scalar_defaults = {}
|
||||
|
||||
# pre-determine scalar Python-side defaults
|
||||
# to avoid many calls of get_insert_default()/get_update_default()
|
||||
for c in self.compiled.prefetch:
|
||||
if self.isinsert and c.default and c.default.is_scalar:
|
||||
scalar_defaults[c] = c.default.arg
|
||||
elif self.isupdate and c.onupdate and c.onupdate.is_scalar:
|
||||
scalar_defaults[c] = c.onupdate.arg
|
||||
|
||||
for param in self.compiled_parameters:
|
||||
self.current_parameters = param
|
||||
for c in self.compiled.prefetch:
|
||||
if c in scalar_defaults:
|
||||
val = scalar_defaults[c]
|
||||
elif self.isinsert:
|
||||
val = self.get_insert_default(c)
|
||||
else:
|
||||
val = self.get_update_default(c)
|
||||
if val is not None:
|
||||
param[c.key] = val
|
||||
del self.current_parameters
|
||||
|
||||
else:
|
||||
self.current_parameters = compiled_parameters = self.compiled_parameters[0]
|
||||
|
||||
for c in self.compiled.prefetch:
|
||||
if self.isinsert:
|
||||
val = self.get_insert_default(c)
|
||||
else:
|
||||
val = self.get_update_default(c)
|
||||
|
||||
if val is not None:
|
||||
compiled_parameters[c.key] = val
|
||||
del self.current_parameters
|
||||
|
||||
if self.isinsert:
|
||||
self._inserted_primary_key = [compiled_parameters.get(c.key, None)
|
||||
for c in self.compiled.statement.table.primary_key]
|
||||
self._last_inserted_params = compiled_parameters
|
||||
else:
|
||||
self._last_updated_params = compiled_parameters
|
||||
|
||||
self.postfetch_cols = self.compiled.postfetch
|
||||
self.prefetch_cols = self.compiled.prefetch
|
||||
|
||||
DefaultDialect.execution_ctx_cls = DefaultExecutionContext
|
||||
370
sqlalchemy/engine/reflection.py
Normal file
370
sqlalchemy/engine/reflection.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""Provides an abstraction for obtaining database schema information.
|
||||
|
||||
Usage Notes:
|
||||
|
||||
Here are some general conventions when accessing the low level inspector
|
||||
methods such as get_table_names, get_columns, etc.
|
||||
|
||||
1. Inspector methods return lists of dicts in most cases for the following
|
||||
reasons:
|
||||
|
||||
* They're both standard types that can be serialized.
|
||||
* Using a dict instead of a tuple allows easy expansion of attributes.
|
||||
* Using a list for the outer structure maintains order and is easy to work
|
||||
with (e.g. list comprehension [d['name'] for d in cols]).
|
||||
|
||||
2. Records that contain a name, such as the column name in a column record
|
||||
use the key 'name'. So for most return values, each record will have a
|
||||
'name' attribute..
|
||||
"""
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import exc, sql
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.types import TypeEngine
|
||||
from sqlalchemy import schema as sa_schema
|
||||
|
||||
|
||||
@util.decorator
|
||||
def cache(fn, self, con, *args, **kw):
|
||||
info_cache = kw.get('info_cache', None)
|
||||
if info_cache is None:
|
||||
return fn(self, con, *args, **kw)
|
||||
key = (
|
||||
fn.__name__,
|
||||
tuple(a for a in args if isinstance(a, basestring)),
|
||||
tuple((k, v) for k, v in kw.iteritems() if isinstance(v, (basestring, int, float)))
|
||||
)
|
||||
ret = info_cache.get(key)
|
||||
if ret is None:
|
||||
ret = fn(self, con, *args, **kw)
|
||||
info_cache[key] = ret
|
||||
return ret
|
||||
|
||||
|
||||
class Inspector(object):
|
||||
"""Performs database schema inspection.
|
||||
|
||||
The Inspector acts as a proxy to the dialects' reflection methods and
|
||||
provides higher level functions for accessing database schema information.
|
||||
"""
|
||||
|
||||
def __init__(self, conn):
|
||||
"""Initialize the instance.
|
||||
|
||||
:param conn: a :class:`~sqlalchemy.engine.base.Connectable`
|
||||
"""
|
||||
|
||||
self.conn = conn
|
||||
# set the engine
|
||||
if hasattr(conn, 'engine'):
|
||||
self.engine = conn.engine
|
||||
else:
|
||||
self.engine = conn
|
||||
self.dialect = self.engine.dialect
|
||||
self.info_cache = {}
|
||||
|
||||
@classmethod
|
||||
def from_engine(cls, engine):
|
||||
if hasattr(engine.dialect, 'inspector'):
|
||||
return engine.dialect.inspector(engine)
|
||||
return Inspector(engine)
|
||||
|
||||
@property
|
||||
def default_schema_name(self):
|
||||
return self.dialect.default_schema_name
|
||||
|
||||
def get_schema_names(self):
|
||||
"""Return all schema names.
|
||||
"""
|
||||
|
||||
if hasattr(self.dialect, 'get_schema_names'):
|
||||
return self.dialect.get_schema_names(self.conn,
|
||||
info_cache=self.info_cache)
|
||||
return []
|
||||
|
||||
def get_table_names(self, schema=None, order_by=None):
|
||||
"""Return all table names in `schema`.
|
||||
|
||||
:param schema: Optional, retrieve names from a non-default schema.
|
||||
:param order_by: Optional, may be the string "foreign_key" to sort
|
||||
the result on foreign key dependencies.
|
||||
|
||||
This should probably not return view names or maybe it should return
|
||||
them with an indicator t or v.
|
||||
"""
|
||||
|
||||
if hasattr(self.dialect, 'get_table_names'):
|
||||
tnames = self.dialect.get_table_names(self.conn,
|
||||
schema,
|
||||
info_cache=self.info_cache)
|
||||
else:
|
||||
tnames = self.engine.table_names(schema)
|
||||
if order_by == 'foreign_key':
|
||||
ordered_tnames = tnames[:]
|
||||
# Order based on foreign key dependencies.
|
||||
for tname in tnames:
|
||||
table_pos = tnames.index(tname)
|
||||
fkeys = self.get_foreign_keys(tname, schema)
|
||||
for fkey in fkeys:
|
||||
rtable = fkey['referred_table']
|
||||
if rtable in ordered_tnames:
|
||||
ref_pos = ordered_tnames.index(rtable)
|
||||
# Make sure it's lower in the list than anything it
|
||||
# references.
|
||||
if table_pos > ref_pos:
|
||||
ordered_tnames.pop(table_pos) # rtable moves up 1
|
||||
# insert just below rtable
|
||||
ordered_tnames.index(ref_pos, tname)
|
||||
tnames = ordered_tnames
|
||||
return tnames
|
||||
|
||||
def get_table_options(self, table_name, schema=None, **kw):
|
||||
if hasattr(self.dialect, 'get_table_options'):
|
||||
return self.dialect.get_table_options(self.conn, table_name, schema,
|
||||
info_cache=self.info_cache,
|
||||
**kw)
|
||||
return {}
|
||||
|
||||
def get_view_names(self, schema=None):
|
||||
"""Return all view names in `schema`.
|
||||
|
||||
:param schema: Optional, retrieve names from a non-default schema.
|
||||
"""
|
||||
|
||||
return self.dialect.get_view_names(self.conn, schema,
|
||||
info_cache=self.info_cache)
|
||||
|
||||
def get_view_definition(self, view_name, schema=None):
|
||||
"""Return definition for `view_name`.
|
||||
|
||||
:param schema: Optional, retrieve names from a non-default schema.
|
||||
"""
|
||||
|
||||
return self.dialect.get_view_definition(
|
||||
self.conn, view_name, schema, info_cache=self.info_cache)
|
||||
|
||||
def get_columns(self, table_name, schema=None, **kw):
|
||||
"""Return information about columns in `table_name`.
|
||||
|
||||
Given a string `table_name` and an optional string `schema`, return
|
||||
column information as a list of dicts with these keys:
|
||||
|
||||
name
|
||||
the column's name
|
||||
|
||||
type
|
||||
:class:`~sqlalchemy.types.TypeEngine`
|
||||
|
||||
nullable
|
||||
boolean
|
||||
|
||||
default
|
||||
the column's default value
|
||||
|
||||
attrs
|
||||
dict containing optional column attributes
|
||||
"""
|
||||
|
||||
col_defs = self.dialect.get_columns(self.conn, table_name, schema,
|
||||
info_cache=self.info_cache,
|
||||
**kw)
|
||||
for col_def in col_defs:
|
||||
# make this easy and only return instances for coltype
|
||||
coltype = col_def['type']
|
||||
if not isinstance(coltype, TypeEngine):
|
||||
col_def['type'] = coltype()
|
||||
return col_defs
|
||||
|
||||
def get_primary_keys(self, table_name, schema=None, **kw):
|
||||
"""Return information about primary keys in `table_name`.
|
||||
|
||||
Given a string `table_name`, and an optional string `schema`, return
|
||||
primary key information as a list of column names.
|
||||
"""
|
||||
|
||||
pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema,
|
||||
info_cache=self.info_cache,
|
||||
**kw)
|
||||
|
||||
return pkeys
|
||||
|
||||
def get_foreign_keys(self, table_name, schema=None, **kw):
|
||||
"""Return information about foreign_keys in `table_name`.
|
||||
|
||||
Given a string `table_name`, and an optional string `schema`, return
|
||||
foreign key information as a list of dicts with these keys:
|
||||
|
||||
constrained_columns
|
||||
a list of column names that make up the foreign key
|
||||
|
||||
referred_schema
|
||||
the name of the referred schema
|
||||
|
||||
referred_table
|
||||
the name of the referred table
|
||||
|
||||
referred_columns
|
||||
a list of column names in the referred table that correspond to
|
||||
constrained_columns
|
||||
|
||||
\**kw
|
||||
other options passed to the dialect's get_foreign_keys() method.
|
||||
|
||||
"""
|
||||
|
||||
fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema,
|
||||
info_cache=self.info_cache,
|
||||
**kw)
|
||||
return fk_defs
|
||||
|
||||
def get_indexes(self, table_name, schema=None, **kw):
|
||||
"""Return information about indexes in `table_name`.
|
||||
|
||||
Given a string `table_name` and an optional string `schema`, return
|
||||
index information as a list of dicts with these keys:
|
||||
|
||||
name
|
||||
the index's name
|
||||
|
||||
column_names
|
||||
list of column names in order
|
||||
|
||||
unique
|
||||
boolean
|
||||
|
||||
\**kw
|
||||
other options passed to the dialect's get_indexes() method.
|
||||
"""
|
||||
|
||||
indexes = self.dialect.get_indexes(self.conn, table_name,
|
||||
schema,
|
||||
info_cache=self.info_cache, **kw)
|
||||
return indexes
|
||||
|
||||
def reflecttable(self, table, include_columns):
|
||||
|
||||
dialect = self.conn.dialect
|
||||
|
||||
# MySQL dialect does this. Applicable with other dialects?
|
||||
if hasattr(dialect, '_connection_charset') \
|
||||
and hasattr(dialect, '_adjust_casing'):
|
||||
charset = dialect._connection_charset
|
||||
dialect._adjust_casing(table)
|
||||
|
||||
# table attributes we might need.
|
||||
reflection_options = dict(
|
||||
(k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs)
|
||||
|
||||
schema = table.schema
|
||||
table_name = table.name
|
||||
|
||||
# apply table options
|
||||
tbl_opts = self.get_table_options(table_name, schema, **table.kwargs)
|
||||
if tbl_opts:
|
||||
table.kwargs.update(tbl_opts)
|
||||
|
||||
# table.kwargs will need to be passed to each reflection method. Make
|
||||
# sure keywords are strings.
|
||||
tblkw = table.kwargs.copy()
|
||||
for (k, v) in tblkw.items():
|
||||
del tblkw[k]
|
||||
tblkw[str(k)] = v
|
||||
|
||||
# Py2K
|
||||
if isinstance(schema, str):
|
||||
schema = schema.decode(dialect.encoding)
|
||||
if isinstance(table_name, str):
|
||||
table_name = table_name.decode(dialect.encoding)
|
||||
# end Py2K
|
||||
|
||||
# columns
|
||||
found_table = False
|
||||
for col_d in self.get_columns(table_name, schema, **tblkw):
|
||||
found_table = True
|
||||
name = col_d['name']
|
||||
if include_columns and name not in include_columns:
|
||||
continue
|
||||
|
||||
coltype = col_d['type']
|
||||
col_kw = {
|
||||
'nullable':col_d['nullable'],
|
||||
}
|
||||
if 'autoincrement' in col_d:
|
||||
col_kw['autoincrement'] = col_d['autoincrement']
|
||||
if 'quote' in col_d:
|
||||
col_kw['quote'] = col_d['quote']
|
||||
|
||||
colargs = []
|
||||
if col_d.get('default') is not None:
|
||||
# the "default" value is assumed to be a literal SQL expression,
|
||||
# so is wrapped in text() so that no quoting occurs on re-issuance.
|
||||
colargs.append(sa_schema.DefaultClause(sql.text(col_d['default'])))
|
||||
|
||||
if 'sequence' in col_d:
|
||||
# TODO: mssql, maxdb and sybase are using this.
|
||||
seq = col_d['sequence']
|
||||
sequence = sa_schema.Sequence(seq['name'], 1, 1)
|
||||
if 'start' in seq:
|
||||
sequence.start = seq['start']
|
||||
if 'increment' in seq:
|
||||
sequence.increment = seq['increment']
|
||||
colargs.append(sequence)
|
||||
|
||||
col = sa_schema.Column(name, coltype, *colargs, **col_kw)
|
||||
table.append_column(col)
|
||||
|
||||
if not found_table:
|
||||
raise exc.NoSuchTableError(table.name)
|
||||
|
||||
# Primary keys
|
||||
primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[
|
||||
table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw)
|
||||
if pk in table.c
|
||||
])
|
||||
|
||||
table.append_constraint(primary_key_constraint)
|
||||
|
||||
# Foreign keys
|
||||
fkeys = self.get_foreign_keys(table_name, schema, **tblkw)
|
||||
for fkey_d in fkeys:
|
||||
conname = fkey_d['name']
|
||||
constrained_columns = fkey_d['constrained_columns']
|
||||
referred_schema = fkey_d['referred_schema']
|
||||
referred_table = fkey_d['referred_table']
|
||||
referred_columns = fkey_d['referred_columns']
|
||||
refspec = []
|
||||
if referred_schema is not None:
|
||||
sa_schema.Table(referred_table, table.metadata,
|
||||
autoload=True, schema=referred_schema,
|
||||
autoload_with=self.conn,
|
||||
**reflection_options
|
||||
)
|
||||
for column in referred_columns:
|
||||
refspec.append(".".join(
|
||||
[referred_schema, referred_table, column]))
|
||||
else:
|
||||
sa_schema.Table(referred_table, table.metadata, autoload=True,
|
||||
autoload_with=self.conn,
|
||||
**reflection_options
|
||||
)
|
||||
for column in referred_columns:
|
||||
refspec.append(".".join([referred_table, column]))
|
||||
table.append_constraint(
|
||||
sa_schema.ForeignKeyConstraint(constrained_columns, refspec,
|
||||
conname, link_to_name=True))
|
||||
# Indexes
|
||||
indexes = self.get_indexes(table_name, schema)
|
||||
for index_d in indexes:
|
||||
name = index_d['name']
|
||||
columns = index_d['column_names']
|
||||
unique = index_d['unique']
|
||||
flavor = index_d.get('type', 'unknown type')
|
||||
if include_columns and \
|
||||
not set(columns).issubset(include_columns):
|
||||
util.warn(
|
||||
"Omitting %s KEY for (%s), key covers omitted columns." %
|
||||
(flavor, ', '.join(columns)))
|
||||
continue
|
||||
sa_schema.Index(name, *[table.columns[c] for c in columns],
|
||||
**dict(unique=unique))
|
||||
227
sqlalchemy/engine/strategies.py
Normal file
227
sqlalchemy/engine/strategies.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Strategies for creating new instances of Engine types.
|
||||
|
||||
These are semi-private implementation classes which provide the
|
||||
underlying behavior for the "strategy" keyword argument available on
|
||||
:func:`~sqlalchemy.engine.create_engine`. Current available options are
|
||||
``plain``, ``threadlocal``, and ``mock``.
|
||||
|
||||
New strategies can be added via new ``EngineStrategy`` classes.
|
||||
"""
|
||||
|
||||
from operator import attrgetter
|
||||
|
||||
from sqlalchemy.engine import base, threadlocal, url
|
||||
from sqlalchemy import util, exc
|
||||
from sqlalchemy import pool as poollib
|
||||
|
||||
strategies = {}
|
||||
|
||||
|
||||
class EngineStrategy(object):
|
||||
"""An adaptor that processes input arguements and produces an Engine.
|
||||
|
||||
Provides a ``create`` method that receives input arguments and
|
||||
produces an instance of base.Engine or a subclass.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
strategies[self.name] = self
|
||||
|
||||
def create(self, *args, **kwargs):
|
||||
"""Given arguments, returns a new Engine instance."""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DefaultEngineStrategy(EngineStrategy):
|
||||
"""Base class for built-in stratgies."""
|
||||
|
||||
pool_threadlocal = False
|
||||
|
||||
def create(self, name_or_url, **kwargs):
|
||||
# create url.URL object
|
||||
u = url.make_url(name_or_url)
|
||||
|
||||
dialect_cls = u.get_dialect()
|
||||
|
||||
dialect_args = {}
|
||||
# consume dialect arguments from kwargs
|
||||
for k in util.get_cls_kwargs(dialect_cls):
|
||||
if k in kwargs:
|
||||
dialect_args[k] = kwargs.pop(k)
|
||||
|
||||
dbapi = kwargs.pop('module', None)
|
||||
if dbapi is None:
|
||||
dbapi_args = {}
|
||||
for k in util.get_func_kwargs(dialect_cls.dbapi):
|
||||
if k in kwargs:
|
||||
dbapi_args[k] = kwargs.pop(k)
|
||||
dbapi = dialect_cls.dbapi(**dbapi_args)
|
||||
|
||||
dialect_args['dbapi'] = dbapi
|
||||
|
||||
# create dialect
|
||||
dialect = dialect_cls(**dialect_args)
|
||||
|
||||
# assemble connection arguments
|
||||
(cargs, cparams) = dialect.create_connect_args(u)
|
||||
cparams.update(kwargs.pop('connect_args', {}))
|
||||
|
||||
# look for existing pool or create
|
||||
pool = kwargs.pop('pool', None)
|
||||
if pool is None:
|
||||
def connect():
|
||||
try:
|
||||
return dialect.connect(*cargs, **cparams)
|
||||
except Exception, e:
|
||||
# Py3K
|
||||
#raise exc.DBAPIError.instance(None, None, e) from e
|
||||
# Py2K
|
||||
import sys
|
||||
raise exc.DBAPIError.instance(None, None, e), None, sys.exc_info()[2]
|
||||
# end Py2K
|
||||
|
||||
creator = kwargs.pop('creator', connect)
|
||||
|
||||
poolclass = (kwargs.pop('poolclass', None) or
|
||||
getattr(dialect_cls, 'poolclass', poollib.QueuePool))
|
||||
pool_args = {}
|
||||
|
||||
# consume pool arguments from kwargs, translating a few of
|
||||
# the arguments
|
||||
translate = {'logging_name': 'pool_logging_name',
|
||||
'echo': 'echo_pool',
|
||||
'timeout': 'pool_timeout',
|
||||
'recycle': 'pool_recycle',
|
||||
'use_threadlocal':'pool_threadlocal'}
|
||||
for k in util.get_cls_kwargs(poolclass):
|
||||
tk = translate.get(k, k)
|
||||
if tk in kwargs:
|
||||
pool_args[k] = kwargs.pop(tk)
|
||||
pool_args.setdefault('use_threadlocal', self.pool_threadlocal)
|
||||
pool = poolclass(creator, **pool_args)
|
||||
else:
|
||||
if isinstance(pool, poollib._DBProxy):
|
||||
pool = pool.get_pool(*cargs, **cparams)
|
||||
else:
|
||||
pool = pool
|
||||
|
||||
# create engine.
|
||||
engineclass = self.engine_cls
|
||||
engine_args = {}
|
||||
for k in util.get_cls_kwargs(engineclass):
|
||||
if k in kwargs:
|
||||
engine_args[k] = kwargs.pop(k)
|
||||
|
||||
_initialize = kwargs.pop('_initialize', True)
|
||||
|
||||
# all kwargs should be consumed
|
||||
if kwargs:
|
||||
raise TypeError(
|
||||
"Invalid argument(s) %s sent to create_engine(), "
|
||||
"using configuration %s/%s/%s. Please check that the "
|
||||
"keyword arguments are appropriate for this combination "
|
||||
"of components." % (','.join("'%s'" % k for k in kwargs),
|
||||
dialect.__class__.__name__,
|
||||
pool.__class__.__name__,
|
||||
engineclass.__name__))
|
||||
|
||||
engine = engineclass(pool, dialect, u, **engine_args)
|
||||
|
||||
if _initialize:
|
||||
do_on_connect = dialect.on_connect()
|
||||
if do_on_connect:
|
||||
def on_connect(conn, rec):
|
||||
conn = getattr(conn, '_sqla_unwrap', conn)
|
||||
if conn is None:
|
||||
return
|
||||
do_on_connect(conn)
|
||||
|
||||
pool.add_listener({'first_connect': on_connect, 'connect':on_connect})
|
||||
|
||||
def first_connect(conn, rec):
|
||||
c = base.Connection(engine, connection=conn)
|
||||
dialect.initialize(c)
|
||||
pool.add_listener({'first_connect':first_connect})
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
class PlainEngineStrategy(DefaultEngineStrategy):
|
||||
"""Strategy for configuring a regular Engine."""
|
||||
|
||||
name = 'plain'
|
||||
engine_cls = base.Engine
|
||||
|
||||
PlainEngineStrategy()
|
||||
|
||||
|
||||
class ThreadLocalEngineStrategy(DefaultEngineStrategy):
|
||||
"""Strategy for configuring an Engine with thredlocal behavior."""
|
||||
|
||||
name = 'threadlocal'
|
||||
pool_threadlocal = True
|
||||
engine_cls = threadlocal.TLEngine
|
||||
|
||||
ThreadLocalEngineStrategy()
|
||||
|
||||
|
||||
class MockEngineStrategy(EngineStrategy):
|
||||
"""Strategy for configuring an Engine-like object with mocked execution.
|
||||
|
||||
Produces a single mock Connectable object which dispatches
|
||||
statement execution to a passed-in function.
|
||||
|
||||
"""
|
||||
|
||||
name = 'mock'
|
||||
|
||||
def create(self, name_or_url, executor, **kwargs):
|
||||
# create url.URL object
|
||||
u = url.make_url(name_or_url)
|
||||
|
||||
dialect_cls = u.get_dialect()
|
||||
|
||||
dialect_args = {}
|
||||
# consume dialect arguments from kwargs
|
||||
for k in util.get_cls_kwargs(dialect_cls):
|
||||
if k in kwargs:
|
||||
dialect_args[k] = kwargs.pop(k)
|
||||
|
||||
# create dialect
|
||||
dialect = dialect_cls(**dialect_args)
|
||||
|
||||
return MockEngineStrategy.MockConnection(dialect, executor)
|
||||
|
||||
class MockConnection(base.Connectable):
|
||||
def __init__(self, dialect, execute):
|
||||
self._dialect = dialect
|
||||
self.execute = execute
|
||||
|
||||
engine = property(lambda s: s)
|
||||
dialect = property(attrgetter('_dialect'))
|
||||
name = property(lambda s: s._dialect.name)
|
||||
|
||||
def contextual_connect(self, **kwargs):
|
||||
return self
|
||||
|
||||
def compiler(self, statement, parameters, **kwargs):
|
||||
return self._dialect.compiler(
|
||||
statement, parameters, engine=self, **kwargs)
|
||||
|
||||
def create(self, entity, **kwargs):
|
||||
kwargs['checkfirst'] = False
|
||||
from sqlalchemy.engine import ddl
|
||||
|
||||
ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity)
|
||||
|
||||
def drop(self, entity, **kwargs):
|
||||
kwargs['checkfirst'] = False
|
||||
from sqlalchemy.engine import ddl
|
||||
ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity)
|
||||
|
||||
def execute(self, object, *multiparams, **params):
|
||||
raise NotImplementedError()
|
||||
|
||||
MockEngineStrategy()
|
||||
103
sqlalchemy/engine/threadlocal.py
Normal file
103
sqlalchemy/engine/threadlocal.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Provides a thread-local transactional wrapper around the root Engine class.
|
||||
|
||||
The ``threadlocal`` module is invoked when using the ``strategy="threadlocal"`` flag
|
||||
with :func:`~sqlalchemy.engine.create_engine`. This module is semi-private and is
|
||||
invoked automatically when the threadlocal engine strategy is used.
|
||||
"""
|
||||
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.engine import base
|
||||
import weakref
|
||||
|
||||
class TLConnection(base.Connection):
|
||||
def __init__(self, *arg, **kw):
|
||||
super(TLConnection, self).__init__(*arg, **kw)
|
||||
self.__opencount = 0
|
||||
|
||||
def _increment_connect(self):
|
||||
self.__opencount += 1
|
||||
return self
|
||||
|
||||
def close(self):
|
||||
if self.__opencount == 1:
|
||||
base.Connection.close(self)
|
||||
self.__opencount -= 1
|
||||
|
||||
def _force_close(self):
|
||||
self.__opencount = 0
|
||||
base.Connection.close(self)
|
||||
|
||||
|
||||
class TLEngine(base.Engine):
|
||||
"""An Engine that includes support for thread-local managed transactions."""
|
||||
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TLEngine, self).__init__(*args, **kwargs)
|
||||
self._connections = util.threading.local()
|
||||
proxy = kwargs.get('proxy')
|
||||
if proxy:
|
||||
self.TLConnection = base._proxy_connection_cls(TLConnection, proxy)
|
||||
else:
|
||||
self.TLConnection = TLConnection
|
||||
|
||||
def contextual_connect(self, **kw):
|
||||
if not hasattr(self._connections, 'conn'):
|
||||
connection = None
|
||||
else:
|
||||
connection = self._connections.conn()
|
||||
|
||||
if connection is None or connection.closed:
|
||||
# guards against pool-level reapers, if desired.
|
||||
# or not connection.connection.is_valid:
|
||||
connection = self.TLConnection(self, self.pool.connect(), **kw)
|
||||
self._connections.conn = conn = weakref.ref(connection)
|
||||
|
||||
return connection._increment_connect()
|
||||
|
||||
def begin_twophase(self, xid=None):
|
||||
if not hasattr(self._connections, 'trans'):
|
||||
self._connections.trans = []
|
||||
self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid))
|
||||
|
||||
def begin_nested(self):
|
||||
if not hasattr(self._connections, 'trans'):
|
||||
self._connections.trans = []
|
||||
self._connections.trans.append(self.contextual_connect().begin_nested())
|
||||
|
||||
def begin(self):
|
||||
if not hasattr(self._connections, 'trans'):
|
||||
self._connections.trans = []
|
||||
self._connections.trans.append(self.contextual_connect().begin())
|
||||
|
||||
def prepare(self):
|
||||
self._connections.trans[-1].prepare()
|
||||
|
||||
def commit(self):
|
||||
trans = self._connections.trans.pop(-1)
|
||||
trans.commit()
|
||||
|
||||
def rollback(self):
|
||||
trans = self._connections.trans.pop(-1)
|
||||
trans.rollback()
|
||||
|
||||
def dispose(self):
|
||||
self._connections = util.threading.local()
|
||||
super(TLEngine, self).dispose()
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
return not hasattr(self._connections, 'conn') or \
|
||||
self._connections.conn() is None or \
|
||||
self._connections.conn().closed
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
self.contextual_connect().close()
|
||||
connection = self._connections.conn()
|
||||
connection._force_close()
|
||||
del self._connections.conn
|
||||
self._connections.trans = []
|
||||
|
||||
def __repr__(self):
|
||||
return 'TLEngine(%s)' % str(self.url)
|
||||
214
sqlalchemy/engine/url.py
Normal file
214
sqlalchemy/engine/url.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates
|
||||
information about a database connection specification.
|
||||
|
||||
The URL object is created automatically when :func:`~sqlalchemy.engine.create_engine` is called
|
||||
with a string argument; alternatively, the URL is a public-facing construct which can
|
||||
be used directly and is also accepted directly by ``create_engine()``.
|
||||
"""
|
||||
|
||||
import re, cgi, sys, urllib
|
||||
from sqlalchemy import exc
|
||||
|
||||
|
||||
class URL(object):
|
||||
"""
|
||||
Represent the components of a URL used to connect to a database.
|
||||
|
||||
This object is suitable to be passed directly to a
|
||||
``create_engine()`` call. The fields of the URL are parsed from a
|
||||
string by the ``module-level make_url()`` function. the string
|
||||
format of the URL is an RFC-1738-style string.
|
||||
|
||||
All initialization parameters are available as public attributes.
|
||||
|
||||
:param drivername: the name of the database backend.
|
||||
This name will correspond to a module in sqlalchemy/databases
|
||||
or a third party plug-in.
|
||||
|
||||
:param username: The user name.
|
||||
|
||||
:param password: database password.
|
||||
|
||||
:param host: The name of the host.
|
||||
|
||||
:param port: The port number.
|
||||
|
||||
:param database: The database name.
|
||||
|
||||
:param query: A dictionary of options to be passed to the
|
||||
dialect and/or the DBAPI upon connect.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, drivername, username=None, password=None,
|
||||
host=None, port=None, database=None, query=None):
|
||||
self.drivername = drivername
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.host = host
|
||||
if port is not None:
|
||||
self.port = int(port)
|
||||
else:
|
||||
self.port = None
|
||||
self.database = database
|
||||
self.query = query or {}
|
||||
|
||||
def __str__(self):
|
||||
s = self.drivername + "://"
|
||||
if self.username is not None:
|
||||
s += self.username
|
||||
if self.password is not None:
|
||||
s += ':' + urllib.quote_plus(self.password)
|
||||
s += "@"
|
||||
if self.host is not None:
|
||||
s += self.host
|
||||
if self.port is not None:
|
||||
s += ':' + str(self.port)
|
||||
if self.database is not None:
|
||||
s += '/' + self.database
|
||||
if self.query:
|
||||
keys = self.query.keys()
|
||||
keys.sort()
|
||||
s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys)
|
||||
return s
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
return \
|
||||
isinstance(other, URL) and \
|
||||
self.drivername == other.drivername and \
|
||||
self.username == other.username and \
|
||||
self.password == other.password and \
|
||||
self.host == other.host and \
|
||||
self.database == other.database and \
|
||||
self.query == other.query
|
||||
|
||||
def get_dialect(self):
|
||||
"""Return the SQLAlchemy database dialect class corresponding
|
||||
to this URL's driver name.
|
||||
"""
|
||||
|
||||
try:
|
||||
if '+' in self.drivername:
|
||||
dialect, driver = self.drivername.split('+')
|
||||
else:
|
||||
dialect, driver = self.drivername, 'base'
|
||||
|
||||
module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
|
||||
module = getattr(module, dialect)
|
||||
module = getattr(module, driver)
|
||||
|
||||
return module.dialect
|
||||
except ImportError:
|
||||
module = self._load_entry_point()
|
||||
if module is not None:
|
||||
return module
|
||||
else:
|
||||
raise
|
||||
|
||||
def _load_entry_point(self):
|
||||
"""attempt to load this url's dialect from entry points, or return None
|
||||
if pkg_resources is not installed or there is no matching entry point.
|
||||
|
||||
Raise ImportError if the actual load fails.
|
||||
|
||||
"""
|
||||
try:
|
||||
import pkg_resources
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
for res in pkg_resources.iter_entry_points('sqlalchemy.dialects'):
|
||||
if res.name == self.drivername:
|
||||
return res.load()
|
||||
else:
|
||||
return None
|
||||
|
||||
def translate_connect_args(self, names=[], **kw):
|
||||
"""Translate url attributes into a dictionary of connection arguments.
|
||||
|
||||
Returns attributes of this url (`host`, `database`, `username`,
|
||||
`password`, `port`) as a plain dictionary. The attribute names are
|
||||
used as the keys by default. Unset or false attributes are omitted
|
||||
from the final dictionary.
|
||||
|
||||
:param \**kw: Optional, alternate key names for url attributes.
|
||||
|
||||
:param names: Deprecated. Same purpose as the keyword-based alternate names,
|
||||
but correlates the name to the original positionally.
|
||||
"""
|
||||
|
||||
translated = {}
|
||||
attribute_names = ['host', 'database', 'username', 'password', 'port']
|
||||
for sname in attribute_names:
|
||||
if names:
|
||||
name = names.pop(0)
|
||||
elif sname in kw:
|
||||
name = kw[sname]
|
||||
else:
|
||||
name = sname
|
||||
if name is not None and getattr(self, sname, False):
|
||||
translated[name] = getattr(self, sname)
|
||||
return translated
|
||||
|
||||
def make_url(name_or_url):
|
||||
"""Given a string or unicode instance, produce a new URL instance.
|
||||
|
||||
The given string is parsed according to the RFC 1738 spec. If an
|
||||
existing URL object is passed, just returns the object.
|
||||
"""
|
||||
|
||||
if isinstance(name_or_url, basestring):
|
||||
return _parse_rfc1738_args(name_or_url)
|
||||
else:
|
||||
return name_or_url
|
||||
|
||||
def _parse_rfc1738_args(name):
|
||||
pattern = re.compile(r'''
|
||||
(?P<name>[\w\+]+)://
|
||||
(?:
|
||||
(?P<username>[^:/]*)
|
||||
(?::(?P<password>[^/]*))?
|
||||
@)?
|
||||
(?:
|
||||
(?P<host>[^/:]*)
|
||||
(?::(?P<port>[^/]*))?
|
||||
)?
|
||||
(?:/(?P<database>.*))?
|
||||
'''
|
||||
, re.X)
|
||||
|
||||
m = pattern.match(name)
|
||||
if m is not None:
|
||||
components = m.groupdict()
|
||||
if components['database'] is not None:
|
||||
tokens = components['database'].split('?', 2)
|
||||
components['database'] = tokens[0]
|
||||
query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None
|
||||
# Py2K
|
||||
if query is not None:
|
||||
query = dict((k.encode('ascii'), query[k]) for k in query)
|
||||
# end Py2K
|
||||
else:
|
||||
query = None
|
||||
components['query'] = query
|
||||
|
||||
if components['password'] is not None:
|
||||
components['password'] = urllib.unquote_plus(components['password'])
|
||||
|
||||
name = components.pop('name')
|
||||
return URL(name, **components)
|
||||
else:
|
||||
raise exc.ArgumentError(
|
||||
"Could not parse rfc1738 URL from string '%s'" % name)
|
||||
|
||||
def _parse_keyvalue_args(name):
|
||||
m = re.match( r'(\w+)://(.*)', name)
|
||||
if m is not None:
|
||||
(name, args) = m.group(1, 2)
|
||||
opts = dict( cgi.parse_qsl( args ) )
|
||||
return URL(name, *opts)
|
||||
else:
|
||||
return None
|
||||
191
sqlalchemy/exc.py
Normal file
191
sqlalchemy/exc.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Exceptions used with SQLAlchemy.
|
||||
|
||||
The base exception class is SQLAlchemyError. Exceptions which are raised as a
|
||||
result of DBAPI exceptions are all subclasses of
|
||||
:class:`~sqlalchemy.exc.DBAPIError`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class SQLAlchemyError(Exception):
|
||||
"""Generic error class."""
|
||||
|
||||
|
||||
class ArgumentError(SQLAlchemyError):
|
||||
"""Raised when an invalid or conflicting function argument is supplied.
|
||||
|
||||
This error generally corresponds to construction time state errors.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class CircularDependencyError(SQLAlchemyError):
|
||||
"""Raised by topological sorts when a circular dependency is detected"""
|
||||
|
||||
|
||||
class CompileError(SQLAlchemyError):
|
||||
"""Raised when an error occurs during SQL compilation"""
|
||||
|
||||
class IdentifierError(SQLAlchemyError):
|
||||
"""Raised when a schema name is beyond the max character limit"""
|
||||
|
||||
# Moved to orm.exc; compatability definition installed by orm import until 0.6
|
||||
ConcurrentModificationError = None
|
||||
|
||||
class DisconnectionError(SQLAlchemyError):
|
||||
"""A disconnect is detected on a raw DB-API connection.
|
||||
|
||||
This error is raised and consumed internally by a connection pool. It can
|
||||
be raised by a ``PoolListener`` so that the host pool forces a disconnect.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# Moved to orm.exc; compatability definition installed by orm import until 0.6
|
||||
FlushError = None
|
||||
|
||||
class TimeoutError(SQLAlchemyError):
|
||||
"""Raised when a connection pool times out on getting a connection."""
|
||||
|
||||
|
||||
class InvalidRequestError(SQLAlchemyError):
|
||||
"""SQLAlchemy was asked to do something it can't do.
|
||||
|
||||
This error generally corresponds to runtime state errors.
|
||||
|
||||
"""
|
||||
|
||||
class NoSuchColumnError(KeyError, InvalidRequestError):
|
||||
"""A nonexistent column is requested from a ``RowProxy``."""
|
||||
|
||||
class NoReferenceError(InvalidRequestError):
|
||||
"""Raised by ``ForeignKey`` to indicate a reference cannot be resolved."""
|
||||
|
||||
class NoReferencedTableError(NoReferenceError):
|
||||
"""Raised by ``ForeignKey`` when the referred ``Table`` cannot be located."""
|
||||
|
||||
class NoReferencedColumnError(NoReferenceError):
|
||||
"""Raised by ``ForeignKey`` when the referred ``Column`` cannot be located."""
|
||||
|
||||
class NoSuchTableError(InvalidRequestError):
|
||||
"""Table does not exist or is not visible to a connection."""
|
||||
|
||||
|
||||
class UnboundExecutionError(InvalidRequestError):
|
||||
"""SQL was attempted without a database connection to execute it on."""
|
||||
|
||||
|
||||
# Moved to orm.exc; compatability definition installed by orm import until 0.6
|
||||
UnmappedColumnError = None
|
||||
|
||||
class DBAPIError(SQLAlchemyError):
|
||||
"""Raised when the execution of a database operation fails.
|
||||
|
||||
``DBAPIError`` wraps exceptions raised by the DB-API underlying the
|
||||
database operation. Driver-specific implementations of the standard
|
||||
DB-API exception types are wrapped by matching sub-types of SQLAlchemy's
|
||||
``DBAPIError`` when possible. DB-API's ``Error`` type maps to
|
||||
``DBAPIError`` in SQLAlchemy, otherwise the names are identical. Note
|
||||
that there is no guarantee that different DB-API implementations will
|
||||
raise the same exception type for any given error condition.
|
||||
|
||||
If the error-raising operation occured in the execution of a SQL
|
||||
statement, that statement and its parameters will be available on
|
||||
the exception object in the ``statement`` and ``params`` attributes.
|
||||
|
||||
The wrapped exception object is available in the ``orig`` attribute.
|
||||
Its type and properties are DB-API implementation specific.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def instance(cls, statement, params, orig, connection_invalidated=False):
|
||||
# Don't ever wrap these, just return them directly as if
|
||||
# DBAPIError didn't exist.
|
||||
if isinstance(orig, (KeyboardInterrupt, SystemExit)):
|
||||
return orig
|
||||
|
||||
if orig is not None:
|
||||
name, glob = orig.__class__.__name__, globals()
|
||||
if name in glob and issubclass(glob[name], DBAPIError):
|
||||
cls = glob[name]
|
||||
|
||||
return cls(statement, params, orig, connection_invalidated)
|
||||
|
||||
def __init__(self, statement, params, orig, connection_invalidated=False):
|
||||
try:
|
||||
text = str(orig)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
raise
|
||||
except Exception, e:
|
||||
text = 'Error in str() of DB-API-generated exception: ' + str(e)
|
||||
SQLAlchemyError.__init__(
|
||||
self, '(%s) %s' % (orig.__class__.__name__, text))
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
self.orig = orig
|
||||
self.connection_invalidated = connection_invalidated
|
||||
|
||||
def __str__(self):
|
||||
if isinstance(self.params, (list, tuple)) and len(self.params) > 10 and isinstance(self.params[0], (list, dict, tuple)):
|
||||
return ' '.join((SQLAlchemyError.__str__(self),
|
||||
repr(self.statement),
|
||||
repr(self.params[:2]),
|
||||
'... and a total of %i bound parameter sets' % len(self.params)))
|
||||
return ' '.join((SQLAlchemyError.__str__(self),
|
||||
repr(self.statement), repr(self.params)))
|
||||
|
||||
|
||||
# As of 0.4, SQLError is now DBAPIError.
|
||||
# SQLError alias will be removed in 0.6.
|
||||
SQLError = DBAPIError
|
||||
|
||||
class InterfaceError(DBAPIError):
|
||||
"""Wraps a DB-API InterfaceError."""
|
||||
|
||||
|
||||
class DatabaseError(DBAPIError):
|
||||
"""Wraps a DB-API DatabaseError."""
|
||||
|
||||
|
||||
class DataError(DatabaseError):
|
||||
"""Wraps a DB-API DataError."""
|
||||
|
||||
|
||||
class OperationalError(DatabaseError):
|
||||
"""Wraps a DB-API OperationalError."""
|
||||
|
||||
|
||||
class IntegrityError(DatabaseError):
|
||||
"""Wraps a DB-API IntegrityError."""
|
||||
|
||||
|
||||
class InternalError(DatabaseError):
|
||||
"""Wraps a DB-API InternalError."""
|
||||
|
||||
|
||||
class ProgrammingError(DatabaseError):
|
||||
"""Wraps a DB-API ProgrammingError."""
|
||||
|
||||
|
||||
class NotSupportedError(DatabaseError):
|
||||
"""Wraps a DB-API NotSupportedError."""
|
||||
|
||||
|
||||
# Warnings
|
||||
|
||||
class SADeprecationWarning(DeprecationWarning):
|
||||
"""Issued once per usage of a deprecated API."""
|
||||
|
||||
|
||||
class SAPendingDeprecationWarning(PendingDeprecationWarning):
|
||||
"""Issued once per usage of a deprecated API."""
|
||||
|
||||
|
||||
class SAWarning(RuntimeWarning):
|
||||
"""Issued at runtime."""
|
||||
1
sqlalchemy/ext/__init__.py
Normal file
1
sqlalchemy/ext/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
878
sqlalchemy/ext/associationproxy.py
Normal file
878
sqlalchemy/ext/associationproxy.py
Normal file
@@ -0,0 +1,878 @@
|
||||
"""Contain the ``AssociationProxy`` class.
|
||||
|
||||
The ``AssociationProxy`` is a Python property object which provides
|
||||
transparent proxied access to the endpoint of an association object.
|
||||
|
||||
See the example ``examples/association/proxied_association.py``.
|
||||
|
||||
"""
|
||||
import itertools
|
||||
import operator
|
||||
import weakref
|
||||
from sqlalchemy import exceptions
|
||||
from sqlalchemy import orm
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.orm import collections
|
||||
from sqlalchemy.sql import not_
|
||||
|
||||
|
||||
def association_proxy(target_collection, attr, **kw):
|
||||
"""Return a Python property implementing a view of *attr* over a collection.
|
||||
|
||||
Implements a read/write view over an instance's *target_collection*,
|
||||
extracting *attr* from each member of the collection. The property acts
|
||||
somewhat like this list comprehension::
|
||||
|
||||
[getattr(member, *attr*)
|
||||
for member in getattr(instance, *target_collection*)]
|
||||
|
||||
Unlike the list comprehension, the collection returned by the property is
|
||||
always in sync with *target_collection*, and mutations made to either
|
||||
collection will be reflected in both.
|
||||
|
||||
Implements a Python property representing a relationship as a collection of
|
||||
simpler values. The proxied property will mimic the collection type of
|
||||
the target (list, dict or set), or, in the case of a one to one relationship,
|
||||
a simple scalar value.
|
||||
|
||||
:param target_collection: Name of the relationship attribute we'll proxy to,
|
||||
usually created with :func:`~sqlalchemy.orm.relationship`.
|
||||
|
||||
:param attr: Attribute on the associated instances we'll proxy for.
|
||||
|
||||
For example, given a target collection of [obj1, obj2], a list created
|
||||
by this proxy property would look like [getattr(obj1, *attr*),
|
||||
getattr(obj2, *attr*)]
|
||||
|
||||
If the relationship is one-to-one or otherwise uselist=False, then simply:
|
||||
getattr(obj, *attr*)
|
||||
|
||||
:param creator: optional.
|
||||
|
||||
When new items are added to this proxied collection, new instances of
|
||||
the class collected by the target collection will be created. For list
|
||||
and set collections, the target class constructor will be called with
|
||||
the 'value' for the new instance. For dict types, two arguments are
|
||||
passed: key and value.
|
||||
|
||||
If you want to construct instances differently, supply a *creator*
|
||||
function that takes arguments as above and returns instances.
|
||||
|
||||
For scalar relationships, creator() will be called if the target is None.
|
||||
If the target is present, set operations are proxied to setattr() on the
|
||||
associated object.
|
||||
|
||||
If you have an associated object with multiple attributes, you may set
|
||||
up multiple association proxies mapping to different attributes. See
|
||||
the unit tests for examples, and for examples of how creator() functions
|
||||
can be used to construct the scalar relationship on-demand in this
|
||||
situation.
|
||||
|
||||
:param \*\*kw: Passes along any other keyword arguments to
|
||||
:class:`AssociationProxy`.
|
||||
|
||||
"""
|
||||
return AssociationProxy(target_collection, attr, **kw)
|
||||
|
||||
|
||||
class AssociationProxy(object):
|
||||
"""A descriptor that presents a read/write view of an object attribute."""
|
||||
|
||||
def __init__(self, target_collection, attr, creator=None,
|
||||
getset_factory=None, proxy_factory=None, proxy_bulk_set=None):
|
||||
"""Arguments are:
|
||||
|
||||
target_collection
|
||||
Name of the collection we'll proxy to, usually created with
|
||||
'relationship()' in a mapper setup.
|
||||
|
||||
attr
|
||||
Attribute on the collected instances we'll proxy for. For example,
|
||||
given a target collection of [obj1, obj2], a list created by this
|
||||
proxy property would look like [getattr(obj1, attr), getattr(obj2,
|
||||
attr)]
|
||||
|
||||
creator
|
||||
Optional. When new items are added to this proxied collection, new
|
||||
instances of the class collected by the target collection will be
|
||||
created. For list and set collections, the target class constructor
|
||||
will be called with the 'value' for the new instance. For dict
|
||||
types, two arguments are passed: key and value.
|
||||
|
||||
If you want to construct instances differently, supply a 'creator'
|
||||
function that takes arguments as above and returns instances.
|
||||
|
||||
getset_factory
|
||||
Optional. Proxied attribute access is automatically handled by
|
||||
routines that get and set values based on the `attr` argument for
|
||||
this proxy.
|
||||
|
||||
If you would like to customize this behavior, you may supply a
|
||||
`getset_factory` callable that produces a tuple of `getter` and
|
||||
`setter` functions. The factory is called with two arguments, the
|
||||
abstract type of the underlying collection and this proxy instance.
|
||||
|
||||
proxy_factory
|
||||
Optional. The type of collection to emulate is determined by
|
||||
sniffing the target collection. If your collection type can't be
|
||||
determined by duck typing or you'd like to use a different
|
||||
collection implementation, you may supply a factory function to
|
||||
produce those collections. Only applicable to non-scalar relationships.
|
||||
|
||||
proxy_bulk_set
|
||||
Optional, use with proxy_factory. See the _set() method for
|
||||
details.
|
||||
|
||||
"""
|
||||
self.target_collection = target_collection
|
||||
self.value_attr = attr
|
||||
self.creator = creator
|
||||
self.getset_factory = getset_factory
|
||||
self.proxy_factory = proxy_factory
|
||||
self.proxy_bulk_set = proxy_bulk_set
|
||||
|
||||
self.scalar = None
|
||||
self.owning_class = None
|
||||
self.key = '_%s_%s_%s' % (
|
||||
type(self).__name__, target_collection, id(self))
|
||||
self.collection_class = None
|
||||
|
||||
def _get_property(self):
|
||||
return (orm.class_mapper(self.owning_class).
|
||||
get_property(self.target_collection))
|
||||
|
||||
@property
|
||||
def target_class(self):
|
||||
"""The class the proxy is attached to."""
|
||||
return self._get_property().mapper.class_
|
||||
|
||||
def _target_is_scalar(self):
|
||||
return not self._get_property().uselist
|
||||
|
||||
def __get__(self, obj, class_):
|
||||
if self.owning_class is None:
|
||||
self.owning_class = class_ and class_ or type(obj)
|
||||
if obj is None:
|
||||
return self
|
||||
elif self.scalar is None:
|
||||
self.scalar = self._target_is_scalar()
|
||||
if self.scalar:
|
||||
self._initialize_scalar_accessors()
|
||||
|
||||
if self.scalar:
|
||||
return self._scalar_get(getattr(obj, self.target_collection))
|
||||
else:
|
||||
try:
|
||||
# If the owning instance is reborn (orm session resurrect,
|
||||
# etc.), refresh the proxy cache.
|
||||
creator_id, proxy = getattr(obj, self.key)
|
||||
if id(obj) == creator_id:
|
||||
return proxy
|
||||
except AttributeError:
|
||||
pass
|
||||
proxy = self._new(_lazy_collection(obj, self.target_collection))
|
||||
setattr(obj, self.key, (id(obj), proxy))
|
||||
return proxy
|
||||
|
||||
def __set__(self, obj, values):
|
||||
if self.owning_class is None:
|
||||
self.owning_class = type(obj)
|
||||
if self.scalar is None:
|
||||
self.scalar = self._target_is_scalar()
|
||||
if self.scalar:
|
||||
self._initialize_scalar_accessors()
|
||||
|
||||
if self.scalar:
|
||||
creator = self.creator and self.creator or self.target_class
|
||||
target = getattr(obj, self.target_collection)
|
||||
if target is None:
|
||||
setattr(obj, self.target_collection, creator(values))
|
||||
else:
|
||||
self._scalar_set(target, values)
|
||||
else:
|
||||
proxy = self.__get__(obj, None)
|
||||
if proxy is not values:
|
||||
proxy.clear()
|
||||
self._set(proxy, values)
|
||||
|
||||
def __delete__(self, obj):
|
||||
if self.owning_class is None:
|
||||
self.owning_class = type(obj)
|
||||
delattr(obj, self.key)
|
||||
|
||||
def _initialize_scalar_accessors(self):
|
||||
if self.getset_factory:
|
||||
get, set = self.getset_factory(None, self)
|
||||
else:
|
||||
get, set = self._default_getset(None)
|
||||
self._scalar_get, self._scalar_set = get, set
|
||||
|
||||
def _default_getset(self, collection_class):
|
||||
attr = self.value_attr
|
||||
getter = operator.attrgetter(attr)
|
||||
if collection_class is dict:
|
||||
setter = lambda o, k, v: setattr(o, attr, v)
|
||||
else:
|
||||
setter = lambda o, v: setattr(o, attr, v)
|
||||
return getter, setter
|
||||
|
||||
def _new(self, lazy_collection):
|
||||
creator = self.creator and self.creator or self.target_class
|
||||
self.collection_class = util.duck_type_collection(lazy_collection())
|
||||
|
||||
if self.proxy_factory:
|
||||
return self.proxy_factory(lazy_collection, creator, self.value_attr, self)
|
||||
|
||||
if self.getset_factory:
|
||||
getter, setter = self.getset_factory(self.collection_class, self)
|
||||
else:
|
||||
getter, setter = self._default_getset(self.collection_class)
|
||||
|
||||
if self.collection_class is list:
|
||||
return _AssociationList(lazy_collection, creator, getter, setter, self)
|
||||
elif self.collection_class is dict:
|
||||
return _AssociationDict(lazy_collection, creator, getter, setter, self)
|
||||
elif self.collection_class is set:
|
||||
return _AssociationSet(lazy_collection, creator, getter, setter, self)
|
||||
else:
|
||||
raise exceptions.ArgumentError(
|
||||
'could not guess which interface to use for '
|
||||
'collection_class "%s" backing "%s"; specify a '
|
||||
'proxy_factory and proxy_bulk_set manually' %
|
||||
(self.collection_class.__name__, self.target_collection))
|
||||
|
||||
def _inflate(self, proxy):
|
||||
creator = self.creator and self.creator or self.target_class
|
||||
|
||||
if self.getset_factory:
|
||||
getter, setter = self.getset_factory(self.collection_class, self)
|
||||
else:
|
||||
getter, setter = self._default_getset(self.collection_class)
|
||||
|
||||
proxy.creator = creator
|
||||
proxy.getter = getter
|
||||
proxy.setter = setter
|
||||
|
||||
def _set(self, proxy, values):
|
||||
if self.proxy_bulk_set:
|
||||
self.proxy_bulk_set(proxy, values)
|
||||
elif self.collection_class is list:
|
||||
proxy.extend(values)
|
||||
elif self.collection_class is dict:
|
||||
proxy.update(values)
|
||||
elif self.collection_class is set:
|
||||
proxy.update(values)
|
||||
else:
|
||||
raise exceptions.ArgumentError(
|
||||
'no proxy_bulk_set supplied for custom '
|
||||
'collection_class implementation')
|
||||
|
||||
@property
|
||||
def _comparator(self):
|
||||
return self._get_property().comparator
|
||||
|
||||
def any(self, criterion=None, **kwargs):
|
||||
return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
|
||||
|
||||
def has(self, criterion=None, **kwargs):
|
||||
return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
|
||||
|
||||
def contains(self, obj):
|
||||
return self._comparator.any(**{self.value_attr: obj})
|
||||
|
||||
def __eq__(self, obj):
|
||||
return self._comparator.has(**{self.value_attr: obj})
|
||||
|
||||
def __ne__(self, obj):
|
||||
return not_(self.__eq__(obj))
|
||||
|
||||
|
||||
class _lazy_collection(object):
|
||||
def __init__(self, obj, target):
|
||||
self.ref = weakref.ref(obj)
|
||||
self.target = target
|
||||
|
||||
def __call__(self):
|
||||
obj = self.ref()
|
||||
if obj is None:
|
||||
raise exceptions.InvalidRequestError(
|
||||
"stale association proxy, parent object has gone out of "
|
||||
"scope")
|
||||
return getattr(obj, self.target)
|
||||
|
||||
def __getstate__(self):
|
||||
return {'obj':self.ref(), 'target':self.target}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.ref = weakref.ref(state['obj'])
|
||||
self.target = state['target']
|
||||
|
||||
class _AssociationCollection(object):
|
||||
def __init__(self, lazy_collection, creator, getter, setter, parent):
|
||||
"""Constructs an _AssociationCollection.
|
||||
|
||||
This will always be a subclass of either _AssociationList,
|
||||
_AssociationSet, or _AssociationDict.
|
||||
|
||||
lazy_collection
|
||||
A callable returning a list-based collection of entities (usually an
|
||||
object attribute managed by a SQLAlchemy relationship())
|
||||
|
||||
creator
|
||||
A function that creates new target entities. Given one parameter:
|
||||
value. This assertion is assumed::
|
||||
|
||||
obj = creator(somevalue)
|
||||
assert getter(obj) == somevalue
|
||||
|
||||
getter
|
||||
A function. Given an associated object, return the 'value'.
|
||||
|
||||
setter
|
||||
A function. Given an associated object and a value, store that
|
||||
value on the object.
|
||||
|
||||
"""
|
||||
self.lazy_collection = lazy_collection
|
||||
self.creator = creator
|
||||
self.getter = getter
|
||||
self.setter = setter
|
||||
self.parent = parent
|
||||
|
||||
col = property(lambda self: self.lazy_collection())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.col)
|
||||
|
||||
def __nonzero__(self):
|
||||
return bool(self.col)
|
||||
|
||||
def __getstate__(self):
|
||||
return {'parent':self.parent, 'lazy_collection':self.lazy_collection}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.parent = state['parent']
|
||||
self.lazy_collection = state['lazy_collection']
|
||||
self.parent._inflate(self)
|
||||
|
||||
class _AssociationList(_AssociationCollection):
|
||||
"""Generic, converting, list-to-list proxy."""
|
||||
|
||||
def _create(self, value):
|
||||
return self.creator(value)
|
||||
|
||||
def _get(self, object):
|
||||
return self.getter(object)
|
||||
|
||||
def _set(self, object, value):
|
||||
return self.setter(object, value)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._get(self.col[index])
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
if not isinstance(index, slice):
|
||||
self._set(self.col[index], value)
|
||||
else:
|
||||
if index.stop is None:
|
||||
stop = len(self)
|
||||
elif index.stop < 0:
|
||||
stop = len(self) + index.stop
|
||||
else:
|
||||
stop = index.stop
|
||||
step = index.step or 1
|
||||
|
||||
rng = range(index.start or 0, stop, step)
|
||||
if step == 1:
|
||||
for i in rng:
|
||||
del self[index.start]
|
||||
i = index.start
|
||||
for item in value:
|
||||
self.insert(i, item)
|
||||
i += 1
|
||||
else:
|
||||
if len(value) != len(rng):
|
||||
raise ValueError(
|
||||
"attempt to assign sequence of size %s to "
|
||||
"extended slice of size %s" % (len(value),
|
||||
len(rng)))
|
||||
for i, item in zip(rng, value):
|
||||
self._set(self.col[i], item)
|
||||
|
||||
def __delitem__(self, index):
|
||||
del self.col[index]
|
||||
|
||||
def __contains__(self, value):
|
||||
for member in self.col:
|
||||
# testlib.pragma exempt:__eq__
|
||||
if self._get(member) == value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __getslice__(self, start, end):
|
||||
return [self._get(member) for member in self.col[start:end]]
|
||||
|
||||
def __setslice__(self, start, end, values):
|
||||
members = [self._create(v) for v in values]
|
||||
self.col[start:end] = members
|
||||
|
||||
def __delslice__(self, start, end):
|
||||
del self.col[start:end]
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over proxied values.
|
||||
|
||||
For the actual domain objects, iterate over .col instead or
|
||||
just use the underlying collection directly from its property
|
||||
on the parent.
|
||||
"""
|
||||
|
||||
for member in self.col:
|
||||
yield self._get(member)
|
||||
raise StopIteration
|
||||
|
||||
def append(self, value):
|
||||
item = self._create(value)
|
||||
self.col.append(item)
|
||||
|
||||
def count(self, value):
|
||||
return sum([1 for _ in
|
||||
itertools.ifilter(lambda v: v == value, iter(self))])
|
||||
|
||||
def extend(self, values):
|
||||
for v in values:
|
||||
self.append(v)
|
||||
|
||||
def insert(self, index, value):
|
||||
self.col[index:index] = [self._create(value)]
|
||||
|
||||
def pop(self, index=-1):
|
||||
return self.getter(self.col.pop(index))
|
||||
|
||||
def remove(self, value):
|
||||
for i, val in enumerate(self):
|
||||
if val == value:
|
||||
del self.col[i]
|
||||
return
|
||||
raise ValueError("value not in list")
|
||||
|
||||
def reverse(self):
|
||||
"""Not supported, use reversed(mylist)"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def sort(self):
|
||||
"""Not supported, use sorted(mylist)"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def clear(self):
|
||||
del self.col[0:len(self.col)]
|
||||
|
||||
def __eq__(self, other):
|
||||
return list(self) == other
|
||||
|
||||
def __ne__(self, other):
|
||||
return list(self) != other
|
||||
|
||||
def __lt__(self, other):
|
||||
return list(self) < other
|
||||
|
||||
def __le__(self, other):
|
||||
return list(self) <= other
|
||||
|
||||
def __gt__(self, other):
|
||||
return list(self) > other
|
||||
|
||||
def __ge__(self, other):
|
||||
return list(self) >= other
|
||||
|
||||
def __cmp__(self, other):
|
||||
return cmp(list(self), other)
|
||||
|
||||
def __add__(self, iterable):
|
||||
try:
|
||||
other = list(iterable)
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
return list(self) + other
|
||||
|
||||
def __radd__(self, iterable):
|
||||
try:
|
||||
other = list(iterable)
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
return other + list(self)
|
||||
|
||||
def __mul__(self, n):
|
||||
if not isinstance(n, int):
|
||||
return NotImplemented
|
||||
return list(self) * n
|
||||
__rmul__ = __mul__
|
||||
|
||||
def __iadd__(self, iterable):
|
||||
self.extend(iterable)
|
||||
return self
|
||||
|
||||
def __imul__(self, n):
|
||||
# unlike a regular list *=, proxied __imul__ will generate unique
|
||||
# backing objects for each copy. *= on proxied lists is a bit of
|
||||
# a stretch anyhow, and this interpretation of the __imul__ contract
|
||||
# is more plausibly useful than copying the backing objects.
|
||||
if not isinstance(n, int):
|
||||
return NotImplemented
|
||||
if n == 0:
|
||||
self.clear()
|
||||
elif n > 1:
|
||||
self.extend(list(self) * (n - 1))
|
||||
return self
|
||||
|
||||
def copy(self):
|
||||
return list(self)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(list(self))
|
||||
|
||||
def __hash__(self):
|
||||
raise TypeError("%s objects are unhashable" % type(self).__name__)
|
||||
|
||||
for func_name, func in locals().items():
|
||||
if (util.callable(func) and func.func_name == func_name and
|
||||
not func.__doc__ and hasattr(list, func_name)):
|
||||
func.__doc__ = getattr(list, func_name).__doc__
|
||||
del func_name, func
|
||||
|
||||
|
||||
_NotProvided = util.symbol('_NotProvided')
|
||||
class _AssociationDict(_AssociationCollection):
|
||||
"""Generic, converting, dict-to-dict proxy."""
|
||||
|
||||
def _create(self, key, value):
|
||||
return self.creator(key, value)
|
||||
|
||||
def _get(self, object):
|
||||
return self.getter(object)
|
||||
|
||||
def _set(self, object, key, value):
|
||||
return self.setter(object, key, value)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._get(self.col[key])
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key in self.col:
|
||||
self._set(self.col[key], key, value)
|
||||
else:
|
||||
self.col[key] = self._create(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.col[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
# testlib.pragma exempt:__hash__
|
||||
return key in self.col
|
||||
|
||||
def has_key(self, key):
|
||||
# testlib.pragma exempt:__hash__
|
||||
return key in self.col
|
||||
|
||||
def __iter__(self):
|
||||
return self.col.iterkeys()
|
||||
|
||||
def clear(self):
|
||||
self.col.clear()
|
||||
|
||||
def __eq__(self, other):
|
||||
return dict(self) == other
|
||||
|
||||
def __ne__(self, other):
|
||||
return dict(self) != other
|
||||
|
||||
def __lt__(self, other):
|
||||
return dict(self) < other
|
||||
|
||||
def __le__(self, other):
|
||||
return dict(self) <= other
|
||||
|
||||
def __gt__(self, other):
|
||||
return dict(self) > other
|
||||
|
||||
def __ge__(self, other):
|
||||
return dict(self) >= other
|
||||
|
||||
def __cmp__(self, other):
|
||||
return cmp(dict(self), other)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(dict(self.items()))
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
if key not in self.col:
|
||||
self.col[key] = self._create(key, default)
|
||||
return default
|
||||
else:
|
||||
return self[key]
|
||||
|
||||
def keys(self):
|
||||
return self.col.keys()
|
||||
|
||||
def iterkeys(self):
|
||||
return self.col.iterkeys()
|
||||
|
||||
def values(self):
|
||||
return [ self._get(member) for member in self.col.values() ]
|
||||
|
||||
def itervalues(self):
|
||||
for key in self.col:
|
||||
yield self._get(self.col[key])
|
||||
raise StopIteration
|
||||
|
||||
def items(self):
|
||||
return [(k, self._get(self.col[k])) for k in self]
|
||||
|
||||
def iteritems(self):
|
||||
for key in self.col:
|
||||
yield (key, self._get(self.col[key]))
|
||||
raise StopIteration
|
||||
|
||||
def pop(self, key, default=_NotProvided):
|
||||
if default is _NotProvided:
|
||||
member = self.col.pop(key)
|
||||
else:
|
||||
member = self.col.pop(key, default)
|
||||
return self._get(member)
|
||||
|
||||
def popitem(self):
|
||||
item = self.col.popitem()
|
||||
return (item[0], self._get(item[1]))
|
||||
|
||||
def update(self, *a, **kw):
|
||||
if len(a) > 1:
|
||||
raise TypeError('update expected at most 1 arguments, got %i' %
|
||||
len(a))
|
||||
elif len(a) == 1:
|
||||
seq_or_map = a[0]
|
||||
for item in seq_or_map:
|
||||
if isinstance(item, tuple):
|
||||
self[item[0]] = item[1]
|
||||
else:
|
||||
self[item] = seq_or_map[item]
|
||||
|
||||
for key, value in kw:
|
||||
self[key] = value
|
||||
|
||||
def copy(self):
|
||||
return dict(self.items())
|
||||
|
||||
def __hash__(self):
|
||||
raise TypeError("%s objects are unhashable" % type(self).__name__)
|
||||
|
||||
for func_name, func in locals().items():
|
||||
if (util.callable(func) and func.func_name == func_name and
|
||||
not func.__doc__ and hasattr(dict, func_name)):
|
||||
func.__doc__ = getattr(dict, func_name).__doc__
|
||||
del func_name, func
|
||||
|
||||
|
||||
class _AssociationSet(_AssociationCollection):
|
||||
"""Generic, converting, set-to-set proxy."""
|
||||
|
||||
def _create(self, value):
|
||||
return self.creator(value)
|
||||
|
||||
def _get(self, object):
|
||||
return self.getter(object)
|
||||
|
||||
def _set(self, object, value):
|
||||
return self.setter(object, value)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.col)
|
||||
|
||||
def __nonzero__(self):
|
||||
if self.col:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def __contains__(self, value):
|
||||
for member in self.col:
|
||||
# testlib.pragma exempt:__eq__
|
||||
if self._get(member) == value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over proxied values.
|
||||
|
||||
For the actual domain objects, iterate over .col instead or just use
|
||||
the underlying collection directly from its property on the parent.
|
||||
|
||||
"""
|
||||
for member in self.col:
|
||||
yield self._get(member)
|
||||
raise StopIteration
|
||||
|
||||
def add(self, value):
|
||||
if value not in self:
|
||||
self.col.add(self._create(value))
|
||||
|
||||
# for discard and remove, choosing a more expensive check strategy rather
|
||||
# than call self.creator()
|
||||
def discard(self, value):
|
||||
for member in self.col:
|
||||
if self._get(member) == value:
|
||||
self.col.discard(member)
|
||||
break
|
||||
|
||||
def remove(self, value):
|
||||
for member in self.col:
|
||||
if self._get(member) == value:
|
||||
self.col.discard(member)
|
||||
return
|
||||
raise KeyError(value)
|
||||
|
||||
def pop(self):
|
||||
if not self.col:
|
||||
raise KeyError('pop from an empty set')
|
||||
member = self.col.pop()
|
||||
return self._get(member)
|
||||
|
||||
def update(self, other):
|
||||
for value in other:
|
||||
self.add(value)
|
||||
|
||||
def __ior__(self, other):
|
||||
if not collections._set_binops_check_strict(self, other):
|
||||
return NotImplemented
|
||||
for value in other:
|
||||
self.add(value)
|
||||
return self
|
||||
|
||||
def _set(self):
|
||||
return set(iter(self))
|
||||
|
||||
def union(self, other):
|
||||
return set(self).union(other)
|
||||
|
||||
__or__ = union
|
||||
|
||||
def difference(self, other):
|
||||
return set(self).difference(other)
|
||||
|
||||
__sub__ = difference
|
||||
|
||||
def difference_update(self, other):
|
||||
for value in other:
|
||||
self.discard(value)
|
||||
|
||||
def __isub__(self, other):
|
||||
if not collections._set_binops_check_strict(self, other):
|
||||
return NotImplemented
|
||||
for value in other:
|
||||
self.discard(value)
|
||||
return self
|
||||
|
||||
def intersection(self, other):
|
||||
return set(self).intersection(other)
|
||||
|
||||
__and__ = intersection
|
||||
|
||||
def intersection_update(self, other):
|
||||
want, have = self.intersection(other), set(self)
|
||||
|
||||
remove, add = have - want, want - have
|
||||
|
||||
for value in remove:
|
||||
self.remove(value)
|
||||
for value in add:
|
||||
self.add(value)
|
||||
|
||||
def __iand__(self, other):
|
||||
if not collections._set_binops_check_strict(self, other):
|
||||
return NotImplemented
|
||||
want, have = self.intersection(other), set(self)
|
||||
|
||||
remove, add = have - want, want - have
|
||||
|
||||
for value in remove:
|
||||
self.remove(value)
|
||||
for value in add:
|
||||
self.add(value)
|
||||
return self
|
||||
|
||||
def symmetric_difference(self, other):
|
||||
return set(self).symmetric_difference(other)
|
||||
|
||||
__xor__ = symmetric_difference
|
||||
|
||||
def symmetric_difference_update(self, other):
|
||||
want, have = self.symmetric_difference(other), set(self)
|
||||
|
||||
remove, add = have - want, want - have
|
||||
|
||||
for value in remove:
|
||||
self.remove(value)
|
||||
for value in add:
|
||||
self.add(value)
|
||||
|
||||
def __ixor__(self, other):
|
||||
if not collections._set_binops_check_strict(self, other):
|
||||
return NotImplemented
|
||||
want, have = self.symmetric_difference(other), set(self)
|
||||
|
||||
remove, add = have - want, want - have
|
||||
|
||||
for value in remove:
|
||||
self.remove(value)
|
||||
for value in add:
|
||||
self.add(value)
|
||||
return self
|
||||
|
||||
def issubset(self, other):
|
||||
return set(self).issubset(other)
|
||||
|
||||
def issuperset(self, other):
|
||||
return set(self).issuperset(other)
|
||||
|
||||
def clear(self):
|
||||
self.col.clear()
|
||||
|
||||
def copy(self):
|
||||
return set(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return set(self) == other
|
||||
|
||||
def __ne__(self, other):
|
||||
return set(self) != other
|
||||
|
||||
def __lt__(self, other):
|
||||
return set(self) < other
|
||||
|
||||
def __le__(self, other):
|
||||
return set(self) <= other
|
||||
|
||||
def __gt__(self, other):
|
||||
return set(self) > other
|
||||
|
||||
def __ge__(self, other):
|
||||
return set(self) >= other
|
||||
|
||||
def __repr__(self):
|
||||
return repr(set(self))
|
||||
|
||||
def __hash__(self):
|
||||
raise TypeError("%s objects are unhashable" % type(self).__name__)
|
||||
|
||||
for func_name, func in locals().items():
|
||||
if (util.callable(func) and func.func_name == func_name and
|
||||
not func.__doc__ and hasattr(set, func_name)):
|
||||
func.__doc__ = getattr(set, func_name).__doc__
|
||||
del func_name, func
|
||||
194
sqlalchemy/ext/compiler.py
Normal file
194
sqlalchemy/ext/compiler.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Provides an API for creation of custom ClauseElements and compilers.
|
||||
|
||||
Synopsis
|
||||
========
|
||||
|
||||
Usage involves the creation of one or more :class:`~sqlalchemy.sql.expression.ClauseElement`
|
||||
subclasses and one or more callables defining its compilation::
|
||||
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
|
||||
class MyColumn(ColumnClause):
|
||||
pass
|
||||
|
||||
@compiles(MyColumn)
|
||||
def compile_mycolumn(element, compiler, **kw):
|
||||
return "[%s]" % element.name
|
||||
|
||||
Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`,
|
||||
the base expression element for named column objects. The ``compiles``
|
||||
decorator registers itself with the ``MyColumn`` class so that it is invoked
|
||||
when the object is compiled to a string::
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
s = select([MyColumn('x'), MyColumn('y')])
|
||||
print str(s)
|
||||
|
||||
Produces::
|
||||
|
||||
SELECT [x], [y]
|
||||
|
||||
Dialect-specific compilation rules
|
||||
==================================
|
||||
|
||||
Compilers can also be made dialect-specific. The appropriate compiler will be
|
||||
invoked for the dialect in use::
|
||||
|
||||
from sqlalchemy.schema import DDLElement
|
||||
|
||||
class AlterColumn(DDLElement):
|
||||
|
||||
def __init__(self, column, cmd):
|
||||
self.column = column
|
||||
self.cmd = cmd
|
||||
|
||||
@compiles(AlterColumn)
|
||||
def visit_alter_column(element, compiler, **kw):
|
||||
return "ALTER COLUMN %s ..." % element.column.name
|
||||
|
||||
@compiles(AlterColumn, 'postgresql')
|
||||
def visit_alter_column(element, compiler, **kw):
|
||||
return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, element.column.name)
|
||||
|
||||
The second ``visit_alter_table`` will be invoked when any ``postgresql`` dialect is used.
|
||||
|
||||
Compiling sub-elements of a custom expression construct
|
||||
=======================================================
|
||||
|
||||
The ``compiler`` argument is the :class:`~sqlalchemy.engine.base.Compiled`
|
||||
object in use. This object can be inspected for any information about the
|
||||
in-progress compilation, including ``compiler.dialect``,
|
||||
``compiler.statement`` etc. The :class:`~sqlalchemy.sql.compiler.SQLCompiler`
|
||||
and :class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
|
||||
method which can be used for compilation of embedded attributes::
|
||||
|
||||
from sqlalchemy.sql.expression import Executable, ClauseElement
|
||||
|
||||
class InsertFromSelect(Executable, ClauseElement):
|
||||
def __init__(self, table, select):
|
||||
self.table = table
|
||||
self.select = select
|
||||
|
||||
@compiles(InsertFromSelect)
|
||||
def visit_insert_from_select(element, compiler, **kw):
|
||||
return "INSERT INTO %s (%s)" % (
|
||||
compiler.process(element.table, asfrom=True),
|
||||
compiler.process(element.select)
|
||||
)
|
||||
|
||||
insert = InsertFromSelect(t1, select([t1]).where(t1.c.x>5))
|
||||
print insert
|
||||
|
||||
Produces::
|
||||
|
||||
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)"
|
||||
|
||||
Cross Compiling between SQL and DDL compilers
|
||||
---------------------------------------------
|
||||
|
||||
SQL and DDL constructs are each compiled using different base compilers - ``SQLCompiler``
|
||||
and ``DDLCompiler``. A common need is to access the compilation rules of SQL expressions
|
||||
from within a DDL expression. The ``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as below where we generate a CHECK
|
||||
constraint that embeds a SQL expression::
|
||||
|
||||
@compiles(MyConstraint)
|
||||
def compile_my_constraint(constraint, ddlcompiler, **kw):
|
||||
return "CONSTRAINT %s CHECK (%s)" % (
|
||||
constraint.name,
|
||||
ddlcompiler.sql_compiler.process(constraint.expression)
|
||||
)
|
||||
|
||||
Changing the default compilation of existing constructs
|
||||
=======================================================
|
||||
|
||||
The compiler extension applies just as well to the existing constructs. When overriding
|
||||
the compilation of a built in SQL construct, the @compiles decorator is invoked upon
|
||||
the appropriate class (be sure to use the class, i.e. ``Insert`` or ``Select``, instead of the creation function such as ``insert()`` or ``select()``).
|
||||
|
||||
Within the new compilation function, to get at the "original" compilation routine,
|
||||
use the appropriate visit_XXX method - this because compiler.process() will call upon the
|
||||
overriding routine and cause an endless loop. Such as, to add "prefix" to all insert statements::
|
||||
|
||||
from sqlalchemy.sql.expression import Insert
|
||||
|
||||
@compiles(Insert)
|
||||
def prefix_inserts(insert, compiler, **kw):
|
||||
return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
|
||||
|
||||
The above compiler will prefix all INSERT statements with "some prefix" when compiled.
|
||||
|
||||
Subclassing Guidelines
|
||||
======================
|
||||
|
||||
A big part of using the compiler extension is subclassing SQLAlchemy expression constructs. To make this easier, the expression and schema packages feature a set of "bases" intended for common tasks. A synopsis is as follows:
|
||||
|
||||
* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root
|
||||
expression class. Any SQL expression can be derived from this base, and is
|
||||
probably the best choice for longer constructs such as specialized INSERT
|
||||
statements.
|
||||
|
||||
* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all
|
||||
"column-like" elements. Anything that you'd place in the "columns" clause of
|
||||
a SELECT statement (as well as order by and group by) can derive from this -
|
||||
the object will automatically have Python "comparison" behavior.
|
||||
|
||||
:class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a
|
||||
``type`` member which is expression's return type. This can be established
|
||||
at the instance level in the constructor, or at the class level if its
|
||||
generally constant::
|
||||
|
||||
class timestamp(ColumnElement):
|
||||
type = TIMESTAMP()
|
||||
|
||||
* :class:`~sqlalchemy.sql.expression.FunctionElement` - This is a hybrid of a
|
||||
``ColumnElement`` and a "from clause" like object, and represents a SQL
|
||||
function or stored procedure type of call. Since most databases support
|
||||
statements along the line of "SELECT FROM <some function>"
|
||||
``FunctionElement`` adds in the ability to be used in the FROM clause of a
|
||||
``select()`` construct.
|
||||
|
||||
* :class:`~sqlalchemy.schema.DDLElement` - The root of all DDL expressions,
|
||||
like CREATE TABLE, ALTER TABLE, etc. Compilation of ``DDLElement``
|
||||
subclasses is issued by a ``DDLCompiler`` instead of a ``SQLCompiler``.
|
||||
``DDLElement`` also features ``Table`` and ``MetaData`` event hooks via the
|
||||
``execute_at()`` method, allowing the construct to be invoked during CREATE
|
||||
TABLE and DROP TABLE sequences.
|
||||
|
||||
* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which should be
|
||||
used with any expression class that represents a "standalone" SQL statement that
|
||||
can be passed directly to an ``execute()`` method. It is already implicit
|
||||
within ``DDLElement`` and ``FunctionElement``.
|
||||
|
||||
"""
|
||||
|
||||
def compiles(class_, *specs):
|
||||
def decorate(fn):
|
||||
existing = getattr(class_, '_compiler_dispatcher', None)
|
||||
if not existing:
|
||||
existing = _dispatcher()
|
||||
|
||||
# TODO: why is the lambda needed ?
|
||||
setattr(class_, '_compiler_dispatch', lambda *arg, **kw: existing(*arg, **kw))
|
||||
setattr(class_, '_compiler_dispatcher', existing)
|
||||
|
||||
if specs:
|
||||
for s in specs:
|
||||
existing.specs[s] = fn
|
||||
else:
|
||||
existing.specs['default'] = fn
|
||||
return fn
|
||||
return decorate
|
||||
|
||||
class _dispatcher(object):
|
||||
def __init__(self):
|
||||
self.specs = {}
|
||||
|
||||
def __call__(self, element, compiler, **kw):
|
||||
# TODO: yes, this could also switch off of DBAPI in use.
|
||||
fn = self.specs.get(compiler.dialect.name, None)
|
||||
if not fn:
|
||||
fn = self.specs['default']
|
||||
return fn(element, compiler, **kw)
|
||||
|
||||
940
sqlalchemy/ext/declarative.py
Normal file
940
sqlalchemy/ext/declarative.py
Normal file
@@ -0,0 +1,940 @@
|
||||
"""
|
||||
Synopsis
|
||||
========
|
||||
|
||||
SQLAlchemy object-relational configuration involves the use of
|
||||
:class:`~sqlalchemy.schema.Table`, :func:`~sqlalchemy.orm.mapper`, and
|
||||
class objects to define the three areas of configuration.
|
||||
:mod:`~sqlalchemy.ext.declarative` allows all three types of
|
||||
configuration to be expressed declaratively on an individual
|
||||
mapped class. Regular SQLAlchemy schema elements and ORM constructs
|
||||
are used in most cases.
|
||||
|
||||
As a simple example::
|
||||
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class SomeClass(Base):
|
||||
__tablename__ = 'some_table'
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
Above, the :func:`declarative_base` callable returns a new base class from which
|
||||
all mapped classes should inherit. When the class definition is completed, a
|
||||
new :class:`~sqlalchemy.schema.Table` and
|
||||
:class:`~sqlalchemy.orm.mapper` will have been generated, accessible
|
||||
via the ``__table__`` and ``__mapper__`` attributes on the ``SomeClass`` class.
|
||||
|
||||
Defining Attributes
|
||||
===================
|
||||
|
||||
In the above example, the :class:`~sqlalchemy.schema.Column` objects are
|
||||
automatically named with the name of the attribute to which they are
|
||||
assigned.
|
||||
|
||||
They can also be explicitly named, and that name does not have to be
|
||||
the same as name assigned on the class.
|
||||
The column will be assigned to the :class:`~sqlalchemy.schema.Table` using the
|
||||
given name, and mapped to the class using the attribute name::
|
||||
|
||||
class SomeClass(Base):
|
||||
__tablename__ = 'some_table'
|
||||
id = Column("some_table_id", Integer, primary_key=True)
|
||||
name = Column("name", String(50))
|
||||
|
||||
Attributes may be added to the class after its construction, and they will be
|
||||
added to the underlying :class:`~sqlalchemy.schema.Table` and
|
||||
:func:`~sqlalchemy.orm.mapper()` definitions as appropriate::
|
||||
|
||||
SomeClass.data = Column('data', Unicode)
|
||||
SomeClass.related = relationship(RelatedInfo)
|
||||
|
||||
Classes which are mapped explicitly using
|
||||
:func:`~sqlalchemy.orm.mapper()` can interact freely with declarative
|
||||
classes.
|
||||
|
||||
It is recommended, though not required, that all tables
|
||||
share the same underlying :class:`~sqlalchemy.schema.MetaData` object,
|
||||
so that string-configured :class:`~sqlalchemy.schema.ForeignKey`
|
||||
references can be resolved without issue.
|
||||
|
||||
Association of Metadata and Engine
|
||||
==================================
|
||||
|
||||
The :func:`declarative_base` base class contains a
|
||||
:class:`~sqlalchemy.schema.MetaData` object where newly
|
||||
defined :class:`~sqlalchemy.schema.Table` objects are collected. This
|
||||
is accessed via the :class:`~sqlalchemy.schema.MetaData` class level
|
||||
accessor, so to create tables we can say::
|
||||
|
||||
engine = create_engine('sqlite://')
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
The :class:`~sqlalchemy.engine.base.Engine` created above may also be
|
||||
directly associated with the declarative base class using the ``bind``
|
||||
keyword argument, where it will be associated with the underlying
|
||||
:class:`~sqlalchemy.schema.MetaData` object and allow SQL operations
|
||||
involving that metadata and its tables to make use of that engine
|
||||
automatically::
|
||||
|
||||
Base = declarative_base(bind=create_engine('sqlite://'))
|
||||
|
||||
Alternatively, by way of the normal
|
||||
:class:`~sqlalchemy.schema.MetaData` behaviour, the ``bind`` attribute
|
||||
of the class level accessor can be assigned at any time as follows::
|
||||
|
||||
Base.metadata.bind = create_engine('sqlite://')
|
||||
|
||||
The :func:`declarative_base` can also receive a pre-created
|
||||
:class:`~sqlalchemy.schema.MetaData` object, which allows a
|
||||
declarative setup to be associated with an already
|
||||
existing traditional collection of :class:`~sqlalchemy.schema.Table`
|
||||
objects::
|
||||
|
||||
mymetadata = MetaData()
|
||||
Base = declarative_base(metadata=mymetadata)
|
||||
|
||||
Configuring Relationships
|
||||
=========================
|
||||
|
||||
Relationships to other classes are done in the usual way, with the added
|
||||
feature that the class specified to :func:`~sqlalchemy.orm.relationship`
|
||||
may be a string name (note that :func:`~sqlalchemy.orm.relationship` is
|
||||
only available as of SQLAlchemy 0.6beta2, and in all prior versions is known
|
||||
as :func:`~sqlalchemy.orm.relation`,
|
||||
including 0.5 and 0.4). The "class registry" associated with ``Base``
|
||||
is used at mapper compilation time to resolve the name into the actual
|
||||
class object, which is expected to have been defined once the mapper
|
||||
configuration is used::
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'users'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
addresses = relationship("Address", backref="user")
|
||||
|
||||
class Address(Base):
|
||||
__tablename__ = 'addresses'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
email = Column(String(50))
|
||||
user_id = Column(Integer, ForeignKey('users.id'))
|
||||
|
||||
Column constructs, since they are just that, are immediately usable,
|
||||
as below where we define a primary join condition on the ``Address``
|
||||
class using them::
|
||||
|
||||
class Address(Base):
|
||||
__tablename__ = 'addresses'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
email = Column(String(50))
|
||||
user_id = Column(Integer, ForeignKey('users.id'))
|
||||
user = relationship(User, primaryjoin=user_id == User.id)
|
||||
|
||||
In addition to the main argument for :func:`~sqlalchemy.orm.relationship`,
|
||||
other arguments which depend upon the columns present on an as-yet
|
||||
undefined class may also be specified as strings. These strings are
|
||||
evaluated as Python expressions. The full namespace available within
|
||||
this evaluation includes all classes mapped for this declarative base,
|
||||
as well as the contents of the ``sqlalchemy`` package, including
|
||||
expression functions like :func:`~sqlalchemy.sql.expression.desc` and
|
||||
:attr:`~sqlalchemy.sql.expression.func`::
|
||||
|
||||
class User(Base):
|
||||
# ....
|
||||
addresses = relationship("Address",
|
||||
order_by="desc(Address.email)",
|
||||
primaryjoin="Address.user_id==User.id")
|
||||
|
||||
As an alternative to string-based attributes, attributes may also be
|
||||
defined after all classes have been created. Just add them to the target
|
||||
class after the fact::
|
||||
|
||||
User.addresses = relationship(Address,
|
||||
primaryjoin=Address.user_id==User.id)
|
||||
|
||||
Configuring Many-to-Many Relationships
|
||||
======================================
|
||||
|
||||
There's nothing special about many-to-many with declarative. The
|
||||
``secondary`` argument to :func:`~sqlalchemy.orm.relationship` still
|
||||
requires a :class:`~sqlalchemy.schema.Table` object, not a declarative
|
||||
class. The :class:`~sqlalchemy.schema.Table` should share the same
|
||||
:class:`~sqlalchemy.schema.MetaData` object used by the declarative
|
||||
base::
|
||||
|
||||
keywords = Table(
|
||||
'keywords', Base.metadata,
|
||||
Column('author_id', Integer, ForeignKey('authors.id')),
|
||||
Column('keyword_id', Integer, ForeignKey('keywords.id'))
|
||||
)
|
||||
|
||||
class Author(Base):
|
||||
__tablename__ = 'authors'
|
||||
id = Column(Integer, primary_key=True)
|
||||
keywords = relationship("Keyword", secondary=keywords)
|
||||
|
||||
You should generally **not** map a class and also specify its table in
|
||||
a many-to-many relationship, since the ORM may issue duplicate INSERT and
|
||||
DELETE statements.
|
||||
|
||||
|
||||
Defining Synonyms
|
||||
=================
|
||||
|
||||
Synonyms are introduced in :ref:`synonyms`. To define a getter/setter
|
||||
which proxies to an underlying attribute, use
|
||||
:func:`~sqlalchemy.orm.synonym` with the ``descriptor`` argument::
|
||||
|
||||
class MyClass(Base):
|
||||
__tablename__ = 'sometable'
|
||||
|
||||
_attr = Column('attr', String)
|
||||
|
||||
def _get_attr(self):
|
||||
return self._some_attr
|
||||
def _set_attr(self, attr):
|
||||
self._some_attr = attr
|
||||
attr = synonym('_attr', descriptor=property(_get_attr, _set_attr))
|
||||
|
||||
The above synonym is then usable as an instance attribute as well as a
|
||||
class-level expression construct::
|
||||
|
||||
x = MyClass()
|
||||
x.attr = "some value"
|
||||
session.query(MyClass).filter(MyClass.attr == 'some other value').all()
|
||||
|
||||
For simple getters, the :func:`synonym_for` decorator can be used in
|
||||
conjunction with ``@property``::
|
||||
|
||||
class MyClass(Base):
|
||||
__tablename__ = 'sometable'
|
||||
|
||||
_attr = Column('attr', String)
|
||||
|
||||
@synonym_for('_attr')
|
||||
@property
|
||||
def attr(self):
|
||||
return self._some_attr
|
||||
|
||||
Similarly, :func:`comparable_using` is a front end for the
|
||||
:func:`~sqlalchemy.orm.comparable_property` ORM function::
|
||||
|
||||
class MyClass(Base):
|
||||
__tablename__ = 'sometable'
|
||||
|
||||
name = Column('name', String)
|
||||
|
||||
@comparable_using(MyUpperCaseComparator)
|
||||
@property
|
||||
def uc_name(self):
|
||||
return self.name.upper()
|
||||
|
||||
Table Configuration
|
||||
===================
|
||||
|
||||
Table arguments other than the name, metadata, and mapped Column
|
||||
arguments are specified using the ``__table_args__`` class attribute.
|
||||
This attribute accommodates both positional as well as keyword
|
||||
arguments that are normally sent to the
|
||||
:class:`~sqlalchemy.schema.Table` constructor.
|
||||
The attribute can be specified in one of two forms. One is as a
|
||||
dictionary::
|
||||
|
||||
class MyClass(Base):
|
||||
__tablename__ = 'sometable'
|
||||
__table_args__ = {'mysql_engine':'InnoDB'}
|
||||
|
||||
The other, a tuple of the form
|
||||
``(arg1, arg2, ..., {kwarg1:value, ...})``, which allows positional
|
||||
arguments to be specified as well (usually constraints)::
|
||||
|
||||
class MyClass(Base):
|
||||
__tablename__ = 'sometable'
|
||||
__table_args__ = (
|
||||
ForeignKeyConstraint(['id'], ['remote_table.id']),
|
||||
UniqueConstraint('foo'),
|
||||
{'autoload':True}
|
||||
)
|
||||
|
||||
Note that the keyword parameters dictionary is required in the tuple
|
||||
form even if empty.
|
||||
|
||||
As an alternative to ``__tablename__``, a direct
|
||||
:class:`~sqlalchemy.schema.Table` construct may be used. The
|
||||
:class:`~sqlalchemy.schema.Column` objects, which in this case require
|
||||
their names, will be added to the mapping just like a regular mapping
|
||||
to a table::
|
||||
|
||||
class MyClass(Base):
|
||||
__table__ = Table('my_table', Base.metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('name', String(50))
|
||||
)
|
||||
|
||||
Mapper Configuration
|
||||
====================
|
||||
|
||||
Configuration of mappers is done with the
|
||||
:func:`~sqlalchemy.orm.mapper` function and all the possible mapper
|
||||
configuration parameters can be found in the documentation for that
|
||||
function.
|
||||
|
||||
:func:`~sqlalchemy.orm.mapper` is still used by declaratively mapped
|
||||
classes and keyword parameters to the function can be passed by
|
||||
placing them in the ``__mapper_args__`` class variable::
|
||||
|
||||
class Widget(Base):
|
||||
__tablename__ = 'widgets'
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
__mapper_args__ = {'extension': MyWidgetExtension()}
|
||||
|
||||
Inheritance Configuration
|
||||
=========================
|
||||
|
||||
Declarative supports all three forms of inheritance as intuitively
|
||||
as possible. The ``inherits`` mapper keyword argument is not needed
|
||||
as declarative will determine this from the class itself. The various
|
||||
"polymorphic" keyword arguments are specified using ``__mapper_args__``.
|
||||
|
||||
Joined Table Inheritance
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Joined table inheritance is defined as a subclass that defines its own
|
||||
table::
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'people'
|
||||
id = Column(Integer, primary_key=True)
|
||||
discriminator = Column('type', String(50))
|
||||
__mapper_args__ = {'polymorphic_on': discriminator}
|
||||
|
||||
class Engineer(Person):
|
||||
__tablename__ = 'engineers'
|
||||
__mapper_args__ = {'polymorphic_identity': 'engineer'}
|
||||
id = Column(Integer, ForeignKey('people.id'), primary_key=True)
|
||||
primary_language = Column(String(50))
|
||||
|
||||
Note that above, the ``Engineer.id`` attribute, since it shares the
|
||||
same attribute name as the ``Person.id`` attribute, will in fact
|
||||
represent the ``people.id`` and ``engineers.id`` columns together, and
|
||||
will render inside a query as ``"people.id"``.
|
||||
To provide the ``Engineer`` class with an attribute that represents
|
||||
only the ``engineers.id`` column, give it a different attribute name::
|
||||
|
||||
class Engineer(Person):
|
||||
__tablename__ = 'engineers'
|
||||
__mapper_args__ = {'polymorphic_identity': 'engineer'}
|
||||
engineer_id = Column('id', Integer, ForeignKey('people.id'), primary_key=True)
|
||||
primary_language = Column(String(50))
|
||||
|
||||
Single Table Inheritance
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Single table inheritance is defined as a subclass that does not have
|
||||
its own table; you just leave out the ``__table__`` and ``__tablename__``
|
||||
attributes::
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'people'
|
||||
id = Column(Integer, primary_key=True)
|
||||
discriminator = Column('type', String(50))
|
||||
__mapper_args__ = {'polymorphic_on': discriminator}
|
||||
|
||||
class Engineer(Person):
|
||||
__mapper_args__ = {'polymorphic_identity': 'engineer'}
|
||||
primary_language = Column(String(50))
|
||||
|
||||
When the above mappers are configured, the ``Person`` class is mapped
|
||||
to the ``people`` table *before* the ``primary_language`` column is
|
||||
defined, and this column will not be included in its own mapping.
|
||||
When ``Engineer`` then defines the ``primary_language`` column, the
|
||||
column is added to the ``people`` table so that it is included in the
|
||||
mapping for ``Engineer`` and is also part of the table's full set of
|
||||
columns. Columns which are not mapped to ``Person`` are also excluded
|
||||
from any other single or joined inheriting classes using the
|
||||
``exclude_properties`` mapper argument. Below, ``Manager`` will have
|
||||
all the attributes of ``Person`` and ``Manager`` but *not* the
|
||||
``primary_language`` attribute of ``Engineer``::
|
||||
|
||||
class Manager(Person):
|
||||
__mapper_args__ = {'polymorphic_identity': 'manager'}
|
||||
golf_swing = Column(String(50))
|
||||
|
||||
The attribute exclusion logic is provided by the
|
||||
``exclude_properties`` mapper argument, and declarative's default
|
||||
behavior can be disabled by passing an explicit ``exclude_properties``
|
||||
collection (empty or otherwise) to the ``__mapper_args__``.
|
||||
|
||||
Concrete Table Inheritance
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Concrete is defined as a subclass which has its own table and sets the
|
||||
``concrete`` keyword argument to ``True``::
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'people'
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
class Engineer(Person):
|
||||
__tablename__ = 'engineers'
|
||||
__mapper_args__ = {'concrete':True}
|
||||
id = Column(Integer, primary_key=True)
|
||||
primary_language = Column(String(50))
|
||||
name = Column(String(50))
|
||||
|
||||
Usage of an abstract base class is a little less straightforward as it
|
||||
requires usage of :func:`~sqlalchemy.orm.util.polymorphic_union`::
|
||||
|
||||
engineers = Table('engineers', Base.metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('name', String(50)),
|
||||
Column('primary_language', String(50))
|
||||
)
|
||||
managers = Table('managers', Base.metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('name', String(50)),
|
||||
Column('golf_swing', String(50))
|
||||
)
|
||||
|
||||
punion = polymorphic_union({
|
||||
'engineer':engineers,
|
||||
'manager':managers
|
||||
}, 'type', 'punion')
|
||||
|
||||
class Person(Base):
|
||||
__table__ = punion
|
||||
__mapper_args__ = {'polymorphic_on':punion.c.type}
|
||||
|
||||
class Engineer(Person):
|
||||
__table__ = engineers
|
||||
__mapper_args__ = {'polymorphic_identity':'engineer', 'concrete':True}
|
||||
|
||||
class Manager(Person):
|
||||
__table__ = managers
|
||||
__mapper_args__ = {'polymorphic_identity':'manager', 'concrete':True}
|
||||
|
||||
|
||||
Mix-in Classes
|
||||
==============
|
||||
|
||||
A common need when using :mod:`~sqlalchemy.ext.declarative` is to
|
||||
share some functionality, often a set of columns, across many
|
||||
classes. The normal python idiom would be to put this common code into
|
||||
a base class and have all the other classes subclass this class.
|
||||
|
||||
When using :mod:`~sqlalchemy.ext.declarative`, this need is met by
|
||||
using a "mix-in class". A mix-in class is one that isn't mapped to a
|
||||
table and doesn't subclass the declarative :class:`Base`. For example::
|
||||
|
||||
class MyMixin(object):
|
||||
|
||||
__table_args__ = {'mysql_engine':'InnoDB'}
|
||||
__mapper_args__=dict(always_refresh=True)
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
def foo(self):
|
||||
return 'bar'+str(self.id)
|
||||
|
||||
class MyModel(Base,MyMixin):
|
||||
__tablename__='test'
|
||||
name = Column(String(1000), nullable=False, index=True)
|
||||
|
||||
As the above example shows, ``__table_args__`` and ``__mapper_args__``
|
||||
can both be abstracted out into a mix-in if you use common values for
|
||||
these across many classes.
|
||||
|
||||
However, particularly in the case of ``__table_args__``, you may want
|
||||
to combine some parameters from several mix-ins with those you wish to
|
||||
define on the class iteself. To help with this, a
|
||||
:func:`~sqlalchemy.util.classproperty` decorator is provided that lets
|
||||
you implement a class property with a function. For example::
|
||||
|
||||
from sqlalchemy.util import classproperty
|
||||
|
||||
class MySQLSettings:
|
||||
__table_args__ = {'mysql_engine':'InnoDB'}
|
||||
|
||||
class MyOtherMixin:
|
||||
__table_args__ = {'info':'foo'}
|
||||
|
||||
class MyModel(Base,MySQLSettings,MyOtherMixin):
|
||||
__tablename__='my_model'
|
||||
|
||||
@classproperty
|
||||
def __table_args__(self):
|
||||
args = dict()
|
||||
args.update(MySQLSettings.__table_args__)
|
||||
args.update(MyOtherMixin.__table_args__)
|
||||
return args
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
Class Constructor
|
||||
=================
|
||||
|
||||
As a convenience feature, the :func:`declarative_base` sets a default
|
||||
constructor on classes which takes keyword arguments, and assigns them
|
||||
to the named attributes::
|
||||
|
||||
e = Engineer(primary_language='python')
|
||||
|
||||
Sessions
|
||||
========
|
||||
|
||||
Note that ``declarative`` does nothing special with sessions, and is
|
||||
only intended as an easier way to configure mappers and
|
||||
:class:`~sqlalchemy.schema.Table` objects. A typical application
|
||||
setup using :func:`~sqlalchemy.orm.scoped_session` might look like::
|
||||
|
||||
engine = create_engine('postgresql://scott:tiger@localhost/test')
|
||||
Session = scoped_session(sessionmaker(autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine))
|
||||
Base = declarative_base()
|
||||
|
||||
Mapped instances then make usage of
|
||||
:class:`~sqlalchemy.orm.session.Session` in the usual way.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.schema import Table, Column, MetaData
|
||||
from sqlalchemy.orm import synonym as _orm_synonym, mapper, comparable_property, class_mapper
|
||||
from sqlalchemy.orm.interfaces import MapperProperty
|
||||
from sqlalchemy.orm.properties import RelationshipProperty, ColumnProperty
|
||||
from sqlalchemy.orm.util import _is_mapped_class
|
||||
from sqlalchemy import util, exceptions
|
||||
from sqlalchemy.sql import util as sql_util
|
||||
|
||||
|
||||
__all__ = 'declarative_base', 'synonym_for', 'comparable_using', 'instrument_declarative'
|
||||
|
||||
def instrument_declarative(cls, registry, metadata):
|
||||
"""Given a class, configure the class declaratively,
|
||||
using the given registry, which can be any dictionary, and
|
||||
MetaData object.
|
||||
|
||||
"""
|
||||
if '_decl_class_registry' in cls.__dict__:
|
||||
raise exceptions.InvalidRequestError(
|
||||
"Class %r already has been "
|
||||
"instrumented declaratively" % cls)
|
||||
cls._decl_class_registry = registry
|
||||
cls.metadata = metadata
|
||||
_as_declarative(cls, cls.__name__, cls.__dict__)
|
||||
|
||||
def _as_declarative(cls, classname, dict_):
|
||||
|
||||
# dict_ will be a dictproxy, which we can't write to, and we need to!
|
||||
dict_ = dict(dict_)
|
||||
|
||||
column_copies = dict()
|
||||
unmapped_mixins = False
|
||||
for base in cls.__bases__:
|
||||
names = dir(base)
|
||||
if not _is_mapped_class(base):
|
||||
unmapped_mixins = True
|
||||
for name in names:
|
||||
obj = getattr(base,name, None)
|
||||
if isinstance(obj, Column):
|
||||
if obj.foreign_keys:
|
||||
raise exceptions.InvalidRequestError(
|
||||
"Columns with foreign keys to other columns "
|
||||
"are not allowed on declarative mixins at this time."
|
||||
)
|
||||
dict_[name]=column_copies[obj]=obj.copy()
|
||||
elif isinstance(obj, RelationshipProperty):
|
||||
raise exceptions.InvalidRequestError(
|
||||
"relationships are not allowed on "
|
||||
"declarative mixins at this time.")
|
||||
|
||||
# doing it this way enables these attributes to be descriptors
|
||||
get_mapper_args = '__mapper_args__' in dict_
|
||||
get_table_args = '__table_args__' in dict_
|
||||
if unmapped_mixins:
|
||||
get_mapper_args = get_mapper_args or getattr(cls,'__mapper_args__',None)
|
||||
get_table_args = get_table_args or getattr(cls,'__table_args__',None)
|
||||
tablename = getattr(cls,'__tablename__',None)
|
||||
if tablename:
|
||||
# subtle: if tablename is a descriptor here, we actually
|
||||
# put the wrong value in, but it serves as a marker to get
|
||||
# the right value value...
|
||||
dict_['__tablename__']=tablename
|
||||
|
||||
# now that we know whether or not to get these, get them from the class
|
||||
# if we should, enabling them to be decorators
|
||||
mapper_args = get_mapper_args and cls.__mapper_args__ or {}
|
||||
table_args = get_table_args and cls.__table_args__ or None
|
||||
|
||||
# make sure that column copies are used rather than the original columns
|
||||
# from any mixins
|
||||
for k, v in mapper_args.iteritems():
|
||||
mapper_args[k] = column_copies.get(v,v)
|
||||
|
||||
cls._decl_class_registry[classname] = cls
|
||||
our_stuff = util.OrderedDict()
|
||||
for k in dict_:
|
||||
value = dict_[k]
|
||||
if (isinstance(value, tuple) and len(value) == 1 and
|
||||
isinstance(value[0], (Column, MapperProperty))):
|
||||
util.warn("Ignoring declarative-like tuple value of attribute "
|
||||
"%s: possibly a copy-and-paste error with a comma "
|
||||
"left at the end of the line?" % k)
|
||||
continue
|
||||
if not isinstance(value, (Column, MapperProperty)):
|
||||
continue
|
||||
prop = _deferred_relationship(cls, value)
|
||||
our_stuff[k] = prop
|
||||
|
||||
# set up attributes in the order they were created
|
||||
our_stuff.sort(key=lambda key: our_stuff[key]._creation_order)
|
||||
|
||||
# extract columns from the class dict
|
||||
cols = []
|
||||
for key, c in our_stuff.iteritems():
|
||||
if isinstance(c, ColumnProperty):
|
||||
for col in c.columns:
|
||||
if isinstance(col, Column) and col.table is None:
|
||||
_undefer_column_name(key, col)
|
||||
cols.append(col)
|
||||
elif isinstance(c, Column):
|
||||
_undefer_column_name(key, c)
|
||||
cols.append(c)
|
||||
# if the column is the same name as the key,
|
||||
# remove it from the explicit properties dict.
|
||||
# the normal rules for assigning column-based properties
|
||||
# will take over, including precedence of columns
|
||||
# in multi-column ColumnProperties.
|
||||
if key == c.key:
|
||||
del our_stuff[key]
|
||||
|
||||
table = None
|
||||
if '__table__' not in dict_:
|
||||
if '__tablename__' in dict_:
|
||||
# see above: if __tablename__ is a descriptor, this
|
||||
# means we get the right value used!
|
||||
tablename = cls.__tablename__
|
||||
|
||||
if isinstance(table_args, dict):
|
||||
args, table_kw = (), table_args
|
||||
elif isinstance(table_args, tuple):
|
||||
args = table_args[0:-1]
|
||||
table_kw = table_args[-1]
|
||||
if len(table_args) < 2 or not isinstance(table_kw, dict):
|
||||
raise exceptions.ArgumentError(
|
||||
"Tuple form of __table_args__ is "
|
||||
"(arg1, arg2, arg3, ..., {'kw1':val1, 'kw2':val2, ...})"
|
||||
)
|
||||
else:
|
||||
args, table_kw = (), {}
|
||||
|
||||
autoload = dict_.get('__autoload__')
|
||||
if autoload:
|
||||
table_kw['autoload'] = True
|
||||
|
||||
cls.__table__ = table = Table(tablename, cls.metadata,
|
||||
*(tuple(cols) + tuple(args)), **table_kw)
|
||||
else:
|
||||
table = cls.__table__
|
||||
if cols:
|
||||
for c in cols:
|
||||
if not table.c.contains_column(c):
|
||||
raise exceptions.ArgumentError(
|
||||
"Can't add additional column %r when specifying __table__" % key
|
||||
)
|
||||
|
||||
if 'inherits' not in mapper_args:
|
||||
for c in cls.__bases__:
|
||||
if _is_mapped_class(c):
|
||||
mapper_args['inherits'] = cls._decl_class_registry.get(c.__name__, None)
|
||||
break
|
||||
|
||||
if hasattr(cls, '__mapper_cls__'):
|
||||
mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
|
||||
else:
|
||||
mapper_cls = mapper
|
||||
|
||||
if table is None and 'inherits' not in mapper_args:
|
||||
raise exceptions.InvalidRequestError(
|
||||
"Class %r does not have a __table__ or __tablename__ "
|
||||
"specified and does not inherit from an existing table-mapped class." % cls
|
||||
)
|
||||
|
||||
elif 'inherits' in mapper_args and not mapper_args.get('concrete', False):
|
||||
inherited_mapper = class_mapper(mapper_args['inherits'], compile=False)
|
||||
inherited_table = inherited_mapper.local_table
|
||||
if 'inherit_condition' not in mapper_args and table is not None:
|
||||
# figure out the inherit condition with relaxed rules
|
||||
# about nonexistent tables, to allow for ForeignKeys to
|
||||
# not-yet-defined tables (since we know for sure that our
|
||||
# parent table is defined within the same MetaData)
|
||||
mapper_args['inherit_condition'] = sql_util.join_condition(
|
||||
mapper_args['inherits'].__table__, table,
|
||||
ignore_nonexistent_tables=True)
|
||||
|
||||
if table is None:
|
||||
# single table inheritance.
|
||||
# ensure no table args
|
||||
if table_args is not None:
|
||||
raise exceptions.ArgumentError(
|
||||
"Can't place __table_args__ on an inherited class with no table."
|
||||
)
|
||||
|
||||
# add any columns declared here to the inherited table.
|
||||
for c in cols:
|
||||
if c.primary_key:
|
||||
raise exceptions.ArgumentError(
|
||||
"Can't place primary key columns on an inherited class with no table."
|
||||
)
|
||||
if c.name in inherited_table.c:
|
||||
raise exceptions.ArgumentError(
|
||||
"Column '%s' on class %s conflicts with existing column '%s'" %
|
||||
(c, cls, inherited_table.c[c.name])
|
||||
)
|
||||
inherited_table.append_column(c)
|
||||
|
||||
# single or joined inheritance
|
||||
# exclude any cols on the inherited table which are not mapped on the
|
||||
# parent class, to avoid
|
||||
# mapping columns specific to sibling/nephew classes
|
||||
inherited_mapper = class_mapper(mapper_args['inherits'], compile=False)
|
||||
inherited_table = inherited_mapper.local_table
|
||||
|
||||
if 'exclude_properties' not in mapper_args:
|
||||
mapper_args['exclude_properties'] = exclude_properties = \
|
||||
set([c.key for c in inherited_table.c
|
||||
if c not in inherited_mapper._columntoproperty])
|
||||
exclude_properties.difference_update([c.key for c in cols])
|
||||
|
||||
cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args)
|
||||
|
||||
class DeclarativeMeta(type):
|
||||
def __init__(cls, classname, bases, dict_):
|
||||
if '_decl_class_registry' in cls.__dict__:
|
||||
return type.__init__(cls, classname, bases, dict_)
|
||||
|
||||
_as_declarative(cls, classname, cls.__dict__)
|
||||
return type.__init__(cls, classname, bases, dict_)
|
||||
|
||||
def __setattr__(cls, key, value):
|
||||
if '__mapper__' in cls.__dict__:
|
||||
if isinstance(value, Column):
|
||||
_undefer_column_name(key, value)
|
||||
cls.__table__.append_column(value)
|
||||
cls.__mapper__.add_property(key, value)
|
||||
elif isinstance(value, ColumnProperty):
|
||||
for col in value.columns:
|
||||
if isinstance(col, Column) and col.table is None:
|
||||
_undefer_column_name(key, col)
|
||||
cls.__table__.append_column(col)
|
||||
cls.__mapper__.add_property(key, value)
|
||||
elif isinstance(value, MapperProperty):
|
||||
cls.__mapper__.add_property(key, _deferred_relationship(cls, value))
|
||||
else:
|
||||
type.__setattr__(cls, key, value)
|
||||
else:
|
||||
type.__setattr__(cls, key, value)
|
||||
|
||||
|
||||
class _GetColumns(object):
|
||||
def __init__(self, cls):
|
||||
self.cls = cls
|
||||
def __getattr__(self, key):
|
||||
|
||||
mapper = class_mapper(self.cls, compile=False)
|
||||
if mapper:
|
||||
prop = mapper.get_property(key)
|
||||
if not isinstance(prop, ColumnProperty):
|
||||
raise exceptions.InvalidRequestError(
|
||||
"Property %r is not an instance of"
|
||||
" ColumnProperty (i.e. does not correspond"
|
||||
" directly to a Column)." % key)
|
||||
return getattr(self.cls, key)
|
||||
|
||||
|
||||
def _deferred_relationship(cls, prop):
|
||||
def resolve_arg(arg):
|
||||
import sqlalchemy
|
||||
|
||||
def access_cls(key):
|
||||
if key in cls._decl_class_registry:
|
||||
return _GetColumns(cls._decl_class_registry[key])
|
||||
elif key in cls.metadata.tables:
|
||||
return cls.metadata.tables[key]
|
||||
else:
|
||||
return sqlalchemy.__dict__[key]
|
||||
|
||||
d = util.PopulateDict(access_cls)
|
||||
def return_cls():
|
||||
try:
|
||||
x = eval(arg, globals(), d)
|
||||
|
||||
if isinstance(x, _GetColumns):
|
||||
return x.cls
|
||||
else:
|
||||
return x
|
||||
except NameError, n:
|
||||
raise exceptions.InvalidRequestError(
|
||||
"When compiling mapper %s, expression %r failed to locate a name (%r). "
|
||||
"If this is a class name, consider adding this relationship() to the %r "
|
||||
"class after both dependent classes have been defined." % (
|
||||
prop.parent, arg, n.args[0], cls))
|
||||
return return_cls
|
||||
|
||||
if isinstance(prop, RelationshipProperty):
|
||||
for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin',
|
||||
'secondary', '_foreign_keys', 'remote_side'):
|
||||
v = getattr(prop, attr)
|
||||
if isinstance(v, basestring):
|
||||
setattr(prop, attr, resolve_arg(v))
|
||||
|
||||
if prop.backref and isinstance(prop.backref, tuple):
|
||||
key, kwargs = prop.backref
|
||||
for attr in ('primaryjoin', 'secondaryjoin', 'secondary',
|
||||
'foreign_keys', 'remote_side', 'order_by'):
|
||||
if attr in kwargs and isinstance(kwargs[attr], basestring):
|
||||
kwargs[attr] = resolve_arg(kwargs[attr])
|
||||
|
||||
|
||||
return prop
|
||||
|
||||
def synonym_for(name, map_column=False):
|
||||
"""Decorator, make a Python @property a query synonym for a column.
|
||||
|
||||
A decorator version of :func:`~sqlalchemy.orm.synonym`. The function being
|
||||
decorated is the 'descriptor', otherwise passes its arguments through
|
||||
to synonym()::
|
||||
|
||||
@synonym_for('col')
|
||||
@property
|
||||
def prop(self):
|
||||
return 'special sauce'
|
||||
|
||||
The regular ``synonym()`` is also usable directly in a declarative setting
|
||||
and may be convenient for read/write properties::
|
||||
|
||||
prop = synonym('col', descriptor=property(_read_prop, _write_prop))
|
||||
|
||||
"""
|
||||
def decorate(fn):
|
||||
return _orm_synonym(name, map_column=map_column, descriptor=fn)
|
||||
return decorate
|
||||
|
||||
def comparable_using(comparator_factory):
|
||||
"""Decorator, allow a Python @property to be used in query criteria.
|
||||
|
||||
This is a decorator front end to
|
||||
:func:`~sqlalchemy.orm.comparable_property` that passes
|
||||
through the comparator_factory and the function being decorated::
|
||||
|
||||
@comparable_using(MyComparatorType)
|
||||
@property
|
||||
def prop(self):
|
||||
return 'special sauce'
|
||||
|
||||
The regular ``comparable_property()`` is also usable directly in a
|
||||
declarative setting and may be convenient for read/write properties::
|
||||
|
||||
prop = comparable_property(MyComparatorType)
|
||||
|
||||
"""
|
||||
def decorate(fn):
|
||||
return comparable_property(comparator_factory, fn)
|
||||
return decorate
|
||||
|
||||
def _declarative_constructor(self, **kwargs):
|
||||
"""A simple constructor that allows initialization from kwargs.
|
||||
|
||||
Sets attributes on the constructed instance using the names and
|
||||
values in ``kwargs``.
|
||||
|
||||
Only keys that are present as
|
||||
attributes of the instance's class are allowed. These could be,
|
||||
for example, any mapped columns or relationships.
|
||||
"""
|
||||
for k in kwargs:
|
||||
if not hasattr(type(self), k):
|
||||
raise TypeError(
|
||||
"%r is an invalid keyword argument for %s" %
|
||||
(k, type(self).__name__))
|
||||
setattr(self, k, kwargs[k])
|
||||
_declarative_constructor.__name__ = '__init__'
|
||||
|
||||
def declarative_base(bind=None, metadata=None, mapper=None, cls=object,
|
||||
name='Base', constructor=_declarative_constructor,
|
||||
metaclass=DeclarativeMeta):
|
||||
"""Construct a base class for declarative class definitions.
|
||||
|
||||
The new base class will be given a metaclass that produces
|
||||
appropriate :class:`~sqlalchemy.schema.Table` objects and makes
|
||||
the appropriate :func:`~sqlalchemy.orm.mapper` calls based on the
|
||||
information provided declaratively in the class and any subclasses
|
||||
of the class.
|
||||
|
||||
:param bind: An optional
|
||||
:class:`~sqlalchemy.engine.base.Connectable`, will be assigned
|
||||
the ``bind`` attribute on the :class:`~sqlalchemy.MetaData`
|
||||
instance.
|
||||
|
||||
|
||||
:param metadata:
|
||||
An optional :class:`~sqlalchemy.MetaData` instance. All
|
||||
:class:`~sqlalchemy.schema.Table` objects implicitly declared by
|
||||
subclasses of the base will share this MetaData. A MetaData instance
|
||||
will be created if none is provided. The
|
||||
:class:`~sqlalchemy.MetaData` instance will be available via the
|
||||
`metadata` attribute of the generated declarative base class.
|
||||
|
||||
:param mapper:
|
||||
An optional callable, defaults to :func:`~sqlalchemy.orm.mapper`. Will be
|
||||
used to map subclasses to their Tables.
|
||||
|
||||
:param cls:
|
||||
Defaults to :class:`object`. A type to use as the base for the generated
|
||||
declarative base class. May be a class or tuple of classes.
|
||||
|
||||
:param name:
|
||||
Defaults to ``Base``. The display name for the generated
|
||||
class. Customizing this is not required, but can improve clarity in
|
||||
tracebacks and debugging.
|
||||
|
||||
:param constructor:
|
||||
Defaults to
|
||||
:func:`~sqlalchemy.ext.declarative._declarative_constructor`, an
|
||||
__init__ implementation that assigns \**kwargs for declared
|
||||
fields and relationships to an instance. If ``None`` is supplied,
|
||||
no __init__ will be provided and construction will fall back to
|
||||
cls.__init__ by way of the normal Python semantics.
|
||||
|
||||
:param metaclass:
|
||||
Defaults to :class:`DeclarativeMeta`. A metaclass or __metaclass__
|
||||
compatible callable to use as the meta type of the generated
|
||||
declarative base class.
|
||||
|
||||
"""
|
||||
lcl_metadata = metadata or MetaData()
|
||||
if bind:
|
||||
lcl_metadata.bind = bind
|
||||
|
||||
bases = not isinstance(cls, tuple) and (cls,) or cls
|
||||
class_dict = dict(_decl_class_registry=dict(),
|
||||
metadata=lcl_metadata)
|
||||
|
||||
if constructor:
|
||||
class_dict['__init__'] = constructor
|
||||
if mapper:
|
||||
class_dict['__mapper_cls__'] = mapper
|
||||
|
||||
return metaclass(name, bases, class_dict)
|
||||
|
||||
def _undefer_column_name(key, column):
|
||||
if column.key is None:
|
||||
column.key = key
|
||||
if column.name is None:
|
||||
column.name = key
|
||||
125
sqlalchemy/ext/horizontal_shard.py
Normal file
125
sqlalchemy/ext/horizontal_shard.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# horizontal_shard.py
|
||||
# Copyright (C) the SQLAlchemy authors and contributors
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Horizontal sharding support.
|
||||
|
||||
Defines a rudimental 'horizontal sharding' system which allows a Session to
|
||||
distribute queries and persistence operations across multiple databases.
|
||||
|
||||
For a usage example, see the :ref:`examples_sharding` example included in
|
||||
the source distrbution.
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy.exceptions as sa_exc
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.orm.query import Query
|
||||
|
||||
__all__ = ['ShardedSession', 'ShardedQuery']
|
||||
|
||||
|
||||
class ShardedSession(Session):
|
||||
def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):
|
||||
"""Construct a ShardedSession.
|
||||
|
||||
:param shard_chooser: A callable which, passed a Mapper, a mapped instance, and possibly a
|
||||
SQL clause, returns a shard ID. This id may be based off of the
|
||||
attributes present within the object, or on some round-robin
|
||||
scheme. If the scheme is based on a selection, it should set
|
||||
whatever state on the instance to mark it in the future as
|
||||
participating in that shard.
|
||||
|
||||
:param id_chooser: A callable, passed a query and a tuple of identity values, which
|
||||
should return a list of shard ids where the ID might reside. The
|
||||
databases will be queried in the order of this listing.
|
||||
|
||||
:param query_chooser: For a given Query, returns the list of shard_ids where the query
|
||||
should be issued. Results from all shards returned will be combined
|
||||
together into a single listing.
|
||||
|
||||
:param shards: A dictionary of string shard names to :class:`~sqlalchemy.engine.base.Engine`
|
||||
objects.
|
||||
|
||||
"""
|
||||
super(ShardedSession, self).__init__(**kwargs)
|
||||
self.shard_chooser = shard_chooser
|
||||
self.id_chooser = id_chooser
|
||||
self.query_chooser = query_chooser
|
||||
self.__binds = {}
|
||||
self._mapper_flush_opts = {'connection_callable':self.connection}
|
||||
self._query_cls = ShardedQuery
|
||||
if shards is not None:
|
||||
for k in shards:
|
||||
self.bind_shard(k, shards[k])
|
||||
|
||||
def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
|
||||
if shard_id is None:
|
||||
shard_id = self.shard_chooser(mapper, instance)
|
||||
|
||||
if self.transaction is not None:
|
||||
return self.transaction.connection(mapper, shard_id=shard_id)
|
||||
else:
|
||||
return self.get_bind(mapper,
|
||||
shard_id=shard_id,
|
||||
instance=instance).contextual_connect(**kwargs)
|
||||
|
||||
def get_bind(self, mapper, shard_id=None, instance=None, clause=None, **kw):
|
||||
if shard_id is None:
|
||||
shard_id = self.shard_chooser(mapper, instance, clause=clause)
|
||||
return self.__binds[shard_id]
|
||||
|
||||
def bind_shard(self, shard_id, bind):
|
||||
self.__binds[shard_id] = bind
|
||||
|
||||
class ShardedQuery(Query):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ShardedQuery, self).__init__(*args, **kwargs)
|
||||
self.id_chooser = self.session.id_chooser
|
||||
self.query_chooser = self.session.query_chooser
|
||||
self._shard_id = None
|
||||
|
||||
def set_shard(self, shard_id):
|
||||
"""return a new query, limited to a single shard ID.
|
||||
|
||||
all subsequent operations with the returned query will
|
||||
be against the single shard regardless of other state.
|
||||
"""
|
||||
|
||||
q = self._clone()
|
||||
q._shard_id = shard_id
|
||||
return q
|
||||
|
||||
def _execute_and_instances(self, context):
|
||||
if self._shard_id is not None:
|
||||
result = self.session.connection(
|
||||
mapper=self._mapper_zero(),
|
||||
shard_id=self._shard_id).execute(context.statement, self._params)
|
||||
return self.instances(result, context)
|
||||
else:
|
||||
partial = []
|
||||
for shard_id in self.query_chooser(self):
|
||||
result = self.session.connection(
|
||||
mapper=self._mapper_zero(),
|
||||
shard_id=shard_id).execute(context.statement, self._params)
|
||||
partial = partial + list(self.instances(result, context))
|
||||
|
||||
# if some kind of in memory 'sorting'
|
||||
# were done, this is where it would happen
|
||||
return iter(partial)
|
||||
|
||||
def get(self, ident, **kwargs):
|
||||
if self._shard_id is not None:
|
||||
return super(ShardedQuery, self).get(ident)
|
||||
else:
|
||||
ident = util.to_list(ident)
|
||||
for shard_id in self.id_chooser(self, ident):
|
||||
o = self.set_shard(shard_id).get(ident, **kwargs)
|
||||
if o is not None:
|
||||
return o
|
||||
else:
|
||||
return None
|
||||
|
||||
315
sqlalchemy/ext/orderinglist.py
Normal file
315
sqlalchemy/ext/orderinglist.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""A custom list that manages index/position information for its children.
|
||||
|
||||
:author: Jason Kirtland
|
||||
|
||||
``orderinglist`` is a helper for mutable ordered relationships. It will intercept
|
||||
list operations performed on a relationship collection and automatically
|
||||
synchronize changes in list position with an attribute on the related objects.
|
||||
(See :ref:`advdatamapping_entitycollections` for more information on the general pattern.)
|
||||
|
||||
Example: Two tables that store slides in a presentation. Each slide
|
||||
has a number of bullet points, displayed in order by the 'position'
|
||||
column on the bullets table. These bullets can be inserted and re-ordered
|
||||
by your end users, and you need to update the 'position' column of all
|
||||
affected rows when changes are made.
|
||||
|
||||
.. sourcecode:: python+sql
|
||||
|
||||
slides_table = Table('Slides', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('name', String))
|
||||
|
||||
bullets_table = Table('Bullets', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('slide_id', Integer, ForeignKey('Slides.id')),
|
||||
Column('position', Integer),
|
||||
Column('text', String))
|
||||
|
||||
class Slide(object):
|
||||
pass
|
||||
class Bullet(object):
|
||||
pass
|
||||
|
||||
mapper(Slide, slides_table, properties={
|
||||
'bullets': relationship(Bullet, order_by=[bullets_table.c.position])
|
||||
})
|
||||
mapper(Bullet, bullets_table)
|
||||
|
||||
The standard relationship mapping will produce a list-like attribute on each Slide
|
||||
containing all related Bullets, but coping with changes in ordering is totally
|
||||
your responsibility. If you insert a Bullet into that list, there is no
|
||||
magic- it won't have a position attribute unless you assign it it one, and
|
||||
you'll need to manually renumber all the subsequent Bullets in the list to
|
||||
accommodate the insert.
|
||||
|
||||
An ``orderinglist`` can automate this and manage the 'position' attribute on all
|
||||
related bullets for you.
|
||||
|
||||
.. sourcecode:: python+sql
|
||||
|
||||
mapper(Slide, slides_table, properties={
|
||||
'bullets': relationship(Bullet,
|
||||
collection_class=ordering_list('position'),
|
||||
order_by=[bullets_table.c.position])
|
||||
})
|
||||
mapper(Bullet, bullets_table)
|
||||
|
||||
s = Slide()
|
||||
s.bullets.append(Bullet())
|
||||
s.bullets.append(Bullet())
|
||||
s.bullets[1].position
|
||||
>>> 1
|
||||
s.bullets.insert(1, Bullet())
|
||||
s.bullets[2].position
|
||||
>>> 2
|
||||
|
||||
Use the ``ordering_list`` function to set up the ``collection_class`` on relationships
|
||||
(as in the mapper example above). This implementation depends on the list
|
||||
starting in the proper order, so be SURE to put an order_by on your relationship.
|
||||
|
||||
.. warning:: ``ordering_list`` only provides limited functionality when a primary
|
||||
key column or unique column is the target of the sort. Since changing the order of
|
||||
entries often means that two rows must trade values, this is not possible when
|
||||
the value is constrained by a primary key or unique constraint, since one of the rows
|
||||
would temporarily have to point to a third available value so that the other row
|
||||
could take its old value. ``ordering_list`` doesn't do any of this for you,
|
||||
nor does SQLAlchemy itself.
|
||||
|
||||
``ordering_list`` takes the name of the related object's ordering attribute as
|
||||
an argument. By default, the zero-based integer index of the object's
|
||||
position in the ``ordering_list`` is synchronized with the ordering attribute:
|
||||
index 0 will get position 0, index 1 position 1, etc. To start numbering at 1
|
||||
or some other integer, provide ``count_from=1``.
|
||||
|
||||
Ordering values are not limited to incrementing integers. Almost any scheme
|
||||
can implemented by supplying a custom ``ordering_func`` that maps a Python list
|
||||
index to any value you require.
|
||||
|
||||
|
||||
|
||||
|
||||
"""
|
||||
from sqlalchemy.orm.collections import collection
|
||||
from sqlalchemy import util
|
||||
|
||||
__all__ = [ 'ordering_list' ]
|
||||
|
||||
|
||||
def ordering_list(attr, count_from=None, **kw):
|
||||
"""Prepares an OrderingList factory for use in mapper definitions.
|
||||
|
||||
Returns an object suitable for use as an argument to a Mapper relationship's
|
||||
``collection_class`` option. Arguments are:
|
||||
|
||||
attr
|
||||
Name of the mapped attribute to use for storage and retrieval of
|
||||
ordering information
|
||||
|
||||
count_from (optional)
|
||||
Set up an integer-based ordering, starting at ``count_from``. For
|
||||
example, ``ordering_list('pos', count_from=1)`` would create a 1-based
|
||||
list in SQL, storing the value in the 'pos' column. Ignored if
|
||||
``ordering_func`` is supplied.
|
||||
|
||||
Passes along any keyword arguments to ``OrderingList`` constructor.
|
||||
"""
|
||||
|
||||
kw = _unsugar_count_from(count_from=count_from, **kw)
|
||||
return lambda: OrderingList(attr, **kw)
|
||||
|
||||
# Ordering utility functions
|
||||
def count_from_0(index, collection):
|
||||
"""Numbering function: consecutive integers starting at 0."""
|
||||
|
||||
return index
|
||||
|
||||
def count_from_1(index, collection):
|
||||
"""Numbering function: consecutive integers starting at 1."""
|
||||
|
||||
return index + 1
|
||||
|
||||
def count_from_n_factory(start):
|
||||
"""Numbering function: consecutive integers starting at arbitrary start."""
|
||||
|
||||
def f(index, collection):
|
||||
return index + start
|
||||
try:
|
||||
f.__name__ = 'count_from_%i' % start
|
||||
except TypeError:
|
||||
pass
|
||||
return f
|
||||
|
||||
def _unsugar_count_from(**kw):
|
||||
"""Builds counting functions from keywrod arguments.
|
||||
|
||||
Keyword argument filter, prepares a simple ``ordering_func`` from a
|
||||
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
|
||||
"""
|
||||
|
||||
count_from = kw.pop('count_from', None)
|
||||
if kw.get('ordering_func', None) is None and count_from is not None:
|
||||
if count_from == 0:
|
||||
kw['ordering_func'] = count_from_0
|
||||
elif count_from == 1:
|
||||
kw['ordering_func'] = count_from_1
|
||||
else:
|
||||
kw['ordering_func'] = count_from_n_factory(count_from)
|
||||
return kw
|
||||
|
||||
class OrderingList(list):
|
||||
"""A custom list that manages position information for its children.
|
||||
|
||||
See the module and __init__ documentation for more details. The
|
||||
``ordering_list`` factory function is used to configure ``OrderingList``
|
||||
collections in ``mapper`` relationship definitions.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, ordering_attr=None, ordering_func=None,
|
||||
reorder_on_append=False):
|
||||
"""A custom list that manages position information for its children.
|
||||
|
||||
``OrderingList`` is a ``collection_class`` list implementation that
|
||||
syncs position in a Python list with a position attribute on the
|
||||
mapped objects.
|
||||
|
||||
This implementation relies on the list starting in the proper order,
|
||||
so be **sure** to put an ``order_by`` on your relationship.
|
||||
|
||||
ordering_attr
|
||||
Name of the attribute that stores the object's order in the
|
||||
relationship.
|
||||
|
||||
ordering_func
|
||||
Optional. A function that maps the position in the Python list to a
|
||||
value to store in the ``ordering_attr``. Values returned are
|
||||
usually (but need not be!) integers.
|
||||
|
||||
An ``ordering_func`` is called with two positional parameters: the
|
||||
index of the element in the list, and the list itself.
|
||||
|
||||
If omitted, Python list indexes are used for the attribute values.
|
||||
Two basic pre-built numbering functions are provided in this module:
|
||||
``count_from_0`` and ``count_from_1``. For more exotic examples
|
||||
like stepped numbering, alphabetical and Fibonacci numbering, see
|
||||
the unit tests.
|
||||
|
||||
reorder_on_append
|
||||
Default False. When appending an object with an existing (non-None)
|
||||
ordering value, that value will be left untouched unless
|
||||
``reorder_on_append`` is true. This is an optimization to avoid a
|
||||
variety of dangerous unexpected database writes.
|
||||
|
||||
SQLAlchemy will add instances to the list via append() when your
|
||||
object loads. If for some reason the result set from the database
|
||||
skips a step in the ordering (say, row '1' is missing but you get
|
||||
'2', '3', and '4'), reorder_on_append=True would immediately
|
||||
renumber the items to '1', '2', '3'. If you have multiple sessions
|
||||
making changes, any of whom happen to load this collection even in
|
||||
passing, all of the sessions would try to "clean up" the numbering
|
||||
in their commits, possibly causing all but one to fail with a
|
||||
concurrent modification error. Spooky action at a distance.
|
||||
|
||||
Recommend leaving this with the default of False, and just call
|
||||
``reorder()`` if you're doing ``append()`` operations with
|
||||
previously ordered instances or when doing some housekeeping after
|
||||
manual sql operations.
|
||||
|
||||
"""
|
||||
self.ordering_attr = ordering_attr
|
||||
if ordering_func is None:
|
||||
ordering_func = count_from_0
|
||||
self.ordering_func = ordering_func
|
||||
self.reorder_on_append = reorder_on_append
|
||||
|
||||
# More complex serialization schemes (multi column, e.g.) are possible by
|
||||
# subclassing and reimplementing these two methods.
|
||||
def _get_order_value(self, entity):
|
||||
return getattr(entity, self.ordering_attr)
|
||||
|
||||
def _set_order_value(self, entity, value):
|
||||
setattr(entity, self.ordering_attr, value)
|
||||
|
||||
def reorder(self):
|
||||
"""Synchronize ordering for the entire collection.
|
||||
|
||||
Sweeps through the list and ensures that each object has accurate
|
||||
ordering information set.
|
||||
|
||||
"""
|
||||
for index, entity in enumerate(self):
|
||||
self._order_entity(index, entity, True)
|
||||
|
||||
# As of 0.5, _reorder is no longer semi-private
|
||||
_reorder = reorder
|
||||
|
||||
def _order_entity(self, index, entity, reorder=True):
|
||||
have = self._get_order_value(entity)
|
||||
|
||||
# Don't disturb existing ordering if reorder is False
|
||||
if have is not None and not reorder:
|
||||
return
|
||||
|
||||
should_be = self.ordering_func(index, self)
|
||||
if have != should_be:
|
||||
self._set_order_value(entity, should_be)
|
||||
|
||||
def append(self, entity):
|
||||
super(OrderingList, self).append(entity)
|
||||
self._order_entity(len(self) - 1, entity, self.reorder_on_append)
|
||||
|
||||
def _raw_append(self, entity):
|
||||
"""Append without any ordering behavior."""
|
||||
|
||||
super(OrderingList, self).append(entity)
|
||||
_raw_append = collection.adds(1)(_raw_append)
|
||||
|
||||
def insert(self, index, entity):
|
||||
super(OrderingList, self).insert(index, entity)
|
||||
self._reorder()
|
||||
|
||||
def remove(self, entity):
|
||||
super(OrderingList, self).remove(entity)
|
||||
self._reorder()
|
||||
|
||||
def pop(self, index=-1):
|
||||
entity = super(OrderingList, self).pop(index)
|
||||
self._reorder()
|
||||
return entity
|
||||
|
||||
def __setitem__(self, index, entity):
|
||||
if isinstance(index, slice):
|
||||
step = index.step or 1
|
||||
start = index.start or 0
|
||||
if start < 0:
|
||||
start += len(self)
|
||||
stop = index.stop or len(self)
|
||||
if stop < 0:
|
||||
stop += len(self)
|
||||
|
||||
for i in xrange(start, stop, step):
|
||||
self.__setitem__(i, entity[i])
|
||||
else:
|
||||
self._order_entity(index, entity, True)
|
||||
super(OrderingList, self).__setitem__(index, entity)
|
||||
|
||||
def __delitem__(self, index):
|
||||
super(OrderingList, self).__delitem__(index)
|
||||
self._reorder()
|
||||
|
||||
# Py2K
|
||||
def __setslice__(self, start, end, values):
|
||||
super(OrderingList, self).__setslice__(start, end, values)
|
||||
self._reorder()
|
||||
|
||||
def __delslice__(self, start, end):
|
||||
super(OrderingList, self).__delslice__(start, end)
|
||||
self._reorder()
|
||||
# end Py2K
|
||||
|
||||
for func_name, func in locals().items():
|
||||
if (util.callable(func) and func.func_name == func_name and
|
||||
not func.__doc__ and hasattr(list, func_name)):
|
||||
func.__doc__ = getattr(list, func_name).__doc__
|
||||
del func_name, func
|
||||
|
||||
155
sqlalchemy/ext/serializer.py
Normal file
155
sqlalchemy/ext/serializer.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
|
||||
allowing "contextual" deserialization.
|
||||
|
||||
Any SQLAlchemy query structure, either based on sqlalchemy.sql.*
|
||||
or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session
|
||||
etc. which are referenced by the structure are not persisted in serialized
|
||||
form, but are instead re-associated with the query structure
|
||||
when it is deserialized.
|
||||
|
||||
Usage is nearly the same as that of the standard Python pickle module::
|
||||
|
||||
from sqlalchemy.ext.serializer import loads, dumps
|
||||
metadata = MetaData(bind=some_engine)
|
||||
Session = scoped_session(sessionmaker())
|
||||
|
||||
# ... define mappers
|
||||
|
||||
query = Session.query(MyClass).filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
|
||||
|
||||
# pickle the query
|
||||
serialized = dumps(query)
|
||||
|
||||
# unpickle. Pass in metadata + scoped_session
|
||||
query2 = loads(serialized, metadata, Session)
|
||||
|
||||
print query2.all()
|
||||
|
||||
Similar restrictions as when using raw pickle apply; mapped classes must be
|
||||
themselves be pickleable, meaning they are importable from a module-level
|
||||
namespace.
|
||||
|
||||
The serializer module is only appropriate for query structures. It is not
|
||||
needed for:
|
||||
|
||||
* instances of user-defined classes. These contain no references to engines,
|
||||
sessions or expression constructs in the typical case and can be serialized directly.
|
||||
|
||||
* Table metadata that is to be loaded entirely from the serialized structure (i.e. is
|
||||
not already declared in the application). Regular pickle.loads()/dumps() can
|
||||
be used to fully dump any ``MetaData`` object, typically one which was reflected
|
||||
from an existing database at some previous point in time. The serializer module
|
||||
is specifically for the opposite case, where the Table metadata is already present
|
||||
in memory.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import class_mapper, Query
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.attributes import QueryableAttribute
|
||||
from sqlalchemy import Table, Column
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.util import pickle
|
||||
import re
|
||||
import base64
|
||||
# Py3K
|
||||
#from io import BytesIO as byte_buffer
|
||||
# Py2K
|
||||
from cStringIO import StringIO as byte_buffer
|
||||
# end Py2K
|
||||
|
||||
# Py3K
|
||||
#def b64encode(x):
|
||||
# return base64.b64encode(x).decode('ascii')
|
||||
#def b64decode(x):
|
||||
# return base64.b64decode(x.encode('ascii'))
|
||||
# Py2K
|
||||
b64encode = base64.b64encode
|
||||
b64decode = base64.b64decode
|
||||
# end Py2K
|
||||
|
||||
__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads']
|
||||
|
||||
|
||||
|
||||
def Serializer(*args, **kw):
|
||||
pickler = pickle.Pickler(*args, **kw)
|
||||
|
||||
def persistent_id(obj):
|
||||
#print "serializing:", repr(obj)
|
||||
if isinstance(obj, QueryableAttribute):
|
||||
cls = obj.impl.class_
|
||||
key = obj.impl.key
|
||||
id = "attribute:" + key + ":" + b64encode(pickle.dumps(cls))
|
||||
elif isinstance(obj, Mapper) and not obj.non_primary:
|
||||
id = "mapper:" + b64encode(pickle.dumps(obj.class_))
|
||||
elif isinstance(obj, Table):
|
||||
id = "table:" + str(obj)
|
||||
elif isinstance(obj, Column) and isinstance(obj.table, Table):
|
||||
id = "column:" + str(obj.table) + ":" + obj.key
|
||||
elif isinstance(obj, Session):
|
||||
id = "session:"
|
||||
elif isinstance(obj, Engine):
|
||||
id = "engine:"
|
||||
else:
|
||||
return None
|
||||
return id
|
||||
|
||||
pickler.persistent_id = persistent_id
|
||||
return pickler
|
||||
|
||||
our_ids = re.compile(r'(mapper|table|column|session|attribute|engine):(.*)')
|
||||
|
||||
def Deserializer(file, metadata=None, scoped_session=None, engine=None):
|
||||
unpickler = pickle.Unpickler(file)
|
||||
|
||||
def get_engine():
|
||||
if engine:
|
||||
return engine
|
||||
elif scoped_session and scoped_session().bind:
|
||||
return scoped_session().bind
|
||||
elif metadata and metadata.bind:
|
||||
return metadata.bind
|
||||
else:
|
||||
return None
|
||||
|
||||
def persistent_load(id):
|
||||
m = our_ids.match(id)
|
||||
if not m:
|
||||
return None
|
||||
else:
|
||||
type_, args = m.group(1, 2)
|
||||
if type_ == 'attribute':
|
||||
key, clsarg = args.split(":")
|
||||
cls = pickle.loads(b64decode(clsarg))
|
||||
return getattr(cls, key)
|
||||
elif type_ == "mapper":
|
||||
cls = pickle.loads(b64decode(args))
|
||||
return class_mapper(cls)
|
||||
elif type_ == "table":
|
||||
return metadata.tables[args]
|
||||
elif type_ == "column":
|
||||
table, colname = args.split(':')
|
||||
return metadata.tables[table].c[colname]
|
||||
elif type_ == "session":
|
||||
return scoped_session()
|
||||
elif type_ == "engine":
|
||||
return get_engine()
|
||||
else:
|
||||
raise Exception("Unknown token: %s" % type_)
|
||||
unpickler.persistent_load = persistent_load
|
||||
return unpickler
|
||||
|
||||
def dumps(obj, protocol=0):
|
||||
buf = byte_buffer()
|
||||
pickler = Serializer(buf, protocol)
|
||||
pickler.dump(obj)
|
||||
return buf.getvalue()
|
||||
|
||||
def loads(data, metadata=None, scoped_session=None, engine=None):
|
||||
buf = byte_buffer(data)
|
||||
unpickler = Deserializer(buf, metadata, scoped_session, engine)
|
||||
return unpickler.load()
|
||||
|
||||
|
||||
551
sqlalchemy/ext/sqlsoup.py
Normal file
551
sqlalchemy/ext/sqlsoup.py
Normal file
@@ -0,0 +1,551 @@
|
||||
"""
|
||||
Introduction
|
||||
============
|
||||
|
||||
SqlSoup provides a convenient way to access existing database tables without
|
||||
having to declare table or mapper classes ahead of time. It is built on top of the SQLAlchemy ORM and provides a super-minimalistic interface to an existing database.
|
||||
|
||||
Suppose we have a database with users, books, and loans tables
|
||||
(corresponding to the PyWebOff dataset, if you're curious).
|
||||
|
||||
Creating a SqlSoup gateway is just like creating an SQLAlchemy
|
||||
engine::
|
||||
|
||||
>>> from sqlalchemy.ext.sqlsoup import SqlSoup
|
||||
>>> db = SqlSoup('sqlite:///:memory:')
|
||||
|
||||
or, you can re-use an existing engine::
|
||||
|
||||
>>> db = SqlSoup(engine)
|
||||
|
||||
You can optionally specify a schema within the database for your
|
||||
SqlSoup::
|
||||
|
||||
>>> db.schema = myschemaname
|
||||
|
||||
Loading objects
|
||||
===============
|
||||
|
||||
Loading objects is as easy as this::
|
||||
|
||||
>>> users = db.users.all()
|
||||
>>> users.sort()
|
||||
>>> users
|
||||
[MappedUsers(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0), MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)]
|
||||
|
||||
Of course, letting the database do the sort is better::
|
||||
|
||||
>>> db.users.order_by(db.users.name).all()
|
||||
[MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1), MappedUsers(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0)]
|
||||
|
||||
Field access is intuitive::
|
||||
|
||||
>>> users[0].email
|
||||
u'student@example.edu'
|
||||
|
||||
Of course, you don't want to load all users very often. Let's add a
|
||||
WHERE clause. Let's also switch the order_by to DESC while we're at
|
||||
it::
|
||||
|
||||
>>> from sqlalchemy import or_, and_, desc
|
||||
>>> where = or_(db.users.name=='Bhargan Basepair', db.users.email=='student@example.edu')
|
||||
>>> db.users.filter(where).order_by(desc(db.users.name)).all()
|
||||
[MappedUsers(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0), MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)]
|
||||
|
||||
You can also use .first() (to retrieve only the first object from a query) or
|
||||
.one() (like .first when you expect exactly one user -- it will raise an
|
||||
exception if more were returned)::
|
||||
|
||||
>>> db.users.filter(db.users.name=='Bhargan Basepair').one()
|
||||
MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)
|
||||
|
||||
Since name is the primary key, this is equivalent to
|
||||
|
||||
>>> db.users.get('Bhargan Basepair')
|
||||
MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)
|
||||
|
||||
This is also equivalent to
|
||||
|
||||
>>> db.users.filter_by(name='Bhargan Basepair').one()
|
||||
MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu',password=u'basepair',classname=None,admin=1)
|
||||
|
||||
filter_by is like filter, but takes kwargs instead of full clause expressions.
|
||||
This makes it more concise for simple queries like this, but you can't do
|
||||
complex queries like the or\_ above or non-equality based comparisons this way.
|
||||
|
||||
Full query documentation
|
||||
------------------------
|
||||
|
||||
Get, filter, filter_by, order_by, limit, and the rest of the
|
||||
query methods are explained in detail in :ref:`ormtutorial_querying`.
|
||||
|
||||
Modifying objects
|
||||
=================
|
||||
|
||||
Modifying objects is intuitive::
|
||||
|
||||
>>> user = _
|
||||
>>> user.email = 'basepair+nospam@example.edu'
|
||||
>>> db.commit()
|
||||
|
||||
(SqlSoup leverages the sophisticated SQLAlchemy unit-of-work code, so
|
||||
multiple updates to a single object will be turned into a single
|
||||
``UPDATE`` statement when you commit.)
|
||||
|
||||
To finish covering the basics, let's insert a new loan, then delete
|
||||
it::
|
||||
|
||||
>>> book_id = db.books.filter_by(title='Regional Variation in Moss').first().id
|
||||
>>> db.loans.insert(book_id=book_id, user_name=user.name)
|
||||
MappedLoans(book_id=2,user_name=u'Bhargan Basepair',loan_date=None)
|
||||
|
||||
>>> loan = db.loans.filter_by(book_id=2, user_name='Bhargan Basepair').one()
|
||||
>>> db.delete(loan)
|
||||
>>> db.commit()
|
||||
|
||||
You can also delete rows that have not been loaded as objects. Let's
|
||||
do our insert/delete cycle once more, this time using the loans
|
||||
table's delete method. (For SQLAlchemy experts: note that no flush()
|
||||
call is required since this delete acts at the SQL level, not at the
|
||||
Mapper level.) The same where-clause construction rules apply here as
|
||||
to the select methods.
|
||||
|
||||
::
|
||||
|
||||
>>> db.loans.insert(book_id=book_id, user_name=user.name)
|
||||
MappedLoans(book_id=2,user_name=u'Bhargan Basepair',loan_date=None)
|
||||
>>> db.loans.delete(db.loans.book_id==2)
|
||||
|
||||
You can similarly update multiple rows at once. This will change the
|
||||
book_id to 1 in all loans whose book_id is 2::
|
||||
|
||||
>>> db.loans.update(db.loans.book_id==2, book_id=1)
|
||||
>>> db.loans.filter_by(book_id=1).all()
|
||||
[MappedLoans(book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))]
|
||||
|
||||
|
||||
Joins
|
||||
=====
|
||||
|
||||
Occasionally, you will want to pull out a lot of data from related
|
||||
tables all at once. In this situation, it is far more efficient to
|
||||
have the database perform the necessary join. (Here we do not have *a
|
||||
lot of data* but hopefully the concept is still clear.) SQLAlchemy is
|
||||
smart enough to recognize that loans has a foreign key to users, and
|
||||
uses that as the join condition automatically.
|
||||
|
||||
::
|
||||
|
||||
>>> join1 = db.join(db.users, db.loans, isouter=True)
|
||||
>>> join1.filter_by(name='Joe Student').all()
|
||||
[MappedJoin(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0,book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))]
|
||||
|
||||
If you're unfortunate enough to be using MySQL with the default MyISAM
|
||||
storage engine, you'll have to specify the join condition manually,
|
||||
since MyISAM does not store foreign keys. Here's the same join again,
|
||||
with the join condition explicitly specified::
|
||||
|
||||
>>> db.join(db.users, db.loans, db.users.name==db.loans.user_name, isouter=True)
|
||||
<class 'sqlalchemy.ext.sqlsoup.MappedJoin'>
|
||||
|
||||
You can compose arbitrarily complex joins by combining Join objects
|
||||
with tables or other joins. Here we combine our first join with the
|
||||
books table::
|
||||
|
||||
>>> join2 = db.join(join1, db.books)
|
||||
>>> join2.all()
|
||||
[MappedJoin(name=u'Joe Student',email=u'student@example.edu',password=u'student',classname=None,admin=0,book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0),id=1,title=u'Mustards I Have Known',published_year=u'1989',authors=u'Jones')]
|
||||
|
||||
If you join tables that have an identical column name, wrap your join
|
||||
with `with_labels`, to disambiguate columns with their table name
|
||||
(.c is short for .columns)::
|
||||
|
||||
>>> db.with_labels(join1).c.keys()
|
||||
[u'users_name', u'users_email', u'users_password', u'users_classname', u'users_admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date']
|
||||
|
||||
You can also join directly to a labeled object::
|
||||
|
||||
>>> labeled_loans = db.with_labels(db.loans)
|
||||
>>> db.join(db.users, labeled_loans, isouter=True).c.keys()
|
||||
[u'name', u'email', u'password', u'classname', u'admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date']
|
||||
|
||||
|
||||
Relationships
|
||||
=============
|
||||
|
||||
You can define relationships on SqlSoup classes:
|
||||
|
||||
>>> db.users.relate('loans', db.loans)
|
||||
|
||||
These can then be used like a normal SA property:
|
||||
|
||||
>>> db.users.get('Joe Student').loans
|
||||
[MappedLoans(book_id=1,user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))]
|
||||
|
||||
>>> db.users.filter(~db.users.loans.any()).all()
|
||||
[MappedUsers(name=u'Bhargan Basepair',email='basepair+nospam@example.edu',password=u'basepair',classname=None,admin=1)]
|
||||
|
||||
|
||||
relate can take any options that the relationship function accepts in normal mapper definition:
|
||||
|
||||
>>> del db._cache['users']
|
||||
>>> db.users.relate('loans', db.loans, order_by=db.loans.loan_date, cascade='all, delete-orphan')
|
||||
|
||||
Advanced Use
|
||||
============
|
||||
|
||||
Sessions, Transations and Application Integration
|
||||
-------------------------------------------------
|
||||
|
||||
**Note:** please read and understand this section thoroughly before using SqlSoup in any web application.
|
||||
|
||||
SqlSoup uses a ScopedSession to provide thread-local sessions. You
|
||||
can get a reference to the current one like this::
|
||||
|
||||
>>> session = db.session
|
||||
|
||||
The default session is available at the module level in SQLSoup, via::
|
||||
|
||||
>>> from sqlalchemy.ext.sqlsoup import Session
|
||||
|
||||
The configuration of this session is ``autoflush=True``, ``autocommit=False``.
|
||||
This means when you work with the SqlSoup object, you need to call ``db.commit()``
|
||||
in order to have changes persisted. You may also call ``db.rollback()`` to
|
||||
roll things back.
|
||||
|
||||
Since the SqlSoup object's Session automatically enters into a transaction as soon
|
||||
as it's used, it is *essential* that you call ``commit()`` or ``rollback()``
|
||||
on it when the work within a thread completes. This means all the guidelines
|
||||
for web application integration at :ref:`session_lifespan` must be followed.
|
||||
|
||||
The SqlSoup object can have any session or scoped session configured onto it.
|
||||
This is of key importance when integrating with existing code or frameworks
|
||||
such as Pylons. If your application already has a ``Session`` configured,
|
||||
pass it to your SqlSoup object::
|
||||
|
||||
>>> from myapplication import Session
|
||||
>>> db = SqlSoup(session=Session)
|
||||
|
||||
If the ``Session`` is configured with ``autocommit=True``, use ``flush()``
|
||||
instead of ``commit()`` to persist changes - in this case, the ``Session``
|
||||
closes out its transaction immediately and no external management is needed. ``rollback()`` is also not available. Configuring a new SQLSoup object in "autocommit" mode looks like::
|
||||
|
||||
>>> from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
>>> db = SqlSoup('sqlite://', session=scoped_session(sessionmaker(autoflush=False, expire_on_commit=False, autocommit=True)))
|
||||
|
||||
|
||||
Mapping arbitrary Selectables
|
||||
-----------------------------
|
||||
|
||||
SqlSoup can map any SQLAlchemy ``Selectable`` with the map
|
||||
method. Let's map a ``Select`` object that uses an aggregate function;
|
||||
we'll use the SQLAlchemy ``Table`` that SqlSoup introspected as the
|
||||
basis. (Since we're not mapping to a simple table or join, we need to
|
||||
tell SQLAlchemy how to find the *primary key* which just needs to be
|
||||
unique within the select, and not necessarily correspond to a *real*
|
||||
PK in the database.)
|
||||
|
||||
::
|
||||
|
||||
>>> from sqlalchemy import select, func
|
||||
>>> b = db.books._table
|
||||
>>> s = select([b.c.published_year, func.count('*').label('n')], from_obj=[b], group_by=[b.c.published_year])
|
||||
>>> s = s.alias('years_with_count')
|
||||
>>> years_with_count = db.map(s, primary_key=[s.c.published_year])
|
||||
>>> years_with_count.filter_by(published_year='1989').all()
|
||||
[MappedBooks(published_year=u'1989',n=1)]
|
||||
|
||||
Obviously if we just wanted to get a list of counts associated with
|
||||
book years once, raw SQL is going to be less work. The advantage of
|
||||
mapping a Select is reusability, both standalone and in Joins. (And if
|
||||
you go to full SQLAlchemy, you can perform mappings like this directly
|
||||
to your object models.)
|
||||
|
||||
An easy way to save mapped selectables like this is to just hang them on
|
||||
your db object::
|
||||
|
||||
>>> db.years_with_count = years_with_count
|
||||
|
||||
Python is flexible like that!
|
||||
|
||||
|
||||
Raw SQL
|
||||
-------
|
||||
|
||||
SqlSoup works fine with SQLAlchemy's text construct, described in :ref:`sqlexpression_text`.
|
||||
You can also execute textual SQL directly using the `execute()` method,
|
||||
which corresponds to the `execute()` method on the underlying `Session`.
|
||||
Expressions here are expressed like ``text()`` constructs, using named parameters
|
||||
with colons::
|
||||
|
||||
>>> rp = db.execute('select name, email from users where name like :name order by name', name='%Bhargan%')
|
||||
>>> for name, email in rp.fetchall(): print name, email
|
||||
Bhargan Basepair basepair+nospam@example.edu
|
||||
|
||||
Or you can get at the current transaction's connection using `connection()`. This is the
|
||||
raw connection object which can accept any sort of SQL expression or raw SQL string passed to the database::
|
||||
|
||||
>>> conn = db.connection()
|
||||
>>> conn.execute("'select name, email from users where name like ? order by name'", '%Bhargan%')
|
||||
|
||||
|
||||
Dynamic table names
|
||||
-------------------
|
||||
|
||||
You can load a table whose name is specified at runtime with the entity() method:
|
||||
|
||||
>>> tablename = 'loans'
|
||||
>>> db.entity(tablename) == db.loans
|
||||
True
|
||||
|
||||
entity() also takes an optional schema argument. If none is specified, the
|
||||
default schema is used.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import Table, MetaData, join
|
||||
from sqlalchemy import schema, sql
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker, mapper, \
|
||||
class_mapper, relationship, session,\
|
||||
object_session
|
||||
from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE
|
||||
from sqlalchemy.exceptions import SQLAlchemyError, InvalidRequestError, ArgumentError
|
||||
from sqlalchemy.sql import expression
|
||||
|
||||
|
||||
__all__ = ['PKNotFoundError', 'SqlSoup']
|
||||
|
||||
Session = scoped_session(sessionmaker(autoflush=True, autocommit=False))
|
||||
|
||||
class AutoAdd(MapperExtension):
|
||||
def __init__(self, scoped_session):
|
||||
self.scoped_session = scoped_session
|
||||
|
||||
def instrument_class(self, mapper, class_):
|
||||
class_.__init__ = self._default__init__(mapper)
|
||||
|
||||
def _default__init__(ext, mapper):
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.iteritems():
|
||||
setattr(self, key, value)
|
||||
return __init__
|
||||
|
||||
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
|
||||
session = self.scoped_session()
|
||||
session._save_without_cascade(instance)
|
||||
return EXT_CONTINUE
|
||||
|
||||
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
|
||||
sess = object_session(instance)
|
||||
if sess:
|
||||
sess.expunge(instance)
|
||||
return EXT_CONTINUE
|
||||
|
||||
class PKNotFoundError(SQLAlchemyError):
|
||||
pass
|
||||
|
||||
def _ddl_error(cls):
|
||||
msg = 'SQLSoup can only modify mapped Tables (found: %s)' \
|
||||
% cls._table.__class__.__name__
|
||||
raise InvalidRequestError(msg)
|
||||
|
||||
# metaclass is necessary to expose class methods with getattr, e.g.
|
||||
# we want to pass db.users.select through to users._mapper.select
|
||||
class SelectableClassType(type):
|
||||
def insert(cls, **kwargs):
|
||||
_ddl_error(cls)
|
||||
|
||||
def __clause_element__(cls):
|
||||
return cls._table
|
||||
|
||||
def __getattr__(cls, attr):
|
||||
if attr == '_query':
|
||||
# called during mapper init
|
||||
raise AttributeError()
|
||||
return getattr(cls._query, attr)
|
||||
|
||||
class TableClassType(SelectableClassType):
|
||||
def insert(cls, **kwargs):
|
||||
o = cls()
|
||||
o.__dict__.update(kwargs)
|
||||
return o
|
||||
|
||||
def relate(cls, propname, *args, **kwargs):
|
||||
class_mapper(cls)._configure_property(propname, relationship(*args, **kwargs))
|
||||
|
||||
def _is_outer_join(selectable):
|
||||
if not isinstance(selectable, sql.Join):
|
||||
return False
|
||||
if selectable.isouter:
|
||||
return True
|
||||
return _is_outer_join(selectable.left) or _is_outer_join(selectable.right)
|
||||
|
||||
def _selectable_name(selectable):
|
||||
if isinstance(selectable, sql.Alias):
|
||||
return _selectable_name(selectable.element)
|
||||
elif isinstance(selectable, sql.Select):
|
||||
return ''.join(_selectable_name(s) for s in selectable.froms)
|
||||
elif isinstance(selectable, schema.Table):
|
||||
return selectable.name.capitalize()
|
||||
else:
|
||||
x = selectable.__class__.__name__
|
||||
if x[0] == '_':
|
||||
x = x[1:]
|
||||
return x
|
||||
|
||||
def _class_for_table(session, engine, selectable, **mapper_kwargs):
|
||||
selectable = expression._clause_element_as_expr(selectable)
|
||||
mapname = 'Mapped' + _selectable_name(selectable)
|
||||
# Py2K
|
||||
if isinstance(mapname, unicode):
|
||||
engine_encoding = engine.dialect.encoding
|
||||
mapname = mapname.encode(engine_encoding)
|
||||
# end Py2K
|
||||
|
||||
if isinstance(selectable, Table):
|
||||
klass = TableClassType(mapname, (object,), {})
|
||||
else:
|
||||
klass = SelectableClassType(mapname, (object,), {})
|
||||
|
||||
def _compare(self, o):
|
||||
L = list(self.__class__.c.keys())
|
||||
L.sort()
|
||||
t1 = [getattr(self, k) for k in L]
|
||||
try:
|
||||
t2 = [getattr(o, k) for k in L]
|
||||
except AttributeError:
|
||||
raise TypeError('unable to compare with %s' % o.__class__)
|
||||
return t1, t2
|
||||
|
||||
# python2/python3 compatible system of
|
||||
# __cmp__ - __lt__ + __eq__
|
||||
|
||||
def __lt__(self, o):
|
||||
t1, t2 = _compare(self, o)
|
||||
return t1 < t2
|
||||
|
||||
def __eq__(self, o):
|
||||
t1, t2 = _compare(self, o)
|
||||
return t1 == t2
|
||||
|
||||
def __repr__(self):
|
||||
L = ["%s=%r" % (key, getattr(self, key, ''))
|
||||
for key in self.__class__.c.keys()]
|
||||
return '%s(%s)' % (self.__class__.__name__, ','.join(L))
|
||||
|
||||
for m in ['__eq__', '__repr__', '__lt__']:
|
||||
setattr(klass, m, eval(m))
|
||||
klass._table = selectable
|
||||
klass.c = expression.ColumnCollection()
|
||||
mappr = mapper(klass,
|
||||
selectable,
|
||||
extension=AutoAdd(session),
|
||||
**mapper_kwargs)
|
||||
|
||||
for k in mappr.iterate_properties:
|
||||
klass.c[k.key] = k.columns[0]
|
||||
|
||||
klass._query = session.query_property()
|
||||
return klass
|
||||
|
||||
class SqlSoup(object):
|
||||
def __init__(self, engine_or_metadata, **kw):
|
||||
"""Initialize a new ``SqlSoup``.
|
||||
|
||||
`args` may either be an ``SQLEngine`` or a set of arguments
|
||||
suitable for passing to ``create_engine``.
|
||||
"""
|
||||
|
||||
self.session = kw.pop('session', Session)
|
||||
|
||||
if isinstance(engine_or_metadata, MetaData):
|
||||
self._metadata = engine_or_metadata
|
||||
elif isinstance(engine_or_metadata, (basestring, Engine)):
|
||||
self._metadata = MetaData(engine_or_metadata)
|
||||
else:
|
||||
raise ArgumentError("invalid engine or metadata argument %r" % engine_or_metadata)
|
||||
|
||||
self._cache = {}
|
||||
self.schema = None
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
return self._metadata.bind
|
||||
|
||||
bind = engine
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
self.session.delete(*args, **kwargs)
|
||||
|
||||
def execute(self, stmt, **params):
|
||||
return self.session.execute(sql.text(stmt, bind=self.bind), **params)
|
||||
|
||||
@property
|
||||
def _underlying_session(self):
|
||||
if isinstance(self.session, session.Session):
|
||||
return self.session
|
||||
else:
|
||||
return self.session()
|
||||
|
||||
def connection(self):
|
||||
return self._underlying_session._connection_for_bind(self.bind)
|
||||
|
||||
def flush(self):
|
||||
self.session.flush()
|
||||
|
||||
def rollback(self):
|
||||
self.session.rollback()
|
||||
|
||||
def commit(self):
|
||||
self.session.commit()
|
||||
|
||||
def clear(self):
|
||||
self.session.expunge_all()
|
||||
|
||||
def expunge(self, *args, **kw):
|
||||
self.session.expunge(*args, **kw)
|
||||
|
||||
def expunge_all(self):
|
||||
self.session.expunge_all()
|
||||
|
||||
def map(self, selectable, **kwargs):
|
||||
try:
|
||||
t = self._cache[selectable]
|
||||
except KeyError:
|
||||
t = _class_for_table(self.session, self.engine, selectable, **kwargs)
|
||||
self._cache[selectable] = t
|
||||
return t
|
||||
|
||||
def with_labels(self, item):
|
||||
# TODO give meaningful aliases
|
||||
return self.map(
|
||||
expression._clause_element_as_expr(item).
|
||||
select(use_labels=True).
|
||||
alias('foo'))
|
||||
|
||||
def join(self, *args, **kwargs):
|
||||
j = join(*args, **kwargs)
|
||||
return self.map(j)
|
||||
|
||||
def entity(self, attr, schema=None):
|
||||
try:
|
||||
t = self._cache[attr]
|
||||
except KeyError, ke:
|
||||
table = Table(attr, self._metadata, autoload=True, autoload_with=self.bind, schema=schema or self.schema)
|
||||
if not table.primary_key.columns:
|
||||
raise PKNotFoundError('table %r does not have a primary key defined [columns: %s]' % (attr, ','.join(table.c.keys())))
|
||||
if table.columns:
|
||||
t = _class_for_table(self.session, self.engine, table)
|
||||
else:
|
||||
t = None
|
||||
self._cache[attr] = t
|
||||
return t
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return self.entity(attr)
|
||||
|
||||
def __repr__(self):
|
||||
return 'SqlSoup(%r)' % self._metadata
|
||||
|
||||
205
sqlalchemy/interfaces.py
Normal file
205
sqlalchemy/interfaces.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# interfaces.py
|
||||
# Copyright (C) 2007 Jason Kirtland jek@discorporate.us
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Interfaces and abstract types."""
|
||||
|
||||
|
||||
class PoolListener(object):
|
||||
"""Hooks into the lifecycle of connections in a ``Pool``.
|
||||
|
||||
Usage::
|
||||
|
||||
class MyListener(PoolListener):
|
||||
def connect(self, dbapi_con, con_record):
|
||||
'''perform connect operations'''
|
||||
# etc.
|
||||
|
||||
# create a new pool with a listener
|
||||
p = QueuePool(..., listeners=[MyListener()])
|
||||
|
||||
# add a listener after the fact
|
||||
p.add_listener(MyListener())
|
||||
|
||||
# usage with create_engine()
|
||||
e = create_engine("url://", listeners=[MyListener()])
|
||||
|
||||
All of the standard connection :class:`~sqlalchemy.pool.Pool` types can
|
||||
accept event listeners for key connection lifecycle events:
|
||||
creation, pool check-out and check-in. There are no events fired
|
||||
when a connection closes.
|
||||
|
||||
For any given DB-API connection, there will be one ``connect``
|
||||
event, `n` number of ``checkout`` events, and either `n` or `n - 1`
|
||||
``checkin`` events. (If a ``Connection`` is detached from its
|
||||
pool via the ``detach()`` method, it won't be checked back in.)
|
||||
|
||||
These are low-level events for low-level objects: raw Python
|
||||
DB-API connections, without the conveniences of the SQLAlchemy
|
||||
``Connection`` wrapper, ``Dialect`` services or ``ClauseElement``
|
||||
execution. If you execute SQL through the connection, explicitly
|
||||
closing all cursors and other resources is recommended.
|
||||
|
||||
Events also receive a ``_ConnectionRecord``, a long-lived internal
|
||||
``Pool`` object that basically represents a "slot" in the
|
||||
connection pool. ``_ConnectionRecord`` objects have one public
|
||||
attribute of note: ``info``, a dictionary whose contents are
|
||||
scoped to the lifetime of the DB-API connection managed by the
|
||||
record. You can use this shared storage area however you like.
|
||||
|
||||
There is no need to subclass ``PoolListener`` to handle events.
|
||||
Any class that implements one or more of these methods can be used
|
||||
as a pool listener. The ``Pool`` will inspect the methods
|
||||
provided by a listener object and add the listener to one or more
|
||||
internal event queues based on its capabilities. In terms of
|
||||
efficiency and function call overhead, you're much better off only
|
||||
providing implementations for the hooks you'll be using.
|
||||
|
||||
"""
|
||||
|
||||
def connect(self, dbapi_con, con_record):
|
||||
"""Called once for each new DB-API connection or Pool's ``creator()``.
|
||||
|
||||
dbapi_con
|
||||
A newly connected raw DB-API connection (not a SQLAlchemy
|
||||
``Connection`` wrapper).
|
||||
|
||||
con_record
|
||||
The ``_ConnectionRecord`` that persistently manages the connection
|
||||
|
||||
"""
|
||||
|
||||
def first_connect(self, dbapi_con, con_record):
|
||||
"""Called exactly once for the first DB-API connection.
|
||||
|
||||
dbapi_con
|
||||
A newly connected raw DB-API connection (not a SQLAlchemy
|
||||
``Connection`` wrapper).
|
||||
|
||||
con_record
|
||||
The ``_ConnectionRecord`` that persistently manages the connection
|
||||
|
||||
"""
|
||||
|
||||
def checkout(self, dbapi_con, con_record, con_proxy):
|
||||
"""Called when a connection is retrieved from the Pool.
|
||||
|
||||
dbapi_con
|
||||
A raw DB-API connection
|
||||
|
||||
con_record
|
||||
The ``_ConnectionRecord`` that persistently manages the connection
|
||||
|
||||
con_proxy
|
||||
The ``_ConnectionFairy`` which manages the connection for the span of
|
||||
the current checkout.
|
||||
|
||||
If you raise an ``exc.DisconnectionError``, the current
|
||||
connection will be disposed and a fresh connection retrieved.
|
||||
Processing of all checkout listeners will abort and restart
|
||||
using the new connection.
|
||||
"""
|
||||
|
||||
def checkin(self, dbapi_con, con_record):
|
||||
"""Called when a connection returns to the pool.
|
||||
|
||||
Note that the connection may be closed, and may be None if the
|
||||
connection has been invalidated. ``checkin`` will not be called
|
||||
for detached connections. (They do not return to the pool.)
|
||||
|
||||
dbapi_con
|
||||
A raw DB-API connection
|
||||
|
||||
con_record
|
||||
The ``_ConnectionRecord`` that persistently manages the connection
|
||||
|
||||
"""
|
||||
|
||||
class ConnectionProxy(object):
|
||||
"""Allows interception of statement execution by Connections.
|
||||
|
||||
Either or both of the ``execute()`` and ``cursor_execute()``
|
||||
may be implemented to intercept compiled statement and
|
||||
cursor level executions, e.g.::
|
||||
|
||||
class MyProxy(ConnectionProxy):
|
||||
def execute(self, conn, execute, clauseelement, *multiparams, **params):
|
||||
print "compiled statement:", clauseelement
|
||||
return execute(clauseelement, *multiparams, **params)
|
||||
|
||||
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
|
||||
print "raw statement:", statement
|
||||
return execute(cursor, statement, parameters, context)
|
||||
|
||||
The ``execute`` argument is a function that will fulfill the default
|
||||
execution behavior for the operation. The signature illustrated
|
||||
in the example should be used.
|
||||
|
||||
The proxy is installed into an :class:`~sqlalchemy.engine.Engine` via
|
||||
the ``proxy`` argument::
|
||||
|
||||
e = create_engine('someurl://', proxy=MyProxy())
|
||||
|
||||
"""
|
||||
def execute(self, conn, execute, clauseelement, *multiparams, **params):
|
||||
"""Intercept high level execute() events."""
|
||||
|
||||
return execute(clauseelement, *multiparams, **params)
|
||||
|
||||
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
|
||||
"""Intercept low-level cursor execute() events."""
|
||||
|
||||
return execute(cursor, statement, parameters, context)
|
||||
|
||||
def begin(self, conn, begin):
|
||||
"""Intercept begin() events."""
|
||||
|
||||
return begin()
|
||||
|
||||
def rollback(self, conn, rollback):
|
||||
"""Intercept rollback() events."""
|
||||
|
||||
return rollback()
|
||||
|
||||
def commit(self, conn, commit):
|
||||
"""Intercept commit() events."""
|
||||
|
||||
return commit()
|
||||
|
||||
def savepoint(self, conn, savepoint, name=None):
|
||||
"""Intercept savepoint() events."""
|
||||
|
||||
return savepoint(name=name)
|
||||
|
||||
def rollback_savepoint(self, conn, rollback_savepoint, name, context):
|
||||
"""Intercept rollback_savepoint() events."""
|
||||
|
||||
return rollback_savepoint(name, context)
|
||||
|
||||
def release_savepoint(self, conn, release_savepoint, name, context):
|
||||
"""Intercept release_savepoint() events."""
|
||||
|
||||
return release_savepoint(name, context)
|
||||
|
||||
def begin_twophase(self, conn, begin_twophase, xid):
|
||||
"""Intercept begin_twophase() events."""
|
||||
|
||||
return begin_twophase(xid)
|
||||
|
||||
def prepare_twophase(self, conn, prepare_twophase, xid):
|
||||
"""Intercept prepare_twophase() events."""
|
||||
|
||||
return prepare_twophase(xid)
|
||||
|
||||
def rollback_twophase(self, conn, rollback_twophase, xid, is_prepared):
|
||||
"""Intercept rollback_twophase() events."""
|
||||
|
||||
return rollback_twophase(xid, is_prepared)
|
||||
|
||||
def commit_twophase(self, conn, commit_twophase, xid, is_prepared):
|
||||
"""Intercept commit_twophase() events."""
|
||||
|
||||
return commit_twophase(xid, is_prepared)
|
||||
|
||||
119
sqlalchemy/log.py
Normal file
119
sqlalchemy/log.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# log.py - adapt python logging module to SQLAlchemy
|
||||
# Copyright (C) 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Logging control and utilities.
|
||||
|
||||
Control of logging for SA can be performed from the regular python logging
|
||||
module. The regular dotted module namespace is used, starting at
|
||||
'sqlalchemy'. For class-level logging, the class name is appended.
|
||||
|
||||
The "echo" keyword parameter which is available on SQLA ``Engine``
|
||||
and ``Pool`` objects corresponds to a logger specific to that
|
||||
instance only.
|
||||
|
||||
E.g.::
|
||||
|
||||
engine.echo = True
|
||||
|
||||
is equivalent to::
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger('sqlalchemy.engine.Engine.%s' % hex(id(engine)))
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from sqlalchemy import util
|
||||
|
||||
rootlogger = logging.getLogger('sqlalchemy')
|
||||
if rootlogger.level == logging.NOTSET:
|
||||
rootlogger.setLevel(logging.WARN)
|
||||
|
||||
default_enabled = False
|
||||
def default_logging(name):
|
||||
global default_enabled
|
||||
if logging.getLogger(name).getEffectiveLevel() < logging.WARN:
|
||||
default_enabled = True
|
||||
if not default_enabled:
|
||||
default_enabled = True
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s %(levelname)s %(name)s %(message)s'))
|
||||
rootlogger.addHandler(handler)
|
||||
|
||||
_logged_classes = set()
|
||||
def class_logger(cls, enable=False):
|
||||
logger = logging.getLogger(cls.__module__ + "." + cls.__name__)
|
||||
if enable == 'debug':
|
||||
logger.setLevel(logging.DEBUG)
|
||||
elif enable == 'info':
|
||||
logger.setLevel(logging.INFO)
|
||||
cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG)
|
||||
cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO)
|
||||
cls.logger = logger
|
||||
_logged_classes.add(cls)
|
||||
|
||||
|
||||
class Identified(object):
|
||||
@util.memoized_property
|
||||
def logging_name(self):
|
||||
# limit the number of loggers by chopping off the hex(id).
|
||||
# some novice users unfortunately create an unlimited number
|
||||
# of Engines in their applications which would otherwise
|
||||
# cause the app to run out of memory.
|
||||
return "0x...%s" % hex(id(self))[-4:]
|
||||
|
||||
|
||||
def instance_logger(instance, echoflag=None):
|
||||
"""create a logger for an instance that implements :class:`Identified`.
|
||||
|
||||
Warning: this is an expensive call which also results in a permanent
|
||||
increase in memory overhead for each call. Use only for
|
||||
low-volume, long-time-spanning objects.
|
||||
|
||||
"""
|
||||
|
||||
name = "%s.%s.%s" % (instance.__class__.__module__,
|
||||
instance.__class__.__name__, instance.logging_name)
|
||||
|
||||
if echoflag is not None:
|
||||
l = logging.getLogger(name)
|
||||
if echoflag == 'debug':
|
||||
default_logging(name)
|
||||
l.setLevel(logging.DEBUG)
|
||||
elif echoflag is True:
|
||||
default_logging(name)
|
||||
l.setLevel(logging.INFO)
|
||||
elif echoflag is False:
|
||||
l.setLevel(logging.WARN)
|
||||
else:
|
||||
l = logging.getLogger(name)
|
||||
instance._should_log_debug = lambda: l.isEnabledFor(logging.DEBUG)
|
||||
instance._should_log_info = lambda: l.isEnabledFor(logging.INFO)
|
||||
return l
|
||||
|
||||
class echo_property(object):
|
||||
__doc__ = """\
|
||||
When ``True``, enable log output for this element.
|
||||
|
||||
This has the effect of setting the Python logging level for the namespace
|
||||
of this element's class and object reference. A value of boolean ``True``
|
||||
indicates that the loglevel ``logging.INFO`` will be set for the logger,
|
||||
whereas the string value ``debug`` will set the loglevel to
|
||||
``logging.DEBUG``.
|
||||
"""
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
else:
|
||||
return instance._should_log_debug() and 'debug' or \
|
||||
(instance._should_log_info() and True or False)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
instance_logger(instance, echoflag=value)
|
||||
1176
sqlalchemy/orm/__init__.py
Normal file
1176
sqlalchemy/orm/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
1708
sqlalchemy/orm/attributes.py
Normal file
1708
sqlalchemy/orm/attributes.py
Normal file
File diff suppressed because it is too large
Load Diff
1438
sqlalchemy/orm/collections.py
Normal file
1438
sqlalchemy/orm/collections.py
Normal file
File diff suppressed because it is too large
Load Diff
575
sqlalchemy/orm/dependency.py
Normal file
575
sqlalchemy/orm/dependency.py
Normal file
@@ -0,0 +1,575 @@
|
||||
# orm/dependency.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Relationship dependencies.
|
||||
|
||||
Bridges the ``PropertyLoader`` (i.e. a ``relationship()``) and the
|
||||
``UOWTransaction`` together to allow processing of relationship()-based
|
||||
dependencies at flush time.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import sql, util
|
||||
import sqlalchemy.exceptions as sa_exc
|
||||
from sqlalchemy.orm import attributes, exc, sync
|
||||
from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
|
||||
|
||||
|
||||
def create_dependency_processor(prop):
|
||||
types = {
|
||||
ONETOMANY : OneToManyDP,
|
||||
MANYTOONE: ManyToOneDP,
|
||||
MANYTOMANY : ManyToManyDP,
|
||||
}
|
||||
return types[prop.direction](prop)
|
||||
|
||||
class DependencyProcessor(object):
|
||||
has_dependencies = True
|
||||
|
||||
def __init__(self, prop):
|
||||
self.prop = prop
|
||||
self.cascade = prop.cascade
|
||||
self.mapper = prop.mapper
|
||||
self.parent = prop.parent
|
||||
self.secondary = prop.secondary
|
||||
self.direction = prop.direction
|
||||
self.post_update = prop.post_update
|
||||
self.passive_deletes = prop.passive_deletes
|
||||
self.passive_updates = prop.passive_updates
|
||||
self.enable_typechecks = prop.enable_typechecks
|
||||
self.key = prop.key
|
||||
self.dependency_marker = MapperStub(self.parent, self.mapper, self.key)
|
||||
if not self.prop.synchronize_pairs:
|
||||
raise sa_exc.ArgumentError("Can't build a DependencyProcessor for relationship %s. "
|
||||
"No target attributes to populate between parent and child are present" % self.prop)
|
||||
|
||||
def _get_instrumented_attribute(self):
|
||||
"""Return the ``InstrumentedAttribute`` handled by this
|
||||
``DependencyProecssor``.
|
||||
|
||||
"""
|
||||
return self.parent.class_manager.get_impl(self.key)
|
||||
|
||||
def hasparent(self, state):
|
||||
"""return True if the given object instance has a parent,
|
||||
according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``.
|
||||
|
||||
"""
|
||||
# TODO: use correct API for this
|
||||
return self._get_instrumented_attribute().hasparent(state)
|
||||
|
||||
def register_dependencies(self, uowcommit):
|
||||
"""Tell a ``UOWTransaction`` what mappers are dependent on
|
||||
which, with regards to the two or three mappers handled by
|
||||
this ``DependencyProcessor``.
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def register_processors(self, uowcommit):
|
||||
"""Tell a ``UOWTransaction`` about this object as a processor,
|
||||
which will be executed after that mapper's objects have been
|
||||
saved or before they've been deleted. The process operation
|
||||
manages attributes and dependent operations between two mappers.
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def whose_dependent_on_who(self, state1, state2):
|
||||
"""Given an object pair assuming `obj2` is a child of `obj1`,
|
||||
return a tuple with the dependent object second, or None if
|
||||
there is no dependency.
|
||||
|
||||
"""
|
||||
if state1 is state2:
|
||||
return None
|
||||
elif self.direction == ONETOMANY:
|
||||
return (state1, state2)
|
||||
else:
|
||||
return (state2, state1)
|
||||
|
||||
def process_dependencies(self, task, deplist, uowcommit, delete = False):
|
||||
"""This method is called during a flush operation to
|
||||
synchronize data between a parent and child object.
|
||||
|
||||
It is called within the context of the various mappers and
|
||||
sometimes individual objects sorted according to their
|
||||
insert/update/delete order (topological sort).
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
|
||||
"""Used before the flushes' topological sort to traverse
|
||||
through related objects and ensure every instance which will
|
||||
require save/update/delete is properly added to the
|
||||
UOWTransaction.
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _verify_canload(self, state):
|
||||
if state is not None and \
|
||||
not self.mapper._canload(state, allow_subtypes=not self.enable_typechecks):
|
||||
if self.mapper._canload(state, allow_subtypes=True):
|
||||
raise exc.FlushError(
|
||||
"Attempting to flush an item of type %s on collection '%s', "
|
||||
"which is not the expected type %s. Configure mapper '%s' to "
|
||||
"load this subtype polymorphically, or set "
|
||||
"enable_typechecks=False to allow subtypes. "
|
||||
"Mismatched typeloading may cause bi-directional relationships "
|
||||
"(backrefs) to not function properly." %
|
||||
(state.class_, self.prop, self.mapper.class_, self.mapper))
|
||||
else:
|
||||
raise exc.FlushError(
|
||||
"Attempting to flush an item of type %s on collection '%s', "
|
||||
"whose mapper does not inherit from that of %s." %
|
||||
(state.class_, self.prop, self.mapper.class_))
|
||||
|
||||
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
|
||||
"""Called during a flush to synchronize primary key identifier
|
||||
values between a parent/child object, as well as to an
|
||||
associationrow in the case of many-to-many.
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _check_reverse_action(self, uowcommit, parent, child, action):
|
||||
"""Determine if an action has been performed by the 'reverse' property of this property.
|
||||
|
||||
this is used to ensure that only one side of a bidirectional relationship
|
||||
issues a certain operation for a parent/child pair.
|
||||
|
||||
"""
|
||||
for r in self.prop._reverse_property:
|
||||
if not r.viewonly and (r._dependency_processor, action, parent, child) in uowcommit.attributes:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _performed_action(self, uowcommit, parent, child, action):
|
||||
"""Establish that an action has been performed for a certain parent/child pair.
|
||||
|
||||
Used only for actions that are sensitive to bidirectional double-action,
|
||||
i.e. manytomany, post_update.
|
||||
|
||||
"""
|
||||
uowcommit.attributes[(self, action, parent, child)] = True
|
||||
|
||||
def _conditional_post_update(self, state, uowcommit, related):
|
||||
"""Execute a post_update call.
|
||||
|
||||
For relationships that contain the post_update flag, an additional
|
||||
``UPDATE`` statement may be associated after an ``INSERT`` or
|
||||
before a ``DELETE`` in order to resolve circular row
|
||||
dependencies.
|
||||
|
||||
This method will check for the post_update flag being set on a
|
||||
particular relationship, and given a target object and list of
|
||||
one or more related objects, and execute the ``UPDATE`` if the
|
||||
given related object list contains ``INSERT``s or ``DELETE``s.
|
||||
|
||||
"""
|
||||
if state is not None and self.post_update:
|
||||
for x in related:
|
||||
if x is not None and not self._check_reverse_action(uowcommit, x, state, "postupdate"):
|
||||
uowcommit.register_object(state, postupdate=True, post_update_cols=[r for l, r in self.prop.synchronize_pairs])
|
||||
self._performed_action(uowcommit, x, state, "postupdate")
|
||||
break
|
||||
|
||||
def _pks_changed(self, uowcommit, state):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (self.__class__.__name__, self.prop)
|
||||
|
||||
class OneToManyDP(DependencyProcessor):
|
||||
def register_dependencies(self, uowcommit):
|
||||
if self.post_update:
|
||||
uowcommit.register_dependency(self.mapper, self.dependency_marker)
|
||||
uowcommit.register_dependency(self.parent, self.dependency_marker)
|
||||
else:
|
||||
uowcommit.register_dependency(self.parent, self.mapper)
|
||||
|
||||
def register_processors(self, uowcommit):
|
||||
if self.post_update:
|
||||
uowcommit.register_processor(self.dependency_marker, self, self.parent)
|
||||
else:
|
||||
uowcommit.register_processor(self.parent, self, self.parent)
|
||||
|
||||
def process_dependencies(self, task, deplist, uowcommit, delete = False):
|
||||
if delete:
|
||||
# head object is being deleted, and we manage its list of child objects
|
||||
# the child objects have to have their foreign key to the parent set to NULL
|
||||
# this phase can be called safely for any cascade but is unnecessary if delete cascade
|
||||
# is on.
|
||||
if self.post_update or not self.passive_deletes == 'all':
|
||||
for state in deplist:
|
||||
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
|
||||
if history:
|
||||
for child in history.deleted:
|
||||
if child is not None and self.hasparent(child) is False:
|
||||
self._synchronize(state, child, None, True, uowcommit)
|
||||
self._conditional_post_update(child, uowcommit, [state])
|
||||
if self.post_update or not self.cascade.delete:
|
||||
for child in history.unchanged:
|
||||
if child is not None:
|
||||
self._synchronize(state, child, None, True, uowcommit)
|
||||
self._conditional_post_update(child, uowcommit, [state])
|
||||
else:
|
||||
for state in deplist:
|
||||
history = uowcommit.get_attribute_history(state, self.key, passive=True)
|
||||
if history:
|
||||
for child in history.added:
|
||||
self._synchronize(state, child, None, False, uowcommit)
|
||||
if child is not None:
|
||||
self._conditional_post_update(child, uowcommit, [state])
|
||||
|
||||
for child in history.deleted:
|
||||
if not self.cascade.delete_orphan and not self.hasparent(child):
|
||||
self._synchronize(state, child, None, True, uowcommit)
|
||||
|
||||
if self._pks_changed(uowcommit, state):
|
||||
for child in history.unchanged:
|
||||
self._synchronize(state, child, None, False, uowcommit)
|
||||
|
||||
def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
|
||||
if delete:
|
||||
# head object is being deleted, and we manage its list of child objects
|
||||
# the child objects have to have their foreign key to the parent set to NULL
|
||||
if not self.post_update:
|
||||
should_null_fks = not self.cascade.delete and not self.passive_deletes == 'all'
|
||||
for state in deplist:
|
||||
history = uowcommit.get_attribute_history(
|
||||
state, self.key, passive=self.passive_deletes)
|
||||
if history:
|
||||
for child in history.deleted:
|
||||
if child is not None and self.hasparent(child) is False:
|
||||
if self.cascade.delete_orphan:
|
||||
uowcommit.register_object(child, isdelete=True)
|
||||
else:
|
||||
uowcommit.register_object(child)
|
||||
if should_null_fks:
|
||||
for child in history.unchanged:
|
||||
if child is not None:
|
||||
uowcommit.register_object(child)
|
||||
else:
|
||||
for state in deplist:
|
||||
history = uowcommit.get_attribute_history(state, self.key, passive=True)
|
||||
if history:
|
||||
for child in history.added:
|
||||
if child is not None:
|
||||
uowcommit.register_object(child)
|
||||
for child in history.deleted:
|
||||
if not self.cascade.delete_orphan:
|
||||
uowcommit.register_object(child, isdelete=False)
|
||||
elif self.hasparent(child) is False:
|
||||
uowcommit.register_object(child, isdelete=True)
|
||||
for c, m in self.mapper.cascade_iterator('delete', child):
|
||||
uowcommit.register_object(
|
||||
attributes.instance_state(c),
|
||||
isdelete=True)
|
||||
if self._pks_changed(uowcommit, state):
|
||||
if not history:
|
||||
history = uowcommit.get_attribute_history(
|
||||
state, self.key, passive=self.passive_updates)
|
||||
if history:
|
||||
for child in history.unchanged:
|
||||
if child is not None:
|
||||
uowcommit.register_object(child)
|
||||
|
||||
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
|
||||
source = state
|
||||
dest = child
|
||||
if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
|
||||
return
|
||||
self._verify_canload(child)
|
||||
if clearkeys:
|
||||
sync.clear(dest, self.mapper, self.prop.synchronize_pairs)
|
||||
else:
|
||||
sync.populate(source, self.parent, dest, self.mapper,
|
||||
self.prop.synchronize_pairs, uowcommit,
|
||||
self.passive_updates)
|
||||
|
||||
def _pks_changed(self, uowcommit, state):
|
||||
return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
|
||||
|
||||
class DetectKeySwitch(DependencyProcessor):
|
||||
"""a special DP that works for many-to-one relationships, fires off for
|
||||
child items who have changed their referenced key."""
|
||||
|
||||
has_dependencies = False
|
||||
|
||||
def register_dependencies(self, uowcommit):
|
||||
pass
|
||||
|
||||
def register_processors(self, uowcommit):
|
||||
uowcommit.register_processor(self.parent, self, self.mapper)
|
||||
|
||||
def preprocess_dependencies(self, task, deplist, uowcommit, delete=False):
|
||||
# for non-passive updates, register in the preprocess stage
|
||||
# so that mapper save_obj() gets a hold of changes
|
||||
if not delete and not self.passive_updates:
|
||||
self._process_key_switches(deplist, uowcommit)
|
||||
|
||||
def process_dependencies(self, task, deplist, uowcommit, delete=False):
|
||||
# for passive updates, register objects in the process stage
|
||||
# so that we avoid ManyToOneDP's registering the object without
|
||||
# the listonly flag in its own preprocess stage (results in UPDATE)
|
||||
# statements being emitted
|
||||
if not delete and self.passive_updates:
|
||||
self._process_key_switches(deplist, uowcommit)
|
||||
|
||||
def _process_key_switches(self, deplist, uowcommit):
|
||||
switchers = set(s for s in deplist if self._pks_changed(uowcommit, s))
|
||||
if switchers:
|
||||
# yes, we're doing a linear search right now through the UOW. only
|
||||
# takes effect when primary key values have actually changed.
|
||||
# a possible optimization might be to enhance the "hasparents" capability of
|
||||
# attributes to actually store all parent references, but this introduces
|
||||
# more complicated attribute accounting.
|
||||
for s in [elem for elem in uowcommit.session.identity_map.all_states()
|
||||
if issubclass(elem.class_, self.parent.class_) and
|
||||
self.key in elem.dict and
|
||||
elem.dict[self.key] is not None and
|
||||
attributes.instance_state(elem.dict[self.key]) in switchers
|
||||
]:
|
||||
uowcommit.register_object(s)
|
||||
sync.populate(
|
||||
attributes.instance_state(s.dict[self.key]),
|
||||
self.mapper, s, self.parent, self.prop.synchronize_pairs,
|
||||
uowcommit, self.passive_updates)
|
||||
|
||||
def _pks_changed(self, uowcommit, state):
|
||||
return sync.source_modified(uowcommit, state, self.mapper, self.prop.synchronize_pairs)
|
||||
|
||||
class ManyToOneDP(DependencyProcessor):
|
||||
def __init__(self, prop):
|
||||
DependencyProcessor.__init__(self, prop)
|
||||
self.mapper._dependency_processors.append(DetectKeySwitch(prop))
|
||||
|
||||
def register_dependencies(self, uowcommit):
|
||||
if self.post_update:
|
||||
uowcommit.register_dependency(self.mapper, self.dependency_marker)
|
||||
uowcommit.register_dependency(self.parent, self.dependency_marker)
|
||||
else:
|
||||
uowcommit.register_dependency(self.mapper, self.parent)
|
||||
|
||||
def register_processors(self, uowcommit):
|
||||
if self.post_update:
|
||||
uowcommit.register_processor(self.dependency_marker, self, self.parent)
|
||||
else:
|
||||
uowcommit.register_processor(self.mapper, self, self.parent)
|
||||
|
||||
def process_dependencies(self, task, deplist, uowcommit, delete=False):
|
||||
if delete:
|
||||
if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all':
|
||||
# post_update means we have to update our row to not reference the child object
|
||||
# before we can DELETE the row
|
||||
for state in deplist:
|
||||
self._synchronize(state, None, None, True, uowcommit)
|
||||
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
|
||||
if history:
|
||||
self._conditional_post_update(state, uowcommit, history.sum())
|
||||
else:
|
||||
for state in deplist:
|
||||
history = uowcommit.get_attribute_history(state, self.key, passive=True)
|
||||
if history:
|
||||
for child in history.added:
|
||||
self._synchronize(state, child, None, False, uowcommit)
|
||||
self._conditional_post_update(state, uowcommit, history.sum())
|
||||
|
||||
def preprocess_dependencies(self, task, deplist, uowcommit, delete=False):
|
||||
if self.post_update:
|
||||
return
|
||||
if delete:
|
||||
if self.cascade.delete or self.cascade.delete_orphan:
|
||||
for state in deplist:
|
||||
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
|
||||
if history:
|
||||
if self.cascade.delete_orphan:
|
||||
todelete = history.sum()
|
||||
else:
|
||||
todelete = history.non_deleted()
|
||||
for child in todelete:
|
||||
if child is None:
|
||||
continue
|
||||
uowcommit.register_object(child, isdelete=True)
|
||||
for c, m in self.mapper.cascade_iterator('delete', child):
|
||||
uowcommit.register_object(
|
||||
attributes.instance_state(c), isdelete=True)
|
||||
else:
|
||||
for state in deplist:
|
||||
uowcommit.register_object(state)
|
||||
if self.cascade.delete_orphan:
|
||||
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
|
||||
if history:
|
||||
for child in history.deleted:
|
||||
if self.hasparent(child) is False:
|
||||
uowcommit.register_object(child, isdelete=True)
|
||||
for c, m in self.mapper.cascade_iterator('delete', child):
|
||||
uowcommit.register_object(
|
||||
attributes.instance_state(c),
|
||||
isdelete=True)
|
||||
|
||||
|
||||
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
|
||||
if state is None or (not self.post_update and uowcommit.is_deleted(state)):
|
||||
return
|
||||
|
||||
if clearkeys or child is None:
|
||||
sync.clear(state, self.parent, self.prop.synchronize_pairs)
|
||||
else:
|
||||
self._verify_canload(child)
|
||||
sync.populate(child, self.mapper, state,
|
||||
self.parent, self.prop.synchronize_pairs, uowcommit,
|
||||
self.passive_updates
|
||||
)
|
||||
|
||||
class ManyToManyDP(DependencyProcessor):
|
||||
def register_dependencies(self, uowcommit):
|
||||
# many-to-many. create a "Stub" mapper to represent the
|
||||
# "middle table" in the relationship. This stub mapper doesnt save
|
||||
# or delete any objects, but just marks a dependency on the two
|
||||
# related mappers. its dependency processor then populates the
|
||||
# association table.
|
||||
|
||||
uowcommit.register_dependency(self.parent, self.dependency_marker)
|
||||
uowcommit.register_dependency(self.mapper, self.dependency_marker)
|
||||
|
||||
def register_processors(self, uowcommit):
|
||||
uowcommit.register_processor(self.dependency_marker, self, self.parent)
|
||||
|
||||
def process_dependencies(self, task, deplist, uowcommit, delete = False):
|
||||
connection = uowcommit.transaction.connection(self.mapper)
|
||||
secondary_delete = []
|
||||
secondary_insert = []
|
||||
secondary_update = []
|
||||
|
||||
if delete:
|
||||
for state in deplist:
|
||||
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
|
||||
if history:
|
||||
for child in history.non_added():
|
||||
if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"):
|
||||
continue
|
||||
associationrow = {}
|
||||
self._synchronize(state, child, associationrow, False, uowcommit)
|
||||
secondary_delete.append(associationrow)
|
||||
self._performed_action(uowcommit, state, child, "manytomany")
|
||||
else:
|
||||
for state in deplist:
|
||||
history = uowcommit.get_attribute_history(state, self.key)
|
||||
if history:
|
||||
for child in history.added:
|
||||
if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"):
|
||||
continue
|
||||
associationrow = {}
|
||||
self._synchronize(state, child, associationrow, False, uowcommit)
|
||||
self._performed_action(uowcommit, state, child, "manytomany")
|
||||
secondary_insert.append(associationrow)
|
||||
for child in history.deleted:
|
||||
if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"):
|
||||
continue
|
||||
associationrow = {}
|
||||
self._synchronize(state, child, associationrow, False, uowcommit)
|
||||
self._performed_action(uowcommit, state, child, "manytomany")
|
||||
secondary_delete.append(associationrow)
|
||||
|
||||
if not self.passive_updates and self._pks_changed(uowcommit, state):
|
||||
if not history:
|
||||
history = uowcommit.get_attribute_history(state, self.key, passive=False)
|
||||
|
||||
for child in history.unchanged:
|
||||
associationrow = {}
|
||||
sync.update(state, self.parent, associationrow, "old_", self.prop.synchronize_pairs)
|
||||
sync.update(child, self.mapper, associationrow, "old_", self.prop.secondary_synchronize_pairs)
|
||||
|
||||
#self.syncrules.update(associationrow, state, child, "old_")
|
||||
secondary_update.append(associationrow)
|
||||
|
||||
if secondary_delete:
|
||||
statement = self.secondary.delete(sql.and_(*[
|
||||
c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow
|
||||
]))
|
||||
result = connection.execute(statement, secondary_delete)
|
||||
if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_delete):
|
||||
raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of "
|
||||
"secondary table rows deleted from table '%s': %d" %
|
||||
(result.rowcount, self.secondary.description, len(secondary_delete)))
|
||||
|
||||
if secondary_update:
|
||||
statement = self.secondary.update(sql.and_(*[
|
||||
c == sql.bindparam("old_" + c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow
|
||||
]))
|
||||
result = connection.execute(statement, secondary_update)
|
||||
if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update):
|
||||
raise exc.ConcurrentModificationError("Updated rowcount %d does not match number of "
|
||||
"secondary table rows updated from table '%s': %d" %
|
||||
(result.rowcount, self.secondary.description, len(secondary_update)))
|
||||
|
||||
if secondary_insert:
|
||||
statement = self.secondary.insert()
|
||||
connection.execute(statement, secondary_insert)
|
||||
|
||||
def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
|
||||
if not delete:
|
||||
for state in deplist:
|
||||
history = uowcommit.get_attribute_history(state, self.key, passive=True)
|
||||
if history:
|
||||
for child in history.deleted:
|
||||
if self.cascade.delete_orphan and self.hasparent(child) is False:
|
||||
uowcommit.register_object(child, isdelete=True)
|
||||
for c, m in self.mapper.cascade_iterator('delete', child):
|
||||
uowcommit.register_object(
|
||||
attributes.instance_state(c), isdelete=True)
|
||||
|
||||
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
|
||||
if associationrow is None:
|
||||
return
|
||||
self._verify_canload(child)
|
||||
|
||||
sync.populate_dict(state, self.parent, associationrow,
|
||||
self.prop.synchronize_pairs)
|
||||
sync.populate_dict(child, self.mapper, associationrow,
|
||||
self.prop.secondary_synchronize_pairs)
|
||||
|
||||
def _pks_changed(self, uowcommit, state):
|
||||
return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
|
||||
|
||||
class MapperStub(object):
|
||||
"""Represent a many-to-many dependency within a flush
|
||||
context.
|
||||
|
||||
The UOWTransaction corresponds dependencies to mappers.
|
||||
MapperStub takes the place of the "association table"
|
||||
so that a depedendency can be corresponded to it.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, parent, mapper, key):
|
||||
self.mapper = mapper
|
||||
self.base_mapper = self
|
||||
self.class_ = mapper.class_
|
||||
self._inheriting_mappers = []
|
||||
|
||||
def polymorphic_iterator(self):
|
||||
return iter((self,))
|
||||
|
||||
def _register_dependencies(self, uowcommit):
|
||||
pass
|
||||
|
||||
def _register_procesors(self, uowcommit):
|
||||
pass
|
||||
|
||||
def _save_obj(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def _delete_obj(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def primary_mapper(self):
|
||||
return self
|
||||
293
sqlalchemy/orm/dynamic.py
Normal file
293
sqlalchemy/orm/dynamic.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# dynamic.py
|
||||
# Copyright (C) the SQLAlchemy authors and contributors
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Dynamic collection API.
|
||||
|
||||
Dynamic collections act like Query() objects for read operations and support
|
||||
basic add/delete mutation.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import log, util
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy.orm import exc as sa_exc
|
||||
from sqlalchemy.sql import operators
|
||||
from sqlalchemy.orm import (
|
||||
attributes, object_session, util as mapperutil, strategies, object_mapper
|
||||
)
|
||||
from sqlalchemy.orm.query import Query
|
||||
from sqlalchemy.orm.util import _state_has_identity, has_identity
|
||||
from sqlalchemy.orm import attributes, collections
|
||||
|
||||
class DynaLoader(strategies.AbstractRelationshipLoader):
|
||||
def init_class_attribute(self, mapper):
|
||||
self.is_class_level = True
|
||||
|
||||
strategies._register_attribute(self,
|
||||
mapper,
|
||||
useobject=True,
|
||||
impl_class=DynamicAttributeImpl,
|
||||
target_mapper=self.parent_property.mapper,
|
||||
order_by=self.parent_property.order_by,
|
||||
query_class=self.parent_property.query_class
|
||||
)
|
||||
|
||||
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
|
||||
return (None, None)
|
||||
|
||||
log.class_logger(DynaLoader)
|
||||
|
||||
class DynamicAttributeImpl(attributes.AttributeImpl):
|
||||
uses_objects = True
|
||||
accepts_scalar_loader = False
|
||||
|
||||
def __init__(self, class_, key, typecallable,
|
||||
target_mapper, order_by, query_class=None, **kwargs):
|
||||
super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs)
|
||||
self.target_mapper = target_mapper
|
||||
self.order_by = order_by
|
||||
if not query_class:
|
||||
self.query_class = AppenderQuery
|
||||
elif AppenderMixin in query_class.mro():
|
||||
self.query_class = query_class
|
||||
else:
|
||||
self.query_class = mixin_user_query(query_class)
|
||||
|
||||
def get(self, state, dict_, passive=False):
|
||||
if passive:
|
||||
return self._get_collection_history(state, passive=True).added_items
|
||||
else:
|
||||
return self.query_class(self, state)
|
||||
|
||||
def get_collection(self, state, dict_, user_data=None, passive=True):
|
||||
if passive:
|
||||
return self._get_collection_history(state, passive=passive).added_items
|
||||
else:
|
||||
history = self._get_collection_history(state, passive=passive)
|
||||
return history.added_items + history.unchanged_items
|
||||
|
||||
def fire_append_event(self, state, dict_, value, initiator):
|
||||
collection_history = self._modified_event(state, dict_)
|
||||
collection_history.added_items.append(value)
|
||||
|
||||
for ext in self.extensions:
|
||||
ext.append(state, value, initiator or self)
|
||||
|
||||
if self.trackparent and value is not None:
|
||||
self.sethasparent(attributes.instance_state(value), True)
|
||||
|
||||
def fire_remove_event(self, state, dict_, value, initiator):
|
||||
collection_history = self._modified_event(state, dict_)
|
||||
collection_history.deleted_items.append(value)
|
||||
|
||||
if self.trackparent and value is not None:
|
||||
self.sethasparent(attributes.instance_state(value), False)
|
||||
|
||||
for ext in self.extensions:
|
||||
ext.remove(state, value, initiator or self)
|
||||
|
||||
def _modified_event(self, state, dict_):
|
||||
|
||||
if self.key not in state.committed_state:
|
||||
state.committed_state[self.key] = CollectionHistory(self, state)
|
||||
|
||||
state.modified_event(dict_,
|
||||
self,
|
||||
False,
|
||||
attributes.NEVER_SET,
|
||||
passive=attributes.PASSIVE_NO_INITIALIZE)
|
||||
|
||||
# this is a hack to allow the _base.ComparableEntity fixture
|
||||
# to work
|
||||
dict_[self.key] = True
|
||||
return state.committed_state[self.key]
|
||||
|
||||
def set(self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF):
|
||||
if initiator is self:
|
||||
return
|
||||
|
||||
self._set_iterable(state, dict_, value)
|
||||
|
||||
def _set_iterable(self, state, dict_, iterable, adapter=None):
|
||||
|
||||
collection_history = self._modified_event(state, dict_)
|
||||
new_values = list(iterable)
|
||||
|
||||
if _state_has_identity(state):
|
||||
old_collection = list(self.get(state, dict_))
|
||||
else:
|
||||
old_collection = []
|
||||
|
||||
collections.bulk_replace(new_values, DynCollectionAdapter(self, state, old_collection), DynCollectionAdapter(self, state, new_values))
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_history(self, state, dict_, passive=False):
|
||||
c = self._get_collection_history(state, passive)
|
||||
return attributes.History(c.added_items, c.unchanged_items, c.deleted_items)
|
||||
|
||||
def _get_collection_history(self, state, passive=False):
|
||||
if self.key in state.committed_state:
|
||||
c = state.committed_state[self.key]
|
||||
else:
|
||||
c = CollectionHistory(self, state)
|
||||
|
||||
if not passive:
|
||||
return CollectionHistory(self, state, apply_to=c)
|
||||
else:
|
||||
return c
|
||||
|
||||
def append(self, state, dict_, value, initiator, passive=False):
|
||||
if initiator is not self:
|
||||
self.fire_append_event(state, dict_, value, initiator)
|
||||
|
||||
def remove(self, state, dict_, value, initiator, passive=False):
|
||||
if initiator is not self:
|
||||
self.fire_remove_event(state, dict_, value, initiator)
|
||||
|
||||
class DynCollectionAdapter(object):
|
||||
"""the dynamic analogue to orm.collections.CollectionAdapter"""
|
||||
|
||||
def __init__(self, attr, owner_state, data):
|
||||
self.attr = attr
|
||||
self.state = owner_state
|
||||
self.data = data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.data)
|
||||
|
||||
def append_with_event(self, item, initiator=None):
|
||||
self.attr.append(self.state, self.state.dict, item, initiator)
|
||||
|
||||
def remove_with_event(self, item, initiator=None):
|
||||
self.attr.remove(self.state, self.state.dict, item, initiator)
|
||||
|
||||
def append_without_event(self, item):
|
||||
pass
|
||||
|
||||
def remove_without_event(self, item):
|
||||
pass
|
||||
|
||||
class AppenderMixin(object):
|
||||
query_class = None
|
||||
|
||||
def __init__(self, attr, state):
|
||||
Query.__init__(self, attr.target_mapper, None)
|
||||
self.instance = instance = state.obj()
|
||||
self.attr = attr
|
||||
|
||||
mapper = object_mapper(instance)
|
||||
prop = mapper.get_property(self.attr.key, resolve_synonyms=True)
|
||||
self._criterion = prop.compare(
|
||||
operators.eq,
|
||||
instance,
|
||||
value_is_parent=True,
|
||||
alias_secondary=False)
|
||||
|
||||
if self.attr.order_by:
|
||||
self._order_by = self.attr.order_by
|
||||
|
||||
def __session(self):
|
||||
sess = object_session(self.instance)
|
||||
if sess is not None and self.autoflush and sess.autoflush and self.instance in sess:
|
||||
sess.flush()
|
||||
if not has_identity(self.instance):
|
||||
return None
|
||||
else:
|
||||
return sess
|
||||
|
||||
def session(self):
|
||||
return self.__session()
|
||||
session = property(session, lambda s, x:None)
|
||||
|
||||
def __iter__(self):
|
||||
sess = self.__session()
|
||||
if sess is None:
|
||||
return iter(self.attr._get_collection_history(
|
||||
attributes.instance_state(self.instance),
|
||||
passive=True).added_items)
|
||||
else:
|
||||
return iter(self._clone(sess))
|
||||
|
||||
def __getitem__(self, index):
|
||||
sess = self.__session()
|
||||
if sess is None:
|
||||
return self.attr._get_collection_history(
|
||||
attributes.instance_state(self.instance),
|
||||
passive=True).added_items.__getitem__(index)
|
||||
else:
|
||||
return self._clone(sess).__getitem__(index)
|
||||
|
||||
def count(self):
|
||||
sess = self.__session()
|
||||
if sess is None:
|
||||
return len(self.attr._get_collection_history(
|
||||
attributes.instance_state(self.instance),
|
||||
passive=True).added_items)
|
||||
else:
|
||||
return self._clone(sess).count()
|
||||
|
||||
def _clone(self, sess=None):
|
||||
# note we're returning an entirely new Query class instance
|
||||
# here without any assignment capabilities; the class of this
|
||||
# query is determined by the session.
|
||||
instance = self.instance
|
||||
if sess is None:
|
||||
sess = object_session(instance)
|
||||
if sess is None:
|
||||
raise orm_exc.DetachedInstanceError(
|
||||
"Parent instance %s is not bound to a Session, and no "
|
||||
"contextual session is established; lazy load operation "
|
||||
"of attribute '%s' cannot proceed" % (
|
||||
mapperutil.instance_str(instance), self.attr.key))
|
||||
|
||||
if self.query_class:
|
||||
query = self.query_class(self.attr.target_mapper, session=sess)
|
||||
else:
|
||||
query = sess.query(self.attr.target_mapper)
|
||||
|
||||
query._criterion = self._criterion
|
||||
query._order_by = self._order_by
|
||||
|
||||
return query
|
||||
|
||||
def append(self, item):
|
||||
self.attr.append(
|
||||
attributes.instance_state(self.instance),
|
||||
attributes.instance_dict(self.instance), item, None)
|
||||
|
||||
def remove(self, item):
|
||||
self.attr.remove(
|
||||
attributes.instance_state(self.instance),
|
||||
attributes.instance_dict(self.instance), item, None)
|
||||
|
||||
|
||||
class AppenderQuery(AppenderMixin, Query):
|
||||
"""A dynamic query that supports basic collection storage operations."""
|
||||
|
||||
|
||||
def mixin_user_query(cls):
|
||||
"""Return a new class with AppenderQuery functionality layered over."""
|
||||
name = 'Appender' + cls.__name__
|
||||
return type(name, (AppenderMixin, cls), {'query_class': cls})
|
||||
|
||||
class CollectionHistory(object):
|
||||
"""Overrides AttributeHistory to receive append/remove events directly."""
|
||||
|
||||
def __init__(self, attr, state, apply_to=None):
|
||||
if apply_to:
|
||||
deleted = util.IdentitySet(apply_to.deleted_items)
|
||||
added = apply_to.added_items
|
||||
coll = AppenderQuery(attr, state).autoflush(False)
|
||||
self.unchanged_items = [o for o in util.IdentitySet(coll) if o not in deleted]
|
||||
self.added_items = apply_to.added_items
|
||||
self.deleted_items = apply_to.deleted_items
|
||||
else:
|
||||
self.deleted_items = []
|
||||
self.added_items = []
|
||||
self.unchanged_items = []
|
||||
|
||||
104
sqlalchemy/orm/evaluator.py
Normal file
104
sqlalchemy/orm/evaluator.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import operator
|
||||
from sqlalchemy.sql import operators, functions
|
||||
from sqlalchemy.sql import expression as sql
|
||||
|
||||
|
||||
class UnevaluatableError(Exception):
|
||||
pass
|
||||
|
||||
_straight_ops = set(getattr(operators, op)
|
||||
for op in ('add', 'mul', 'sub',
|
||||
# Py2K
|
||||
'div',
|
||||
# end Py2K
|
||||
'mod', 'truediv',
|
||||
'lt', 'le', 'ne', 'gt', 'ge', 'eq'))
|
||||
|
||||
|
||||
_notimplemented_ops = set(getattr(operators, op)
|
||||
for op in ('like_op', 'notlike_op', 'ilike_op',
|
||||
'notilike_op', 'between_op', 'in_op',
|
||||
'notin_op', 'endswith_op', 'concat_op'))
|
||||
|
||||
class EvaluatorCompiler(object):
|
||||
def process(self, clause):
|
||||
meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
|
||||
if not meth:
|
||||
raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__)
|
||||
return meth(clause)
|
||||
|
||||
def visit_grouping(self, clause):
|
||||
return self.process(clause.element)
|
||||
|
||||
def visit_null(self, clause):
|
||||
return lambda obj: None
|
||||
|
||||
def visit_column(self, clause):
|
||||
if 'parentmapper' in clause._annotations:
|
||||
key = clause._annotations['parentmapper']._get_col_to_prop(clause).key
|
||||
else:
|
||||
key = clause.key
|
||||
get_corresponding_attr = operator.attrgetter(key)
|
||||
return lambda obj: get_corresponding_attr(obj)
|
||||
|
||||
def visit_clauselist(self, clause):
|
||||
evaluators = map(self.process, clause.clauses)
|
||||
if clause.operator is operators.or_:
|
||||
def evaluate(obj):
|
||||
has_null = False
|
||||
for sub_evaluate in evaluators:
|
||||
value = sub_evaluate(obj)
|
||||
if value:
|
||||
return True
|
||||
has_null = has_null or value is None
|
||||
if has_null:
|
||||
return None
|
||||
return False
|
||||
elif clause.operator is operators.and_:
|
||||
def evaluate(obj):
|
||||
for sub_evaluate in evaluators:
|
||||
value = sub_evaluate(obj)
|
||||
if not value:
|
||||
if value is None:
|
||||
return None
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
raise UnevaluatableError("Cannot evaluate clauselist with operator %s" % clause.operator)
|
||||
|
||||
return evaluate
|
||||
|
||||
def visit_binary(self, clause):
|
||||
eval_left,eval_right = map(self.process, [clause.left, clause.right])
|
||||
operator = clause.operator
|
||||
if operator is operators.is_:
|
||||
def evaluate(obj):
|
||||
return eval_left(obj) == eval_right(obj)
|
||||
elif operator is operators.isnot:
|
||||
def evaluate(obj):
|
||||
return eval_left(obj) != eval_right(obj)
|
||||
elif operator in _straight_ops:
|
||||
def evaluate(obj):
|
||||
left_val = eval_left(obj)
|
||||
right_val = eval_right(obj)
|
||||
if left_val is None or right_val is None:
|
||||
return None
|
||||
return operator(eval_left(obj), eval_right(obj))
|
||||
else:
|
||||
raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator))
|
||||
return evaluate
|
||||
|
||||
def visit_unary(self, clause):
|
||||
eval_inner = self.process(clause.element)
|
||||
if clause.operator is operators.inv:
|
||||
def evaluate(obj):
|
||||
value = eval_inner(obj)
|
||||
if value is None:
|
||||
return None
|
||||
return not value
|
||||
return evaluate
|
||||
raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator))
|
||||
|
||||
def visit_bindparam(self, clause):
|
||||
val = clause.value
|
||||
return lambda obj: val
|
||||
98
sqlalchemy/orm/exc.py
Normal file
98
sqlalchemy/orm/exc.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# exc.py - ORM exceptions
|
||||
# Copyright (C) the SQLAlchemy authors and contributors
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""SQLAlchemy ORM exceptions."""
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
NO_STATE = (AttributeError, KeyError)
|
||||
"""Exception types that may be raised by instrumentation implementations."""
|
||||
|
||||
class ConcurrentModificationError(sa.exc.SQLAlchemyError):
|
||||
"""Rows have been modified outside of the unit of work."""
|
||||
|
||||
|
||||
class FlushError(sa.exc.SQLAlchemyError):
|
||||
"""A invalid condition was detected during flush()."""
|
||||
|
||||
|
||||
class UnmappedError(sa.exc.InvalidRequestError):
|
||||
"""TODO"""
|
||||
|
||||
class DetachedInstanceError(sa.exc.SQLAlchemyError):
|
||||
"""An attempt to access unloaded attributes on a mapped instance that is detached."""
|
||||
|
||||
class UnmappedInstanceError(UnmappedError):
|
||||
"""An mapping operation was requested for an unknown instance."""
|
||||
|
||||
def __init__(self, obj, msg=None):
|
||||
if not msg:
|
||||
try:
|
||||
mapper = sa.orm.class_mapper(type(obj))
|
||||
name = _safe_cls_name(type(obj))
|
||||
msg = ("Class %r is mapped, but this instance lacks "
|
||||
"instrumentation. This occurs when the instance is created "
|
||||
"before sqlalchemy.orm.mapper(%s) was called." % (name, name))
|
||||
except UnmappedClassError:
|
||||
msg = _default_unmapped(type(obj))
|
||||
if isinstance(obj, type):
|
||||
msg += (
|
||||
'; was a class (%s) supplied where an instance was '
|
||||
'required?' % _safe_cls_name(obj))
|
||||
UnmappedError.__init__(self, msg)
|
||||
|
||||
|
||||
class UnmappedClassError(UnmappedError):
|
||||
"""An mapping operation was requested for an unknown class."""
|
||||
|
||||
def __init__(self, cls, msg=None):
|
||||
if not msg:
|
||||
msg = _default_unmapped(cls)
|
||||
UnmappedError.__init__(self, msg)
|
||||
|
||||
|
||||
class ObjectDeletedError(sa.exc.InvalidRequestError):
|
||||
"""An refresh() operation failed to re-retrieve an object's row."""
|
||||
|
||||
|
||||
class UnmappedColumnError(sa.exc.InvalidRequestError):
|
||||
"""Mapping operation was requested on an unknown column."""
|
||||
|
||||
|
||||
class NoResultFound(sa.exc.InvalidRequestError):
|
||||
"""A database result was required but none was found."""
|
||||
|
||||
|
||||
class MultipleResultsFound(sa.exc.InvalidRequestError):
|
||||
"""A single database result was required but more than one were found."""
|
||||
|
||||
|
||||
# Legacy compat until 0.6.
|
||||
sa.exc.ConcurrentModificationError = ConcurrentModificationError
|
||||
sa.exc.FlushError = FlushError
|
||||
sa.exc.UnmappedColumnError
|
||||
|
||||
def _safe_cls_name(cls):
|
||||
try:
|
||||
cls_name = '.'.join((cls.__module__, cls.__name__))
|
||||
except AttributeError:
|
||||
cls_name = getattr(cls, '__name__', None)
|
||||
if cls_name is None:
|
||||
cls_name = repr(cls)
|
||||
return cls_name
|
||||
|
||||
def _default_unmapped(cls):
|
||||
try:
|
||||
mappers = sa.orm.attributes.manager_of_class(cls).mappers
|
||||
except NO_STATE:
|
||||
mappers = {}
|
||||
except TypeError:
|
||||
mappers = {}
|
||||
name = _safe_cls_name(cls)
|
||||
|
||||
if not mappers:
|
||||
return "Class '%s' is not mapped" % name
|
||||
251
sqlalchemy/orm/identity.py
Normal file
251
sqlalchemy/orm/identity.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# identity.py
|
||||
# Copyright (C) the SQLAlchemy authors and contributors
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import weakref
|
||||
|
||||
from sqlalchemy import util as base_util
|
||||
from sqlalchemy.orm import attributes
|
||||
|
||||
|
||||
class IdentityMap(dict):
|
||||
def __init__(self):
|
||||
self._mutable_attrs = set()
|
||||
self._modified = set()
|
||||
self._wr = weakref.ref(self)
|
||||
|
||||
def replace(self, state):
|
||||
raise NotImplementedError()
|
||||
|
||||
def add(self, state):
|
||||
raise NotImplementedError()
|
||||
|
||||
def remove(self, state):
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self, dict):
|
||||
raise NotImplementedError("IdentityMap uses add() to insert data")
|
||||
|
||||
def clear(self):
|
||||
raise NotImplementedError("IdentityMap uses remove() to remove data")
|
||||
|
||||
def _manage_incoming_state(self, state):
|
||||
state._instance_dict = self._wr
|
||||
|
||||
if state.modified:
|
||||
self._modified.add(state)
|
||||
if state.manager.mutable_attributes:
|
||||
self._mutable_attrs.add(state)
|
||||
|
||||
def _manage_removed_state(self, state):
|
||||
del state._instance_dict
|
||||
self._mutable_attrs.discard(state)
|
||||
self._modified.discard(state)
|
||||
|
||||
def _dirty_states(self):
|
||||
return self._modified.union(s for s in self._mutable_attrs.copy()
|
||||
if s.modified)
|
||||
|
||||
def check_modified(self):
|
||||
"""return True if any InstanceStates present have been marked as 'modified'."""
|
||||
|
||||
if self._modified:
|
||||
return True
|
||||
else:
|
||||
for state in self._mutable_attrs.copy():
|
||||
if state.modified:
|
||||
return True
|
||||
return False
|
||||
|
||||
def has_key(self, key):
|
||||
return key in self
|
||||
|
||||
def popitem(self):
|
||||
raise NotImplementedError("IdentityMap uses remove() to remove data")
|
||||
|
||||
def pop(self, key, *args):
|
||||
raise NotImplementedError("IdentityMap uses remove() to remove data")
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
raise NotImplementedError("IdentityMap uses add() to insert data")
|
||||
|
||||
def copy(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
raise NotImplementedError("IdentityMap uses add() to insert data")
|
||||
|
||||
def __delitem__(self, key):
|
||||
raise NotImplementedError("IdentityMap uses remove() to remove data")
|
||||
|
||||
class WeakInstanceDict(IdentityMap):
|
||||
|
||||
def __getitem__(self, key):
|
||||
state = dict.__getitem__(self, key)
|
||||
o = state.obj()
|
||||
if o is None:
|
||||
o = state._is_really_none()
|
||||
if o is None:
|
||||
raise KeyError, key
|
||||
return o
|
||||
|
||||
def __contains__(self, key):
|
||||
try:
|
||||
if dict.__contains__(self, key):
|
||||
state = dict.__getitem__(self, key)
|
||||
o = state.obj()
|
||||
if o is None:
|
||||
o = state._is_really_none()
|
||||
else:
|
||||
return False
|
||||
except KeyError:
|
||||
return False
|
||||
else:
|
||||
return o is not None
|
||||
|
||||
def contains_state(self, state):
|
||||
return dict.get(self, state.key) is state
|
||||
|
||||
def replace(self, state):
|
||||
if dict.__contains__(self, state.key):
|
||||
existing = dict.__getitem__(self, state.key)
|
||||
if existing is not state:
|
||||
self._manage_removed_state(existing)
|
||||
else:
|
||||
return
|
||||
|
||||
dict.__setitem__(self, state.key, state)
|
||||
self._manage_incoming_state(state)
|
||||
|
||||
def add(self, state):
|
||||
if state.key in self:
|
||||
if dict.__getitem__(self, state.key) is not state:
|
||||
raise AssertionError("A conflicting state is already "
|
||||
"present in the identity map for key %r"
|
||||
% (state.key, ))
|
||||
else:
|
||||
dict.__setitem__(self, state.key, state)
|
||||
self._manage_incoming_state(state)
|
||||
|
||||
def remove_key(self, key):
|
||||
state = dict.__getitem__(self, key)
|
||||
self.remove(state)
|
||||
|
||||
def remove(self, state):
|
||||
if dict.pop(self, state.key) is not state:
|
||||
raise AssertionError("State %s is not present in this identity map" % state)
|
||||
self._manage_removed_state(state)
|
||||
|
||||
def discard(self, state):
|
||||
if self.contains_state(state):
|
||||
dict.__delitem__(self, state.key)
|
||||
self._manage_removed_state(state)
|
||||
|
||||
def get(self, key, default=None):
|
||||
state = dict.get(self, key, default)
|
||||
if state is default:
|
||||
return default
|
||||
o = state.obj()
|
||||
if o is None:
|
||||
o = state._is_really_none()
|
||||
if o is None:
|
||||
return default
|
||||
return o
|
||||
|
||||
# Py2K
|
||||
def items(self):
|
||||
return list(self.iteritems())
|
||||
|
||||
def iteritems(self):
|
||||
for state in dict.itervalues(self):
|
||||
# end Py2K
|
||||
# Py3K
|
||||
#def items(self):
|
||||
# for state in dict.values(self):
|
||||
value = state.obj()
|
||||
if value is not None:
|
||||
yield state.key, value
|
||||
|
||||
# Py2K
|
||||
def values(self):
|
||||
return list(self.itervalues())
|
||||
|
||||
def itervalues(self):
|
||||
for state in dict.itervalues(self):
|
||||
# end Py2K
|
||||
# Py3K
|
||||
#def values(self):
|
||||
# for state in dict.values(self):
|
||||
instance = state.obj()
|
||||
if instance is not None:
|
||||
yield instance
|
||||
|
||||
def all_states(self):
|
||||
# Py3K
|
||||
# return list(dict.values(self))
|
||||
|
||||
# Py2K
|
||||
return dict.values(self)
|
||||
# end Py2K
|
||||
|
||||
def prune(self):
|
||||
return 0
|
||||
|
||||
class StrongInstanceDict(IdentityMap):
|
||||
def all_states(self):
|
||||
return [attributes.instance_state(o) for o in self.itervalues()]
|
||||
|
||||
def contains_state(self, state):
|
||||
return state.key in self and attributes.instance_state(self[state.key]) is state
|
||||
|
||||
def replace(self, state):
|
||||
if dict.__contains__(self, state.key):
|
||||
existing = dict.__getitem__(self, state.key)
|
||||
existing = attributes.instance_state(existing)
|
||||
if existing is not state:
|
||||
self._manage_removed_state(existing)
|
||||
else:
|
||||
return
|
||||
|
||||
dict.__setitem__(self, state.key, state.obj())
|
||||
self._manage_incoming_state(state)
|
||||
|
||||
def add(self, state):
|
||||
if state.key in self:
|
||||
if attributes.instance_state(dict.__getitem__(self, state.key)) is not state:
|
||||
raise AssertionError("A conflicting state is already present in the identity map for key %r" % (state.key, ))
|
||||
else:
|
||||
dict.__setitem__(self, state.key, state.obj())
|
||||
self._manage_incoming_state(state)
|
||||
|
||||
def remove(self, state):
|
||||
if attributes.instance_state(dict.pop(self, state.key)) is not state:
|
||||
raise AssertionError("State %s is not present in this identity map" % state)
|
||||
self._manage_removed_state(state)
|
||||
|
||||
def discard(self, state):
|
||||
if self.contains_state(state):
|
||||
dict.__delitem__(self, state.key)
|
||||
self._manage_removed_state(state)
|
||||
|
||||
def remove_key(self, key):
|
||||
state = attributes.instance_state(dict.__getitem__(self, key))
|
||||
self.remove(state)
|
||||
|
||||
def prune(self):
|
||||
"""prune unreferenced, non-dirty states."""
|
||||
|
||||
ref_count = len(self)
|
||||
dirty = [s.obj() for s in self.all_states() if s.modified]
|
||||
|
||||
# work around http://bugs.python.org/issue6149
|
||||
keepers = weakref.WeakValueDictionary()
|
||||
keepers.update(self)
|
||||
|
||||
dict.clear(self)
|
||||
dict.update(self, keepers)
|
||||
self.modified = bool(dirty)
|
||||
return ref_count - len(self)
|
||||
|
||||
1098
sqlalchemy/orm/interfaces.py
Normal file
1098
sqlalchemy/orm/interfaces.py
Normal file
File diff suppressed because it is too large
Load Diff
1958
sqlalchemy/orm/mapper.py
Normal file
1958
sqlalchemy/orm/mapper.py
Normal file
File diff suppressed because it is too large
Load Diff
1205
sqlalchemy/orm/properties.py
Normal file
1205
sqlalchemy/orm/properties.py
Normal file
File diff suppressed because it is too large
Load Diff
2469
sqlalchemy/orm/query.py
Normal file
2469
sqlalchemy/orm/query.py
Normal file
File diff suppressed because it is too large
Load Diff
205
sqlalchemy/orm/scoping.py
Normal file
205
sqlalchemy/orm/scoping.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# scoping.py
|
||||
# Copyright (C) the SQLAlchemy authors and contributors
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import sqlalchemy.exceptions as sa_exc
|
||||
from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \
|
||||
to_list, get_cls_kwargs, deprecated
|
||||
from sqlalchemy.orm import (
|
||||
EXT_CONTINUE, MapperExtension, class_mapper, object_session
|
||||
)
|
||||
from sqlalchemy.orm import exc as orm_exc
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
|
||||
__all__ = ['ScopedSession']
|
||||
|
||||
|
||||
class ScopedSession(object):
|
||||
"""Provides thread-local management of Sessions.
|
||||
|
||||
Usage::
|
||||
|
||||
Session = scoped_session(sessionmaker(autoflush=True))
|
||||
|
||||
... use session normally.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, session_factory, scopefunc=None):
|
||||
self.session_factory = session_factory
|
||||
if scopefunc:
|
||||
self.registry = ScopedRegistry(session_factory, scopefunc)
|
||||
else:
|
||||
self.registry = ThreadLocalRegistry(session_factory)
|
||||
self.extension = _ScopedExt(self)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if kwargs:
|
||||
scope = kwargs.pop('scope', False)
|
||||
if scope is not None:
|
||||
if self.registry.has():
|
||||
raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
|
||||
else:
|
||||
sess = self.session_factory(**kwargs)
|
||||
self.registry.set(sess)
|
||||
return sess
|
||||
else:
|
||||
return self.session_factory(**kwargs)
|
||||
else:
|
||||
return self.registry()
|
||||
|
||||
def remove(self):
|
||||
"""Dispose of the current contextual session."""
|
||||
|
||||
if self.registry.has():
|
||||
self.registry().close()
|
||||
self.registry.clear()
|
||||
|
||||
@deprecated("Session.mapper is deprecated. "
|
||||
"Please see http://www.sqlalchemy.org/trac/wiki/UsageRecipes/SessionAwareMapper "
|
||||
"for information on how to replicate its behavior.")
|
||||
def mapper(self, *args, **kwargs):
|
||||
"""return a mapper() function which associates this ScopedSession with the Mapper.
|
||||
|
||||
DEPRECATED.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import mapper
|
||||
|
||||
extension_args = dict((arg, kwargs.pop(arg))
|
||||
for arg in get_cls_kwargs(_ScopedExt)
|
||||
if arg in kwargs)
|
||||
|
||||
kwargs['extension'] = extension = to_list(kwargs.get('extension', []))
|
||||
if extension_args:
|
||||
extension.append(self.extension.configure(**extension_args))
|
||||
else:
|
||||
extension.append(self.extension)
|
||||
return mapper(*args, **kwargs)
|
||||
|
||||
def configure(self, **kwargs):
|
||||
"""reconfigure the sessionmaker used by this ScopedSession."""
|
||||
|
||||
self.session_factory.configure(**kwargs)
|
||||
|
||||
def query_property(self, query_cls=None):
|
||||
"""return a class property which produces a `Query` object against the
|
||||
class when called.
|
||||
|
||||
e.g.::
|
||||
Session = scoped_session(sessionmaker())
|
||||
|
||||
class MyClass(object):
|
||||
query = Session.query_property()
|
||||
|
||||
# after mappers are defined
|
||||
result = MyClass.query.filter(MyClass.name=='foo').all()
|
||||
|
||||
Produces instances of the session's configured query class by
|
||||
default. To override and use a custom implementation, provide
|
||||
a ``query_cls`` callable. The callable will be invoked with
|
||||
the class's mapper as a positional argument and a session
|
||||
keyword argument.
|
||||
|
||||
There is no limit to the number of query properties placed on
|
||||
a class.
|
||||
|
||||
"""
|
||||
class query(object):
|
||||
def __get__(s, instance, owner):
|
||||
try:
|
||||
mapper = class_mapper(owner)
|
||||
if mapper:
|
||||
if query_cls:
|
||||
# custom query class
|
||||
return query_cls(mapper, session=self.registry())
|
||||
else:
|
||||
# session's configured query class
|
||||
return self.registry().query(mapper)
|
||||
except orm_exc.UnmappedClassError:
|
||||
return None
|
||||
return query()
|
||||
|
||||
def instrument(name):
|
||||
def do(self, *args, **kwargs):
|
||||
return getattr(self.registry(), name)(*args, **kwargs)
|
||||
return do
|
||||
for meth in Session.public_methods:
|
||||
setattr(ScopedSession, meth, instrument(meth))
|
||||
|
||||
def makeprop(name):
|
||||
def set(self, attr):
|
||||
setattr(self.registry(), name, attr)
|
||||
def get(self):
|
||||
return getattr(self.registry(), name)
|
||||
return property(get, set)
|
||||
for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active', 'autoflush'):
|
||||
setattr(ScopedSession, prop, makeprop(prop))
|
||||
|
||||
def clslevel(name):
|
||||
def do(cls, *args, **kwargs):
|
||||
return getattr(Session, name)(*args, **kwargs)
|
||||
return classmethod(do)
|
||||
for prop in ('close_all', 'object_session', 'identity_key'):
|
||||
setattr(ScopedSession, prop, clslevel(prop))
|
||||
|
||||
class _ScopedExt(MapperExtension):
|
||||
def __init__(self, context, validate=False, save_on_init=True):
|
||||
self.context = context
|
||||
self.validate = validate
|
||||
self.save_on_init = save_on_init
|
||||
self.set_kwargs_on_init = True
|
||||
|
||||
def validating(self):
|
||||
return _ScopedExt(self.context, validate=True)
|
||||
|
||||
def configure(self, **kwargs):
|
||||
return _ScopedExt(self.context, **kwargs)
|
||||
|
||||
def instrument_class(self, mapper, class_):
|
||||
class query(object):
|
||||
def __getattr__(s, key):
|
||||
return getattr(self.context.registry().query(class_), key)
|
||||
def __call__(s):
|
||||
return self.context.registry().query(class_)
|
||||
def __get__(self, instance, cls):
|
||||
return self
|
||||
|
||||
if not 'query' in class_.__dict__:
|
||||
class_.query = query()
|
||||
|
||||
if self.set_kwargs_on_init and class_.__init__ is object.__init__:
|
||||
class_.__init__ = self._default__init__(mapper)
|
||||
|
||||
def _default__init__(ext, mapper):
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.iteritems():
|
||||
if ext.validate:
|
||||
if not mapper.get_property(key, resolve_synonyms=False,
|
||||
raiseerr=False):
|
||||
raise sa_exc.ArgumentError(
|
||||
"Invalid __init__ argument: '%s'" % key)
|
||||
setattr(self, key, value)
|
||||
return __init__
|
||||
|
||||
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
|
||||
if self.save_on_init:
|
||||
session = kwargs.pop('_sa_session', None)
|
||||
if session is None:
|
||||
session = self.context.registry()
|
||||
session._save_without_cascade(instance)
|
||||
return EXT_CONTINUE
|
||||
|
||||
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
|
||||
sess = object_session(instance)
|
||||
if sess:
|
||||
sess.expunge(instance)
|
||||
return EXT_CONTINUE
|
||||
|
||||
def dispose_class(self, mapper, class_):
|
||||
if hasattr(class_, 'query'):
|
||||
delattr(class_, 'query')
|
||||
1604
sqlalchemy/orm/session.py
Normal file
1604
sqlalchemy/orm/session.py
Normal file
File diff suppressed because it is too large
Load Diff
15
sqlalchemy/orm/shard.py
Normal file
15
sqlalchemy/orm/shard.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# shard.py
|
||||
# Copyright (C) the SQLAlchemy authors and contributors
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from sqlalchemy import util
|
||||
|
||||
util.warn_deprecated(
|
||||
"Horizontal sharding is now importable via "
|
||||
"'import sqlalchemy.ext.horizontal_shard"
|
||||
)
|
||||
|
||||
from sqlalchemy.ext.horizontal_shard import *
|
||||
|
||||
527
sqlalchemy/orm/state.py
Normal file
527
sqlalchemy/orm/state.py
Normal file
@@ -0,0 +1,527 @@
|
||||
from sqlalchemy.util import EMPTY_SET
|
||||
import weakref
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.orm.attributes import PASSIVE_NO_RESULT, PASSIVE_OFF, \
|
||||
NEVER_SET, NO_VALUE, manager_of_class, \
|
||||
ATTR_WAS_SET
|
||||
from sqlalchemy.orm import attributes, exc as orm_exc, interfaces
|
||||
|
||||
import sys
|
||||
attributes.state = sys.modules['sqlalchemy.orm.state']
|
||||
|
||||
class InstanceState(object):
|
||||
"""tracks state information at the instance level."""
|
||||
|
||||
session_id = None
|
||||
key = None
|
||||
runid = None
|
||||
load_options = EMPTY_SET
|
||||
load_path = ()
|
||||
insert_order = None
|
||||
mutable_dict = None
|
||||
_strong_obj = None
|
||||
modified = False
|
||||
expired = False
|
||||
|
||||
def __init__(self, obj, manager):
|
||||
self.class_ = obj.__class__
|
||||
self.manager = manager
|
||||
self.obj = weakref.ref(obj, self._cleanup)
|
||||
|
||||
@util.memoized_property
|
||||
def committed_state(self):
|
||||
return {}
|
||||
|
||||
@util.memoized_property
|
||||
def parents(self):
|
||||
return {}
|
||||
|
||||
@util.memoized_property
|
||||
def pending(self):
|
||||
return {}
|
||||
|
||||
@util.memoized_property
|
||||
def callables(self):
|
||||
return {}
|
||||
|
||||
def detach(self):
|
||||
if self.session_id:
|
||||
try:
|
||||
del self.session_id
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def dispose(self):
|
||||
self.detach()
|
||||
del self.obj
|
||||
|
||||
def _cleanup(self, ref):
|
||||
instance_dict = self._instance_dict()
|
||||
if instance_dict:
|
||||
try:
|
||||
instance_dict.remove(self)
|
||||
except AssertionError:
|
||||
pass
|
||||
# remove possible cycles
|
||||
self.__dict__.pop('callables', None)
|
||||
self.dispose()
|
||||
|
||||
def obj(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def dict(self):
|
||||
o = self.obj()
|
||||
if o is not None:
|
||||
return attributes.instance_dict(o)
|
||||
else:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def sort_key(self):
|
||||
return self.key and self.key[1] or (self.insert_order, )
|
||||
|
||||
def initialize_instance(*mixed, **kwargs):
|
||||
self, instance, args = mixed[0], mixed[1], mixed[2:]
|
||||
manager = self.manager
|
||||
|
||||
for fn in manager.events.on_init:
|
||||
fn(self, instance, args, kwargs)
|
||||
|
||||
# LESSTHANIDEAL:
|
||||
# adjust for the case where the InstanceState was created before
|
||||
# mapper compilation, and this actually needs to be a MutableAttrInstanceState
|
||||
if manager.mutable_attributes and self.__class__ is not MutableAttrInstanceState:
|
||||
self.__class__ = MutableAttrInstanceState
|
||||
self.obj = weakref.ref(self.obj(), self._cleanup)
|
||||
self.mutable_dict = {}
|
||||
|
||||
try:
|
||||
return manager.events.original_init(*mixed[1:], **kwargs)
|
||||
except:
|
||||
for fn in manager.events.on_init_failure:
|
||||
fn(self, instance, args, kwargs)
|
||||
raise
|
||||
|
||||
def get_history(self, key, **kwargs):
|
||||
return self.manager.get_impl(key).get_history(self, self.dict, **kwargs)
|
||||
|
||||
def get_impl(self, key):
|
||||
return self.manager.get_impl(key)
|
||||
|
||||
def get_pending(self, key):
|
||||
if key not in self.pending:
|
||||
self.pending[key] = PendingCollection()
|
||||
return self.pending[key]
|
||||
|
||||
def value_as_iterable(self, key, passive=PASSIVE_OFF):
|
||||
"""return an InstanceState attribute as a list,
|
||||
regardless of it being a scalar or collection-based
|
||||
attribute.
|
||||
|
||||
returns None if passive is not PASSIVE_OFF and the getter returns
|
||||
PASSIVE_NO_RESULT.
|
||||
"""
|
||||
|
||||
impl = self.get_impl(key)
|
||||
dict_ = self.dict
|
||||
x = impl.get(self, dict_, passive=passive)
|
||||
if x is PASSIVE_NO_RESULT:
|
||||
return None
|
||||
elif hasattr(impl, 'get_collection'):
|
||||
return impl.get_collection(self, dict_, x, passive=passive)
|
||||
else:
|
||||
return [x]
|
||||
|
||||
def _run_on_load(self, instance):
|
||||
self.manager.events.run('on_load', instance)
|
||||
|
||||
def __getstate__(self):
|
||||
d = {'instance':self.obj()}
|
||||
|
||||
d.update(
|
||||
(k, self.__dict__[k]) for k in (
|
||||
'committed_state', 'pending', 'parents', 'modified', 'expired',
|
||||
'callables', 'key', 'load_options', 'mutable_dict'
|
||||
) if k in self.__dict__
|
||||
)
|
||||
if self.load_path:
|
||||
d['load_path'] = interfaces.serialize_path(self.load_path)
|
||||
return d
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.obj = weakref.ref(state['instance'], self._cleanup)
|
||||
self.class_ = state['instance'].__class__
|
||||
self.manager = manager = manager_of_class(self.class_)
|
||||
if manager is None:
|
||||
raise orm_exc.UnmappedInstanceError(
|
||||
state['instance'],
|
||||
"Cannot deserialize object of type %r - no mapper() has"
|
||||
" been configured for this class within the current Python process!" %
|
||||
self.class_)
|
||||
elif manager.mapper and not manager.mapper.compiled:
|
||||
manager.mapper.compile()
|
||||
|
||||
self.committed_state = state.get('committed_state', {})
|
||||
self.pending = state.get('pending', {})
|
||||
self.parents = state.get('parents', {})
|
||||
self.modified = state.get('modified', False)
|
||||
self.expired = state.get('expired', False)
|
||||
self.callables = state.get('callables', {})
|
||||
|
||||
if self.modified:
|
||||
self._strong_obj = state['instance']
|
||||
|
||||
self.__dict__.update([
|
||||
(k, state[k]) for k in (
|
||||
'key', 'load_options', 'mutable_dict'
|
||||
) if k in state
|
||||
])
|
||||
|
||||
if 'load_path' in state:
|
||||
self.load_path = interfaces.deserialize_path(state['load_path'])
|
||||
|
||||
def initialize(self, key):
|
||||
"""Set this attribute to an empty value or collection,
|
||||
based on the AttributeImpl in use."""
|
||||
|
||||
self.manager.get_impl(key).initialize(self, self.dict)
|
||||
|
||||
def reset(self, dict_, key):
|
||||
"""Remove the given attribute and any
|
||||
callables associated with it."""
|
||||
|
||||
dict_.pop(key, None)
|
||||
self.callables.pop(key, None)
|
||||
|
||||
def expire_attribute_pre_commit(self, dict_, key):
|
||||
"""a fast expire that can be called by column loaders during a load.
|
||||
|
||||
The additional bookkeeping is finished up in commit_all().
|
||||
|
||||
This method is actually called a lot with joined-table
|
||||
loading, when the second table isn't present in the result.
|
||||
|
||||
"""
|
||||
dict_.pop(key, None)
|
||||
self.callables[key] = self
|
||||
|
||||
def set_callable(self, dict_, key, callable_):
|
||||
"""Remove the given attribute and set the given callable
|
||||
as a loader."""
|
||||
|
||||
dict_.pop(key, None)
|
||||
self.callables[key] = callable_
|
||||
|
||||
def expire_attributes(self, dict_, attribute_names, instance_dict=None):
|
||||
"""Expire all or a group of attributes.
|
||||
|
||||
If all attributes are expired, the "expired" flag is set to True.
|
||||
|
||||
"""
|
||||
if attribute_names is None:
|
||||
attribute_names = self.manager.keys()
|
||||
self.expired = True
|
||||
if self.modified:
|
||||
if not instance_dict:
|
||||
instance_dict = self._instance_dict()
|
||||
if instance_dict:
|
||||
instance_dict._modified.discard(self)
|
||||
else:
|
||||
instance_dict._modified.discard(self)
|
||||
|
||||
self.modified = False
|
||||
filter_deferred = True
|
||||
else:
|
||||
filter_deferred = False
|
||||
|
||||
to_clear = (
|
||||
self.__dict__.get('pending', None),
|
||||
self.__dict__.get('committed_state', None),
|
||||
self.mutable_dict
|
||||
)
|
||||
|
||||
for key in attribute_names:
|
||||
impl = self.manager[key].impl
|
||||
if impl.accepts_scalar_loader and \
|
||||
(not filter_deferred or impl.expire_missing or key in dict_):
|
||||
self.callables[key] = self
|
||||
dict_.pop(key, None)
|
||||
|
||||
for d in to_clear:
|
||||
if d is not None:
|
||||
d.pop(key, None)
|
||||
|
||||
def __call__(self, **kw):
|
||||
"""__call__ allows the InstanceState to act as a deferred
|
||||
callable for loading expired attributes, which is also
|
||||
serializable (picklable).
|
||||
|
||||
"""
|
||||
|
||||
if kw.get('passive') is attributes.PASSIVE_NO_FETCH:
|
||||
return attributes.PASSIVE_NO_RESULT
|
||||
|
||||
toload = self.expired_attributes.\
|
||||
intersection(self.unmodified)
|
||||
|
||||
self.manager.deferred_scalar_loader(self, toload)
|
||||
|
||||
# if the loader failed, or this
|
||||
# instance state didn't have an identity,
|
||||
# the attributes still might be in the callables
|
||||
# dict. ensure they are removed.
|
||||
for k in toload.intersection(self.callables):
|
||||
del self.callables[k]
|
||||
|
||||
return ATTR_WAS_SET
|
||||
|
||||
@property
|
||||
def unmodified(self):
|
||||
"""Return the set of keys which have no uncommitted changes"""
|
||||
|
||||
return set(self.manager).difference(self.committed_state)
|
||||
|
||||
@property
|
||||
def unloaded(self):
|
||||
"""Return the set of keys which do not have a loaded value.
|
||||
|
||||
This includes expired attributes and any other attribute that
|
||||
was never populated or modified.
|
||||
|
||||
"""
|
||||
return set(self.manager).\
|
||||
difference(self.committed_state).\
|
||||
difference(self.dict)
|
||||
|
||||
@property
|
||||
def expired_attributes(self):
|
||||
"""Return the set of keys which are 'expired' to be loaded by
|
||||
the manager's deferred scalar loader, assuming no pending
|
||||
changes.
|
||||
|
||||
see also the ``unmodified`` collection which is intersected
|
||||
against this set when a refresh operation occurs.
|
||||
|
||||
"""
|
||||
return set([k for k, v in self.callables.items() if v is self])
|
||||
|
||||
def _instance_dict(self):
|
||||
return None
|
||||
|
||||
def _is_really_none(self):
|
||||
return self.obj()
|
||||
|
||||
def modified_event(self, dict_, attr, should_copy, previous, passive=PASSIVE_OFF):
|
||||
needs_committed = attr.key not in self.committed_state
|
||||
|
||||
if needs_committed:
|
||||
if previous is NEVER_SET:
|
||||
if passive:
|
||||
if attr.key in dict_:
|
||||
previous = dict_[attr.key]
|
||||
else:
|
||||
previous = attr.get(self, dict_)
|
||||
|
||||
if should_copy and previous not in (None, NO_VALUE, NEVER_SET):
|
||||
previous = attr.copy(previous)
|
||||
|
||||
if needs_committed:
|
||||
self.committed_state[attr.key] = previous
|
||||
|
||||
if not self.modified:
|
||||
instance_dict = self._instance_dict()
|
||||
if instance_dict:
|
||||
instance_dict._modified.add(self)
|
||||
|
||||
self.modified = True
|
||||
if self._strong_obj is None:
|
||||
self._strong_obj = self.obj()
|
||||
|
||||
def commit(self, dict_, keys):
|
||||
"""Commit attributes.
|
||||
|
||||
This is used by a partial-attribute load operation to mark committed
|
||||
those attributes which were refreshed from the database.
|
||||
|
||||
Attributes marked as "expired" can potentially remain "expired" after
|
||||
this step if a value was not populated in state.dict.
|
||||
|
||||
"""
|
||||
class_manager = self.manager
|
||||
for key in keys:
|
||||
if key in dict_ and key in class_manager.mutable_attributes:
|
||||
self.committed_state[key] = self.manager[key].impl.copy(dict_[key])
|
||||
else:
|
||||
self.committed_state.pop(key, None)
|
||||
|
||||
self.expired = False
|
||||
|
||||
for key in set(self.callables).\
|
||||
intersection(keys).\
|
||||
intersection(dict_):
|
||||
del self.callables[key]
|
||||
|
||||
def commit_all(self, dict_, instance_dict=None):
|
||||
"""commit all attributes unconditionally.
|
||||
|
||||
This is used after a flush() or a full load/refresh
|
||||
to remove all pending state from the instance.
|
||||
|
||||
- all attributes are marked as "committed"
|
||||
- the "strong dirty reference" is removed
|
||||
- the "modified" flag is set to False
|
||||
- any "expired" markers/callables for attributes loaded are removed.
|
||||
|
||||
Attributes marked as "expired" can potentially remain "expired" after this step
|
||||
if a value was not populated in state.dict.
|
||||
|
||||
"""
|
||||
|
||||
self.__dict__.pop('committed_state', None)
|
||||
self.__dict__.pop('pending', None)
|
||||
|
||||
if 'callables' in self.__dict__:
|
||||
callables = self.callables
|
||||
for key in list(callables):
|
||||
if key in dict_ and callables[key] is self:
|
||||
del callables[key]
|
||||
|
||||
for key in self.manager.mutable_attributes:
|
||||
if key in dict_:
|
||||
self.committed_state[key] = self.manager[key].impl.copy(dict_[key])
|
||||
|
||||
if instance_dict and self.modified:
|
||||
instance_dict._modified.discard(self)
|
||||
|
||||
self.modified = self.expired = False
|
||||
self._strong_obj = None
|
||||
|
||||
class MutableAttrInstanceState(InstanceState):
|
||||
"""InstanceState implementation for objects that reference 'mutable'
|
||||
attributes.
|
||||
|
||||
Has a more involved "cleanup" handler that checks mutable attributes
|
||||
for changes upon dereference, resurrecting if needed.
|
||||
|
||||
"""
|
||||
|
||||
@util.memoized_property
|
||||
def mutable_dict(self):
|
||||
return {}
|
||||
|
||||
def _get_modified(self, dict_=None):
|
||||
if self.__dict__.get('modified', False):
|
||||
return True
|
||||
else:
|
||||
if dict_ is None:
|
||||
dict_ = self.dict
|
||||
for key in self.manager.mutable_attributes:
|
||||
if self.manager[key].impl.check_mutable_modified(self, dict_):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _set_modified(self, value):
|
||||
self.__dict__['modified'] = value
|
||||
|
||||
modified = property(_get_modified, _set_modified)
|
||||
|
||||
@property
|
||||
def unmodified(self):
|
||||
"""a set of keys which have no uncommitted changes"""
|
||||
|
||||
dict_ = self.dict
|
||||
|
||||
return set([
|
||||
key for key in self.manager
|
||||
if (key not in self.committed_state or
|
||||
(key in self.manager.mutable_attributes and
|
||||
not self.manager[key].impl.check_mutable_modified(self, dict_)))])
|
||||
|
||||
def _is_really_none(self):
|
||||
"""do a check modified/resurrect.
|
||||
|
||||
This would be called in the extremely rare
|
||||
race condition that the weakref returned None but
|
||||
the cleanup handler had not yet established the
|
||||
__resurrect callable as its replacement.
|
||||
|
||||
"""
|
||||
if self.modified:
|
||||
self.obj = self.__resurrect
|
||||
return self.obj()
|
||||
else:
|
||||
return None
|
||||
|
||||
def reset(self, dict_, key):
|
||||
self.mutable_dict.pop(key, None)
|
||||
InstanceState.reset(self, dict_, key)
|
||||
|
||||
def _cleanup(self, ref):
|
||||
"""weakref callback.
|
||||
|
||||
This method may be called by an asynchronous
|
||||
gc.
|
||||
|
||||
If the state shows pending changes, the weakref
|
||||
is replaced by the __resurrect callable which will
|
||||
re-establish an object reference on next access,
|
||||
else removes this InstanceState from the owning
|
||||
identity map, if any.
|
||||
|
||||
"""
|
||||
if self._get_modified(self.mutable_dict):
|
||||
self.obj = self.__resurrect
|
||||
else:
|
||||
instance_dict = self._instance_dict()
|
||||
if instance_dict:
|
||||
try:
|
||||
instance_dict.remove(self)
|
||||
except AssertionError:
|
||||
pass
|
||||
self.dispose()
|
||||
|
||||
def __resurrect(self):
|
||||
"""A substitute for the obj() weakref function which resurrects."""
|
||||
|
||||
# store strong ref'ed version of the object; will revert
|
||||
# to weakref when changes are persisted
|
||||
|
||||
obj = self.manager.new_instance(state=self)
|
||||
self.obj = weakref.ref(obj, self._cleanup)
|
||||
self._strong_obj = obj
|
||||
obj.__dict__.update(self.mutable_dict)
|
||||
|
||||
# re-establishes identity attributes from the key
|
||||
self.manager.events.run('on_resurrect', self, obj)
|
||||
|
||||
# TODO: don't really think we should run this here.
|
||||
# resurrect is only meant to preserve the minimal state needed to
|
||||
# do an UPDATE, not to produce a fully usable object
|
||||
self._run_on_load(obj)
|
||||
|
||||
return obj
|
||||
|
||||
class PendingCollection(object):
|
||||
"""A writable placeholder for an unloaded collection.
|
||||
|
||||
Stores items appended to and removed from a collection that has not yet
|
||||
been loaded. When the collection is loaded, the changes stored in
|
||||
PendingCollection are applied to it to produce the final result.
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
self.deleted_items = util.IdentitySet()
|
||||
self.added_items = util.OrderedIdentitySet()
|
||||
|
||||
def append(self, value):
|
||||
if value in self.deleted_items:
|
||||
self.deleted_items.remove(value)
|
||||
self.added_items.add(value)
|
||||
|
||||
def remove(self, value):
|
||||
if value in self.added_items:
|
||||
self.added_items.remove(value)
|
||||
self.deleted_items.add(value)
|
||||
|
||||
1229
sqlalchemy/orm/strategies.py
Normal file
1229
sqlalchemy/orm/strategies.py
Normal file
File diff suppressed because it is too large
Load Diff
98
sqlalchemy/orm/sync.py
Normal file
98
sqlalchemy/orm/sync.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# mapper/sync.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""private module containing functions used for copying data
|
||||
between instances based on join conditions.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import exc, util as mapperutil
|
||||
|
||||
def populate(source, source_mapper, dest, dest_mapper,
|
||||
synchronize_pairs, uowcommit, passive_updates):
|
||||
for l, r in synchronize_pairs:
|
||||
try:
|
||||
value = source_mapper._get_state_attr_by_column(source, l)
|
||||
except exc.UnmappedColumnError:
|
||||
_raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
|
||||
|
||||
try:
|
||||
dest_mapper._set_state_attr_by_column(dest, r, value)
|
||||
except exc.UnmappedColumnError:
|
||||
_raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
|
||||
|
||||
# techically the "r.primary_key" check isn't
|
||||
# needed here, but we check for this condition to limit
|
||||
# how often this logic is invoked for memory/performance
|
||||
# reasons, since we only need this info for a primary key
|
||||
# destination.
|
||||
if l.primary_key and r.primary_key and \
|
||||
r.references(l) and passive_updates:
|
||||
uowcommit.attributes[("pk_cascaded", dest, r)] = True
|
||||
|
||||
def clear(dest, dest_mapper, synchronize_pairs):
|
||||
for l, r in synchronize_pairs:
|
||||
if r.primary_key:
|
||||
raise AssertionError(
|
||||
"Dependency rule tried to blank-out primary key "
|
||||
"column '%s' on instance '%s'" %
|
||||
(r, mapperutil.state_str(dest))
|
||||
)
|
||||
try:
|
||||
dest_mapper._set_state_attr_by_column(dest, r, None)
|
||||
except exc.UnmappedColumnError:
|
||||
_raise_col_to_prop(True, None, l, dest_mapper, r)
|
||||
|
||||
def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
|
||||
for l, r in synchronize_pairs:
|
||||
try:
|
||||
oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l)
|
||||
value = source_mapper._get_state_attr_by_column(source, l)
|
||||
except exc.UnmappedColumnError:
|
||||
_raise_col_to_prop(False, source_mapper, l, None, r)
|
||||
dest[r.key] = value
|
||||
dest[old_prefix + r.key] = oldvalue
|
||||
|
||||
def populate_dict(source, source_mapper, dict_, synchronize_pairs):
|
||||
for l, r in synchronize_pairs:
|
||||
try:
|
||||
value = source_mapper._get_state_attr_by_column(source, l)
|
||||
except exc.UnmappedColumnError:
|
||||
_raise_col_to_prop(False, source_mapper, l, None, r)
|
||||
|
||||
dict_[r.key] = value
|
||||
|
||||
def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
|
||||
"""return true if the source object has changes from an old to a
|
||||
new value on the given synchronize pairs
|
||||
|
||||
"""
|
||||
for l, r in synchronize_pairs:
|
||||
try:
|
||||
prop = source_mapper._get_col_to_prop(l)
|
||||
except exc.UnmappedColumnError:
|
||||
_raise_col_to_prop(False, source_mapper, l, None, r)
|
||||
history = uowcommit.get_attribute_history(source, prop.key, passive=True)
|
||||
if len(history.deleted):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column):
|
||||
if isdest:
|
||||
raise exc.UnmappedColumnError(
|
||||
"Can't execute sync rule for destination column '%s'; "
|
||||
"mapper '%s' does not map this column. Try using an explicit"
|
||||
" `foreign_keys` collection which does not include this column "
|
||||
"(or use a viewonly=True relation)." % (dest_column, source_mapper)
|
||||
)
|
||||
else:
|
||||
raise exc.UnmappedColumnError(
|
||||
"Can't execute sync rule for source column '%s'; mapper '%s' "
|
||||
"does not map this column. Try using an explicit `foreign_keys`"
|
||||
" collection which does not include destination column '%s' (or "
|
||||
"use a viewonly=True relation)." %
|
||||
(source_column, source_mapper, dest_column)
|
||||
)
|
||||
781
sqlalchemy/orm/unitofwork.py
Normal file
781
sqlalchemy/orm/unitofwork.py
Normal file
@@ -0,0 +1,781 @@
|
||||
# orm/unitofwork.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""The internals for the Unit Of Work system.
|
||||
|
||||
Includes hooks into the attributes package enabling the routing of
|
||||
change events to Unit Of Work objects, as well as the flush()
|
||||
mechanism which creates a dependency structure that executes change
|
||||
operations.
|
||||
|
||||
A Unit of Work is essentially a system of maintaining a graph of
|
||||
in-memory objects and their modified state. Objects are maintained as
|
||||
unique against their primary key identity using an *identity map*
|
||||
pattern. The Unit of Work then maintains lists of objects that are
|
||||
new, dirty, or deleted and provides the capability to flush all those
|
||||
changes at once.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import util, log, topological
|
||||
from sqlalchemy.orm import attributes, interfaces
|
||||
from sqlalchemy.orm import util as mapperutil
|
||||
from sqlalchemy.orm.mapper import _state_mapper
|
||||
|
||||
# Load lazily
|
||||
object_session = None
|
||||
_state_session = None
|
||||
|
||||
class UOWEventHandler(interfaces.AttributeExtension):
|
||||
"""An event handler added to all relationship attributes which handles
|
||||
session cascade operations.
|
||||
"""
|
||||
|
||||
active_history = False
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
|
||||
def append(self, state, item, initiator):
|
||||
# process "save_update" cascade rules for when an instance is appended to the list of another instance
|
||||
sess = _state_session(state)
|
||||
if sess:
|
||||
prop = _state_mapper(state).get_property(self.key)
|
||||
if prop.cascade.save_update and item not in sess:
|
||||
sess.add(item)
|
||||
return item
|
||||
|
||||
def remove(self, state, item, initiator):
|
||||
sess = _state_session(state)
|
||||
if sess:
|
||||
prop = _state_mapper(state).get_property(self.key)
|
||||
# expunge pending orphans
|
||||
if prop.cascade.delete_orphan and \
|
||||
item in sess.new and \
|
||||
prop.mapper._is_orphan(attributes.instance_state(item)):
|
||||
sess.expunge(item)
|
||||
|
||||
def set(self, state, newvalue, oldvalue, initiator):
|
||||
# process "save_update" cascade rules for when an instance is attached to another instance
|
||||
if oldvalue is newvalue:
|
||||
return newvalue
|
||||
sess = _state_session(state)
|
||||
if sess:
|
||||
prop = _state_mapper(state).get_property(self.key)
|
||||
if newvalue is not None and prop.cascade.save_update and newvalue not in sess:
|
||||
sess.add(newvalue)
|
||||
if prop.cascade.delete_orphan and oldvalue in sess.new and \
|
||||
prop.mapper._is_orphan(attributes.instance_state(oldvalue)):
|
||||
sess.expunge(oldvalue)
|
||||
return newvalue
|
||||
|
||||
|
||||
class UOWTransaction(object):
|
||||
"""Handles the details of organizing and executing transaction
|
||||
tasks during a UnitOfWork object's flush() operation.
|
||||
|
||||
The central operation is to form a graph of nodes represented by the
|
||||
``UOWTask`` class, which is then traversed by a ``UOWExecutor`` object
|
||||
that issues SQL and instance-synchronizing operations via the related
|
||||
packages.
|
||||
"""
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
self.mapper_flush_opts = session._mapper_flush_opts
|
||||
|
||||
# stores tuples of mapper/dependent mapper pairs,
|
||||
# representing a partial ordering fed into topological sort
|
||||
self.dependencies = set()
|
||||
|
||||
# dictionary of mappers to UOWTasks
|
||||
self.tasks = {}
|
||||
|
||||
# dictionary used by external actors to store arbitrary state
|
||||
# information.
|
||||
self.attributes = {}
|
||||
|
||||
self.processors = set()
|
||||
|
||||
def get_attribute_history(self, state, key, passive=True):
|
||||
hashkey = ("history", state, key)
|
||||
|
||||
# cache the objects, not the states; the strong reference here
|
||||
# prevents newly loaded objects from being dereferenced during the
|
||||
# flush process
|
||||
if hashkey in self.attributes:
|
||||
(history, cached_passive) = self.attributes[hashkey]
|
||||
# if the cached lookup was "passive" and now we want non-passive, do a non-passive
|
||||
# lookup and re-cache
|
||||
if cached_passive and not passive:
|
||||
history = attributes.get_state_history(state, key, passive=False)
|
||||
self.attributes[hashkey] = (history, passive)
|
||||
else:
|
||||
history = attributes.get_state_history(state, key, passive=passive)
|
||||
self.attributes[hashkey] = (history, passive)
|
||||
|
||||
if not history or not state.get_impl(key).uses_objects:
|
||||
return history
|
||||
else:
|
||||
return history.as_state()
|
||||
|
||||
def register_object(self, state, isdelete=False,
|
||||
listonly=False, postupdate=False, post_update_cols=None):
|
||||
|
||||
# if object is not in the overall session, do nothing
|
||||
if not self.session._contains_state(state):
|
||||
return
|
||||
|
||||
mapper = _state_mapper(state)
|
||||
|
||||
task = self.get_task_by_mapper(mapper)
|
||||
if postupdate:
|
||||
task.append_postupdate(state, post_update_cols)
|
||||
else:
|
||||
task.append(state, listonly=listonly, isdelete=isdelete)
|
||||
|
||||
# ensure the mapper for this object has had its
|
||||
# DependencyProcessors added.
|
||||
if mapper not in self.processors:
|
||||
mapper._register_processors(self)
|
||||
self.processors.add(mapper)
|
||||
|
||||
if mapper.base_mapper not in self.processors:
|
||||
mapper.base_mapper._register_processors(self)
|
||||
self.processors.add(mapper.base_mapper)
|
||||
|
||||
def set_row_switch(self, state):
|
||||
"""mark a deleted object as a 'row switch'.
|
||||
|
||||
this indicates that an INSERT statement elsewhere corresponds to this DELETE;
|
||||
the INSERT is converted to an UPDATE and the DELETE does not occur.
|
||||
|
||||
"""
|
||||
mapper = _state_mapper(state)
|
||||
task = self.get_task_by_mapper(mapper)
|
||||
taskelement = task._objects[state]
|
||||
taskelement.isdelete = "rowswitch"
|
||||
|
||||
def is_deleted(self, state):
|
||||
"""return true if the given state is marked as deleted within this UOWTransaction."""
|
||||
|
||||
mapper = _state_mapper(state)
|
||||
task = self.get_task_by_mapper(mapper)
|
||||
return task.is_deleted(state)
|
||||
|
||||
def get_task_by_mapper(self, mapper, dontcreate=False):
|
||||
"""return UOWTask element corresponding to the given mapper.
|
||||
|
||||
Will create a new UOWTask, including a UOWTask corresponding to the
|
||||
"base" inherited mapper, if needed, unless the dontcreate flag is True.
|
||||
|
||||
"""
|
||||
try:
|
||||
return self.tasks[mapper]
|
||||
except KeyError:
|
||||
if dontcreate:
|
||||
return None
|
||||
|
||||
base_mapper = mapper.base_mapper
|
||||
if base_mapper in self.tasks:
|
||||
base_task = self.tasks[base_mapper]
|
||||
else:
|
||||
self.tasks[base_mapper] = base_task = UOWTask(self, base_mapper)
|
||||
base_mapper._register_dependencies(self)
|
||||
|
||||
if mapper not in self.tasks:
|
||||
self.tasks[mapper] = task = UOWTask(self, mapper, base_task=base_task)
|
||||
mapper._register_dependencies(self)
|
||||
else:
|
||||
task = self.tasks[mapper]
|
||||
|
||||
return task
|
||||
|
||||
def register_dependency(self, mapper, dependency):
|
||||
"""register a dependency between two mappers.
|
||||
|
||||
Called by ``mapper.PropertyLoader`` to register the objects
|
||||
handled by one mapper being dependent on the objects handled
|
||||
by another.
|
||||
|
||||
"""
|
||||
# correct for primary mapper
|
||||
# also convert to the "base mapper", the parentmost task at the top of an inheritance chain
|
||||
# dependency sorting is done via non-inheriting mappers only, dependencies between mappers
|
||||
# in the same inheritance chain is done at the per-object level
|
||||
mapper = mapper.primary_mapper().base_mapper
|
||||
dependency = dependency.primary_mapper().base_mapper
|
||||
|
||||
self.dependencies.add((mapper, dependency))
|
||||
|
||||
def register_processor(self, mapper, processor, mapperfrom):
|
||||
"""register a dependency processor, corresponding to
|
||||
operations which occur between two mappers.
|
||||
|
||||
"""
|
||||
# correct for primary mapper
|
||||
mapper = mapper.primary_mapper()
|
||||
mapperfrom = mapperfrom.primary_mapper()
|
||||
|
||||
task = self.get_task_by_mapper(mapper)
|
||||
targettask = self.get_task_by_mapper(mapperfrom)
|
||||
up = UOWDependencyProcessor(processor, targettask)
|
||||
task.dependencies.add(up)
|
||||
|
||||
def execute(self):
|
||||
"""Execute this UOWTransaction.
|
||||
|
||||
This will organize all collected UOWTasks into a dependency-sorted
|
||||
list which is then traversed using the traversal scheme
|
||||
encoded in the UOWExecutor class. Operations to mappers and dependency
|
||||
processors are fired off in order to issue SQL to the database and
|
||||
synchronize instance attributes with database values and related
|
||||
foreign key values."""
|
||||
|
||||
# pre-execute dependency processors. this process may
|
||||
# result in new tasks, objects and/or dependency processors being added,
|
||||
# particularly with 'delete-orphan' cascade rules.
|
||||
# keep running through the full list of tasks until all
|
||||
# objects have been processed.
|
||||
while True:
|
||||
ret = False
|
||||
for task in self.tasks.values():
|
||||
for up in list(task.dependencies):
|
||||
if up.preexecute(self):
|
||||
ret = True
|
||||
if not ret:
|
||||
break
|
||||
|
||||
tasks = self._sort_dependencies()
|
||||
if self._should_log_info():
|
||||
self.logger.info("Task dump:\n%s", self._dump(tasks))
|
||||
UOWExecutor().execute(self, tasks)
|
||||
self.logger.info("Execute Complete")
|
||||
|
||||
def _dump(self, tasks):
|
||||
from uowdumper import UOWDumper
|
||||
return UOWDumper.dump(tasks)
|
||||
|
||||
@property
|
||||
def elements(self):
|
||||
"""Iterate UOWTaskElements."""
|
||||
|
||||
for task in self.tasks.itervalues():
|
||||
for elem in task.elements:
|
||||
yield elem
|
||||
|
||||
def finalize_flush_changes(self):
|
||||
"""mark processed objects as clean / deleted after a successful flush().
|
||||
|
||||
this method is called within the flush() method after the
|
||||
execute() method has succeeded and the transaction has been committed.
|
||||
"""
|
||||
|
||||
for elem in self.elements:
|
||||
if elem.isdelete:
|
||||
self.session._remove_newly_deleted(elem.state)
|
||||
elif not elem.listonly:
|
||||
self.session._register_newly_persistent(elem.state)
|
||||
|
||||
def _sort_dependencies(self):
|
||||
nodes = topological.sort_with_cycles(self.dependencies,
|
||||
[t.mapper for t in self.tasks.itervalues() if t.base_task is t]
|
||||
)
|
||||
|
||||
ret = []
|
||||
for item, cycles in nodes:
|
||||
task = self.get_task_by_mapper(item)
|
||||
if cycles:
|
||||
for t in task._sort_circular_dependencies(
|
||||
self,
|
||||
[self.get_task_by_mapper(i) for i in cycles]
|
||||
):
|
||||
ret.append(t)
|
||||
else:
|
||||
ret.append(task)
|
||||
|
||||
return ret
|
||||
|
||||
log.class_logger(UOWTransaction)
|
||||
|
||||
class UOWTask(object):
|
||||
"""A collection of mapped states corresponding to a particular mapper."""
|
||||
|
||||
def __init__(self, uowtransaction, mapper, base_task=None):
|
||||
self.uowtransaction = uowtransaction
|
||||
|
||||
# base_task is the UOWTask which represents the "base mapper"
|
||||
# in our mapper's inheritance chain. if the mapper does not
|
||||
# inherit from any other mapper, the base_task is self.
|
||||
# the _inheriting_tasks dictionary is a dictionary present only
|
||||
# on the "base_task"-holding UOWTask, which maps all mappers within
|
||||
# an inheritance hierarchy to their corresponding UOWTask instances.
|
||||
if base_task is None:
|
||||
self.base_task = self
|
||||
self._inheriting_tasks = {mapper:self}
|
||||
else:
|
||||
self.base_task = base_task
|
||||
base_task._inheriting_tasks[mapper] = self
|
||||
|
||||
# the Mapper which this UOWTask corresponds to
|
||||
self.mapper = mapper
|
||||
|
||||
# mapping of InstanceState -> UOWTaskElement
|
||||
self._objects = {}
|
||||
|
||||
self.dependent_tasks = []
|
||||
self.dependencies = set()
|
||||
self.cyclical_dependencies = set()
|
||||
|
||||
@util.memoized_property
|
||||
def inheriting_mappers(self):
|
||||
return list(self.mapper.polymorphic_iterator())
|
||||
|
||||
@property
|
||||
def polymorphic_tasks(self):
|
||||
"""Return an iterator of UOWTask objects corresponding to the
|
||||
inheritance sequence of this UOWTask's mapper.
|
||||
|
||||
e.g. if mapper B and mapper C inherit from mapper A, and
|
||||
mapper D inherits from B:
|
||||
|
||||
mapperA -> mapperB -> mapperD
|
||||
-> mapperC
|
||||
|
||||
the inheritance sequence starting at mapper A is a depth-first
|
||||
traversal:
|
||||
|
||||
[mapperA, mapperB, mapperD, mapperC]
|
||||
|
||||
this method will therefore return
|
||||
|
||||
[UOWTask(mapperA), UOWTask(mapperB), UOWTask(mapperD),
|
||||
UOWTask(mapperC)]
|
||||
|
||||
The concept of "polymporphic iteration" is adapted into
|
||||
several property-based iterators which return object
|
||||
instances, UOWTaskElements and UOWDependencyProcessors in an
|
||||
order corresponding to this sequence of parent UOWTasks. This
|
||||
is used to issue operations related to inheritance-chains of
|
||||
mappers in the proper order based on dependencies between
|
||||
those mappers.
|
||||
|
||||
"""
|
||||
for mapper in self.inheriting_mappers:
|
||||
t = self.base_task._inheriting_tasks.get(mapper, None)
|
||||
if t is not None:
|
||||
yield t
|
||||
|
||||
def is_empty(self):
|
||||
"""return True if this UOWTask is 'empty', meaning it has no child items.
|
||||
|
||||
used only for debugging output.
|
||||
"""
|
||||
|
||||
return not self._objects and not self.dependencies
|
||||
|
||||
def append(self, state, listonly=False, isdelete=False):
|
||||
if state not in self._objects:
|
||||
self._objects[state] = rec = UOWTaskElement(state)
|
||||
else:
|
||||
rec = self._objects[state]
|
||||
|
||||
rec.update(listonly, isdelete)
|
||||
|
||||
def append_postupdate(self, state, post_update_cols):
|
||||
"""issue a 'post update' UPDATE statement via this object's mapper immediately.
|
||||
|
||||
this operation is used only with relationships that specify the `post_update=True`
|
||||
flag.
|
||||
"""
|
||||
|
||||
# postupdates are UPDATED immeditely (for now)
|
||||
# convert post_update_cols list to a Set so that __hash__() is used to compare columns
|
||||
# instead of __eq__()
|
||||
self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=set(post_update_cols))
|
||||
|
||||
def __contains__(self, state):
|
||||
"""return True if the given object is contained within this UOWTask or inheriting tasks."""
|
||||
|
||||
for task in self.polymorphic_tasks:
|
||||
if state in task._objects:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_deleted(self, state):
|
||||
"""return True if the given object is marked as to be deleted within this UOWTask."""
|
||||
|
||||
try:
|
||||
return self._objects[state].isdelete
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def _polymorphic_collection(fn):
|
||||
"""return a property that will adapt the collection returned by the
|
||||
given callable into a polymorphic traversal."""
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
for task in self.polymorphic_tasks:
|
||||
for rec in fn(task):
|
||||
yield rec
|
||||
return collection
|
||||
|
||||
def _polymorphic_collection_filtered(fn):
|
||||
|
||||
def collection(self, mappers):
|
||||
for task in self.polymorphic_tasks:
|
||||
if task.mapper in mappers:
|
||||
for rec in fn(task):
|
||||
yield rec
|
||||
return collection
|
||||
|
||||
@property
|
||||
def elements(self):
|
||||
return self._objects.values()
|
||||
|
||||
@_polymorphic_collection
|
||||
def polymorphic_elements(self):
|
||||
return self.elements
|
||||
|
||||
@_polymorphic_collection_filtered
|
||||
def filter_polymorphic_elements(self):
|
||||
return self.elements
|
||||
|
||||
@property
|
||||
def polymorphic_tosave_elements(self):
|
||||
return [rec for rec in self.polymorphic_elements if not rec.isdelete]
|
||||
|
||||
@property
|
||||
def polymorphic_todelete_elements(self):
|
||||
return [rec for rec in self.polymorphic_elements if rec.isdelete]
|
||||
|
||||
@property
|
||||
def polymorphic_tosave_objects(self):
|
||||
return [
|
||||
rec.state for rec in self.polymorphic_elements
|
||||
if rec.state is not None and not rec.listonly and rec.isdelete is False
|
||||
]
|
||||
|
||||
@property
|
||||
def polymorphic_todelete_objects(self):
|
||||
return [
|
||||
rec.state for rec in self.polymorphic_elements
|
||||
if rec.state is not None and not rec.listonly and rec.isdelete is True
|
||||
]
|
||||
|
||||
@_polymorphic_collection
|
||||
def polymorphic_dependencies(self):
|
||||
return self.dependencies
|
||||
|
||||
@_polymorphic_collection
|
||||
def polymorphic_cyclical_dependencies(self):
|
||||
return self.cyclical_dependencies
|
||||
|
||||
def _sort_circular_dependencies(self, trans, cycles):
|
||||
"""Topologically sort individual entities with row-level dependencies.
|
||||
|
||||
Builds a modified UOWTask structure, and is invoked when the
|
||||
per-mapper topological structure is found to have cycles.
|
||||
|
||||
"""
|
||||
|
||||
dependencies = {}
|
||||
def set_processor_for_state(state, depprocessor, target_state, isdelete):
|
||||
if state not in dependencies:
|
||||
dependencies[state] = {}
|
||||
tasks = dependencies[state]
|
||||
if depprocessor not in tasks:
|
||||
tasks[depprocessor] = UOWDependencyProcessor(
|
||||
depprocessor.processor,
|
||||
UOWTask(self.uowtransaction, depprocessor.targettask.mapper)
|
||||
)
|
||||
tasks[depprocessor].targettask.append(target_state, isdelete=isdelete)
|
||||
|
||||
cycles = set(cycles)
|
||||
def dependency_in_cycles(dep):
|
||||
proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper, True)
|
||||
targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper, True)
|
||||
return targettask in cycles and (proctask is not None and proctask in cycles)
|
||||
|
||||
deps_by_targettask = {}
|
||||
extradeplist = []
|
||||
for task in cycles:
|
||||
for dep in task.polymorphic_dependencies:
|
||||
if not dependency_in_cycles(dep):
|
||||
extradeplist.append(dep)
|
||||
for t in dep.targettask.polymorphic_tasks:
|
||||
l = deps_by_targettask.setdefault(t, [])
|
||||
l.append(dep)
|
||||
|
||||
object_to_original_task = {}
|
||||
tuples = []
|
||||
|
||||
for task in cycles:
|
||||
for subtask in task.polymorphic_tasks:
|
||||
for taskelement in subtask.elements:
|
||||
state = taskelement.state
|
||||
object_to_original_task[state] = subtask
|
||||
if subtask not in deps_by_targettask:
|
||||
continue
|
||||
for dep in deps_by_targettask[subtask]:
|
||||
if not dep.processor.has_dependencies or not dependency_in_cycles(dep):
|
||||
continue
|
||||
(processor, targettask) = (dep.processor, dep.targettask)
|
||||
isdelete = taskelement.isdelete
|
||||
|
||||
# list of dependent objects from this object
|
||||
(added, unchanged, deleted) = dep.get_object_dependencies(state, trans, passive=True)
|
||||
if not added and not unchanged and not deleted:
|
||||
continue
|
||||
|
||||
# the task corresponding to saving/deleting of those dependent objects
|
||||
childtask = trans.get_task_by_mapper(processor.mapper)
|
||||
|
||||
childlist = added + unchanged + deleted
|
||||
|
||||
for o in childlist:
|
||||
if o is None:
|
||||
continue
|
||||
|
||||
if o not in childtask:
|
||||
childtask.append(o, listonly=True)
|
||||
object_to_original_task[o] = childtask
|
||||
|
||||
whosdep = dep.whose_dependent_on_who(state, o)
|
||||
if whosdep is not None:
|
||||
tuples.append(whosdep)
|
||||
|
||||
if whosdep[0] is state:
|
||||
set_processor_for_state(whosdep[0], dep, whosdep[0], isdelete=isdelete)
|
||||
else:
|
||||
set_processor_for_state(whosdep[0], dep, whosdep[1], isdelete=isdelete)
|
||||
else:
|
||||
# TODO: no test coverage here
|
||||
set_processor_for_state(state, dep, state, isdelete=isdelete)
|
||||
|
||||
t = UOWTask(self.uowtransaction, self.mapper)
|
||||
t.dependencies.update(extradeplist)
|
||||
|
||||
used_tasks = set()
|
||||
|
||||
# rationale for "tree" sort as opposed to a straight
|
||||
# dependency - keep non-dependent objects
|
||||
# grouped together, so that insert ordering as determined
|
||||
# by session.add() is maintained.
|
||||
# An alternative might be to represent the "insert order"
|
||||
# as part of the topological sort itself, which would
|
||||
# eliminate the need for this step (but may make the original
|
||||
# topological sort more expensive)
|
||||
head = topological.sort_as_tree(tuples, object_to_original_task.iterkeys())
|
||||
if head is not None:
|
||||
original_to_tasks = {}
|
||||
stack = [(head, t)]
|
||||
while stack:
|
||||
((state, cycles, children), parenttask) = stack.pop()
|
||||
|
||||
originating_task = object_to_original_task[state]
|
||||
used_tasks.add(originating_task)
|
||||
|
||||
if (parenttask, originating_task) not in original_to_tasks:
|
||||
task = UOWTask(self.uowtransaction, originating_task.mapper)
|
||||
original_to_tasks[(parenttask, originating_task)] = task
|
||||
parenttask.dependent_tasks.append(task)
|
||||
else:
|
||||
task = original_to_tasks[(parenttask, originating_task)]
|
||||
|
||||
task.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete)
|
||||
|
||||
if state in dependencies:
|
||||
task.cyclical_dependencies.update(dependencies[state].itervalues())
|
||||
|
||||
stack += [(n, task) for n in children]
|
||||
|
||||
ret = [t]
|
||||
|
||||
# add tasks that were in the cycle, but didnt get assembled
|
||||
# into the cyclical tree, to the start of the list
|
||||
for t2 in cycles:
|
||||
if t2 not in used_tasks and t2 is not self:
|
||||
localtask = UOWTask(self.uowtransaction, t2.mapper)
|
||||
for state in t2.elements:
|
||||
localtask.append(state, t2.listonly, isdelete=t2._objects[state].isdelete)
|
||||
for dep in t2.dependencies:
|
||||
localtask.dependencies.add(dep)
|
||||
ret.insert(0, localtask)
|
||||
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return ("UOWTask(%s) Mapper: '%r'" % (hex(id(self)), self.mapper))
|
||||
|
||||
class UOWTaskElement(object):
|
||||
"""Corresponds to a single InstanceState to be saved, deleted,
|
||||
or otherwise marked as having dependencies. A collection of
|
||||
UOWTaskElements are held by a UOWTask.
|
||||
|
||||
"""
|
||||
def __init__(self, state):
|
||||
self.state = state
|
||||
self.listonly = True
|
||||
self.isdelete = False
|
||||
self.preprocessed = set()
|
||||
|
||||
def update(self, listonly, isdelete):
|
||||
if not listonly and self.listonly:
|
||||
self.listonly = False
|
||||
self.preprocessed.clear()
|
||||
if isdelete and not self.isdelete:
|
||||
self.isdelete = True
|
||||
self.preprocessed.clear()
|
||||
|
||||
def __repr__(self):
|
||||
return "UOWTaskElement/%d: %s/%d %s" % (
|
||||
id(self),
|
||||
self.state.class_.__name__,
|
||||
id(self.state.obj()),
|
||||
(self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save'))
|
||||
)
|
||||
|
||||
class UOWDependencyProcessor(object):
|
||||
"""In between the saving and deleting of objects, process
|
||||
dependent data, such as filling in a foreign key on a child item
|
||||
from a new primary key, or deleting association rows before a
|
||||
delete. This object acts as a proxy to a DependencyProcessor.
|
||||
|
||||
"""
|
||||
def __init__(self, processor, targettask):
|
||||
self.processor = processor
|
||||
self.targettask = targettask
|
||||
prop = processor.prop
|
||||
|
||||
# define a set of mappers which
|
||||
# will filter the lists of entities
|
||||
# this UOWDP processes. this allows
|
||||
# MapperProperties to be overridden
|
||||
# at least for concrete mappers.
|
||||
self._mappers = set([
|
||||
m
|
||||
for m in self.processor.parent.polymorphic_iterator()
|
||||
if m._props[prop.key] is prop
|
||||
]).union(self.processor.mapper.polymorphic_iterator())
|
||||
|
||||
def __repr__(self):
|
||||
return "UOWDependencyProcessor(%s, %s)" % (str(self.processor), str(self.targettask))
|
||||
|
||||
def __eq__(self, other):
|
||||
return other.processor is self.processor and other.targettask is self.targettask
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.processor, self.targettask))
|
||||
|
||||
def preexecute(self, trans):
|
||||
"""preprocess all objects contained within this ``UOWDependencyProcessor``s target task.
|
||||
|
||||
This may locate additional objects which should be part of the
|
||||
transaction, such as those affected deletes, orphans to be
|
||||
deleted, etc.
|
||||
|
||||
Once an object is preprocessed, its ``UOWTaskElement`` is marked as processed. If subsequent
|
||||
changes occur to the ``UOWTaskElement``, its processed flag is reset, and will require processing
|
||||
again.
|
||||
|
||||
Return True if any objects were preprocessed, or False if no
|
||||
objects were preprocessed. If True is returned, the parent ``UOWTransaction`` will
|
||||
ultimately call ``preexecute()`` again on all processors until no new objects are processed.
|
||||
"""
|
||||
|
||||
def getobj(elem):
|
||||
elem.preprocessed.add(self)
|
||||
return elem.state
|
||||
|
||||
ret = False
|
||||
elements = [getobj(elem) for elem in
|
||||
self.targettask.filter_polymorphic_elements(self._mappers)
|
||||
if self not in elem.preprocessed and not elem.isdelete]
|
||||
if elements:
|
||||
ret = True
|
||||
self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False)
|
||||
|
||||
elements = [getobj(elem) for elem in
|
||||
self.targettask.filter_polymorphic_elements(self._mappers)
|
||||
if self not in elem.preprocessed and elem.isdelete]
|
||||
if elements:
|
||||
ret = True
|
||||
self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True)
|
||||
return ret
|
||||
|
||||
def execute(self, trans, delete):
|
||||
"""process all objects contained within this ``UOWDependencyProcessor``s target task."""
|
||||
|
||||
|
||||
elements = [e for e in
|
||||
self.targettask.filter_polymorphic_elements(self._mappers)
|
||||
if bool(e.isdelete)==delete]
|
||||
|
||||
self.processor.process_dependencies(
|
||||
self.targettask,
|
||||
[elem.state for elem in elements],
|
||||
trans,
|
||||
delete=delete)
|
||||
|
||||
def get_object_dependencies(self, state, trans, passive):
|
||||
return trans.get_attribute_history(state, self.processor.key, passive=passive)
|
||||
|
||||
def whose_dependent_on_who(self, state1, state2):
|
||||
"""establish which object is operationally dependent amongst a parent/child
|
||||
using the semantics stated by the dependency processor.
|
||||
|
||||
This method is used to establish a partial ordering (set of dependency tuples)
|
||||
when toplogically sorting on a per-instance basis.
|
||||
|
||||
"""
|
||||
return self.processor.whose_dependent_on_who(state1, state2)
|
||||
|
||||
class UOWExecutor(object):
|
||||
"""Encapsulates the execution traversal of a UOWTransaction structure."""
|
||||
|
||||
def execute(self, trans, tasks, isdelete=None):
|
||||
if isdelete is not True:
|
||||
for task in tasks:
|
||||
self.execute_save_steps(trans, task)
|
||||
if isdelete is not False:
|
||||
for task in reversed(tasks):
|
||||
self.execute_delete_steps(trans, task)
|
||||
|
||||
def save_objects(self, trans, task):
|
||||
task.mapper._save_obj(task.polymorphic_tosave_objects, trans)
|
||||
|
||||
def delete_objects(self, trans, task):
|
||||
task.mapper._delete_obj(task.polymorphic_todelete_objects, trans)
|
||||
|
||||
def execute_dependency(self, trans, dep, isdelete):
|
||||
dep.execute(trans, isdelete)
|
||||
|
||||
def execute_save_steps(self, trans, task):
|
||||
self.save_objects(trans, task)
|
||||
for dep in task.polymorphic_cyclical_dependencies:
|
||||
self.execute_dependency(trans, dep, False)
|
||||
for dep in task.polymorphic_cyclical_dependencies:
|
||||
self.execute_dependency(trans, dep, True)
|
||||
self.execute_cyclical_dependencies(trans, task, False)
|
||||
self.execute_dependencies(trans, task)
|
||||
|
||||
def execute_delete_steps(self, trans, task):
|
||||
self.execute_cyclical_dependencies(trans, task, True)
|
||||
self.delete_objects(trans, task)
|
||||
|
||||
def execute_dependencies(self, trans, task):
|
||||
polymorphic_dependencies = list(task.polymorphic_dependencies)
|
||||
for dep in polymorphic_dependencies:
|
||||
self.execute_dependency(trans, dep, False)
|
||||
for dep in reversed(polymorphic_dependencies):
|
||||
self.execute_dependency(trans, dep, True)
|
||||
|
||||
def execute_cyclical_dependencies(self, trans, task, isdelete):
|
||||
for t in task.dependent_tasks:
|
||||
self.execute(trans, [t], isdelete)
|
||||
101
sqlalchemy/orm/uowdumper.py
Normal file
101
sqlalchemy/orm/uowdumper.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# orm/uowdumper.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Dumps out a string representation of a UOWTask structure"""
|
||||
|
||||
from sqlalchemy.orm import unitofwork
|
||||
from sqlalchemy.orm import util as mapperutil
|
||||
import StringIO
|
||||
|
||||
class UOWDumper(unitofwork.UOWExecutor):
|
||||
def __init__(self, tasks, buf):
|
||||
self.indent = 0
|
||||
self.tasks = tasks
|
||||
self.buf = buf
|
||||
self.execute(None, tasks)
|
||||
|
||||
@classmethod
|
||||
def dump(cls, tasks):
|
||||
buf = StringIO.StringIO()
|
||||
UOWDumper(tasks, buf)
|
||||
return buf.getvalue()
|
||||
|
||||
def execute(self, trans, tasks, isdelete=None):
|
||||
if isdelete is not True:
|
||||
for task in tasks:
|
||||
self._execute(trans, task, False)
|
||||
if isdelete is not False:
|
||||
for task in reversed(tasks):
|
||||
self._execute(trans, task, True)
|
||||
|
||||
def _execute(self, trans, task, isdelete):
|
||||
try:
|
||||
i = self._indent()
|
||||
if i:
|
||||
i = i[:-1] + "+-"
|
||||
self.buf.write(i + " " + self._repr_task(task))
|
||||
self.buf.write(" (" + (isdelete and "delete " or "save/update ") + "phase) \n")
|
||||
self.indent += 1
|
||||
super(UOWDumper, self).execute(trans, [task], isdelete)
|
||||
finally:
|
||||
self.indent -= 1
|
||||
|
||||
|
||||
def save_objects(self, trans, task):
|
||||
for rec in sorted(task.polymorphic_tosave_elements, key=lambda a: a.state.sort_key):
|
||||
if rec.listonly:
|
||||
continue
|
||||
self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec) + "\n")
|
||||
|
||||
def delete_objects(self, trans, task):
|
||||
for rec in task.polymorphic_todelete_elements:
|
||||
if rec.listonly:
|
||||
continue
|
||||
self.buf.write(self._indent() + "- " + self._repr_task_element(rec) + "\n")
|
||||
|
||||
def execute_dependency(self, transaction, dep, isdelete):
|
||||
self._dump_processor(dep, isdelete)
|
||||
|
||||
def _dump_processor(self, proc, deletes):
|
||||
if deletes:
|
||||
val = proc.targettask.polymorphic_todelete_elements
|
||||
else:
|
||||
val = proc.targettask.polymorphic_tosave_elements
|
||||
|
||||
for v in val:
|
||||
self.buf.write(self._indent() + " +- " + self._repr_task_element(v, proc.processor.key, process=True) + "\n")
|
||||
|
||||
def _repr_task_element(self, te, attribute=None, process=False):
|
||||
if getattr(te, 'state', None) is None:
|
||||
objid = "(placeholder)"
|
||||
else:
|
||||
if attribute is not None:
|
||||
objid = "%s.%s" % (mapperutil.state_str(te.state), attribute)
|
||||
else:
|
||||
objid = mapperutil.state_str(te.state)
|
||||
if process:
|
||||
return "Process %s" % (objid)
|
||||
else:
|
||||
return "%s %s" % ((te.isdelete and "Delete" or "Save"), objid)
|
||||
|
||||
def _repr_task(self, task):
|
||||
if task.mapper is not None:
|
||||
if task.mapper.__class__.__name__ == 'Mapper':
|
||||
name = task.mapper.class_.__name__ + "/" + task.mapper.local_table.description
|
||||
else:
|
||||
name = repr(task.mapper)
|
||||
else:
|
||||
name = '(none)'
|
||||
return ("UOWTask(%s, %s)" % (hex(id(task)), name))
|
||||
|
||||
def _repr_task_class(self, task):
|
||||
if task.mapper is not None and task.mapper.__class__.__name__ == 'Mapper':
|
||||
return task.mapper.class_.__name__
|
||||
else:
|
||||
return '(none)'
|
||||
|
||||
def _indent(self):
|
||||
return " |" * self.indent
|
||||
668
sqlalchemy/orm/util.py
Normal file
668
sqlalchemy/orm/util.py
Normal file
@@ -0,0 +1,668 @@
|
||||
# mapper/util.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import sqlalchemy.exceptions as sa_exc
|
||||
from sqlalchemy import sql, util
|
||||
from sqlalchemy.sql import expression, util as sql_util, operators
|
||||
from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, \
|
||||
MapperProperty, AttributeExtension
|
||||
from sqlalchemy.orm import attributes, exc
|
||||
|
||||
mapperlib = None
|
||||
|
||||
all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
|
||||
"expunge", "save-update", "refresh-expire",
|
||||
"none"))
|
||||
|
||||
_INSTRUMENTOR = ('mapper', 'instrumentor')
|
||||
|
||||
class CascadeOptions(object):
|
||||
"""Keeps track of the options sent to relationship().cascade"""
|
||||
|
||||
def __init__(self, arg=""):
|
||||
if not arg:
|
||||
values = set()
|
||||
else:
|
||||
values = set(c.strip() for c in arg.split(','))
|
||||
self.delete_orphan = "delete-orphan" in values
|
||||
self.delete = "delete" in values or "all" in values
|
||||
self.save_update = "save-update" in values or "all" in values
|
||||
self.merge = "merge" in values or "all" in values
|
||||
self.expunge = "expunge" in values or "all" in values
|
||||
self.refresh_expire = "refresh-expire" in values or "all" in values
|
||||
|
||||
if self.delete_orphan and not self.delete:
|
||||
util.warn("The 'delete-orphan' cascade option requires "
|
||||
"'delete'. This will raise an error in 0.6.")
|
||||
|
||||
for x in values:
|
||||
if x not in all_cascades:
|
||||
raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x)
|
||||
|
||||
def __contains__(self, item):
|
||||
return getattr(self, item.replace("-", "_"), False)
|
||||
|
||||
def __repr__(self):
|
||||
return "CascadeOptions(%s)" % repr(",".join(
|
||||
[x for x in ['delete', 'save_update', 'merge', 'expunge',
|
||||
'delete_orphan', 'refresh-expire']
|
||||
if getattr(self, x, False) is True]))
|
||||
|
||||
|
||||
class Validator(AttributeExtension):
|
||||
"""Runs a validation method on an attribute value to be set or appended.
|
||||
|
||||
The Validator class is used by the :func:`~sqlalchemy.orm.validates`
|
||||
decorator, and direct access is usually not needed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, key, validator):
|
||||
"""Construct a new Validator.
|
||||
|
||||
key - name of the attribute to be validated;
|
||||
will be passed as the second argument to
|
||||
the validation method (the first is the object instance itself).
|
||||
|
||||
validator - an function or instance method which accepts
|
||||
three arguments; an instance (usually just 'self' for a method),
|
||||
the key name of the attribute, and the value. The function should
|
||||
return the same value given, unless it wishes to modify it.
|
||||
|
||||
"""
|
||||
self.key = key
|
||||
self.validator = validator
|
||||
|
||||
def append(self, state, value, initiator):
|
||||
return self.validator(state.obj(), self.key, value)
|
||||
|
||||
def set(self, state, value, oldvalue, initiator):
|
||||
return self.validator(state.obj(), self.key, value)
|
||||
|
||||
def polymorphic_union(table_map, typecolname, aliasname='p_union'):
|
||||
"""Create a ``UNION`` statement used by a polymorphic mapper.
|
||||
|
||||
See :ref:`concrete_inheritance` for an example of how
|
||||
this is used.
|
||||
"""
|
||||
|
||||
colnames = set()
|
||||
colnamemaps = {}
|
||||
types = {}
|
||||
for key in table_map.keys():
|
||||
table = table_map[key]
|
||||
|
||||
# mysql doesnt like selecting from a select; make it an alias of the select
|
||||
if isinstance(table, sql.Select):
|
||||
table = table.alias()
|
||||
table_map[key] = table
|
||||
|
||||
m = {}
|
||||
for c in table.c:
|
||||
colnames.add(c.key)
|
||||
m[c.key] = c
|
||||
types[c.key] = c.type
|
||||
colnamemaps[table] = m
|
||||
|
||||
def col(name, table):
|
||||
try:
|
||||
return colnamemaps[table][name]
|
||||
except KeyError:
|
||||
return sql.cast(sql.null(), types[name]).label(name)
|
||||
|
||||
result = []
|
||||
for type, table in table_map.iteritems():
|
||||
if typecolname is not None:
|
||||
result.append(sql.select([col(name, table) for name in colnames] +
|
||||
[sql.literal_column("'%s'" % type).label(typecolname)],
|
||||
from_obj=[table]))
|
||||
else:
|
||||
result.append(sql.select([col(name, table) for name in colnames],
|
||||
from_obj=[table]))
|
||||
return sql.union_all(*result).alias(aliasname)
|
||||
|
||||
def identity_key(*args, **kwargs):
|
||||
"""Get an identity key.
|
||||
|
||||
Valid call signatures:
|
||||
|
||||
* ``identity_key(class, ident)``
|
||||
|
||||
class
|
||||
mapped class (must be a positional argument)
|
||||
|
||||
ident
|
||||
primary key, if the key is composite this is a tuple
|
||||
|
||||
|
||||
* ``identity_key(instance=instance)``
|
||||
|
||||
instance
|
||||
object instance (must be given as a keyword arg)
|
||||
|
||||
* ``identity_key(class, row=row)``
|
||||
|
||||
class
|
||||
mapped class (must be a positional argument)
|
||||
|
||||
row
|
||||
result proxy row (must be given as a keyword arg)
|
||||
|
||||
"""
|
||||
if args:
|
||||
if len(args) == 1:
|
||||
class_ = args[0]
|
||||
try:
|
||||
row = kwargs.pop("row")
|
||||
except KeyError:
|
||||
ident = kwargs.pop("ident")
|
||||
elif len(args) == 2:
|
||||
class_, ident = args
|
||||
elif len(args) == 3:
|
||||
class_, ident = args
|
||||
else:
|
||||
raise sa_exc.ArgumentError("expected up to three "
|
||||
"positional arguments, got %s" % len(args))
|
||||
if kwargs:
|
||||
raise sa_exc.ArgumentError("unknown keyword arguments: %s"
|
||||
% ", ".join(kwargs.keys()))
|
||||
mapper = class_mapper(class_)
|
||||
if "ident" in locals():
|
||||
return mapper.identity_key_from_primary_key(ident)
|
||||
return mapper.identity_key_from_row(row)
|
||||
instance = kwargs.pop("instance")
|
||||
if kwargs:
|
||||
raise sa_exc.ArgumentError("unknown keyword arguments: %s"
|
||||
% ", ".join(kwargs.keys()))
|
||||
mapper = object_mapper(instance)
|
||||
return mapper.identity_key_from_instance(instance)
|
||||
|
||||
class ExtensionCarrier(dict):
|
||||
"""Fronts an ordered collection of MapperExtension objects.
|
||||
|
||||
Bundles multiple MapperExtensions into a unified callable unit,
|
||||
encapsulating ordering, looping and EXT_CONTINUE logic. The
|
||||
ExtensionCarrier implements the MapperExtension interface, e.g.::
|
||||
|
||||
carrier.after_insert(...args...)
|
||||
|
||||
The dictionary interface provides containment for implemented
|
||||
method names mapped to a callable which executes that method
|
||||
for participating extensions.
|
||||
|
||||
"""
|
||||
|
||||
interface = set(method for method in dir(MapperExtension)
|
||||
if not method.startswith('_'))
|
||||
|
||||
def __init__(self, extensions=None):
|
||||
self._extensions = []
|
||||
for ext in extensions or ():
|
||||
self.append(ext)
|
||||
|
||||
def copy(self):
|
||||
return ExtensionCarrier(self._extensions)
|
||||
|
||||
def push(self, extension):
|
||||
"""Insert a MapperExtension at the beginning of the collection."""
|
||||
self._register(extension)
|
||||
self._extensions.insert(0, extension)
|
||||
|
||||
def append(self, extension):
|
||||
"""Append a MapperExtension at the end of the collection."""
|
||||
self._register(extension)
|
||||
self._extensions.append(extension)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over MapperExtensions in the collection."""
|
||||
return iter(self._extensions)
|
||||
|
||||
def _register(self, extension):
|
||||
"""Register callable fronts for overridden interface methods."""
|
||||
|
||||
for method in self.interface.difference(self):
|
||||
impl = getattr(extension, method, None)
|
||||
if impl and impl is not getattr(MapperExtension, method):
|
||||
self[method] = self._create_do(method)
|
||||
|
||||
def _create_do(self, method):
|
||||
"""Return a closure that loops over impls of the named method."""
|
||||
|
||||
def _do(*args, **kwargs):
|
||||
for ext in self._extensions:
|
||||
ret = getattr(ext, method)(*args, **kwargs)
|
||||
if ret is not EXT_CONTINUE:
|
||||
return ret
|
||||
else:
|
||||
return EXT_CONTINUE
|
||||
_do.__name__ = method
|
||||
return _do
|
||||
|
||||
@staticmethod
|
||||
def _pass(*args, **kwargs):
|
||||
return EXT_CONTINUE
|
||||
|
||||
def __getattr__(self, key):
|
||||
"""Delegate MapperExtension methods to bundled fronts."""
|
||||
|
||||
if key not in self.interface:
|
||||
raise AttributeError(key)
|
||||
return self.get(key, self._pass)
|
||||
|
||||
class ORMAdapter(sql_util.ColumnAdapter):
|
||||
"""Extends ColumnAdapter to accept ORM entities.
|
||||
|
||||
The selectable is extracted from the given entity,
|
||||
and the AliasedClass if any is referenced.
|
||||
|
||||
"""
|
||||
def __init__(self, entity, equivalents=None, chain_to=None, adapt_required=False):
|
||||
self.mapper, selectable, is_aliased_class = _entity_info(entity)
|
||||
if is_aliased_class:
|
||||
self.aliased_class = entity
|
||||
else:
|
||||
self.aliased_class = None
|
||||
sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to, adapt_required=adapt_required)
|
||||
|
||||
def replace(self, elem):
|
||||
entity = elem._annotations.get('parentmapper', None)
|
||||
if not entity or entity.isa(self.mapper):
|
||||
return sql_util.ColumnAdapter.replace(self, elem)
|
||||
else:
|
||||
return None
|
||||
|
||||
class AliasedClass(object):
|
||||
"""Represents an "aliased" form of a mapped class for usage with Query.
|
||||
|
||||
The ORM equivalent of a :func:`sqlalchemy.sql.expression.alias`
|
||||
construct, this object mimics the mapped class using a
|
||||
__getattr__ scheme and maintains a reference to a
|
||||
real :class:`~sqlalchemy.sql.expression.Alias` object.
|
||||
|
||||
Usage is via the :class:`~sqlalchemy.orm.aliased()` synonym::
|
||||
|
||||
# find all pairs of users with the same name
|
||||
user_alias = aliased(User)
|
||||
session.query(User, user_alias).\\
|
||||
join((user_alias, User.id > user_alias.id)).\\
|
||||
filter(User.name==user_alias.name)
|
||||
|
||||
"""
|
||||
def __init__(self, cls, alias=None, name=None):
|
||||
self.__mapper = _class_to_mapper(cls)
|
||||
self.__target = self.__mapper.class_
|
||||
if alias is None:
|
||||
alias = self.__mapper._with_polymorphic_selectable.alias()
|
||||
self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
|
||||
self.__alias = alias
|
||||
# used to assign a name to the RowTuple object
|
||||
# returned by Query.
|
||||
self._sa_label_name = name
|
||||
self.__name__ = 'AliasedClass_' + str(self.__target)
|
||||
|
||||
def __getstate__(self):
|
||||
return {'mapper':self.__mapper, 'alias':self.__alias, 'name':self._sa_label_name}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__mapper = state['mapper']
|
||||
self.__target = self.__mapper.class_
|
||||
alias = state['alias']
|
||||
self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
|
||||
self.__alias = alias
|
||||
name = state['name']
|
||||
self._sa_label_name = name
|
||||
self.__name__ = 'AliasedClass_' + str(self.__target)
|
||||
|
||||
def __adapt_element(self, elem):
|
||||
return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper})
|
||||
|
||||
def __adapt_prop(self, prop):
|
||||
existing = getattr(self.__target, prop.key)
|
||||
comparator = existing.comparator.adapted(self.__adapt_element)
|
||||
|
||||
queryattr = attributes.QueryableAttribute(prop.key,
|
||||
impl=existing.impl, parententity=self, comparator=comparator)
|
||||
setattr(self, prop.key, queryattr)
|
||||
return queryattr
|
||||
|
||||
def __getattr__(self, key):
|
||||
prop = self.__mapper._get_property(key, raiseerr=False)
|
||||
if prop:
|
||||
return self.__adapt_prop(prop)
|
||||
|
||||
for base in self.__target.__mro__:
|
||||
try:
|
||||
attr = object.__getattribute__(base, key)
|
||||
except AttributeError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise AttributeError(key)
|
||||
|
||||
if hasattr(attr, 'func_code'):
|
||||
is_method = getattr(self.__target, key, None)
|
||||
if is_method and is_method.im_self is not None:
|
||||
return util.types.MethodType(attr.im_func, self, self)
|
||||
else:
|
||||
return None
|
||||
elif hasattr(attr, '__get__'):
|
||||
return attr.__get__(None, self)
|
||||
else:
|
||||
return attr
|
||||
|
||||
def __repr__(self):
|
||||
return '<AliasedClass at 0x%x; %s>' % (
|
||||
id(self), self.__target.__name__)
|
||||
|
||||
def _orm_annotate(element, exclude=None):
|
||||
"""Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag.
|
||||
|
||||
Elements within the exclude collection will be cloned but not annotated.
|
||||
|
||||
"""
|
||||
return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude)
|
||||
|
||||
_orm_deannotate = sql_util._deep_deannotate
|
||||
|
||||
class _ORMJoin(expression.Join):
|
||||
"""Extend Join to support ORM constructs as input."""
|
||||
|
||||
__visit_name__ = expression.Join.__visit_name__
|
||||
|
||||
def __init__(self, left, right, onclause=None, isouter=False, join_to_left=True):
|
||||
adapt_from = None
|
||||
|
||||
if hasattr(left, '_orm_mappers'):
|
||||
left_mapper = left._orm_mappers[1]
|
||||
if join_to_left:
|
||||
adapt_from = left.right
|
||||
else:
|
||||
left_mapper, left, left_is_aliased = _entity_info(left)
|
||||
if join_to_left and (left_is_aliased or not left_mapper):
|
||||
adapt_from = left
|
||||
|
||||
right_mapper, right, right_is_aliased = _entity_info(right)
|
||||
if right_is_aliased:
|
||||
adapt_to = right
|
||||
else:
|
||||
adapt_to = None
|
||||
|
||||
if left_mapper or right_mapper:
|
||||
self._orm_mappers = (left_mapper, right_mapper)
|
||||
|
||||
if isinstance(onclause, basestring):
|
||||
prop = left_mapper.get_property(onclause)
|
||||
elif isinstance(onclause, attributes.QueryableAttribute):
|
||||
if adapt_from is None:
|
||||
adapt_from = onclause.__clause_element__()
|
||||
prop = onclause.property
|
||||
elif isinstance(onclause, MapperProperty):
|
||||
prop = onclause
|
||||
else:
|
||||
prop = None
|
||||
|
||||
if prop:
|
||||
pj, sj, source, dest, secondary, target_adapter = prop._create_joins(
|
||||
source_selectable=adapt_from,
|
||||
dest_selectable=adapt_to,
|
||||
source_polymorphic=True,
|
||||
dest_polymorphic=True,
|
||||
of_type=right_mapper)
|
||||
|
||||
if sj is not None:
|
||||
left = sql.join(left, secondary, pj, isouter)
|
||||
onclause = sj
|
||||
else:
|
||||
onclause = pj
|
||||
self._target_adapter = target_adapter
|
||||
|
||||
expression.Join.__init__(self, left, right, onclause, isouter)
|
||||
|
||||
def join(self, right, onclause=None, isouter=False, join_to_left=True):
|
||||
return _ORMJoin(self, right, onclause, isouter, join_to_left)
|
||||
|
||||
def outerjoin(self, right, onclause=None, join_to_left=True):
|
||||
return _ORMJoin(self, right, onclause, True, join_to_left)
|
||||
|
||||
def join(left, right, onclause=None, isouter=False, join_to_left=True):
|
||||
"""Produce an inner join between left and right clauses.
|
||||
|
||||
In addition to the interface provided by
|
||||
:func:`~sqlalchemy.sql.expression.join()`, left and right may be mapped
|
||||
classes or AliasedClass instances. The onclause may be a
|
||||
string name of a relationship(), or a class-bound descriptor
|
||||
representing a relationship.
|
||||
|
||||
join_to_left indicates to attempt aliasing the ON clause,
|
||||
in whatever form it is passed, to the selectable
|
||||
passed as the left side. If False, the onclause
|
||||
is used as is.
|
||||
|
||||
"""
|
||||
return _ORMJoin(left, right, onclause, isouter, join_to_left)
|
||||
|
||||
def outerjoin(left, right, onclause=None, join_to_left=True):
|
||||
"""Produce a left outer join between left and right clauses.
|
||||
|
||||
In addition to the interface provided by
|
||||
:func:`~sqlalchemy.sql.expression.outerjoin()`, left and right may be mapped
|
||||
classes or AliasedClass instances. The onclause may be a
|
||||
string name of a relationship(), or a class-bound descriptor
|
||||
representing a relationship.
|
||||
|
||||
"""
|
||||
return _ORMJoin(left, right, onclause, True, join_to_left)
|
||||
|
||||
def with_parent(instance, prop):
|
||||
"""Return criterion which selects instances with a given parent.
|
||||
|
||||
instance
|
||||
a parent instance, which should be persistent or detached.
|
||||
|
||||
property
|
||||
a class-attached descriptor, MapperProperty or string property name
|
||||
attached to the parent instance.
|
||||
|
||||
\**kwargs
|
||||
all extra keyword arguments are propagated to the constructor of
|
||||
Query.
|
||||
|
||||
"""
|
||||
if isinstance(prop, basestring):
|
||||
mapper = object_mapper(instance)
|
||||
prop = mapper.get_property(prop, resolve_synonyms=True)
|
||||
elif isinstance(prop, attributes.QueryableAttribute):
|
||||
prop = prop.property
|
||||
|
||||
return prop.compare(operators.eq, instance, value_is_parent=True)
|
||||
|
||||
|
||||
def _entity_info(entity, compile=True):
|
||||
"""Return mapping information given a class, mapper, or AliasedClass.
|
||||
|
||||
Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this
|
||||
is an aliased() construct.
|
||||
|
||||
If the given entity is not a mapper, mapped class, or aliased construct,
|
||||
returns None, the entity, False. This is typically used to allow
|
||||
unmapped selectables through.
|
||||
|
||||
"""
|
||||
if isinstance(entity, AliasedClass):
|
||||
return entity._AliasedClass__mapper, entity._AliasedClass__alias, True
|
||||
|
||||
global mapperlib
|
||||
if mapperlib is None:
|
||||
from sqlalchemy.orm import mapperlib
|
||||
|
||||
if isinstance(entity, mapperlib.Mapper):
|
||||
mapper = entity
|
||||
|
||||
elif isinstance(entity, type):
|
||||
class_manager = attributes.manager_of_class(entity)
|
||||
|
||||
if class_manager is None:
|
||||
return None, entity, False
|
||||
|
||||
mapper = class_manager.mapper
|
||||
else:
|
||||
return None, entity, False
|
||||
|
||||
if compile:
|
||||
mapper = mapper.compile()
|
||||
return mapper, mapper._with_polymorphic_selectable, False
|
||||
|
||||
def _entity_descriptor(entity, key):
|
||||
"""Return attribute/property information given an entity and string name.
|
||||
|
||||
Returns a 2-tuple representing InstrumentedAttribute/MapperProperty.
|
||||
|
||||
"""
|
||||
if isinstance(entity, AliasedClass):
|
||||
try:
|
||||
desc = getattr(entity, key)
|
||||
return desc, desc.property
|
||||
except AttributeError:
|
||||
raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
|
||||
|
||||
elif isinstance(entity, type):
|
||||
try:
|
||||
desc = attributes.manager_of_class(entity)[key]
|
||||
return desc, desc.property
|
||||
except KeyError:
|
||||
raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
|
||||
|
||||
else:
|
||||
try:
|
||||
desc = entity.class_manager[key]
|
||||
return desc, desc.property
|
||||
except KeyError:
|
||||
raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
|
||||
|
||||
def _orm_columns(entity):
|
||||
mapper, selectable, is_aliased_class = _entity_info(entity)
|
||||
if isinstance(selectable, expression.Selectable):
|
||||
return [c for c in selectable.c]
|
||||
else:
|
||||
return [selectable]
|
||||
|
||||
def _orm_selectable(entity):
|
||||
mapper, selectable, is_aliased_class = _entity_info(entity)
|
||||
return selectable
|
||||
|
||||
def _is_aliased_class(entity):
|
||||
return isinstance(entity, AliasedClass)
|
||||
|
||||
def _state_mapper(state):
|
||||
return state.manager.mapper
|
||||
|
||||
def object_mapper(instance):
|
||||
"""Given an object, return the primary Mapper associated with the object instance.
|
||||
|
||||
Raises UnmappedInstanceError if no mapping is configured.
|
||||
|
||||
"""
|
||||
try:
|
||||
state = attributes.instance_state(instance)
|
||||
if not state.manager.mapper:
|
||||
raise exc.UnmappedInstanceError(instance)
|
||||
return state.manager.mapper
|
||||
except exc.NO_STATE:
|
||||
raise exc.UnmappedInstanceError(instance)
|
||||
|
||||
def class_mapper(class_, compile=True):
|
||||
"""Given a class, return the primary Mapper associated with the key.
|
||||
|
||||
Raises UnmappedClassError if no mapping is configured.
|
||||
|
||||
"""
|
||||
try:
|
||||
class_manager = attributes.manager_of_class(class_)
|
||||
mapper = class_manager.mapper
|
||||
|
||||
# HACK until [ticket:1142] is complete
|
||||
if mapper is None:
|
||||
raise AttributeError
|
||||
|
||||
except exc.NO_STATE:
|
||||
raise exc.UnmappedClassError(class_)
|
||||
|
||||
if compile:
|
||||
mapper = mapper.compile()
|
||||
return mapper
|
||||
|
||||
def _class_to_mapper(class_or_mapper, compile=True):
|
||||
if _is_aliased_class(class_or_mapper):
|
||||
return class_or_mapper._AliasedClass__mapper
|
||||
elif isinstance(class_or_mapper, type):
|
||||
return class_mapper(class_or_mapper, compile=compile)
|
||||
elif hasattr(class_or_mapper, 'compile'):
|
||||
if compile:
|
||||
return class_or_mapper.compile()
|
||||
else:
|
||||
return class_or_mapper
|
||||
else:
|
||||
raise exc.UnmappedClassError(class_or_mapper)
|
||||
|
||||
def has_identity(object):
|
||||
state = attributes.instance_state(object)
|
||||
return _state_has_identity(state)
|
||||
|
||||
def _state_has_identity(state):
|
||||
return bool(state.key)
|
||||
|
||||
def _is_mapped_class(cls):
|
||||
global mapperlib
|
||||
if mapperlib is None:
|
||||
from sqlalchemy.orm import mapperlib
|
||||
if isinstance(cls, (AliasedClass, mapperlib.Mapper)):
|
||||
return True
|
||||
if isinstance(cls, expression.ClauseElement):
|
||||
return False
|
||||
if isinstance(cls, type):
|
||||
manager = attributes.manager_of_class(cls)
|
||||
return manager and _INSTRUMENTOR in manager.info
|
||||
return False
|
||||
|
||||
def instance_str(instance):
|
||||
"""Return a string describing an instance."""
|
||||
|
||||
return state_str(attributes.instance_state(instance))
|
||||
|
||||
def state_str(state):
|
||||
"""Return a string describing an instance via its InstanceState."""
|
||||
|
||||
if state is None:
|
||||
return "None"
|
||||
else:
|
||||
return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj()))
|
||||
|
||||
def attribute_str(instance, attribute):
|
||||
return instance_str(instance) + "." + attribute
|
||||
|
||||
def state_attribute_str(state, attribute):
|
||||
return state_str(state) + "." + attribute
|
||||
|
||||
def identity_equal(a, b):
|
||||
if a is b:
|
||||
return True
|
||||
if a is None or b is None:
|
||||
return False
|
||||
try:
|
||||
state_a = attributes.instance_state(a)
|
||||
state_b = attributes.instance_state(b)
|
||||
except exc.NO_STATE:
|
||||
return False
|
||||
if state_a.key is None or state_b.key is None:
|
||||
return False
|
||||
return state_a.key == state_b.key
|
||||
|
||||
|
||||
# TODO: Avoid circular import.
|
||||
attributes.identity_equal = identity_equal
|
||||
attributes._is_aliased_class = _is_aliased_class
|
||||
attributes._entity_info = _entity_info
|
||||
913
sqlalchemy/pool.py
Normal file
913
sqlalchemy/pool.py
Normal file
@@ -0,0 +1,913 @@
|
||||
# pool.py - Connection pooling for SQLAlchemy
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
"""Connection pooling for DB-API connections.
|
||||
|
||||
Provides a number of connection pool implementations for a variety of
|
||||
usage scenarios and thread behavior requirements imposed by the
|
||||
application, DB-API or database itself.
|
||||
|
||||
Also provides a DB-API 2.0 connection proxying mechanism allowing
|
||||
regular DB-API connect() methods to be transparently managed by a
|
||||
SQLAlchemy connection pool.
|
||||
"""
|
||||
|
||||
import weakref, time, threading
|
||||
|
||||
from sqlalchemy import exc, log
|
||||
from sqlalchemy import queue as sqla_queue
|
||||
from sqlalchemy.util import threading, pickle, as_interface, memoized_property
|
||||
|
||||
proxies = {}
|
||||
|
||||
def manage(module, **params):
|
||||
"""Return a proxy for a DB-API module that automatically
|
||||
pools connections.
|
||||
|
||||
Given a DB-API 2.0 module and pool management parameters, returns
|
||||
a proxy for the module that will automatically pool connections,
|
||||
creating new connection pools for each distinct set of connection
|
||||
arguments sent to the decorated module's connect() function.
|
||||
|
||||
:param module: a DB-API 2.0 database module
|
||||
|
||||
:param poolclass: the class used by the pool module to provide
|
||||
pooling. Defaults to :class:`QueuePool`.
|
||||
|
||||
:param \*\*params: will be passed through to *poolclass*
|
||||
|
||||
"""
|
||||
try:
|
||||
return proxies[module]
|
||||
except KeyError:
|
||||
return proxies.setdefault(module, _DBProxy(module, **params))
|
||||
|
||||
def clear_managers():
|
||||
"""Remove all current DB-API 2.0 managers.
|
||||
|
||||
All pools and connections are disposed.
|
||||
"""
|
||||
|
||||
for manager in proxies.itervalues():
|
||||
manager.close()
|
||||
proxies.clear()
|
||||
|
||||
class Pool(log.Identified):
|
||||
"""Abstract base class for connection pools."""
|
||||
|
||||
def __init__(self,
|
||||
creator, recycle=-1, echo=None,
|
||||
use_threadlocal=False,
|
||||
logging_name=None,
|
||||
reset_on_return=True, listeners=None):
|
||||
"""
|
||||
Construct a Pool.
|
||||
|
||||
:param creator: a callable function that returns a DB-API
|
||||
connection object. The function will be called with
|
||||
parameters.
|
||||
|
||||
:param recycle: If set to non -1, number of seconds between
|
||||
connection recycling, which means upon checkout, if this
|
||||
timeout is surpassed the connection will be closed and
|
||||
replaced with a newly opened connection. Defaults to -1.
|
||||
|
||||
:param logging_name: String identifier which will be used within
|
||||
the "name" field of logging records generated within the
|
||||
"sqlalchemy.pool" logger. Defaults to a hexstring of the object's
|
||||
id.
|
||||
|
||||
:param echo: If True, connections being pulled and retrieved
|
||||
from the pool will be logged to the standard output, as well
|
||||
as pool sizing information. Echoing can also be achieved by
|
||||
enabling logging for the "sqlalchemy.pool"
|
||||
namespace. Defaults to False.
|
||||
|
||||
:param use_threadlocal: If set to True, repeated calls to
|
||||
:meth:`connect` within the same application thread will be
|
||||
guaranteed to return the same connection object, if one has
|
||||
already been retrieved from the pool and has not been
|
||||
returned yet. Offers a slight performance advantage at the
|
||||
cost of individual transactions by default. The
|
||||
:meth:`unique_connection` method is provided to bypass the
|
||||
threadlocal behavior installed into :meth:`connect`.
|
||||
|
||||
:param reset_on_return: If true, reset the database state of
|
||||
connections returned to the pool. This is typically a
|
||||
ROLLBACK to release locks and transaction resources.
|
||||
Disable at your own peril. Defaults to True.
|
||||
|
||||
:param listeners: A list of
|
||||
:class:`~sqlalchemy.interfaces.PoolListener`-like objects or
|
||||
dictionaries of callables that receive events when DB-API
|
||||
connections are created, checked out and checked in to the
|
||||
pool.
|
||||
|
||||
"""
|
||||
if logging_name:
|
||||
self.logging_name = logging_name
|
||||
self.logger = log.instance_logger(self, echoflag=echo)
|
||||
self._threadconns = threading.local()
|
||||
self._creator = creator
|
||||
self._recycle = recycle
|
||||
self._use_threadlocal = use_threadlocal
|
||||
self._reset_on_return = reset_on_return
|
||||
self.echo = echo
|
||||
self.listeners = []
|
||||
self._on_connect = []
|
||||
self._on_first_connect = []
|
||||
self._on_checkout = []
|
||||
self._on_checkin = []
|
||||
|
||||
if listeners:
|
||||
for l in listeners:
|
||||
self.add_listener(l)
|
||||
|
||||
def unique_connection(self):
|
||||
return _ConnectionFairy(self).checkout()
|
||||
|
||||
def create_connection(self):
|
||||
return _ConnectionRecord(self)
|
||||
|
||||
def recreate(self):
|
||||
"""Return a new instance with identical creation arguments."""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def dispose(self):
|
||||
"""Dispose of this pool.
|
||||
|
||||
This method leaves the possibility of checked-out connections
|
||||
remaining open, It is advised to not reuse the pool once dispose()
|
||||
is called, and to instead use a new pool constructed by the
|
||||
recreate() method.
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def connect(self):
|
||||
if not self._use_threadlocal:
|
||||
return _ConnectionFairy(self).checkout()
|
||||
|
||||
try:
|
||||
rec = self._threadconns.current()
|
||||
if rec:
|
||||
return rec.checkout()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
agent = _ConnectionFairy(self)
|
||||
self._threadconns.current = weakref.ref(agent)
|
||||
return agent.checkout()
|
||||
|
||||
def return_conn(self, record):
|
||||
if self._use_threadlocal and hasattr(self._threadconns, "current"):
|
||||
del self._threadconns.current
|
||||
self.do_return_conn(record)
|
||||
|
||||
def get(self):
|
||||
return self.do_get()
|
||||
|
||||
def do_get(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def do_return_conn(self, conn):
|
||||
raise NotImplementedError()
|
||||
|
||||
def status(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_listener(self, listener):
|
||||
"""Add a ``PoolListener``-like object to this pool.
|
||||
|
||||
``listener`` may be an object that implements some or all of
|
||||
PoolListener, or a dictionary of callables containing implementations
|
||||
of some or all of the named methods in PoolListener.
|
||||
|
||||
"""
|
||||
|
||||
listener = as_interface(listener,
|
||||
methods=('connect', 'first_connect', 'checkout', 'checkin'))
|
||||
|
||||
self.listeners.append(listener)
|
||||
if hasattr(listener, 'connect'):
|
||||
self._on_connect.append(listener)
|
||||
if hasattr(listener, 'first_connect'):
|
||||
self._on_first_connect.append(listener)
|
||||
if hasattr(listener, 'checkout'):
|
||||
self._on_checkout.append(listener)
|
||||
if hasattr(listener, 'checkin'):
|
||||
self._on_checkin.append(listener)
|
||||
|
||||
class _ConnectionRecord(object):
|
||||
def __init__(self, pool):
|
||||
self.__pool = pool
|
||||
self.connection = self.__connect()
|
||||
self.info = {}
|
||||
ls = pool.__dict__.pop('_on_first_connect', None)
|
||||
if ls is not None:
|
||||
for l in ls:
|
||||
l.first_connect(self.connection, self)
|
||||
if pool._on_connect:
|
||||
for l in pool._on_connect:
|
||||
l.connect(self.connection, self)
|
||||
|
||||
def close(self):
|
||||
if self.connection is not None:
|
||||
self.__pool.logger.debug("Closing connection %r", self.connection)
|
||||
try:
|
||||
self.connection.close()
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
raise
|
||||
except:
|
||||
self.__pool.logger.debug("Exception closing connection %r",
|
||||
self.connection)
|
||||
|
||||
def invalidate(self, e=None):
|
||||
if e is not None:
|
||||
self.__pool.logger.info("Invalidate connection %r (reason: %s:%s)",
|
||||
self.connection, e.__class__.__name__, e)
|
||||
else:
|
||||
self.__pool.logger.info("Invalidate connection %r", self.connection)
|
||||
self.__close()
|
||||
self.connection = None
|
||||
|
||||
def get_connection(self):
|
||||
if self.connection is None:
|
||||
self.connection = self.__connect()
|
||||
self.info.clear()
|
||||
if self.__pool._on_connect:
|
||||
for l in self.__pool._on_connect:
|
||||
l.connect(self.connection, self)
|
||||
elif self.__pool._recycle > -1 and \
|
||||
time.time() - self.starttime > self.__pool._recycle:
|
||||
self.__pool.logger.info("Connection %r exceeded timeout; recycling",
|
||||
self.connection)
|
||||
self.__close()
|
||||
self.connection = self.__connect()
|
||||
self.info.clear()
|
||||
if self.__pool._on_connect:
|
||||
for l in self.__pool._on_connect:
|
||||
l.connect(self.connection, self)
|
||||
return self.connection
|
||||
|
||||
def __close(self):
|
||||
try:
|
||||
self.__pool.logger.debug("Closing connection %r", self.connection)
|
||||
self.connection.close()
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
raise
|
||||
except Exception, e:
|
||||
self.__pool.logger.debug("Connection %r threw an error on close: %s",
|
||||
self.connection, e)
|
||||
|
||||
def __connect(self):
|
||||
try:
|
||||
self.starttime = time.time()
|
||||
connection = self.__pool._creator()
|
||||
self.__pool.logger.debug("Created new connection %r", connection)
|
||||
return connection
|
||||
except Exception, e:
|
||||
self.__pool.logger.debug("Error on connect(): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _finalize_fairy(connection, connection_record, pool, ref=None):
|
||||
_refs.discard(connection_record)
|
||||
|
||||
if ref is not None and (connection_record.fairy is not ref or isinstance(pool, AssertionPool)):
|
||||
return
|
||||
|
||||
if connection is not None:
|
||||
try:
|
||||
if pool._reset_on_return:
|
||||
connection.rollback()
|
||||
# Immediately close detached instances
|
||||
if connection_record is None:
|
||||
connection.close()
|
||||
except Exception, e:
|
||||
if connection_record is not None:
|
||||
connection_record.invalidate(e=e)
|
||||
if isinstance(e, (SystemExit, KeyboardInterrupt)):
|
||||
raise
|
||||
|
||||
if connection_record is not None:
|
||||
connection_record.fairy = None
|
||||
pool.logger.debug("Connection %r being returned to pool", connection)
|
||||
if pool._on_checkin:
|
||||
for l in pool._on_checkin:
|
||||
l.checkin(connection, connection_record)
|
||||
pool.return_conn(connection_record)
|
||||
|
||||
_refs = set()
|
||||
|
||||
class _ConnectionFairy(object):
|
||||
"""Proxies a DB-API connection and provides return-on-dereference support."""
|
||||
|
||||
__slots__ = '_pool', '__counter', 'connection', \
|
||||
'_connection_record', '__weakref__', '_detached_info'
|
||||
|
||||
def __init__(self, pool):
|
||||
self._pool = pool
|
||||
self.__counter = 0
|
||||
try:
|
||||
rec = self._connection_record = pool.get()
|
||||
conn = self.connection = self._connection_record.get_connection()
|
||||
rec.fairy = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref))
|
||||
_refs.add(rec)
|
||||
except:
|
||||
self.connection = None # helps with endless __getattr__ loops later on
|
||||
self._connection_record = None
|
||||
raise
|
||||
self._pool.logger.debug("Connection %r checked out from pool" %
|
||||
self.connection)
|
||||
|
||||
@property
|
||||
def _logger(self):
|
||||
return self._pool.logger
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
return self.connection is not None
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
"""An info collection unique to this DB-API connection."""
|
||||
|
||||
try:
|
||||
return self._connection_record.info
|
||||
except AttributeError:
|
||||
if self.connection is None:
|
||||
raise exc.InvalidRequestError("This connection is closed")
|
||||
try:
|
||||
return self._detached_info
|
||||
except AttributeError:
|
||||
self._detached_info = value = {}
|
||||
return value
|
||||
|
||||
def invalidate(self, e=None):
|
||||
"""Mark this connection as invalidated.
|
||||
|
||||
The connection will be immediately closed. The containing
|
||||
ConnectionRecord will create a new connection when next used.
|
||||
"""
|
||||
|
||||
if self.connection is None:
|
||||
raise exc.InvalidRequestError("This connection is closed")
|
||||
if self._connection_record is not None:
|
||||
self._connection_record.invalidate(e=e)
|
||||
self.connection = None
|
||||
self._close()
|
||||
|
||||
def cursor(self, *args, **kwargs):
|
||||
try:
|
||||
c = self.connection.cursor(*args, **kwargs)
|
||||
return _CursorFairy(self, c)
|
||||
except Exception, e:
|
||||
self.invalidate(e=e)
|
||||
raise
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.connection, key)
|
||||
|
||||
def checkout(self):
|
||||
if self.connection is None:
|
||||
raise exc.InvalidRequestError("This connection is closed")
|
||||
self.__counter += 1
|
||||
|
||||
if not self._pool._on_checkout or self.__counter != 1:
|
||||
return self
|
||||
|
||||
# Pool listeners can trigger a reconnection on checkout
|
||||
attempts = 2
|
||||
while attempts > 0:
|
||||
try:
|
||||
for l in self._pool._on_checkout:
|
||||
l.checkout(self.connection, self._connection_record, self)
|
||||
return self
|
||||
except exc.DisconnectionError, e:
|
||||
self._pool.logger.info(
|
||||
"Disconnection detected on checkout: %s", e)
|
||||
self._connection_record.invalidate(e)
|
||||
self.connection = self._connection_record.get_connection()
|
||||
attempts -= 1
|
||||
|
||||
self._pool.logger.info("Reconnection attempts exhausted on checkout")
|
||||
self.invalidate()
|
||||
raise exc.InvalidRequestError("This connection is closed")
|
||||
|
||||
def detach(self):
|
||||
"""Separate this connection from its Pool.
|
||||
|
||||
This means that the connection will no longer be returned to the
|
||||
pool when closed, and will instead be literally closed. The
|
||||
containing ConnectionRecord is separated from the DB-API connection,
|
||||
and will create a new connection when next used.
|
||||
|
||||
Note that any overall connection limiting constraints imposed by a
|
||||
Pool implementation may be violated after a detach, as the detached
|
||||
connection is removed from the pool's knowledge and control.
|
||||
"""
|
||||
|
||||
if self._connection_record is not None:
|
||||
_refs.remove(self._connection_record)
|
||||
self._connection_record.fairy = None
|
||||
self._connection_record.connection = None
|
||||
self._pool.do_return_conn(self._connection_record)
|
||||
self._detached_info = \
|
||||
self._connection_record.info.copy()
|
||||
self._connection_record = None
|
||||
|
||||
def close(self):
|
||||
self.__counter -= 1
|
||||
if self.__counter == 0:
|
||||
self._close()
|
||||
|
||||
def _close(self):
|
||||
_finalize_fairy(self.connection, self._connection_record, self._pool)
|
||||
self.connection = None
|
||||
self._connection_record = None
|
||||
|
||||
class _CursorFairy(object):
|
||||
__slots__ = '_parent', 'cursor', 'execute'
|
||||
|
||||
def __init__(self, parent, cursor):
|
||||
self._parent = parent
|
||||
self.cursor = cursor
|
||||
self.execute = cursor.execute
|
||||
|
||||
def invalidate(self, e=None):
|
||||
self._parent.invalidate(e=e)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.cursor.close()
|
||||
except Exception, e:
|
||||
try:
|
||||
ex_text = str(e)
|
||||
except TypeError:
|
||||
ex_text = repr(e)
|
||||
self.__parent._logger.warn("Error closing cursor: %s", ex_text)
|
||||
|
||||
if isinstance(e, (SystemExit, KeyboardInterrupt)):
|
||||
raise
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in self.__slots__:
|
||||
object.__setattr__(self, key, value)
|
||||
else:
|
||||
setattr(self.cursor, key, value)
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.cursor, key)
|
||||
|
||||
class SingletonThreadPool(Pool):
|
||||
"""A Pool that maintains one connection per thread.
|
||||
|
||||
Maintains one connection per each thread, never moving a connection to a
|
||||
thread other than the one which it was created in.
|
||||
|
||||
This is used for SQLite, which both does not handle multithreading by
|
||||
default, and also requires a singleton connection if a :memory: database
|
||||
is being used.
|
||||
|
||||
Options are the same as those of :class:`Pool`, as well as:
|
||||
|
||||
:param pool_size: The number of threads in which to maintain connections
|
||||
at once. Defaults to five.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, creator, pool_size=5, **kw):
|
||||
kw['use_threadlocal'] = True
|
||||
Pool.__init__(self, creator, **kw)
|
||||
self._conn = threading.local()
|
||||
self._all_conns = set()
|
||||
self.size = pool_size
|
||||
|
||||
def recreate(self):
|
||||
self.logger.info("Pool recreating")
|
||||
return SingletonThreadPool(self._creator,
|
||||
pool_size=self.size,
|
||||
recycle=self._recycle,
|
||||
echo=self.echo,
|
||||
use_threadlocal=self._use_threadlocal,
|
||||
listeners=self.listeners)
|
||||
|
||||
def dispose(self):
|
||||
"""Dispose of this pool."""
|
||||
|
||||
for conn in self._all_conns:
|
||||
try:
|
||||
conn.close()
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
raise
|
||||
except:
|
||||
# pysqlite won't even let you close a conn from a thread
|
||||
# that didn't create it
|
||||
pass
|
||||
|
||||
self._all_conns.clear()
|
||||
|
||||
def dispose_local(self):
|
||||
if hasattr(self._conn, 'current'):
|
||||
conn = self._conn.current()
|
||||
self._all_conns.discard(conn)
|
||||
del self._conn.current
|
||||
|
||||
def cleanup(self):
|
||||
while len(self._all_conns) > self.size:
|
||||
self._all_conns.pop()
|
||||
|
||||
def status(self):
|
||||
return "SingletonThreadPool id:%d size: %d" % (id(self), len(self._all_conns))
|
||||
|
||||
def do_return_conn(self, conn):
|
||||
pass
|
||||
|
||||
def do_get(self):
|
||||
try:
|
||||
c = self._conn.current()
|
||||
if c:
|
||||
return c
|
||||
except AttributeError:
|
||||
pass
|
||||
c = self.create_connection()
|
||||
self._conn.current = weakref.ref(c)
|
||||
self._all_conns.add(c)
|
||||
if len(self._all_conns) > self.size:
|
||||
self.cleanup()
|
||||
return c
|
||||
|
||||
class QueuePool(Pool):
|
||||
"""A Pool that imposes a limit on the number of open connections."""
|
||||
|
||||
def __init__(self, creator, pool_size=5, max_overflow=10, timeout=30,
|
||||
**kw):
|
||||
"""
|
||||
Construct a QueuePool.
|
||||
|
||||
:param creator: a callable function that returns a DB-API
|
||||
connection object. The function will be called with
|
||||
parameters.
|
||||
|
||||
:param pool_size: The size of the pool to be maintained. This
|
||||
is the largest number of connections that will be kept
|
||||
persistently in the pool. Note that the pool begins with no
|
||||
connections; once this number of connections is requested,
|
||||
that number of connections will remain. Defaults to 5.
|
||||
|
||||
:param max_overflow: The maximum overflow size of the
|
||||
pool. When the number of checked-out connections reaches the
|
||||
size set in pool_size, additional connections will be
|
||||
returned up to this limit. When those additional connections
|
||||
are returned to the pool, they are disconnected and
|
||||
discarded. It follows then that the total number of
|
||||
simultaneous connections the pool will allow is pool_size +
|
||||
`max_overflow`, and the total number of "sleeping"
|
||||
connections the pool will allow is pool_size. `max_overflow`
|
||||
can be set to -1 to indicate no overflow limit; no limit
|
||||
will be placed on the total number of concurrent
|
||||
connections. Defaults to 10.
|
||||
|
||||
:param timeout: The number of seconds to wait before giving up
|
||||
on returning a connection. Defaults to 30.
|
||||
|
||||
:param recycle: If set to non -1, number of seconds between
|
||||
connection recycling, which means upon checkout, if this
|
||||
timeout is surpassed the connection will be closed and
|
||||
replaced with a newly opened connection. Defaults to -1.
|
||||
|
||||
:param echo: If True, connections being pulled and retrieved
|
||||
from the pool will be logged to the standard output, as well
|
||||
as pool sizing information. Echoing can also be achieved by
|
||||
enabling logging for the "sqlalchemy.pool"
|
||||
namespace. Defaults to False.
|
||||
|
||||
:param use_threadlocal: If set to True, repeated calls to
|
||||
:meth:`connect` within the same application thread will be
|
||||
guaranteed to return the same connection object, if one has
|
||||
already been retrieved from the pool and has not been
|
||||
returned yet. Offers a slight performance advantage at the
|
||||
cost of individual transactions by default. The
|
||||
:meth:`unique_connection` method is provided to bypass the
|
||||
threadlocal behavior installed into :meth:`connect`.
|
||||
|
||||
:param reset_on_return: If true, reset the database state of
|
||||
connections returned to the pool. This is typically a
|
||||
ROLLBACK to release locks and transaction resources.
|
||||
Disable at your own peril. Defaults to True.
|
||||
|
||||
:param listeners: A list of
|
||||
:class:`~sqlalchemy.interfaces.PoolListener`-like objects or
|
||||
dictionaries of callables that receive events when DB-API
|
||||
connections are created, checked out and checked in to the
|
||||
pool.
|
||||
|
||||
"""
|
||||
Pool.__init__(self, creator, **kw)
|
||||
self._pool = sqla_queue.Queue(pool_size)
|
||||
self._overflow = 0 - pool_size
|
||||
self._max_overflow = max_overflow
|
||||
self._timeout = timeout
|
||||
self._overflow_lock = self._max_overflow > -1 and threading.Lock() or None
|
||||
|
||||
def recreate(self):
|
||||
self.logger.info("Pool recreating")
|
||||
return QueuePool(self._creator, pool_size=self._pool.maxsize,
|
||||
max_overflow=self._max_overflow, timeout=self._timeout,
|
||||
recycle=self._recycle, echo=self.echo,
|
||||
use_threadlocal=self._use_threadlocal, listeners=self.listeners)
|
||||
|
||||
def do_return_conn(self, conn):
|
||||
try:
|
||||
self._pool.put(conn, False)
|
||||
except sqla_queue.Full:
|
||||
if self._overflow_lock is None:
|
||||
self._overflow -= 1
|
||||
else:
|
||||
self._overflow_lock.acquire()
|
||||
try:
|
||||
self._overflow -= 1
|
||||
finally:
|
||||
self._overflow_lock.release()
|
||||
|
||||
def do_get(self):
|
||||
try:
|
||||
wait = self._max_overflow > -1 and self._overflow >= self._max_overflow
|
||||
return self._pool.get(wait, self._timeout)
|
||||
except sqla_queue.Empty:
|
||||
if self._max_overflow > -1 and self._overflow >= self._max_overflow:
|
||||
if not wait:
|
||||
return self.do_get()
|
||||
else:
|
||||
raise exc.TimeoutError(
|
||||
"QueuePool limit of size %d overflow %d reached, "
|
||||
"connection timed out, timeout %d" %
|
||||
(self.size(), self.overflow(), self._timeout))
|
||||
|
||||
if self._overflow_lock is not None:
|
||||
self._overflow_lock.acquire()
|
||||
|
||||
if self._max_overflow > -1 and self._overflow >= self._max_overflow:
|
||||
if self._overflow_lock is not None:
|
||||
self._overflow_lock.release()
|
||||
return self.do_get()
|
||||
|
||||
try:
|
||||
con = self.create_connection()
|
||||
self._overflow += 1
|
||||
finally:
|
||||
if self._overflow_lock is not None:
|
||||
self._overflow_lock.release()
|
||||
return con
|
||||
|
||||
def dispose(self):
|
||||
while True:
|
||||
try:
|
||||
conn = self._pool.get(False)
|
||||
conn.close()
|
||||
except sqla_queue.Empty:
|
||||
break
|
||||
|
||||
self._overflow = 0 - self.size()
|
||||
self.logger.info("Pool disposed. %s", self.status())
|
||||
|
||||
def status(self):
|
||||
return "Pool size: %d Connections in pool: %d "\
|
||||
"Current Overflow: %d Current Checked out "\
|
||||
"connections: %d" % (self.size(),
|
||||
self.checkedin(),
|
||||
self.overflow(),
|
||||
self.checkedout())
|
||||
|
||||
def size(self):
|
||||
return self._pool.maxsize
|
||||
|
||||
def checkedin(self):
|
||||
return self._pool.qsize()
|
||||
|
||||
def overflow(self):
|
||||
return self._overflow
|
||||
|
||||
def checkedout(self):
|
||||
return self._pool.maxsize - self._pool.qsize() + self._overflow
|
||||
|
||||
class NullPool(Pool):
|
||||
"""A Pool which does not pool connections.
|
||||
|
||||
Instead it literally opens and closes the underlying DB-API connection
|
||||
per each connection open/close.
|
||||
|
||||
Reconnect-related functions such as ``recycle`` and connection
|
||||
invalidation are not supported by this Pool implementation, since
|
||||
no connections are held persistently.
|
||||
|
||||
"""
|
||||
|
||||
def status(self):
|
||||
return "NullPool"
|
||||
|
||||
def do_return_conn(self, conn):
|
||||
conn.close()
|
||||
|
||||
def do_return_invalid(self, conn):
|
||||
pass
|
||||
|
||||
def do_get(self):
|
||||
return self.create_connection()
|
||||
|
||||
def recreate(self):
|
||||
self.logger.info("Pool recreating")
|
||||
|
||||
return NullPool(self._creator,
|
||||
recycle=self._recycle,
|
||||
echo=self.echo,
|
||||
use_threadlocal=self._use_threadlocal,
|
||||
listeners=self.listeners)
|
||||
|
||||
def dispose(self):
|
||||
pass
|
||||
|
||||
|
||||
class StaticPool(Pool):
|
||||
"""A Pool of exactly one connection, used for all requests.
|
||||
|
||||
Reconnect-related functions such as ``recycle`` and connection
|
||||
invalidation (which is also used to support auto-reconnect) are not
|
||||
currently supported by this Pool implementation but may be implemented
|
||||
in a future release.
|
||||
|
||||
"""
|
||||
|
||||
@memoized_property
|
||||
def _conn(self):
|
||||
return self._creator()
|
||||
|
||||
@memoized_property
|
||||
def connection(self):
|
||||
return _ConnectionRecord(self)
|
||||
|
||||
def status(self):
|
||||
return "StaticPool"
|
||||
|
||||
def dispose(self):
|
||||
if '_conn' in self.__dict__:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
def recreate(self):
|
||||
self.logger.info("Pool recreating")
|
||||
return self.__class__(creator=self._creator,
|
||||
recycle=self._recycle,
|
||||
use_threadlocal=self._use_threadlocal,
|
||||
reset_on_return=self._reset_on_return,
|
||||
echo=self.echo,
|
||||
listeners=self.listeners)
|
||||
|
||||
def create_connection(self):
|
||||
return self._conn
|
||||
|
||||
def do_return_conn(self, conn):
|
||||
pass
|
||||
|
||||
def do_return_invalid(self, conn):
|
||||
pass
|
||||
|
||||
def do_get(self):
|
||||
return self.connection
|
||||
|
||||
class AssertionPool(Pool):
|
||||
"""A Pool that allows at most one checked out connection at any given time.
|
||||
|
||||
This will raise an exception if more than one connection is checked out
|
||||
at a time. Useful for debugging code that is using more connections
|
||||
than desired.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
self._conn = None
|
||||
self._checked_out = False
|
||||
Pool.__init__(self, *args, **kw)
|
||||
|
||||
def status(self):
|
||||
return "AssertionPool"
|
||||
|
||||
def do_return_conn(self, conn):
|
||||
if not self._checked_out:
|
||||
raise AssertionError("connection is not checked out")
|
||||
self._checked_out = False
|
||||
assert conn is self._conn
|
||||
|
||||
def do_return_invalid(self, conn):
|
||||
self._conn = None
|
||||
self._checked_out = False
|
||||
|
||||
def dispose(self):
|
||||
self._checked_out = False
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
|
||||
def recreate(self):
|
||||
self.logger.info("Pool recreating")
|
||||
return AssertionPool(self._creator, echo=self.echo,
|
||||
listeners=self.listeners)
|
||||
|
||||
def do_get(self):
|
||||
if self._checked_out:
|
||||
raise AssertionError("connection is already checked out")
|
||||
|
||||
if not self._conn:
|
||||
self._conn = self.create_connection()
|
||||
|
||||
self._checked_out = True
|
||||
return self._conn
|
||||
|
||||
class _DBProxy(object):
|
||||
"""Layers connection pooling behavior on top of a standard DB-API module.
|
||||
|
||||
Proxies a DB-API 2.0 connect() call to a connection pool keyed to the
|
||||
specific connect parameters. Other functions and attributes are delegated
|
||||
to the underlying DB-API module.
|
||||
"""
|
||||
|
||||
def __init__(self, module, poolclass=QueuePool, **kw):
|
||||
"""Initializes a new proxy.
|
||||
|
||||
module
|
||||
a DB-API 2.0 module
|
||||
|
||||
poolclass
|
||||
a Pool class, defaulting to QueuePool
|
||||
|
||||
Other parameters are sent to the Pool object's constructor.
|
||||
|
||||
"""
|
||||
|
||||
self.module = module
|
||||
self.kw = kw
|
||||
self.poolclass = poolclass
|
||||
self.pools = {}
|
||||
self._create_pool_mutex = threading.Lock()
|
||||
|
||||
def close(self):
|
||||
for key in self.pools.keys():
|
||||
del self.pools[key]
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.module, key)
|
||||
|
||||
def get_pool(self, *args, **kw):
|
||||
key = self._serialize(*args, **kw)
|
||||
try:
|
||||
return self.pools[key]
|
||||
except KeyError:
|
||||
self._create_pool_mutex.acquire()
|
||||
try:
|
||||
if key not in self.pools:
|
||||
pool = self.poolclass(lambda: self.module.connect(*args, **kw), **self.kw)
|
||||
self.pools[key] = pool
|
||||
return pool
|
||||
else:
|
||||
return self.pools[key]
|
||||
finally:
|
||||
self._create_pool_mutex.release()
|
||||
|
||||
def connect(self, *args, **kw):
|
||||
"""Activate a connection to the database.
|
||||
|
||||
Connect to the database using this DBProxy's module and the given
|
||||
connect arguments. If the arguments match an existing pool, the
|
||||
connection will be returned from the pool's current thread-local
|
||||
connection instance, or if there is no thread-local connection
|
||||
instance it will be checked out from the set of pooled connections.
|
||||
|
||||
If the pool has no available connections and allows new connections
|
||||
to be created, a new database connection will be made.
|
||||
|
||||
"""
|
||||
|
||||
return self.get_pool(*args, **kw).connect()
|
||||
|
||||
def dispose(self, *args, **kw):
|
||||
"""Dispose the pool referenced by the given connect arguments."""
|
||||
|
||||
key = self._serialize(*args, **kw)
|
||||
try:
|
||||
del self.pools[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def _serialize(self, *args, **kw):
|
||||
return pickle.dumps([args, kw])
|
||||
101
sqlalchemy/processors.py
Normal file
101
sqlalchemy/processors.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# processors.py
|
||||
# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""defines generic type conversion functions, as used in bind and result
|
||||
processors.
|
||||
|
||||
They all share one common characteristic: None is passed through unchanged.
|
||||
|
||||
"""
|
||||
|
||||
import codecs
|
||||
import re
|
||||
import datetime
|
||||
|
||||
def str_to_datetime_processor_factory(regexp, type_):
|
||||
rmatch = regexp.match
|
||||
# Even on python2.6 datetime.strptime is both slower than this code
|
||||
# and it does not support microseconds.
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return type_(*map(int, rmatch(value).groups(0)))
|
||||
return process
|
||||
|
||||
try:
|
||||
from sqlalchemy.cprocessors import UnicodeResultProcessor, \
|
||||
DecimalResultProcessor, \
|
||||
to_float, to_str, int_to_boolean, \
|
||||
str_to_datetime, str_to_time, \
|
||||
str_to_date
|
||||
|
||||
def to_unicode_processor_factory(encoding, errors=None):
|
||||
# this is cumbersome but it would be even more so on the C side
|
||||
if errors is not None:
|
||||
return UnicodeResultProcessor(encoding, errors).process
|
||||
else:
|
||||
return UnicodeResultProcessor(encoding).process
|
||||
|
||||
def to_decimal_processor_factory(target_class, scale=10):
|
||||
# Note that the scale argument is not taken into account for integer
|
||||
# values in the C implementation while it is in the Python one.
|
||||
# For example, the Python implementation might return
|
||||
# Decimal('5.00000') whereas the C implementation will
|
||||
# return Decimal('5'). These are equivalent of course.
|
||||
return DecimalResultProcessor(target_class, "%%.%df" % scale).process
|
||||
|
||||
except ImportError:
|
||||
def to_unicode_processor_factory(encoding, errors=None):
|
||||
decoder = codecs.getdecoder(encoding)
|
||||
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
# decoder returns a tuple: (value, len). Simply dropping the
|
||||
# len part is safe: it is done that way in the normal
|
||||
# 'xx'.decode(encoding) code path.
|
||||
return decoder(value, errors)[0]
|
||||
return process
|
||||
|
||||
def to_decimal_processor_factory(target_class, scale=10):
|
||||
fstring = "%%.%df" % scale
|
||||
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return target_class(fstring % value)
|
||||
return process
|
||||
|
||||
def to_float(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return float(value)
|
||||
|
||||
def to_str(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
def int_to_boolean(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return value and True or False
|
||||
|
||||
DATETIME_RE = re.compile("(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?")
|
||||
TIME_RE = re.compile("(\d+):(\d+):(\d+)(?:\.(\d+))?")
|
||||
DATE_RE = re.compile("(\d+)-(\d+)-(\d+)")
|
||||
|
||||
str_to_datetime = str_to_datetime_processor_factory(DATETIME_RE,
|
||||
datetime.datetime)
|
||||
str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time)
|
||||
str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date)
|
||||
|
||||
183
sqlalchemy/queue.py
Normal file
183
sqlalchemy/queue.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""An adaptation of Py2.3/2.4's Queue module which supports reentrant
|
||||
behavior, using RLock instead of Lock for its mutex object.
|
||||
|
||||
This is to support the connection pool's usage of weakref callbacks to return
|
||||
connections to the underlying Queue, which can in extremely
|
||||
rare cases be invoked within the ``get()`` method of the Queue itself,
|
||||
producing a ``put()`` inside the ``get()`` and therefore a reentrant
|
||||
condition."""
|
||||
|
||||
from collections import deque
|
||||
from time import time as _time
|
||||
from sqlalchemy.util import threading
|
||||
|
||||
__all__ = ['Empty', 'Full', 'Queue']
|
||||
|
||||
class Empty(Exception):
|
||||
"Exception raised by Queue.get(block=0)/get_nowait()."
|
||||
|
||||
pass
|
||||
|
||||
class Full(Exception):
|
||||
"Exception raised by Queue.put(block=0)/put_nowait()."
|
||||
|
||||
pass
|
||||
|
||||
class Queue:
|
||||
def __init__(self, maxsize=0):
|
||||
"""Initialize a queue object with a given maximum size.
|
||||
|
||||
If `maxsize` is <= 0, the queue size is infinite.
|
||||
"""
|
||||
|
||||
self._init(maxsize)
|
||||
# mutex must be held whenever the queue is mutating. All methods
|
||||
# that acquire mutex must release it before returning. mutex
|
||||
# is shared between the two conditions, so acquiring and
|
||||
# releasing the conditions also acquires and releases mutex.
|
||||
self.mutex = threading.RLock()
|
||||
# Notify not_empty whenever an item is added to the queue; a
|
||||
# thread waiting to get is notified then.
|
||||
self.not_empty = threading.Condition(self.mutex)
|
||||
# Notify not_full whenever an item is removed from the queue;
|
||||
# a thread waiting to put is notified then.
|
||||
self.not_full = threading.Condition(self.mutex)
|
||||
|
||||
def qsize(self):
|
||||
"""Return the approximate size of the queue (not reliable!)."""
|
||||
|
||||
self.mutex.acquire()
|
||||
n = self._qsize()
|
||||
self.mutex.release()
|
||||
return n
|
||||
|
||||
def empty(self):
|
||||
"""Return True if the queue is empty, False otherwise (not reliable!)."""
|
||||
|
||||
self.mutex.acquire()
|
||||
n = self._empty()
|
||||
self.mutex.release()
|
||||
return n
|
||||
|
||||
def full(self):
|
||||
"""Return True if the queue is full, False otherwise (not reliable!)."""
|
||||
|
||||
self.mutex.acquire()
|
||||
n = self._full()
|
||||
self.mutex.release()
|
||||
return n
|
||||
|
||||
def put(self, item, block=True, timeout=None):
|
||||
"""Put an item into the queue.
|
||||
|
||||
If optional args `block` is True and `timeout` is None (the
|
||||
default), block if necessary until a free slot is
|
||||
available. If `timeout` is a positive number, it blocks at
|
||||
most `timeout` seconds and raises the ``Full`` exception if no
|
||||
free slot was available within that time. Otherwise (`block`
|
||||
is false), put an item on the queue if a free slot is
|
||||
immediately available, else raise the ``Full`` exception
|
||||
(`timeout` is ignored in that case).
|
||||
"""
|
||||
|
||||
self.not_full.acquire()
|
||||
try:
|
||||
if not block:
|
||||
if self._full():
|
||||
raise Full
|
||||
elif timeout is None:
|
||||
while self._full():
|
||||
self.not_full.wait()
|
||||
else:
|
||||
if timeout < 0:
|
||||
raise ValueError("'timeout' must be a positive number")
|
||||
endtime = _time() + timeout
|
||||
while self._full():
|
||||
remaining = endtime - _time()
|
||||
if remaining <= 0.0:
|
||||
raise Full
|
||||
self.not_full.wait(remaining)
|
||||
self._put(item)
|
||||
self.not_empty.notify()
|
||||
finally:
|
||||
self.not_full.release()
|
||||
|
||||
def put_nowait(self, item):
|
||||
"""Put an item into the queue without blocking.
|
||||
|
||||
Only enqueue the item if a free slot is immediately available.
|
||||
Otherwise raise the ``Full`` exception.
|
||||
"""
|
||||
return self.put(item, False)
|
||||
|
||||
def get(self, block=True, timeout=None):
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
If optional args `block` is True and `timeout` is None (the
|
||||
default), block if necessary until an item is available. If
|
||||
`timeout` is a positive number, it blocks at most `timeout`
|
||||
seconds and raises the ``Empty`` exception if no item was
|
||||
available within that time. Otherwise (`block` is false),
|
||||
return an item if one is immediately available, else raise the
|
||||
``Empty`` exception (`timeout` is ignored in that case).
|
||||
"""
|
||||
|
||||
self.not_empty.acquire()
|
||||
try:
|
||||
if not block:
|
||||
if self._empty():
|
||||
raise Empty
|
||||
elif timeout is None:
|
||||
while self._empty():
|
||||
self.not_empty.wait()
|
||||
else:
|
||||
if timeout < 0:
|
||||
raise ValueError("'timeout' must be a positive number")
|
||||
endtime = _time() + timeout
|
||||
while self._empty():
|
||||
remaining = endtime - _time()
|
||||
if remaining <= 0.0:
|
||||
raise Empty
|
||||
self.not_empty.wait(remaining)
|
||||
item = self._get()
|
||||
self.not_full.notify()
|
||||
return item
|
||||
finally:
|
||||
self.not_empty.release()
|
||||
|
||||
def get_nowait(self):
|
||||
"""Remove and return an item from the queue without blocking.
|
||||
|
||||
Only get an item if one is immediately available. Otherwise
|
||||
raise the ``Empty`` exception.
|
||||
"""
|
||||
|
||||
return self.get(False)
|
||||
|
||||
# Override these methods to implement other queue organizations
|
||||
# (e.g. stack or priority queue).
|
||||
# These will only be called with appropriate locks held
|
||||
|
||||
# Initialize the queue representation
|
||||
def _init(self, maxsize):
|
||||
self.maxsize = maxsize
|
||||
self.queue = deque()
|
||||
|
||||
def _qsize(self):
|
||||
return len(self.queue)
|
||||
|
||||
# Check whether the queue is empty
|
||||
def _empty(self):
|
||||
return not self.queue
|
||||
|
||||
# Check whether the queue is full
|
||||
def _full(self):
|
||||
return self.maxsize > 0 and len(self.queue) == self.maxsize
|
||||
|
||||
# Put a new item in the queue
|
||||
def _put(self, item):
|
||||
self.queue.append(item)
|
||||
|
||||
# Get an item from the queue
|
||||
def _get(self):
|
||||
return self.queue.popleft()
|
||||
2386
sqlalchemy/schema.py
Normal file
2386
sqlalchemy/schema.py
Normal file
File diff suppressed because it is too large
Load Diff
58
sqlalchemy/sql/__init__.py
Normal file
58
sqlalchemy/sql/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from sqlalchemy.sql.expression import (
|
||||
Alias,
|
||||
ClauseElement,
|
||||
ColumnCollection,
|
||||
ColumnElement,
|
||||
CompoundSelect,
|
||||
Delete,
|
||||
FromClause,
|
||||
Insert,
|
||||
Join,
|
||||
Select,
|
||||
Selectable,
|
||||
TableClause,
|
||||
Update,
|
||||
alias,
|
||||
and_,
|
||||
asc,
|
||||
between,
|
||||
bindparam,
|
||||
case,
|
||||
cast,
|
||||
collate,
|
||||
column,
|
||||
delete,
|
||||
desc,
|
||||
distinct,
|
||||
except_,
|
||||
except_all,
|
||||
exists,
|
||||
extract,
|
||||
func,
|
||||
insert,
|
||||
intersect,
|
||||
intersect_all,
|
||||
join,
|
||||
label,
|
||||
literal,
|
||||
literal_column,
|
||||
modifier,
|
||||
not_,
|
||||
null,
|
||||
or_,
|
||||
outerjoin,
|
||||
outparam,
|
||||
select,
|
||||
subquery,
|
||||
table,
|
||||
text,
|
||||
tuple_,
|
||||
union,
|
||||
union_all,
|
||||
update,
|
||||
)
|
||||
|
||||
from sqlalchemy.sql.visitors import ClauseVisitor
|
||||
|
||||
__tmp = locals().keys()
|
||||
__all__ = sorted([i for i in __tmp if not i.startswith('__')])
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user