Updated SqlAlchemy + the new files

This commit is contained in:
Christoffer Viken 2017-04-15 16:33:29 +00:00
parent e3267d4bda
commit 4669737fe3
134 changed files with 66374 additions and 4528 deletions

View File

@ -1,6 +1,7 @@
/* /*
processors.c processors.c
Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com Copyright (C) 2010-2017 the SQLAlchemy authors and contributors <see AUTHORS file>
Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com
This module is part of SQLAlchemy and is released under This module is part of SQLAlchemy and is released under
the MIT License: http://www.opensource.org/licenses/mit-license.php the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -9,26 +10,30 @@ the MIT License: http://www.opensource.org/licenses/mit-license.php
#include <Python.h> #include <Python.h>
#include <datetime.h> #include <datetime.h>
#define MODULE_NAME "cprocessors"
#define MODULE_DOC "Module containing C versions of data processing functions."
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
typedef int Py_ssize_t;
#define PY_SSIZE_T_MAX INT_MAX
#define PY_SSIZE_T_MIN INT_MIN
#endif
static PyObject * static PyObject *
int_to_boolean(PyObject *self, PyObject *arg) int_to_boolean(PyObject *self, PyObject *arg)
{ {
long l = 0; int l = 0;
PyObject *res; PyObject *res;
if (arg == Py_None) if (arg == Py_None)
Py_RETURN_NONE; Py_RETURN_NONE;
l = PyInt_AsLong(arg); l = PyObject_IsTrue(arg);
if (l == 0) { if (l == 0) {
res = Py_False; res = Py_False;
} else if (l == 1) { } else if (l == 1) {
res = Py_True; res = Py_True;
} else if ((l == -1) && PyErr_Occurred()) {
/* -1 can be either the actual value, or an error flag. */
return NULL;
} else { } else {
PyErr_SetString(PyExc_ValueError,
"int_to_boolean only accepts None, 0 or 1");
return NULL; return NULL;
} }
@ -57,15 +62,51 @@ to_float(PyObject *self, PyObject *arg)
static PyObject * static PyObject *
str_to_datetime(PyObject *self, PyObject *arg) str_to_datetime(PyObject *self, PyObject *arg)
{ {
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
PyObject *err_bytes;
#endif
const char *str; const char *str;
int numparsed;
unsigned int year, month, day, hour, minute, second, microsecond = 0; unsigned int year, month, day, hour, minute, second, microsecond = 0;
PyObject *err_repr;
if (arg == Py_None) if (arg == Py_None)
Py_RETURN_NONE; Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(arg);
if (bytes == NULL)
str = NULL;
else
str = PyBytes_AS_STRING(bytes);
#else
str = PyString_AsString(arg); str = PyString_AsString(arg);
if (str == NULL) #endif
if (str == NULL) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL; return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string '%.200s' "
"- value is not a string.",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string '%.200s' "
"- value is not a string.",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
/* microseconds are optional */ /* microseconds are optional */
/* /*
@ -73,9 +114,31 @@ str_to_datetime(PyObject *self, PyObject *arg)
not accept "2000-01-01 00:00:00.". I don't know which is better, but they not accept "2000-01-01 00:00:00.". I don't know which is better, but they
should be coherent. should be coherent.
*/ */
if (sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day, numparsed = sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day,
&hour, &minute, &second, &microsecond) < 6) { &hour, &minute, &second, &microsecond);
PyErr_SetString(PyExc_ValueError, "Couldn't parse datetime string."); #if PY_MAJOR_VERSION >= 3
Py_DECREF(bytes);
#endif
if (numparsed < 6) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string: %.200s",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string: %.200s",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL; return NULL;
} }
return PyDateTime_FromDateAndTime(year, month, day, return PyDateTime_FromDateAndTime(year, month, day,
@ -85,25 +148,82 @@ str_to_datetime(PyObject *self, PyObject *arg)
static PyObject * static PyObject *
str_to_time(PyObject *self, PyObject *arg) str_to_time(PyObject *self, PyObject *arg)
{ {
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
PyObject *err_bytes;
#endif
const char *str; const char *str;
int numparsed;
unsigned int hour, minute, second, microsecond = 0; unsigned int hour, minute, second, microsecond = 0;
PyObject *err_repr;
if (arg == Py_None) if (arg == Py_None)
Py_RETURN_NONE; Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(arg);
if (bytes == NULL)
str = NULL;
else
str = PyBytes_AS_STRING(bytes);
#else
str = PyString_AsString(arg); str = PyString_AsString(arg);
if (str == NULL) #endif
if (str == NULL) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL; return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string '%.200s' - value is not a string.",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string '%.200s' - value is not a string.",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
/* microseconds are optional */ /* microseconds are optional */
/* /*
TODO: this is slightly less picky than the Python version which would TODO: this is slightly less picky than the Python version which would
not accept "00:00:00.". I don't know which is better, but they should be not accept "00:00:00.". I don't know which is better, but they should be
coherent. coherent.
*/ */
if (sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second, numparsed = sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second,
&microsecond) < 3) { &microsecond);
PyErr_SetString(PyExc_ValueError, "Couldn't parse time string."); #if PY_MAJOR_VERSION >= 3
Py_DECREF(bytes);
#endif
if (numparsed < 3) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string: %.200s",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string: %.200s",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL; return NULL;
} }
return PyTime_FromTime(hour, minute, second, microsecond); return PyTime_FromTime(hour, minute, second, microsecond);
@ -112,18 +232,74 @@ str_to_time(PyObject *self, PyObject *arg)
static PyObject * static PyObject *
str_to_date(PyObject *self, PyObject *arg) str_to_date(PyObject *self, PyObject *arg)
{ {
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
PyObject *err_bytes;
#endif
const char *str; const char *str;
int numparsed;
unsigned int year, month, day; unsigned int year, month, day;
PyObject *err_repr;
if (arg == Py_None) if (arg == Py_None)
Py_RETURN_NONE; Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(arg);
if (bytes == NULL)
str = NULL;
else
str = PyBytes_AS_STRING(bytes);
#else
str = PyString_AsString(arg); str = PyString_AsString(arg);
if (str == NULL) #endif
if (str == NULL) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL; return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string '%.200s' - value is not a string.",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string '%.200s' - value is not a string.",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
if (sscanf(str, "%4u-%2u-%2u", &year, &month, &day) != 3) { numparsed = sscanf(str, "%4u-%2u-%2u", &year, &month, &day);
PyErr_SetString(PyExc_ValueError, "Couldn't parse date string."); #if PY_MAJOR_VERSION >= 3
Py_DECREF(bytes);
#endif
if (numparsed != 3) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string: %.200s",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string: %.200s",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL; return NULL;
} }
return PyDate_FromDate(year, month, day); return PyDate_FromDate(year, month, day);
@ -159,17 +335,35 @@ UnicodeResultProcessor_init(UnicodeResultProcessor *self, PyObject *args,
PyObject *encoding, *errors = NULL; PyObject *encoding, *errors = NULL;
static char *kwlist[] = {"encoding", "errors", NULL}; static char *kwlist[] = {"encoding", "errors", NULL};
#if PY_MAJOR_VERSION >= 3
if (!PyArg_ParseTupleAndKeywords(args, kwds, "U|U:__init__", kwlist,
&encoding, &errors))
return -1;
#else
if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:__init__", kwlist, if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:__init__", kwlist,
&encoding, &errors)) &encoding, &errors))
return -1; return -1;
#endif
#if PY_MAJOR_VERSION >= 3
encoding = PyUnicode_AsASCIIString(encoding);
#else
Py_INCREF(encoding); Py_INCREF(encoding);
#endif
self->encoding = encoding; self->encoding = encoding;
if (errors) { if (errors) {
#if PY_MAJOR_VERSION >= 3
errors = PyUnicode_AsASCIIString(errors);
#else
Py_INCREF(errors); Py_INCREF(errors);
#endif
} else { } else {
#if PY_MAJOR_VERSION >= 3
errors = PyBytes_FromString("strict");
#else
errors = PyString_FromString("strict"); errors = PyString_FromString("strict");
#endif
if (errors == NULL) if (errors == NULL)
return -1; return -1;
} }
@ -188,28 +382,88 @@ UnicodeResultProcessor_process(UnicodeResultProcessor *self, PyObject *value)
if (value == Py_None) if (value == Py_None)
Py_RETURN_NONE; Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
if (PyBytes_AsStringAndSize(value, &str, &len))
return NULL;
encoding = PyBytes_AS_STRING(self->encoding);
errors = PyBytes_AS_STRING(self->errors);
#else
if (PyString_AsStringAndSize(value, &str, &len)) if (PyString_AsStringAndSize(value, &str, &len))
return NULL; return NULL;
encoding = PyString_AS_STRING(self->encoding); encoding = PyString_AS_STRING(self->encoding);
errors = PyString_AS_STRING(self->errors); errors = PyString_AS_STRING(self->errors);
#endif
return PyUnicode_Decode(str, len, encoding, errors); return PyUnicode_Decode(str, len, encoding, errors);
} }
static PyObject *
UnicodeResultProcessor_conditional_process(UnicodeResultProcessor *self, PyObject *value)
{
const char *encoding, *errors;
char *str;
Py_ssize_t len;
if (value == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(value) == 1) {
Py_INCREF(value);
return value;
}
if (PyBytes_AsStringAndSize(value, &str, &len))
return NULL;
encoding = PyBytes_AS_STRING(self->encoding);
errors = PyBytes_AS_STRING(self->errors);
#else
if (PyUnicode_Check(value) == 1) {
Py_INCREF(value);
return value;
}
if (PyString_AsStringAndSize(value, &str, &len))
return NULL;
encoding = PyString_AS_STRING(self->encoding);
errors = PyString_AS_STRING(self->errors);
#endif
return PyUnicode_Decode(str, len, encoding, errors);
}
static void
UnicodeResultProcessor_dealloc(UnicodeResultProcessor *self)
{
Py_XDECREF(self->encoding);
Py_XDECREF(self->errors);
#if PY_MAJOR_VERSION >= 3
Py_TYPE(self)->tp_free((PyObject*)self);
#else
self->ob_type->tp_free((PyObject*)self);
#endif
}
static PyMethodDef UnicodeResultProcessor_methods[] = { static PyMethodDef UnicodeResultProcessor_methods[] = {
{"process", (PyCFunction)UnicodeResultProcessor_process, METH_O, {"process", (PyCFunction)UnicodeResultProcessor_process, METH_O,
"The value processor itself."}, "The value processor itself."},
{"conditional_process", (PyCFunction)UnicodeResultProcessor_conditional_process, METH_O,
"Conditional version of the value processor."},
{NULL} /* Sentinel */ {NULL} /* Sentinel */
}; };
static PyTypeObject UnicodeResultProcessorType = { static PyTypeObject UnicodeResultProcessorType = {
PyObject_HEAD_INIT(NULL) PyVarObject_HEAD_INIT(NULL, 0)
0, /* ob_size */
"sqlalchemy.cprocessors.UnicodeResultProcessor", /* tp_name */ "sqlalchemy.cprocessors.UnicodeResultProcessor", /* tp_name */
sizeof(UnicodeResultProcessor), /* tp_basicsize */ sizeof(UnicodeResultProcessor), /* tp_basicsize */
0, /* tp_itemsize */ 0, /* tp_itemsize */
0, /* tp_dealloc */ (destructor)UnicodeResultProcessor_dealloc, /* tp_dealloc */
0, /* tp_print */ 0, /* tp_print */
0, /* tp_getattr */ 0, /* tp_getattr */
0, /* tp_setattr */ 0, /* tp_setattr */
@ -255,7 +509,11 @@ DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args,
{ {
PyObject *type, *format; PyObject *type, *format;
#if PY_MAJOR_VERSION >= 3
if (!PyArg_ParseTuple(args, "OU", &type, &format))
#else
if (!PyArg_ParseTuple(args, "OS", &type, &format)) if (!PyArg_ParseTuple(args, "OS", &type, &format))
#endif
return -1; return -1;
Py_INCREF(type); Py_INCREF(type);
@ -275,22 +533,40 @@ DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value)
if (value == Py_None) if (value == Py_None)
Py_RETURN_NONE; Py_RETURN_NONE;
if (PyFloat_CheckExact(value)) {
/* Decimal does not accept float values directly */ /* Decimal does not accept float values directly */
/* SQLite can also give us an integer here (see [ticket:2432]) */
/* XXX: starting with Python 3.1, we could use Decimal.from_float(f),
but the result wouldn't be the same */
args = PyTuple_Pack(1, value); args = PyTuple_Pack(1, value);
if (args == NULL) if (args == NULL)
return NULL; return NULL;
#if PY_MAJOR_VERSION >= 3
str = PyUnicode_Format(self->format, args);
#else
str = PyString_Format(self->format, args); str = PyString_Format(self->format, args);
#endif
Py_DECREF(args);
if (str == NULL) if (str == NULL)
return NULL; return NULL;
result = PyObject_CallFunctionObjArgs(self->type, str, NULL); result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
Py_DECREF(str); Py_DECREF(str);
return result; return result;
} else {
return PyObject_CallFunctionObjArgs(self->type, value, NULL);
} }
static void
DecimalResultProcessor_dealloc(DecimalResultProcessor *self)
{
Py_XDECREF(self->type);
Py_XDECREF(self->format);
#if PY_MAJOR_VERSION >= 3
Py_TYPE(self)->tp_free((PyObject*)self);
#else
self->ob_type->tp_free((PyObject*)self);
#endif
} }
static PyMethodDef DecimalResultProcessor_methods[] = { static PyMethodDef DecimalResultProcessor_methods[] = {
@ -300,12 +576,11 @@ static PyMethodDef DecimalResultProcessor_methods[] = {
}; };
static PyTypeObject DecimalResultProcessorType = { static PyTypeObject DecimalResultProcessorType = {
PyObject_HEAD_INIT(NULL) PyVarObject_HEAD_INIT(NULL, 0)
0, /* ob_size */
"sqlalchemy.DecimalResultProcessor", /* tp_name */ "sqlalchemy.DecimalResultProcessor", /* tp_name */
sizeof(DecimalResultProcessor), /* tp_basicsize */ sizeof(DecimalResultProcessor), /* tp_basicsize */
0, /* tp_itemsize */ 0, /* tp_itemsize */
0, /* tp_dealloc */ (destructor)DecimalResultProcessor_dealloc, /* tp_dealloc */
0, /* tp_print */ 0, /* tp_print */
0, /* tp_getattr */ 0, /* tp_getattr */
0, /* tp_setattr */ 0, /* tp_setattr */
@ -341,11 +616,6 @@ static PyTypeObject DecimalResultProcessorType = {
0, /* tp_new */ 0, /* tp_new */
}; };
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
static PyMethodDef module_methods[] = { static PyMethodDef module_methods[] = {
{"int_to_boolean", int_to_boolean, METH_O, {"int_to_boolean", int_to_boolean, METH_O,
"Convert an integer to a boolean."}, "Convert an integer to a boolean."},
@ -362,23 +632,53 @@ static PyMethodDef module_methods[] = {
{NULL, NULL, 0, NULL} /* Sentinel */ {NULL, NULL, 0, NULL} /* Sentinel */
}; };
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
MODULE_NAME,
MODULE_DOC,
-1,
module_methods
};
#define INITERROR return NULL
PyMODINIT_FUNC
PyInit_cprocessors(void)
#else
#define INITERROR return
PyMODINIT_FUNC PyMODINIT_FUNC
initcprocessors(void) initcprocessors(void)
#endif
{ {
PyObject *m; PyObject *m;
UnicodeResultProcessorType.tp_new = PyType_GenericNew; UnicodeResultProcessorType.tp_new = PyType_GenericNew;
if (PyType_Ready(&UnicodeResultProcessorType) < 0) if (PyType_Ready(&UnicodeResultProcessorType) < 0)
return; INITERROR;
DecimalResultProcessorType.tp_new = PyType_GenericNew; DecimalResultProcessorType.tp_new = PyType_GenericNew;
if (PyType_Ready(&DecimalResultProcessorType) < 0) if (PyType_Ready(&DecimalResultProcessorType) < 0)
return; INITERROR;
m = Py_InitModule3("cprocessors", module_methods, #if PY_MAJOR_VERSION >= 3
"Module containing C versions of data processing functions."); m = PyModule_Create(&module_def);
#else
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
#endif
if (m == NULL) if (m == NULL)
return; INITERROR;
PyDateTime_IMPORT; PyDateTime_IMPORT;
@ -389,5 +689,8 @@ initcprocessors(void)
Py_INCREF(&DecimalResultProcessorType); Py_INCREF(&DecimalResultProcessorType);
PyModule_AddObject(m, "DecimalResultProcessor", PyModule_AddObject(m, "DecimalResultProcessor",
(PyObject *)&DecimalResultProcessorType); (PyObject *)&DecimalResultProcessorType);
}
#if PY_MAJOR_VERSION >= 3
return m;
#endif
}

View File

@ -1,6 +1,7 @@
/* /*
resultproxy.c resultproxy.c
Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com Copyright (C) 2010-2017 the SQLAlchemy authors and contributors <see AUTHORS file>
Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com
This module is part of SQLAlchemy and is released under This module is part of SQLAlchemy and is released under
the MIT License: http://www.opensource.org/licenses/mit-license.php the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -8,6 +9,18 @@ the MIT License: http://www.opensource.org/licenses/mit-license.php
#include <Python.h> #include <Python.h>
#define MODULE_NAME "cresultproxy"
#define MODULE_DOC "Module containing C versions of core ResultProxy classes."
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
typedef int Py_ssize_t;
#define PY_SSIZE_T_MAX INT_MAX
#define PY_SSIZE_T_MIN INT_MIN
typedef Py_ssize_t (*lenfunc)(PyObject *);
#define PyInt_FromSsize_t(x) PyInt_FromLong(x)
typedef intargfunc ssizeargfunc;
#endif
/*********** /***********
* Structs * * Structs *
@ -69,8 +82,8 @@ BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds)
Py_INCREF(parent); Py_INCREF(parent);
self->parent = parent; self->parent = parent;
if (!PyTuple_CheckExact(row)) { if (!PySequence_Check(row)) {
PyErr_SetString(PyExc_TypeError, "row must be a tuple"); PyErr_SetString(PyExc_TypeError, "row must be a sequence");
return -1; return -1;
} }
Py_INCREF(row); Py_INCREF(row);
@ -112,7 +125,7 @@ BaseRowProxy_reduce(PyObject *self)
if (state == NULL) if (state == NULL)
return NULL; return NULL;
module = PyImport_ImportModule("sqlalchemy.engine.base"); module = PyImport_ImportModule("sqlalchemy.engine.result");
if (module == NULL) if (module == NULL)
return NULL; return NULL;
@ -140,7 +153,11 @@ BaseRowProxy_dealloc(BaseRowProxy *self)
Py_XDECREF(self->row); Py_XDECREF(self->row);
Py_XDECREF(self->processors); Py_XDECREF(self->processors);
Py_XDECREF(self->keymap); Py_XDECREF(self->keymap);
#if PY_MAJOR_VERSION >= 3
Py_TYPE(self)->tp_free((PyObject *)self);
#else
self->ob_type->tp_free((PyObject *)self); self->ob_type->tp_free((PyObject *)self);
#endif
} }
static PyObject * static PyObject *
@ -148,13 +165,15 @@ BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
{ {
Py_ssize_t num_values, num_processors; Py_ssize_t num_values, num_processors;
PyObject **valueptr, **funcptr, **resultptr; PyObject **valueptr, **funcptr, **resultptr;
PyObject *func, *result, *processed_value; PyObject *func, *result, *processed_value, *values_fastseq;
num_values = Py_SIZE(values); num_values = PySequence_Length(values);
num_processors = Py_SIZE(processors); num_processors = PyList_Size(processors);
if (num_values != num_processors) { if (num_values != num_processors) {
PyErr_SetString(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"number of values in row differ from number of column processors"); "number of values in row (%d) differ from number of column "
"processors (%d)",
(int)num_values, (int)num_processors);
return NULL; return NULL;
} }
@ -166,9 +185,11 @@ BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
if (result == NULL) if (result == NULL)
return NULL; return NULL;
/* we don't need to use PySequence_Fast as long as values, processors and values_fastseq = PySequence_Fast(values, "row must be a sequence");
* result are simple tuple or lists. */ if (values_fastseq == NULL)
valueptr = PySequence_Fast_ITEMS(values); return NULL;
valueptr = PySequence_Fast_ITEMS(values_fastseq);
funcptr = PySequence_Fast_ITEMS(processors); funcptr = PySequence_Fast_ITEMS(processors);
resultptr = PySequence_Fast_ITEMS(result); resultptr = PySequence_Fast_ITEMS(result);
while (--num_values >= 0) { while (--num_values >= 0) {
@ -177,6 +198,7 @@ BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
processed_value = PyObject_CallFunctionObjArgs(func, *valueptr, processed_value = PyObject_CallFunctionObjArgs(func, *valueptr,
NULL); NULL);
if (processed_value == NULL) { if (processed_value == NULL) {
Py_DECREF(values_fastseq);
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;
} }
@ -189,6 +211,7 @@ BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
funcptr++; funcptr++;
resultptr++; resultptr++;
} }
Py_DECREF(values_fastseq);
return result; return result;
} }
@ -199,19 +222,12 @@ BaseRowProxy_values(BaseRowProxy *self)
self->processors, 0); self->processors, 0);
} }
static PyTupleObject *
BaseRowProxy_tuplevalues(BaseRowProxy *self)
{
return (PyTupleObject *)BaseRowProxy_processvalues(self->row,
self->processors, 1);
}
static PyObject * static PyObject *
BaseRowProxy_iter(BaseRowProxy *self) BaseRowProxy_iter(BaseRowProxy *self)
{ {
PyObject *values, *result; PyObject *values, *result;
values = (PyObject *)BaseRowProxy_tuplevalues(self); values = BaseRowProxy_processvalues(self->row, self->processors, 1);
if (values == NULL) if (values == NULL)
return NULL; return NULL;
@ -226,26 +242,39 @@ BaseRowProxy_iter(BaseRowProxy *self)
static Py_ssize_t static Py_ssize_t
BaseRowProxy_length(BaseRowProxy *self) BaseRowProxy_length(BaseRowProxy *self)
{ {
return Py_SIZE(self->row); return PySequence_Length(self->row);
} }
static PyObject * static PyObject *
BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key) BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
{ {
PyObject *processors, *values; PyObject *processors, *values;
PyObject *processor, *value; PyObject *processor, *value, *processed_value;
PyObject *record, *result, *indexobject; PyObject *row, *record, *result, *indexobject;
PyObject *exc_module, *exception; PyObject *exc_module, *exception, *cstr_obj;
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
#endif
char *cstr_key; char *cstr_key;
long index; long index;
int key_fallback = 0;
int tuple_check = 0;
#if PY_MAJOR_VERSION < 3
if (PyInt_CheckExact(key)) { if (PyInt_CheckExact(key)) {
index = PyInt_AS_LONG(key); index = PyInt_AS_LONG(key);
} else if (PyLong_CheckExact(key)) { if (index < 0)
index += BaseRowProxy_length(self);
} else
#endif
if (PyLong_CheckExact(key)) {
index = PyLong_AsLong(key); index = PyLong_AsLong(key);
if ((index == -1) && PyErr_Occurred()) if ((index == -1) && PyErr_Occurred())
/* -1 can be either the actual value, or an error flag. */ /* -1 can be either the actual value, or an error flag. */
return NULL; return NULL;
if (index < 0)
index += BaseRowProxy_length(self);
} else if (PySlice_Check(key)) { } else if (PySlice_Check(key)) {
values = PyObject_GetItem(self->row, key); values = PyObject_GetItem(self->row, key);
if (values == NULL) if (values == NULL)
@ -268,12 +297,17 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
"O", key); "O", key);
if (record == NULL) if (record == NULL)
return NULL; return NULL;
key_fallback = 1;
} }
indexobject = PyTuple_GetItem(record, 1); indexobject = PyTuple_GetItem(record, 2);
if (indexobject == NULL) if (indexobject == NULL)
return NULL; return NULL;
if (key_fallback) {
Py_DECREF(record);
}
if (indexobject == Py_None) { if (indexobject == Py_None) {
exc_module = PyImport_ImportModule("sqlalchemy.exc"); exc_module = PyImport_ImportModule("sqlalchemy.exc");
if (exc_module == NULL) if (exc_module == NULL)
@ -285,17 +319,47 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
if (exception == NULL) if (exception == NULL)
return NULL; return NULL;
cstr_key = PyString_AsString(key); cstr_obj = PyTuple_GetItem(record, 1);
if (cstr_key == NULL) if (cstr_obj == NULL)
return NULL; return NULL;
cstr_obj = PyObject_Str(cstr_obj);
if (cstr_obj == NULL)
return NULL;
/*
FIXME: raise encoding error exception (in both versions below)
if the key contains non-ascii chars, instead of an
InvalidRequestError without any message like in the
python version.
*/
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(cstr_obj);
if (bytes == NULL)
return NULL;
cstr_key = PyBytes_AS_STRING(bytes);
#else
cstr_key = PyString_AsString(cstr_obj);
#endif
if (cstr_key == NULL) {
Py_DECREF(cstr_obj);
return NULL;
}
Py_DECREF(cstr_obj);
PyErr_Format(exception, PyErr_Format(exception,
"Ambiguous column name '%s' in result set! " "Ambiguous column name '%.200s' in "
"try 'use_labels' option on select statement.", cstr_key); "result set column descriptions", cstr_key);
return NULL; return NULL;
} }
#if PY_MAJOR_VERSION >= 3
index = PyLong_AsLong(indexobject);
#else
index = PyInt_AsLong(indexobject); index = PyInt_AsLong(indexobject);
#endif
if ((index == -1) && PyErr_Occurred()) if ((index == -1) && PyErr_Occurred())
/* -1 can be either the actual value, or an error flag. */ /* -1 can be either the actual value, or an error flag. */
return NULL; return NULL;
@ -304,22 +368,53 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
if (processor == NULL) if (processor == NULL)
return NULL; return NULL;
value = PyTuple_GetItem(self->row, index); row = self->row;
if (PyTuple_CheckExact(row)) {
value = PyTuple_GetItem(row, index);
tuple_check = 1;
}
else {
value = PySequence_GetItem(row, index);
tuple_check = 0;
}
if (value == NULL) if (value == NULL)
return NULL; return NULL;
if (processor != Py_None) { if (processor != Py_None) {
return PyObject_CallFunctionObjArgs(processor, value, NULL); processed_value = PyObject_CallFunctionObjArgs(processor, value, NULL);
if (!tuple_check) {
Py_DECREF(value);
}
return processed_value;
} else { } else {
if (tuple_check) {
Py_INCREF(value); Py_INCREF(value);
}
return value; return value;
} }
} }
static PyObject *
BaseRowProxy_getitem(PyObject *self, Py_ssize_t i)
{
PyObject *index;
#if PY_MAJOR_VERSION >= 3
index = PyLong_FromSsize_t(i);
#else
index = PyInt_FromSsize_t(i);
#endif
return BaseRowProxy_subscript((BaseRowProxy*)self, index);
}
static PyObject * static PyObject *
BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name) BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
{ {
PyObject *tmp; PyObject *tmp;
#if PY_MAJOR_VERSION >= 3
PyObject *err_bytes;
#endif
if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) { if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) {
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) if (!PyErr_ExceptionMatches(PyExc_AttributeError))
@ -329,7 +424,28 @@ BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
else else
return tmp; return tmp;
return BaseRowProxy_subscript(self, name); tmp = BaseRowProxy_subscript(self, name);
if (tmp == NULL && PyErr_ExceptionMatches(PyExc_KeyError)) {
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(name);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_AttributeError,
"Could not locate column in row for column '%.200s'",
PyBytes_AS_STRING(err_bytes)
);
#else
PyErr_Format(
PyExc_AttributeError,
"Could not locate column in row for column '%.200s'",
PyString_AsString(name)
);
#endif
return NULL;
}
return tmp;
} }
/*********************** /***********************
@ -354,7 +470,7 @@ BaseRowProxy_setparent(BaseRowProxy *self, PyObject *value, void *closure)
return -1; return -1;
} }
module = PyImport_ImportModule("sqlalchemy.engine.base"); module = PyImport_ImportModule("sqlalchemy.engine.result");
if (module == NULL) if (module == NULL)
return -1; return -1;
@ -393,9 +509,9 @@ BaseRowProxy_setrow(BaseRowProxy *self, PyObject *value, void *closure)
return -1; return -1;
} }
if (!PyTuple_CheckExact(value)) { if (!PySequence_Check(value)) {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"The 'row' attribute value must be a tuple"); "The 'row' attribute value must be a sequence");
return -1; return -1;
} }
@ -496,7 +612,7 @@ static PySequenceMethods BaseRowProxy_as_sequence = {
(lenfunc)BaseRowProxy_length, /* sq_length */ (lenfunc)BaseRowProxy_length, /* sq_length */
0, /* sq_concat */ 0, /* sq_concat */
0, /* sq_repeat */ 0, /* sq_repeat */
0, /* sq_item */ (ssizeargfunc)BaseRowProxy_getitem, /* sq_item */
0, /* sq_slice */ 0, /* sq_slice */
0, /* sq_ass_item */ 0, /* sq_ass_item */
0, /* sq_ass_slice */ 0, /* sq_ass_slice */
@ -512,8 +628,7 @@ static PyMappingMethods BaseRowProxy_as_mapping = {
}; };
static PyTypeObject BaseRowProxyType = { static PyTypeObject BaseRowProxyType = {
PyObject_HEAD_INIT(NULL) PyVarObject_HEAD_INIT(NULL, 0)
0, /* ob_size */
"sqlalchemy.cresultproxy.BaseRowProxy", /* tp_name */ "sqlalchemy.cresultproxy.BaseRowProxy", /* tp_name */
sizeof(BaseRowProxy), /* tp_basicsize */ sizeof(BaseRowProxy), /* tp_basicsize */
0, /* tp_itemsize */ 0, /* tp_itemsize */
@ -553,34 +668,60 @@ static PyTypeObject BaseRowProxyType = {
0 /* tp_new */ 0 /* tp_new */
}; };
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
static PyMethodDef module_methods[] = { static PyMethodDef module_methods[] = {
{"safe_rowproxy_reconstructor", safe_rowproxy_reconstructor, METH_VARARGS, {"safe_rowproxy_reconstructor", safe_rowproxy_reconstructor, METH_VARARGS,
"reconstruct a RowProxy instance from its pickled form."}, "reconstruct a RowProxy instance from its pickled form."},
{NULL, NULL, 0, NULL} /* Sentinel */ {NULL, NULL, 0, NULL} /* Sentinel */
}; };
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
MODULE_NAME,
MODULE_DOC,
-1,
module_methods
};
#define INITERROR return NULL
PyMODINIT_FUNC
PyInit_cresultproxy(void)
#else
#define INITERROR return
PyMODINIT_FUNC PyMODINIT_FUNC
initcresultproxy(void) initcresultproxy(void)
#endif
{ {
PyObject *m; PyObject *m;
BaseRowProxyType.tp_new = PyType_GenericNew; BaseRowProxyType.tp_new = PyType_GenericNew;
if (PyType_Ready(&BaseRowProxyType) < 0) if (PyType_Ready(&BaseRowProxyType) < 0)
return; INITERROR;
m = Py_InitModule3("cresultproxy", module_methods, #if PY_MAJOR_VERSION >= 3
"Module containing C versions of core ResultProxy classes."); m = PyModule_Create(&module_def);
#else
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
#endif
if (m == NULL) if (m == NULL)
return; INITERROR;
Py_INCREF(&BaseRowProxyType); Py_INCREF(&BaseRowProxyType);
PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType); PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType);
#if PY_MAJOR_VERSION >= 3
return m;
#endif
} }

View File

@ -0,0 +1,225 @@
/*
utils.c
Copyright (C) 2012-2017 the SQLAlchemy authors and contributors <see AUTHORS file>
This module is part of SQLAlchemy and is released under
the MIT License: http://www.opensource.org/licenses/mit-license.php
*/
#include <Python.h>
#define MODULE_NAME "cutils"
#define MODULE_DOC "Module containing C versions of utility functions."
/*
Given arguments from the calling form *multiparams, **params,
return a list of bind parameter structures, usually a list of
dictionaries.
In the case of 'raw' execution which accepts positional parameters,
it may be a list of tuples or lists.
*/
static PyObject *
distill_params(PyObject *self, PyObject *args)
{
PyObject *multiparams, *params;
PyObject *enclosing_list, *double_enclosing_list;
PyObject *zero_element, *zero_element_item;
Py_ssize_t multiparam_size, zero_element_length;
if (!PyArg_UnpackTuple(args, "_distill_params", 2, 2, &multiparams, &params)) {
return NULL;
}
if (multiparams != Py_None) {
multiparam_size = PyTuple_Size(multiparams);
if (multiparam_size < 0) {
return NULL;
}
}
else {
multiparam_size = 0;
}
if (multiparam_size == 0) {
if (params != Py_None && PyDict_Size(params) != 0) {
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(params);
if (PyList_SetItem(enclosing_list, 0, params) == -1) {
Py_DECREF(params);
Py_DECREF(enclosing_list);
return NULL;
}
}
else {
enclosing_list = PyList_New(0);
if (enclosing_list == NULL) {
return NULL;
}
}
return enclosing_list;
}
else if (multiparam_size == 1) {
zero_element = PyTuple_GetItem(multiparams, 0);
if (PyTuple_Check(zero_element) || PyList_Check(zero_element)) {
zero_element_length = PySequence_Length(zero_element);
if (zero_element_length != 0) {
zero_element_item = PySequence_GetItem(zero_element, 0);
if (zero_element_item == NULL) {
return NULL;
}
}
else {
zero_element_item = NULL;
}
if (zero_element_length == 0 ||
(
PyObject_HasAttrString(zero_element_item, "__iter__") &&
!PyObject_HasAttrString(zero_element_item, "strip")
)
) {
/*
* execute(stmt, [{}, {}, {}, ...])
* execute(stmt, [(), (), (), ...])
*/
Py_XDECREF(zero_element_item);
Py_INCREF(zero_element);
return zero_element;
}
else {
/*
* execute(stmt, ("value", "value"))
*/
Py_XDECREF(zero_element_item);
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(zero_element);
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
return NULL;
}
return enclosing_list;
}
}
else if (PyObject_HasAttrString(zero_element, "keys")) {
/*
* execute(stmt, {"key":"value"})
*/
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(zero_element);
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
return NULL;
}
return enclosing_list;
} else {
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
double_enclosing_list = PyList_New(1);
if (double_enclosing_list == NULL) {
Py_DECREF(enclosing_list);
return NULL;
}
Py_INCREF(zero_element);
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
Py_DECREF(double_enclosing_list);
return NULL;
}
if (PyList_SetItem(double_enclosing_list, 0, enclosing_list) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
Py_DECREF(double_enclosing_list);
return NULL;
}
return double_enclosing_list;
}
}
else {
zero_element = PyTuple_GetItem(multiparams, 0);
if (PyObject_HasAttrString(zero_element, "__iter__") &&
!PyObject_HasAttrString(zero_element, "strip")
) {
Py_INCREF(multiparams);
return multiparams;
}
else {
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(multiparams);
if (PyList_SetItem(enclosing_list, 0, multiparams) == -1) {
Py_DECREF(multiparams);
Py_DECREF(enclosing_list);
return NULL;
}
return enclosing_list;
}
}
}
static PyMethodDef module_methods[] = {
{"_distill_params", distill_params, METH_VARARGS,
"Distill an execute() parameter structure."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
MODULE_NAME,
MODULE_DOC,
-1,
module_methods
};
#endif
#if PY_MAJOR_VERSION >= 3
PyMODINIT_FUNC
PyInit_cutils(void)
#else
PyMODINIT_FUNC
initcutils(void)
#endif
{
PyObject *m;
#if PY_MAJOR_VERSION >= 3
m = PyModule_Create(&module_def);
#else
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
#endif
#if PY_MAJOR_VERSION >= 3
if (m == NULL)
return NULL;
return m;
#else
if (m == NULL)
return;
#endif
}

View File

@ -1,27 +1,26 @@
# __init__.py # databases/__init__.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.sqlite import base as sqlite """Include imports from the sqlalchemy.dialects package for backwards
from sqlalchemy.dialects.postgresql import base as postgresql compatibility with pre 0.6 versions.
"""
from ..dialects.sqlite import base as sqlite
from ..dialects.postgresql import base as postgresql
postgres = postgresql postgres = postgresql
from sqlalchemy.dialects.mysql import base as mysql from ..dialects.mysql import base as mysql
from sqlalchemy.dialects.oracle import base as oracle from ..dialects.oracle import base as oracle
from sqlalchemy.dialects.firebird import base as firebird from ..dialects.firebird import base as firebird
from sqlalchemy.dialects.maxdb import base as maxdb from ..dialects.mssql import base as mssql
from sqlalchemy.dialects.informix import base as informix from ..dialects.sybase import base as sybase
from sqlalchemy.dialects.mssql import base as mssql
from sqlalchemy.dialects.access import base as access
from sqlalchemy.dialects.sybase import base as sybase
__all__ = ( __all__ = (
'access',
'firebird', 'firebird',
'informix',
'maxdb',
'mssql', 'mssql',
'mysql', 'mysql',
'postgresql', 'postgresql',

View File

@ -1,6 +1,13 @@
from sqlalchemy.dialects.firebird import base, kinterbasdb # firebird/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
base.dialect = kinterbasdb.dialect from sqlalchemy.dialects.firebird import base, kinterbasdb, fdb
base.dialect = fdb.dialect
from sqlalchemy.dialects.firebird.base import \ from sqlalchemy.dialects.firebird.base import \
SMALLINT, BIGINT, FLOAT, FLOAT, DATE, TIME, \ SMALLINT, BIGINT, FLOAT, FLOAT, DATE, TIME, \
@ -12,5 +19,3 @@ __all__ = (
'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB', 'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB',
'dialect' 'dialect'
) )

View File

@ -1,16 +1,17 @@
# firebird.py # firebird/base.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
""" r"""
Support for the Firebird database.
Connectivity is usually supplied via the kinterbasdb_ DBAPI module. .. dialect:: firebird
:name: Firebird
Dialects Firebird Dialects
~~~~~~~~ -----------------
Firebird offers two distinct dialects_ (not to be confused with a Firebird offers two distinct dialects_ (not to be confused with a
SQLAlchemy ``Dialect``): SQLAlchemy ``Dialect``):
@ -27,7 +28,7 @@ support for dialect 1 is not well tested and probably has
incompatibilities. incompatibilities.
Locking Behavior Locking Behavior
~~~~~~~~~~~~~~~~ ----------------
Firebird locks tables aggressively. For this reason, a DROP TABLE may Firebird locks tables aggressively. For this reason, a DROP TABLE may
hang until other transactions are released. SQLAlchemy does its best hang until other transactions are released. SQLAlchemy does its best
@ -47,20 +48,20 @@ The above use case can be alleviated by calling ``first()`` on the
all remaining cursor/connection resources. all remaining cursor/connection resources.
RETURNING support RETURNING support
~~~~~~~~~~~~~~~~~ -----------------
Firebird 2.0 supports returning a result set from inserts, and 2.1 Firebird 2.0 supports returning a result set from inserts, and 2.1
extends that to deletes and updates. This is generically exposed by extends that to deletes and updates. This is generically exposed by
the SQLAlchemy ``returning()`` method, such as:: the SQLAlchemy ``returning()`` method, such as::
# INSERT..RETURNING # INSERT..RETURNING
result = table.insert().returning(table.c.col1, table.c.col2).\\ result = table.insert().returning(table.c.col1, table.c.col2).\
values(name='foo') values(name='foo')
print result.fetchall() print result.fetchall()
# UPDATE..RETURNING # UPDATE..RETURNING
raises = empl.update().returning(empl.c.id, empl.c.salary).\\ raises = empl.update().returning(empl.c.id, empl.c.salary).\
where(empl.c.sales>100).\\ where(empl.c.sales>100).\
values(dict(salary=empl.c.salary * 1.1)) values(dict(salary=empl.c.salary * 1.1))
print raises.fetchall() print raises.fetchall()
@ -69,18 +70,17 @@ the SQLAlchemy ``returning()`` method, such as::
""" """
import datetime, re import datetime
from sqlalchemy import schema as sa_schema from sqlalchemy import schema as sa_schema
from sqlalchemy import exc, types as sqltypes, sql, util from sqlalchemy import exc, types as sqltypes, sql, util
from sqlalchemy.sql import expression from sqlalchemy.sql import expression
from sqlalchemy.engine import base, default, reflection from sqlalchemy.engine import base, default, reflection
from sqlalchemy.sql import compiler from sqlalchemy.sql import compiler
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.types import (BIGINT, BLOB, DATE, FLOAT, INTEGER, NUMERIC,
from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE, SMALLINT, TEXT, TIME, TIMESTAMP, Integer)
FLOAT, INTEGER, NUMERIC, SMALLINT,
TEXT, TIME, TIMESTAMP, VARCHAR)
RESERVED_WORDS = set([ RESERVED_WORDS = set([
@ -123,18 +123,52 @@ RESERVED_WORDS = set([
]) ])
class _StringType(sqltypes.String):
"""Base for Firebird string types."""
def __init__(self, charset=None, **kw):
self.charset = charset
super(_StringType, self).__init__(**kw)
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""Firebird VARCHAR type"""
__visit_name__ = 'VARCHAR'
def __init__(self, length=None, **kwargs):
super(VARCHAR, self).__init__(length=length, **kwargs)
class CHAR(_StringType, sqltypes.CHAR):
"""Firebird CHAR type"""
__visit_name__ = 'CHAR'
def __init__(self, length=None, **kwargs):
super(CHAR, self).__init__(length=length, **kwargs)
class _FBDateTime(sqltypes.DateTime):
def bind_processor(self, dialect):
def process(value):
if type(value) == datetime.date:
return datetime.datetime(value.year, value.month, value.day)
else:
return value
return process
colspecs = { colspecs = {
sqltypes.DateTime: _FBDateTime
} }
ischema_names = { ischema_names = {
'SHORT': SMALLINT, 'SHORT': SMALLINT,
'LONG': BIGINT, 'LONG': INTEGER,
'QUAD': FLOAT, 'QUAD': FLOAT,
'FLOAT': FLOAT, 'FLOAT': FLOAT,
'DATE': DATE, 'DATE': DATE,
'TIME': TIME, 'TIME': TIME,
'TEXT': TEXT, 'TEXT': TEXT,
'INT64': NUMERIC, 'INT64': BIGINT,
'DOUBLE': FLOAT, 'DOUBLE': FLOAT,
'TIMESTAMP': TIMESTAMP, 'TIMESTAMP': TIMESTAMP,
'VARYING': VARCHAR, 'VARYING': VARCHAR,
@ -143,41 +177,86 @@ ischema_names = {
} }
# TODO: date conversion types (should be implemented as _FBDateTime, _FBDate, etc. # TODO: date conversion types (should be implemented as _FBDateTime,
# as bind/result functionality is required) # _FBDate, etc. as bind/result functionality is required)
class FBTypeCompiler(compiler.GenericTypeCompiler): class FBTypeCompiler(compiler.GenericTypeCompiler):
def visit_boolean(self, type_): def visit_boolean(self, type_, **kw):
return self.visit_SMALLINT(type_) return self.visit_SMALLINT(type_, **kw)
def visit_datetime(self, type_): def visit_datetime(self, type_, **kw):
return self.visit_TIMESTAMP(type_) return self.visit_TIMESTAMP(type_, **kw)
def visit_TEXT(self, type_): def visit_TEXT(self, type_, **kw):
return "BLOB SUB_TYPE 1" return "BLOB SUB_TYPE 1"
def visit_BLOB(self, type_): def visit_BLOB(self, type_, **kw):
return "BLOB SUB_TYPE 0" return "BLOB SUB_TYPE 0"
def _extend_string(self, type_, basic):
charset = getattr(type_, 'charset', None)
if charset is None:
return basic
else:
return '%s CHARACTER SET %s' % (basic, charset)
def visit_CHAR(self, type_, **kw):
basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw)
return self._extend_string(type_, basic)
def visit_VARCHAR(self, type_, **kw):
if not type_.length:
raise exc.CompileError(
"VARCHAR requires a length on dialect %s" %
self.dialect.name)
basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw)
return self._extend_string(type_, basic)
class FBCompiler(sql.compiler.SQLCompiler): class FBCompiler(sql.compiler.SQLCompiler):
"""Firebird specific idiosincrasies""" """Firebird specific idiosyncrasies"""
def visit_mod(self, binary, **kw): ansi_bind_rules = True
# Firebird lacks a builtin modulo operator, but there is
# an equivalent function in the ib_udf library. # def visit_contains_op_binary(self, binary, operator, **kw):
return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) # cant use CONTAINING b.c. it's case insensitive.
# def visit_notcontains_op_binary(self, binary, operator, **kw):
# cant use NOT CONTAINING b.c. it's case insensitive.
def visit_now_func(self, fn, **kw):
return "CURRENT_TIMESTAMP"
def visit_startswith_op_binary(self, binary, operator, **kw):
return '%s STARTING WITH %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw))
def visit_notstartswith_op_binary(self, binary, operator, **kw):
return '%s NOT STARTING WITH %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw))
def visit_mod_binary(self, binary, operator, **kw):
return "mod(%s, %s)" % (
self.process(binary.left, **kw),
self.process(binary.right, **kw))
def visit_alias(self, alias, asfrom=False, **kwargs): def visit_alias(self, alias, asfrom=False, **kwargs):
if self.dialect._version_two: if self.dialect._version_two:
return super(FBCompiler, self).visit_alias(alias, asfrom=asfrom, **kwargs) return super(FBCompiler, self).\
visit_alias(alias, asfrom=asfrom, **kwargs)
else: else:
# Override to not use the AS keyword which FB 1.5 does not like # Override to not use the AS keyword which FB 1.5 does not like
if asfrom: if asfrom:
alias_name = isinstance(alias.name, expression._generated_label) and \ alias_name = isinstance(alias.name,
self._truncated_identifier("alias", alias.name) or alias.name expression._truncated_label) and \
self._truncated_identifier("alias",
alias.name) or alias.name
return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + \ return self.process(
alias.original, asfrom=asfrom, **kwargs) + \
" " + \
self.preparer.format_alias(alias, alias_name) self.preparer.format_alias(alias, alias_name)
else: else:
return self.process(alias.original, **kwargs) return self.process(alias.original, **kwargs)
@ -200,8 +279,12 @@ class FBCompiler(sql.compiler.SQLCompiler):
visit_char_length_func = visit_length_func visit_char_length_func = visit_length_func
def function_argspec(self, func, **kw): def function_argspec(self, func, **kw):
# TODO: this probably will need to be
# narrowed to a fixed list, some no-arg functions
# may require parens - see similar example in the oracle
# dialect
if func.clauses is not None and len(func.clauses): if func.clauses is not None and len(func.clauses):
return self.process(func.clause_expr) return self.process(func.clause_expr, **kw)
else: else:
return "" return ""
@ -211,41 +294,37 @@ class FBCompiler(sql.compiler.SQLCompiler):
def visit_sequence(self, seq): def visit_sequence(self, seq):
return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
def get_select_precolumns(self, select): def get_select_precolumns(self, select, **kw):
"""Called when building a ``SELECT`` statement, position is just """Called when building a ``SELECT`` statement, position is just
before column list Firebird puts the limit and offset right before column list Firebird puts the limit and offset right
after the ``SELECT``... after the ``SELECT``...
""" """
result = "" result = ""
if select._limit: if select._limit_clause is not None:
result += "FIRST %d " % select._limit result += "FIRST %s " % self.process(select._limit_clause, **kw)
if select._offset: if select._offset_clause is not None:
result +="SKIP %d " % select._offset result += "SKIP %s " % self.process(select._offset_clause, **kw)
if select._distinct: if select._distinct:
result += "DISTINCT " result += "DISTINCT "
return result return result
def limit_clause(self, select): def limit_clause(self, select, **kw):
"""Already taken care of in the `get_select_precolumns` method.""" """Already taken care of in the `get_select_precolumns` method."""
return "" return ""
def returning_clause(self, stmt, returning_cols): def returning_clause(self, stmt, returning_cols):
columns = [ columns = [
self.process( self._label_select_column(None, c, True, False, {})
self.label_select_column(None, c, asfrom=False),
within_columns_clause=True,
result_map=self.result_map
)
for c in expression._select_iterables(returning_cols) for c in expression._select_iterables(returning_cols)
] ]
return 'RETURNING ' + ', '.join(columns) return 'RETURNING ' + ', '.join(columns)
class FBDDLCompiler(sql.compiler.DDLCompiler): class FBDDLCompiler(sql.compiler.DDLCompiler):
"""Firebird syntactic idiosincrasies""" """Firebird syntactic idiosyncrasies"""
def visit_create_sequence(self, create): def visit_create_sequence(self, create):
"""Generate a ``CREATE GENERATOR`` statement for the sequence.""" """Generate a ``CREATE GENERATOR`` statement for the sequence."""
@ -253,39 +332,50 @@ class FBDDLCompiler(sql.compiler.DDLCompiler):
# no syntax for these # no syntax for these
# http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html # http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html
if create.element.start is not None: if create.element.start is not None:
raise NotImplemented("Firebird SEQUENCE doesn't support START WITH") raise NotImplemented(
"Firebird SEQUENCE doesn't support START WITH")
if create.element.increment is not None: if create.element.increment is not None:
raise NotImplemented("Firebird SEQUENCE doesn't support INCREMENT BY") raise NotImplemented(
"Firebird SEQUENCE doesn't support INCREMENT BY")
if self.dialect._version_two: if self.dialect._version_two:
return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) return "CREATE SEQUENCE %s" % \
self.preparer.format_sequence(create.element)
else: else:
return "CREATE GENERATOR %s" % self.preparer.format_sequence(create.element) return "CREATE GENERATOR %s" % \
self.preparer.format_sequence(create.element)
def visit_drop_sequence(self, drop): def visit_drop_sequence(self, drop):
"""Generate a ``DROP GENERATOR`` statement for the sequence.""" """Generate a ``DROP GENERATOR`` statement for the sequence."""
if self.dialect._version_two: if self.dialect._version_two:
return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) return "DROP SEQUENCE %s" % \
self.preparer.format_sequence(drop.element)
else: else:
return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element) return "DROP GENERATOR %s" % \
self.preparer.format_sequence(drop.element)
class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
"""Install Firebird specific reserved words.""" """Install Firebird specific reserved words."""
reserved_words = RESERVED_WORDS reserved_words = RESERVED_WORDS
illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union(
['_'])
def __init__(self, dialect): def __init__(self, dialect):
super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
class FBExecutionContext(default.DefaultExecutionContext): class FBExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq): def fire_sequence(self, seq, type_):
"""Get the next value from the sequence using ``gen_id()``.""" """Get the next value from the sequence using ``gen_id()``."""
return self._execute_scalar("SELECT gen_id(%s, 1) FROM rdb$database" % \ return self._execute_scalar(
self.dialect.identifier_preparer.format_sequence(seq)) "SELECT gen_id(%s, 1) FROM rdb$database" %
self.dialect.identifier_preparer.format_sequence(seq),
type_
)
class FBDialect(default.DefaultDialect): class FBDialect(default.DefaultDialect):
@ -305,7 +395,6 @@ class FBDialect(default.DefaultDialect):
requires_name_normalize = True requires_name_normalize = True
supports_empty_insert = False supports_empty_insert = False
statement_compiler = FBCompiler statement_compiler = FBCompiler
ddl_compiler = FBDDLCompiler ddl_compiler = FBDDLCompiler
preparer = FBIdentifierPreparer preparer = FBIdentifierPreparer
@ -315,6 +404,8 @@ class FBDialect(default.DefaultDialect):
colspecs = colspecs colspecs = colspecs
ischema_names = ischema_names ischema_names = ischema_names
construct_arguments = []
# defaults to dialect ver. 3, # defaults to dialect ver. 3,
# will be autodetected off upon # will be autodetected off upon
# first connect # first connect
@ -322,7 +413,13 @@ class FBDialect(default.DefaultDialect):
def initialize(self, connection): def initialize(self, connection):
super(FBDialect, self).initialize(connection) super(FBDialect, self).initialize(connection)
self._version_two = self.server_version_info > (2, ) self._version_two = ('firebird' in self.server_version_info and
self.server_version_info >= (2, )
) or \
('interbase' in self.server_version_info and
self.server_version_info >= (6, )
)
if not self._version_two: if not self._version_two:
# TODO: whatever other pre < 2.0 stuff goes here # TODO: whatever other pre < 2.0 stuff goes here
self.ischema_names = ischema_names.copy() self.ischema_names = ischema_names.copy()
@ -330,8 +427,9 @@ class FBDialect(default.DefaultDialect):
self.colspecs = { self.colspecs = {
sqltypes.DateTime: sqltypes.DATE sqltypes.DateTime: sqltypes.DATE
} }
else:
self.implicit_returning = True self.implicit_returning = self._version_two and \
self.__dict__.get('implicit_returning', True)
def normalize_name(self, name): def normalize_name(self, name):
# Remove trailing spaces: FB uses a CHAR() type, # Remove trailing spaces: FB uses a CHAR() type,
@ -342,6 +440,8 @@ class FBDialect(default.DefaultDialect):
elif name.upper() == name and \ elif name.upper() == name and \
not self.identifier_preparer._requires_quotes(name.lower()): not self.identifier_preparer._requires_quotes(name.lower()):
return name.lower() return name.lower()
elif name.lower() == name:
return quoted_name(name, quote=True)
else: else:
return name return name
@ -355,10 +455,11 @@ class FBDialect(default.DefaultDialect):
return name return name
def has_table(self, connection, table_name, schema=None): def has_table(self, connection, table_name, schema=None):
"""Return ``True`` if the given table exists, ignoring the `schema`.""" """Return ``True`` if the given table exists, ignoring
the `schema`."""
tblqry = """ tblqry = """
SELECT 1 FROM rdb$database SELECT 1 AS has_table FROM rdb$database
WHERE EXISTS (SELECT rdb$relation_name WHERE EXISTS (SELECT rdb$relation_name
FROM rdb$relations FROM rdb$relations
WHERE rdb$relation_name=?) WHERE rdb$relation_name=?)
@ -370,7 +471,7 @@ class FBDialect(default.DefaultDialect):
"""Return ``True`` if the given sequence (generator) exists.""" """Return ``True`` if the given sequence (generator) exists."""
genqry = """ genqry = """
SELECT 1 FROM rdb$database SELECT 1 AS has_sequence FROM rdb$database
WHERE EXISTS (SELECT rdb$generator_name WHERE EXISTS (SELECT rdb$generator_name
FROM rdb$generators FROM rdb$generators
WHERE rdb$generator_name=?) WHERE rdb$generator_name=?)
@ -380,18 +481,34 @@ class FBDialect(default.DefaultDialect):
@reflection.cache @reflection.cache
def get_table_names(self, connection, schema=None, **kw): def get_table_names(self, connection, schema=None, **kw):
# there are two queries commonly mentioned for this.
# this one, using view_blr, is at the Firebird FAQ among other places:
# http://www.firebirdfaq.org/faq174/
s = """ s = """
SELECT DISTINCT rdb$relation_name select rdb$relation_name
FROM rdb$relation_fields from rdb$relations
WHERE rdb$system_flag=0 AND rdb$view_context IS NULL where rdb$view_blr is null
and (rdb$system_flag is null or rdb$system_flag = 0);
""" """
# the other query is this one. It's not clear if there's really
# any difference between these two. This link:
# http://www.alberton.info/firebird_sql_meta_info.html#.Ur3vXfZGni8
# states them as interchangeable. Some discussion at [ticket:2898]
# SELECT DISTINCT rdb$relation_name
# FROM rdb$relation_fields
# WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
return [self.normalize_name(row[0]) for row in connection.execute(s)] return [self.normalize_name(row[0]) for row in connection.execute(s)]
@reflection.cache @reflection.cache
def get_view_names(self, connection, schema=None, **kw): def get_view_names(self, connection, schema=None, **kw):
# see http://www.firebirdfaq.org/faq174/
s = """ s = """
SELECT distinct rdb$view_name select rdb$relation_name
FROM rdb$view_relations from rdb$relations
where rdb$view_blr is not null
and (rdb$system_flag is null or rdb$system_flag = 0);
""" """
return [self.normalize_name(row[0]) for row in connection.execute(s)] return [self.normalize_name(row[0]) for row in connection.execute(s)]
@ -410,7 +527,7 @@ class FBDialect(default.DefaultDialect):
return None return None
@reflection.cache @reflection.cache
def get_primary_keys(self, connection, table_name, schema=None, **kw): def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# Query to extract the PK/FK constrained fields of the given table # Query to extract the PK/FK constrained fields of the given table
keyqry = """ keyqry = """
SELECT se.rdb$field_name AS fname SELECT se.rdb$field_name AS fname
@ -422,10 +539,12 @@ class FBDialect(default.DefaultDialect):
# get primary key fields # get primary key fields
c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()] pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()]
return pkfields return {'constrained_columns': pkfields, 'name': None}
@reflection.cache @reflection.cache
def get_column_sequence(self, connection, table_name, column_name, schema=None, **kw): def get_column_sequence(self, connection,
table_name, column_name,
schema=None, **kw):
tablename = self.denormalize_name(table_name) tablename = self.denormalize_name(table_name)
colname = self.denormalize_name(column_name) colname = self.denormalize_name(column_name)
# Heuristic-query to determine the generator associated to a PK field # Heuristic-query to determine the generator associated to a PK field
@ -436,7 +555,8 @@ class FBDialect(default.DefaultDialect):
ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
AND trigdep.rdb$depended_on_type=14 AND trigdep.rdb$depended_on_type=14
AND trigdep.rdb$dependent_type=2 AND trigdep.rdb$dependent_type=2
JOIN rdb$triggers trig ON trig.rdb$trigger_name=tabdep.rdb$dependent_name JOIN rdb$triggers trig ON
trig.rdb$trigger_name=tabdep.rdb$dependent_name
WHERE tabdep.rdb$depended_on_name=? WHERE tabdep.rdb$depended_on_name=?
AND tabdep.rdb$depended_on_type=0 AND tabdep.rdb$depended_on_type=0
AND trig.rdb$trigger_type=1 AND trig.rdb$trigger_type=1
@ -453,24 +573,29 @@ class FBDialect(default.DefaultDialect):
def get_columns(self, connection, table_name, schema=None, **kw): def get_columns(self, connection, table_name, schema=None, **kw):
# Query to extract the details of all the fields of the given table # Query to extract the details of all the fields of the given table
tblqry = """ tblqry = """
SELECT DISTINCT r.rdb$field_name AS fname, SELECT r.rdb$field_name AS fname,
r.rdb$null_flag AS null_flag, r.rdb$null_flag AS null_flag,
t.rdb$type_name AS ftype, t.rdb$type_name AS ftype,
f.rdb$field_sub_type AS stype, f.rdb$field_sub_type AS stype,
f.rdb$field_length/COALESCE(cs.rdb$bytes_per_character,1) AS flen, f.rdb$field_length/
COALESCE(cs.rdb$bytes_per_character,1) AS flen,
f.rdb$field_precision AS fprec, f.rdb$field_precision AS fprec,
f.rdb$field_scale AS fscale, f.rdb$field_scale AS fscale,
COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault COALESCE(r.rdb$default_source,
f.rdb$default_source) AS fdefault
FROM rdb$relation_fields r FROM rdb$relation_fields r
JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
JOIN rdb$types t JOIN rdb$types t
ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE' ON t.rdb$type=f.rdb$field_type AND
LEFT JOIN rdb$character_sets cs ON f.rdb$character_set_id=cs.rdb$character_set_id t.rdb$field_name='RDB$FIELD_TYPE'
LEFT JOIN rdb$character_sets cs ON
f.rdb$character_set_id=cs.rdb$character_set_id
WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=? WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
ORDER BY r.rdb$field_position ORDER BY r.rdb$field_position
""" """
# get the PK, used to determine the eventual associated sequence # get the PK, used to determine the eventual associated sequence
pkey_cols = self.get_primary_keys(connection, table_name) pk_constraint = self.get_pk_constraint(connection, table_name)
pkey_cols = pk_constraint['constrained_columns']
tablename = self.denormalize_name(table_name) tablename = self.denormalize_name(table_name)
# get all of the fields for this table # get all of the fields for this table
@ -490,8 +615,10 @@ class FBDialect(default.DefaultDialect):
util.warn("Did not recognize type '%s' of column '%s'" % util.warn("Did not recognize type '%s' of column '%s'" %
(colspec, name)) (colspec, name))
coltype = sqltypes.NULLTYPE coltype = sqltypes.NULLTYPE
elif colspec == 'INT64': elif issubclass(coltype, Integer) and row['fprec'] != 0:
coltype = coltype(precision=row['fprec'], scale=row['fscale'] * -1) coltype = NUMERIC(
precision=row['fprec'],
scale=row['fscale'] * -1)
elif colspec in ('VARYING', 'CSTRING'): elif colspec in ('VARYING', 'CSTRING'):
coltype = coltype(row['flen']) coltype = coltype(row['flen'])
elif colspec == 'TEXT': elif colspec == 'TEXT':
@ -502,16 +629,19 @@ class FBDialect(default.DefaultDialect):
else: else:
coltype = BLOB() coltype = BLOB()
else: else:
coltype = coltype(row) coltype = coltype()
# does it have a default value? # does it have a default value?
defvalue = None defvalue = None
if row['fdefault'] is not None: if row['fdefault'] is not None:
# the value comes down as "DEFAULT 'value'": there may be # the value comes down as "DEFAULT 'value'": there may be
# more than one whitespace around the "DEFAULT" keyword # more than one whitespace around the "DEFAULT" keyword
# and it may also be lower case
# (see also http://tracker.firebirdsql.org/browse/CORE-356) # (see also http://tracker.firebirdsql.org/browse/CORE-356)
defexpr = row['fdefault'].lstrip() defexpr = row['fdefault'].lstrip()
assert defexpr[:8].rstrip()=='DEFAULT', "Unrecognized default value: %s" % defexpr assert defexpr[:8].rstrip().upper() == \
'DEFAULT', "Unrecognized default value: %s" % \
defexpr
defvalue = defexpr[8:].strip() defvalue = defexpr[8:].strip()
if defvalue == 'NULL': if defvalue == 'NULL':
# Redundant # Redundant
@ -520,7 +650,8 @@ class FBDialect(default.DefaultDialect):
'name': name, 'name': name,
'type': coltype, 'type': coltype,
'nullable': not bool(row['null_flag']), 'nullable': not bool(row['null_flag']),
'default' : defvalue 'default': defvalue,
'autoincrement': 'auto',
} }
if orig_colname.lower() == orig_colname: if orig_colname.lower() == orig_colname:
@ -547,7 +678,8 @@ class FBDialect(default.DefaultDialect):
FROM rdb$relation_constraints rc FROM rdb$relation_constraints rc
JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name JOIN rdb$index_segments cse ON
cse.rdb$index_name=ix1.rdb$index_name
JOIN rdb$index_segments se JOIN rdb$index_segments se
ON se.rdb$index_name=ix2.rdb$index_name ON se.rdb$index_name=ix2.rdb$index_name
AND se.rdb$field_position=cse.rdb$field_position AND se.rdb$field_position=cse.rdb$field_position
@ -571,10 +703,11 @@ class FBDialect(default.DefaultDialect):
if not fk['name']: if not fk['name']:
fk['name'] = cname fk['name'] = cname
fk['referred_table'] = self.normalize_name(row['targetrname']) fk['referred_table'] = self.normalize_name(row['targetrname'])
fk['constrained_columns'].append(self.normalize_name(row['fname'])) fk['constrained_columns'].append(
self.normalize_name(row['fname']))
fk['referred_columns'].append( fk['referred_columns'].append(
self.normalize_name(row['targetfname'])) self.normalize_name(row['targetfname']))
return fks.values() return list(fks.values())
@reflection.cache @reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw): def get_indexes(self, connection, table_name, schema=None, **kw):
@ -586,10 +719,11 @@ class FBDialect(default.DefaultDialect):
JOIN rdb$index_segments ic JOIN rdb$index_segments ic
ON ix.rdb$index_name=ic.rdb$index_name ON ix.rdb$index_name=ic.rdb$index_name
LEFT OUTER JOIN rdb$relation_constraints LEFT OUTER JOIN rdb$relation_constraints
ON rdb$relation_constraints.rdb$index_name = ic.rdb$index_name ON rdb$relation_constraints.rdb$index_name =
ic.rdb$index_name
WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
AND rdb$relation_constraints.rdb$constraint_type IS NULL AND rdb$relation_constraints.rdb$constraint_type IS NULL
ORDER BY index_name, field_name ORDER BY index_name, ic.rdb$field_position
""" """
c = connection.execute(qry, [self.denormalize_name(table_name)]) c = connection.execute(qry, [self.denormalize_name(table_name)])
@ -601,19 +735,7 @@ class FBDialect(default.DefaultDialect):
indexrec['column_names'] = [] indexrec['column_names'] = []
indexrec['unique'] = bool(row['unique_flag']) indexrec['unique'] = bool(row['unique_flag'])
indexrec['column_names'].append(self.normalize_name(row['field_name'])) indexrec['column_names'].append(
self.normalize_name(row['field_name']))
return indexes.values() return list(indexes.values())
def do_execute(self, cursor, statement, parameters, **kwargs):
# kinterbase does not accept a None, but wants an empty list
# when there are no arguments.
cursor.execute(statement, parameters or [])
def do_rollback(self, connection):
# Use the retaining feature, that keeps the transaction going
connection.rollback(True)
def do_commit(self, connection):
# Use the retaining feature, that keeps the transaction going
connection.commit(True)

View File

@ -0,0 +1,118 @@
# firebird/fdb.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: firebird+fdb
:name: fdb
:dbapi: pyodbc
:connectstring: firebird+fdb://user:password@host:port/path/to/db\
[?key=value&key=value...]
:url: http://pypi.python.org/pypi/fdb/
fdb is a kinterbasdb compatible DBAPI for Firebird.
.. versionadded:: 0.8 - Support for the fdb Firebird driver.
.. versionchanged:: 0.9 - The fdb dialect is now the default dialect
under the ``firebird://`` URL space, as ``fdb`` is now the official
Python driver for Firebird.
Arguments
----------
The ``fdb`` dialect is based on the
:mod:`sqlalchemy.dialects.firebird.kinterbasdb` dialect, however does not
accept every argument that Kinterbasdb does.
* ``enable_rowcount`` - True by default, setting this to False disables
the usage of "cursor.rowcount" with the
Kinterbasdb dialect, which SQLAlchemy ordinarily calls upon automatically
after any UPDATE or DELETE statement. When disabled, SQLAlchemy's
ResultProxy will return -1 for result.rowcount. The rationale here is
that Kinterbasdb requires a second round trip to the database when
.rowcount is called - since SQLA's resultproxy automatically closes
the cursor after a non-result-returning statement, rowcount must be
called, if at all, before the result object is returned. Additionally,
cursor.rowcount may not return correct results with older versions
of Firebird, and setting this flag to False will also cause the
SQLAlchemy ORM to ignore its usage. The behavior can also be controlled on a
per-execution basis using the ``enable_rowcount`` option with
:meth:`.Connection.execution_options`::
conn = engine.connect().execution_options(enable_rowcount=True)
r = conn.execute(stmt)
print r.rowcount
* ``retaining`` - False by default. Setting this to True will pass the
``retaining=True`` keyword argument to the ``.commit()`` and ``.rollback()``
methods of the DBAPI connection, which can improve performance in some
situations, but apparently with significant caveats.
Please read the fdb and/or kinterbasdb DBAPI documentation in order to
understand the implications of this flag.
.. versionadded:: 0.8.2 - ``retaining`` keyword argument specifying
transaction retaining behavior - in 0.8 it defaults to ``True``
for backwards compatibility.
.. versionchanged:: 0.9.0 - the ``retaining`` flag defaults to ``False``.
In 0.8 it defaulted to ``True``.
.. seealso::
http://pythonhosted.org/fdb/usage-guide.html#retaining-transactions
- information on the "retaining" flag.
"""
from .kinterbasdb import FBDialect_kinterbasdb
from ... import util
class FBDialect_fdb(FBDialect_kinterbasdb):
def __init__(self, enable_rowcount=True,
retaining=False, **kwargs):
super(FBDialect_fdb, self).__init__(
enable_rowcount=enable_rowcount,
retaining=retaining, **kwargs)
@classmethod
def dbapi(cls):
return __import__('fdb')
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if opts.get('port'):
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
del opts['port']
opts.update(url.query)
util.coerce_kw_type(opts, 'type_conv', int)
return ([], opts)
def _get_server_version_info(self, connection):
"""Get the version of the Firebird server used by a connection.
Returns a tuple of (`major`, `minor`, `build`), three integers
representing the version of the attached server.
"""
# This is the simpler approach (the other uses the services api),
# that for backward compatibility reasons returns a string like
# LI-V6.3.3.12981 Firebird 2.0
# where the first version is a fake one resembling the old
# Interbase signature.
isc_info_firebird_version = 103
fbconn = connection.connection
version = fbconn.db_info(isc_info_firebird_version)
return self._parse_version_info(version)
dialect = FBDialect_fdb

View File

@ -1,69 +1,119 @@
# kinterbasdb.py # firebird/kinterbasdb.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
""" """
The most common way to connect to a Firebird engine is implemented by .. dialect:: firebird+kinterbasdb
kinterbasdb__, currently maintained__ directly by the Firebird people. :name: kinterbasdb
:dbapi: kinterbasdb
:connectstring: firebird+kinterbasdb://user:password@host:port/path/to/db\
[?key=value&key=value...]
:url: http://firebirdsql.org/index.php?op=devel&sub=python
The connection URL is of the form Arguments
``firebird[+kinterbasdb]://user:password@host:port/path/to/db[?key=value&key=value...]``. ----------
Kinterbasedb backend specific keyword arguments are: The Kinterbasdb backend accepts the ``enable_rowcount`` and ``retaining``
arguments accepted by the :mod:`sqlalchemy.dialects.firebird.fdb` dialect.
In addition, it also accepts the following:
type_conv * ``type_conv`` - select the kind of mapping done on the types: by default
select the kind of mapping done on the types: by default SQLAlchemy SQLAlchemy uses 200 with Unicode, datetime and decimal support. See
uses 200 with Unicode, datetime and decimal support (see details__). the linked documents below for further information.
concurrency_level * ``concurrency_level`` - set the backend policy with regards to threading
set the backend policy with regards to threading issues: by default issues: by default SQLAlchemy uses policy 1. See the linked documents
SQLAlchemy uses policy 1 (see details__). below for further information.
.. seealso::
http://sourceforge.net/projects/kinterbasdb
http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
__ http://sourceforge.net/projects/kinterbasdb
__ http://firebirdsql.org/index.php?op=devel&sub=python
__ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
__ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
""" """
from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler from .base import FBDialect, FBExecutionContext
from sqlalchemy import util, types as sqltypes from ... import util, types as sqltypes
from re import match
import decimal
class _FBNumeric_kinterbasdb(sqltypes.Numeric):
class _kinterbasdb_numeric(object):
def bind_processor(self, dialect): def bind_processor(self, dialect):
def process(value): def process(value):
if value is not None: if isinstance(value, decimal.Decimal):
return str(value) return str(value)
else: else:
return value return value
return process return process
class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric):
pass
class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float):
pass
class FBExecutionContext_kinterbasdb(FBExecutionContext):
@property
def rowcount(self):
if self.execution_options.get('enable_rowcount',
self.dialect.enable_rowcount):
return self.cursor.rowcount
else:
return -1
class FBDialect_kinterbasdb(FBDialect): class FBDialect_kinterbasdb(FBDialect):
driver = 'kinterbasdb' driver = 'kinterbasdb'
supports_sane_rowcount = False supports_sane_rowcount = False
supports_sane_multi_rowcount = False supports_sane_multi_rowcount = False
execution_ctx_cls = FBExecutionContext_kinterbasdb
supports_native_decimal = True supports_native_decimal = True
colspecs = util.update_copy( colspecs = util.update_copy(
FBDialect.colspecs, FBDialect.colspecs,
{ {
sqltypes.Numeric:_FBNumeric_kinterbasdb sqltypes.Numeric: _FBNumeric_kinterbasdb,
sqltypes.Float: _FBFloat_kinterbasdb,
} }
) )
def __init__(self, type_conv=200, concurrency_level=1, **kwargs): def __init__(self, type_conv=200, concurrency_level=1,
enable_rowcount=True,
retaining=False, **kwargs):
super(FBDialect_kinterbasdb, self).__init__(**kwargs) super(FBDialect_kinterbasdb, self).__init__(**kwargs)
self.enable_rowcount = enable_rowcount
self.type_conv = type_conv self.type_conv = type_conv
self.concurrency_level = concurrency_level self.concurrency_level = concurrency_level
self.retaining = retaining
if enable_rowcount:
self.supports_sane_rowcount = True
@classmethod @classmethod
def dbapi(cls): def dbapi(cls):
k = __import__('kinterbasdb') return __import__('kinterbasdb')
return k
def do_execute(self, cursor, statement, parameters, context=None):
# kinterbase does not accept a None, but wants an empty list
# when there are no arguments.
cursor.execute(statement, parameters or [])
def do_rollback(self, dbapi_connection):
dbapi_connection.rollback(self.retaining)
def do_commit(self, dbapi_connection):
dbapi_connection.commit(self.retaining)
def create_connect_args(self, url): def create_connect_args(self, url):
opts = url.translate_connect_args(username='user') opts = url.translate_connect_args(username='user')
@ -72,17 +122,22 @@ class FBDialect_kinterbasdb(FBDialect):
del opts['port'] del opts['port']
opts.update(url.query) opts.update(url.query)
util.coerce_kw_type(opts, 'type_conv', int)
type_conv = opts.pop('type_conv', self.type_conv) type_conv = opts.pop('type_conv', self.type_conv)
concurrency_level = opts.pop('concurrency_level', self.concurrency_level) concurrency_level = opts.pop('concurrency_level',
self.concurrency_level)
if self.dbapi is not None: if self.dbapi is not None:
initialized = getattr(self.dbapi, 'initialized', None) initialized = getattr(self.dbapi, 'initialized', None)
if initialized is None: if initialized is None:
# CVS rev 1.96 changed the name of the attribute: # CVS rev 1.96 changed the name of the attribute:
# http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96 # http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/
# Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
initialized = getattr(self.dbapi, '_initialized', False) initialized = getattr(self.dbapi, '_initialized', False)
if not initialized: if not initialized:
self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level) self.dbapi.init(type_conv=type_conv,
concurrency_level=concurrency_level)
return ([], opts) return ([], opts)
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
@ -96,24 +151,33 @@ class FBDialect_kinterbasdb(FBDialect):
# that for backward compatibility reasons returns a string like # that for backward compatibility reasons returns a string like
# LI-V6.3.3.12981 Firebird 2.0 # LI-V6.3.3.12981 Firebird 2.0
# where the first version is a fake one resembling the old # where the first version is a fake one resembling the old
# Interbase signature. This is more than enough for our purposes, # Interbase signature.
# as this is mainly (only?) used by the testsuite.
from re import match
fbconn = connection.connection fbconn = connection.connection
version = fbconn.server_version version = fbconn.server_version
m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
if not m:
raise AssertionError("Could not determine version from string '%s'" % version)
return tuple([int(x) for x in m.group(5, 6, 4)])
def is_disconnect(self, e): return self._parse_version_info(version)
if isinstance(e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)):
def _parse_version_info(self, version):
m = match(
r'\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?', version)
if not m:
raise AssertionError(
"Could not determine version from string '%s'" % version)
if m.group(5) != None:
return tuple([int(x) for x in m.group(6, 7, 4)] + ['firebird'])
else:
return tuple([int(x) for x in m.group(1, 2, 3)] + ['interbase'])
def is_disconnect(self, e, connection, cursor):
if isinstance(e, (self.dbapi.OperationalError,
self.dbapi.ProgrammingError)):
msg = str(e) msg = str(e)
return ('Unable to complete network request to host' in msg or return ('Unable to complete network request to host' in msg or
'Invalid connection state' in msg or 'Invalid connection state' in msg or
'Invalid cursor state' in msg) 'Invalid cursor state' in msg or
'connection shutdown' in msg)
else: else:
return False return False

View File

@ -1,4 +1,12 @@
from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, pymssql, zxjdbc, mxodbc # mssql/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, \
pymssql, zxjdbc, mxodbc
base.dialect = pyodbc.dialect base.dialect = pyodbc.dialect

View File

@ -1,15 +1,34 @@
""" # mssql/adodbapi.py
The adodbapi dialect is not implemented for 0.6 at this time. # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
""" """
.. dialect:: mssql+adodbapi
:name: adodbapi
:dbapi: adodbapi
:connectstring: mssql+adodbapi://<username>:<password>@<dsnname>
:url: http://adodbapi.sourceforge.net/
.. note::
The adodbapi dialect is not implemented SQLAlchemy versions 0.6 and
above at this time.
"""
import datetime
from sqlalchemy import types as sqltypes, util from sqlalchemy import types as sqltypes, util
from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
import sys import sys
class MSDateTime_adodbapi(MSDateTime): class MSDateTime_adodbapi(MSDateTime):
def result_processor(self, dialect, coltype): def result_processor(self, dialect, coltype):
def process(value): def process(value):
# adodbapi will return datetimes with empty time values as datetime.date() objects. # adodbapi will return datetimes with empty time
# values as datetime.date() objects.
# Promote them back to full datetime.datetime() # Promote them back to full datetime.datetime()
if type(value) is datetime.date: if type(value) is datetime.date:
return datetime.datetime(value.year, value.month, value.day) return datetime.datetime(value.year, value.month, value.day)
@ -37,11 +56,19 @@ class MSDialect_adodbapi(MSDialect):
) )
def create_connect_args(self, url): def create_connect_args(self, url):
keys = url.query def check_quote(token):
if ";" in str(token):
token = "'%s'" % token
return token
keys = dict(
(k, check_quote(v)) for k, v in url.query.items()
)
connectors = ["Provider=SQLOLEDB"] connectors = ["Provider=SQLOLEDB"]
if 'port' in keys: if 'port' in keys:
connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port"))) connectors.append("Data Source=%s, %s" %
(keys.get("host"), keys.get("port")))
else: else:
connectors.append("Data Source=%s" % keys.get("host")) connectors.append("Data Source=%s" % keys.get("host"))
connectors.append("Initial Catalog=%s" % keys.get("database")) connectors.append("Initial Catalog=%s" % keys.get("database"))
@ -53,7 +80,8 @@ class MSDialect_adodbapi(MSDialect):
connectors.append("Integrated Security=SSPI") connectors.append("Integrated Security=SSPI")
return [[";".join(connectors)], {}] return [[";".join(connectors)], {}]
def is_disconnect(self, e): def is_disconnect(self, e, connection, cursor):
return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e) return isinstance(e, self.dbapi.adodbapi.DatabaseError) and \
"'connection failure'" in str(e)
dialect = MSDialect_adodbapi dialect = MSDialect_adodbapi

File diff suppressed because it is too large Load Diff

View File

@ -1,16 +1,48 @@
from sqlalchemy import Table, MetaData, Column, ForeignKey # mssql/information_schema.py
from sqlalchemy.types import String, Unicode, Integer, TypeDecorator # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
# TODO: should be using the sys. catalog with SQL Server, not information
# schema
from ... import Table, MetaData, Column
from ...types import String, Unicode, UnicodeText, Integer, TypeDecorator
from ... import cast
from ... import util
from ...sql import expression
from ...ext.compiler import compiles
ischema = MetaData() ischema = MetaData()
class CoerceUnicode(TypeDecorator): class CoerceUnicode(TypeDecorator):
impl = Unicode impl = Unicode
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
if isinstance(value, str): if util.py2k and isinstance(value, util.binary_type):
value = value.decode(dialect.encoding) value = value.decode(dialect.encoding)
return value return value
def bind_expression(self, bindvalue):
return _cast_on_2005(bindvalue)
class _cast_on_2005(expression.ColumnElement):
def __init__(self, bindvalue):
self.bindvalue = bindvalue
@compiles(_cast_on_2005)
def _compile(element, compiler, **kw):
from . import base
if compiler.dialect.server_version_info < base.MS_2005_VERSION:
return compiler.process(element.bindvalue, **kw)
else:
return compiler.process(cast(element.bindvalue, Unicode), **kw)
schemata = Table("SCHEMATA", ischema, schemata = Table("SCHEMATA", ischema,
Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
@ -21,7 +53,9 @@ tables = Table("TABLES", ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"), Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"), Column(
"TABLE_TYPE", String(convert_unicode=True),
key="table_type"),
schema="INFORMATION_SCHEMA") schema="INFORMATION_SCHEMA")
columns = Table("COLUMNS", ischema, columns = Table("COLUMNS", ischema,
@ -31,7 +65,8 @@ columns = Table("COLUMNS", ischema,
Column("IS_NULLABLE", Integer, key="is_nullable"), Column("IS_NULLABLE", Integer, key="is_nullable"),
Column("DATA_TYPE", String, key="data_type"), Column("DATA_TYPE", String, key="data_type"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"), Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"), Column("CHARACTER_MAXIMUM_LENGTH", Integer,
key="character_maximum_length"),
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
Column("NUMERIC_SCALE", Integer, key="numeric_scale"), Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
Column("COLUMN_DEFAULT", Integer, key="column_default"), Column("COLUMN_DEFAULT", Integer, key="column_default"),
@ -41,32 +76,51 @@ columns = Table("COLUMNS", ischema,
constraints = Table("TABLE_CONSTRAINTS", ischema, constraints = Table("TABLE_CONSTRAINTS", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"), Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), Column("CONSTRAINT_NAME", CoerceUnicode,
Column("CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"), key="constraint_name"),
Column("CONSTRAINT_TYPE", String(
convert_unicode=True), key="constraint_type"),
schema="INFORMATION_SCHEMA") schema="INFORMATION_SCHEMA")
column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema, column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), Column("TABLE_SCHEMA", CoerceUnicode,
Column("TABLE_NAME", CoerceUnicode, key="table_name"), key="table_schema"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"), Column("TABLE_NAME", CoerceUnicode,
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), key="table_name"),
Column("COLUMN_NAME", CoerceUnicode,
key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode,
key="constraint_name"),
schema="INFORMATION_SCHEMA") schema="INFORMATION_SCHEMA")
key_constraints = Table("KEY_COLUMN_USAGE", ischema, key_constraints = Table("KEY_COLUMN_USAGE", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), Column("TABLE_SCHEMA", CoerceUnicode,
Column("TABLE_NAME", CoerceUnicode, key="table_name"), key="table_schema"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"), Column("TABLE_NAME", CoerceUnicode,
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), key="table_name"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"), Column("COLUMN_NAME", CoerceUnicode,
key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode,
key="constraint_name"),
Column("ORDINAL_POSITION", Integer,
key="ordinal_position"),
schema="INFORMATION_SCHEMA") schema="INFORMATION_SCHEMA")
ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), Column("CONSTRAINT_CATALOG", CoerceUnicode,
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), key="constraint_catalog"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), Column("CONSTRAINT_SCHEMA", CoerceUnicode,
Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, key="unique_constraint_catalog"), # TODO: is CATLOG misspelled ? key="constraint_schema"),
Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, key="unique_constraint_schema"), Column("CONSTRAINT_NAME", CoerceUnicode,
Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"), key="constraint_name"),
# TODO: is CATLOG misspelled ?
Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode,
key="unique_constraint_catalog"),
Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode,
key="unique_constraint_schema"),
Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode,
key="unique_constraint_name"),
Column("MATCH_OPTION", String, key="match_option"), Column("MATCH_OPTION", String, key="match_option"),
Column("UPDATE_RULE", String, key="update_rule"), Column("UPDATE_RULE", String, key="update_rule"),
Column("DELETE_RULE", String, key="delete_rule"), Column("DELETE_RULE", String, key="delete_rule"),
@ -80,4 +134,3 @@ views = Table("VIEWS", ischema,
Column("CHECK_OPTION", String, key="check_option"), Column("CHECK_OPTION", String, key="check_option"),
Column("IS_UPDATABLE", String, key="is_updatable"), Column("IS_UPDATABLE", String, key="is_updatable"),
schema="INFORMATION_SCHEMA") schema="INFORMATION_SCHEMA")

View File

@ -1,55 +1,105 @@
# mssql/mxodbc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
""" """
Support for MS-SQL via mxODBC. .. dialect:: mssql+mxodbc
:name: mxODBC
mxODBC is available at: :dbapi: mxodbc
:connectstring: mssql+mxodbc://<username>:<password>@<dsnname>
http://www.egenix.com/ :url: http://www.egenix.com/
This was tested with mxODBC 3.1.2 and the SQL Server Native
Client connected to MSSQL 2005 and 2008 Express Editions.
Connecting
~~~~~~~~~~
Connection is via DSN::
mssql+mxodbc://<username>:<password>@<dsnname>
Execution Modes Execution Modes
~~~~~~~~~~~~~~~ ---------------
mxODBC features two styles of statement execution, using the ``cursor.execute()`` mxODBC features two styles of statement execution, using the
and ``cursor.executedirect()`` methods (the second being an extension to the ``cursor.execute()`` and ``cursor.executedirect()`` methods (the second being
DBAPI specification). The former makes use of the native an extension to the DBAPI specification). The former makes use of a particular
parameter binding services of the ODBC driver, while the latter uses string escaping. API call specific to the SQL Server Native Client ODBC driver known
The primary advantage to native parameter binding is that the same statement, when SQLDescribeParam, while the latter does not.
executed many times, is only prepared once. Whereas the primary advantage to the
latter is that the rules for bind parameter placement are relaxed. MS-SQL has very mxODBC apparently only makes repeated use of a single prepared statement
strict rules for native binds, including that they cannot be placed within the argument when SQLDescribeParam is used. The advantage to prepared statement reuse is
lists of function calls, anywhere outside the FROM, or even within subqueries within the one of performance. The disadvantage is that SQLDescribeParam has a limited
FROM clause - making the usage of bind parameters within SELECT statements impossible for set of scenarios in which bind parameters are understood, including that they
all but the most simplistic statements. For this reason, the mxODBC dialect uses the cannot be placed within the argument lists of function calls, anywhere outside
"native" mode by default only for INSERT, UPDATE, and DELETE statements, and uses the the FROM, or even within subqueries within the FROM clause - making the usage
escaped string mode for all other statements. This behavior can be controlled completely of bind parameters within SELECT statements impossible for all but the most
via :meth:`~sqlalchemy.sql.expression.Executable.execution_options` simplistic statements.
using the ``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a value of
``True`` will unconditionally use native bind parameters and a value of ``False`` will For this reason, the mxODBC dialect uses the "native" mode by default only for
uncondtionally use string-escaped parameters. INSERT, UPDATE, and DELETE statements, and uses the escaped string mode for
all other statements.
This behavior can be controlled via
:meth:`~sqlalchemy.sql.expression.Executable.execution_options` using the
``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a
value of ``True`` will unconditionally use native bind parameters and a value
of ``False`` will unconditionally use string-escaped parameters.
""" """
import re
import sys
from sqlalchemy import types as sqltypes from ... import types as sqltypes
from sqlalchemy import util from ...connectors.mxodbc import MxODBCConnector
from sqlalchemy.connectors.mxodbc import MxODBCConnector from .pyodbc import MSExecutionContext_pyodbc, _MSNumeric_pyodbc
from sqlalchemy.dialects.mssql.pyodbc import MSExecutionContext_pyodbc from .base import (MSDialect,
from sqlalchemy.dialects.mssql.base import (MSExecutionContext, MSDialect, MSSQLStrictCompiler,
MSSQLCompiler, MSSQLStrictCompiler, VARBINARY,
_MSDateTime, _MSDate, TIME) _MSDateTime, _MSDate, _MSTime)
class _MSNumeric_mxodbc(_MSNumeric_pyodbc):
"""Include pyodbc's numeric processor.
"""
class _MSDate_mxodbc(_MSDate):
def bind_processor(self, dialect):
def process(value):
if value is not None:
return "%s-%s-%s" % (value.year, value.month, value.day)
else:
return None
return process
class _MSTime_mxodbc(_MSTime):
def bind_processor(self, dialect):
def process(value):
if value is not None:
return "%s:%s:%s" % (value.hour, value.minute, value.second)
else:
return None
return process
class _VARBINARY_mxodbc(VARBINARY):
"""
mxODBC Support for VARBINARY column types.
This handles the special case for null VARBINARY values,
which maps None values to the mx.ODBC.Manager.BinaryNull symbol.
"""
def bind_processor(self, dialect):
if dialect.dbapi is None:
return None
DBAPIBinary = dialect.dbapi.Binary
def process(value):
if value is not None:
return DBAPIBinary(value)
else:
# should pull from mx.ODBC.Manager.BinaryNull
return dialect.dbapi.BinaryNull
return process
class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc): class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
""" """
@ -61,23 +111,29 @@ class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
# is really only being used in cases where OUTPUT # is really only being used in cases where OUTPUT
# won't work. # won't work.
class MSDialect_mxodbc(MxODBCConnector, MSDialect): class MSDialect_mxodbc(MxODBCConnector, MSDialect):
# TODO: may want to use this only if FreeTDS is not in use, # this is only needed if "native ODBC" mode is used,
# since FreeTDS doesn't seem to use native binds. # which is now disabled by default.
statement_compiler = MSSQLStrictCompiler # statement_compiler = MSSQLStrictCompiler
execution_ctx_cls = MSExecutionContext_mxodbc execution_ctx_cls = MSExecutionContext_mxodbc
# flag used by _MSNumeric_mxodbc
_need_decimal_fix = True
colspecs = { colspecs = {
#sqltypes.Numeric : _MSNumeric, sqltypes.Numeric: _MSNumeric_mxodbc,
sqltypes.DateTime: _MSDateTime, sqltypes.DateTime: _MSDateTime,
sqltypes.Date : _MSDate, sqltypes.Date: _MSDate_mxodbc,
sqltypes.Time : TIME, sqltypes.Time: _MSTime_mxodbc,
VARBINARY: _VARBINARY_mxodbc,
sqltypes.LargeBinary: _VARBINARY_mxodbc,
} }
def __init__(self, description_encoding=None, **params):
def __init__(self, description_encoding='latin-1', **params):
super(MSDialect_mxodbc, self).__init__(**params) super(MSDialect_mxodbc, self).__init__(**params)
self.description_encoding = description_encoding self.description_encoding = description_encoding
dialect = MSDialect_mxodbc dialect = MSDialect_mxodbc

View File

@ -1,41 +1,27 @@
""" # mssql/pymssql.py
Support for the pymssql dialect. # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
This dialect supports pymssql 1.0 and greater. #
# This module is part of SQLAlchemy and is released under
pymssql is available at: # the MIT License: http://www.opensource.org/licenses/mit-license.php
http://pymssql.sourceforge.net/
Connecting
^^^^^^^^^^
Sample connect string::
mssql+pymssql://<username>:<password>@<freetds_name>
Adding "?charset=utf8" or similar will cause pymssql to return
strings as Python unicode objects. This can potentially improve
performance in some scenarios as decoding of strings is
handled natively.
Limitations
^^^^^^^^^^^
pymssql inherits a lot of limitations from FreeTDS, including:
* no support for multibyte schema identifiers
* poor support for large decimals
* poor support for binary fields
* poor support for VARCHAR/CHAR fields over 255 characters
Please consult the pymssql documentation for further information.
""" """
from sqlalchemy.dialects.mssql.base import MSDialect .. dialect:: mssql+pymssql
from sqlalchemy import types as sqltypes, util, processors :name: pymssql
:dbapi: pymssql
:connectstring: mssql+pymssql://<username>:<password>@<freetds_name>/?\
charset=utf8
:url: http://pymssql.org/
pymssql is a Python module that provides a Python DBAPI interface around
`FreeTDS <http://www.freetds.org/>`_. Compatible builds are available for
Linux, MacOSX and Windows platforms.
"""
from .base import MSDialect
from ... import types as sqltypes, util, processors
import re import re
import decimal
class _MSNumeric_pymssql(sqltypes.Numeric): class _MSNumeric_pymssql(sqltypes.Numeric):
def result_processor(self, dialect, type_): def result_processor(self, dialect, type_):
@ -44,9 +30,9 @@ class _MSNumeric_pymssql(sqltypes.Numeric):
else: else:
return sqltypes.Numeric.result_processor(self, dialect, type_) return sqltypes.Numeric.result_processor(self, dialect, type_)
class MSDialect_pymssql(MSDialect): class MSDialect_pymssql(MSDialect):
supports_sane_rowcount = False supports_sane_rowcount = False
max_identifier_length = 30
driver = 'pymssql' driver = 'pymssql'
colspecs = util.update_copy( colspecs = util.update_copy(
@ -56,14 +42,16 @@ class MSDialect_pymssql(MSDialect):
sqltypes.Float: sqltypes.Float, sqltypes.Float: sqltypes.Float,
} }
) )
@classmethod @classmethod
def dbapi(cls): def dbapi(cls):
module = __import__('pymssql') module = __import__('pymssql')
# pymmsql doesn't have a Binary method. we use string # pymmsql < 2.1.1 doesn't have a Binary method. we use string
# TODO: monkeypatching here is less than ideal
module.Binary = str
client_ver = tuple(int(x) for x in module.__version__.split(".")) client_ver = tuple(int(x) for x in module.__version__.split("."))
if client_ver < (2, 1, 1):
# TODO: monkeypatching here is less than ideal
module.Binary = lambda x: x if hasattr(x, 'decode') else str(x)
if client_ver < (1, ): if client_ver < (1, ):
util.warn("The pymssql dialect expects at least " util.warn("The pymssql dialect expects at least "
"the 1.0 series of the pymssql DBAPI.") "the 1.0 series of the pymssql DBAPI.")
@ -75,7 +63,8 @@ class MSDialect_pymssql(MSDialect):
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
vers = connection.scalar("select @@version") vers = connection.scalar("select @@version")
m = re.match(r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers) m = re.match(
r"Microsoft .*? - (\d+).(\d+).(\d+).(\d+)", vers)
if m: if m:
return tuple(int(x) for x in m.group(1, 2, 3, 4)) return tuple(int(x) for x in m.group(1, 2, 3, 4))
else: else:
@ -84,14 +73,21 @@ class MSDialect_pymssql(MSDialect):
def create_connect_args(self, url): def create_connect_args(self, url):
opts = url.translate_connect_args(username='user') opts = url.translate_connect_args(username='user')
opts.update(url.query) opts.update(url.query)
opts.pop('port', None) port = opts.pop('port', None)
if port and 'host' in opts:
opts['host'] = "%s:%s" % (opts['host'], port)
return [[], opts] return [[], opts]
def is_disconnect(self, e): def is_disconnect(self, e, connection, cursor):
for msg in ( for msg in (
"Adaptive Server connection timed out",
"Net-Lib error during Connection reset by peer",
"message 20003", # connection timeout
"Error 10054", "Error 10054",
"Not connected to any MS SQL server", "Not connected to any MS SQL server",
"Connection is closed" "Connection is closed",
"message 20006", # Write to the server failed
"message 20017", # Unexpected EOF from the server
): ):
if msg in str(e): if msg in str(e):
return True return True

View File

@ -1,94 +1,130 @@
""" # mssql/pyodbc.py
Support for MS-SQL via pyodbc. # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
pyodbc is available at: r"""
.. dialect:: mssql+pyodbc
:name: PyODBC
:dbapi: pyodbc
:connectstring: mssql+pyodbc://<username>:<password>@<dsnname>
:url: http://pypi.python.org/pypi/pyodbc/
http://pypi.python.org/pypi/pyodbc/ Connecting to PyODBC
--------------------
Connecting The URL here is to be translated to PyODBC connection strings, as
^^^^^^^^^^ detailed in `ConnectionStrings <https://code.google.com/p/pyodbc/wiki/ConnectionStrings>`_.
Examples of pyodbc connection string URLs: DSN Connections
^^^^^^^^^^^^^^^
* ``mssql+pyodbc://mydsn`` - connects using the specified DSN named ``mydsn``. A DSN-based connection is **preferred** overall when using ODBC. A
The connection string that is created will appear like:: basic DSN-based connection looks like::
dsn=mydsn;Trusted_Connection=Yes engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn")
* ``mssql+pyodbc://user:pass@mydsn`` - connects using the DSN named Which above, will pass the following connection string to PyODBC::
``mydsn`` passing in the ``UID`` and ``PWD`` information. The
connection string that is created will appear like::
dsn=mydsn;UID=user;PWD=pass dsn=mydsn;UID=user;PWD=pass
* ``mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english`` - connects If the username and password are omitted, the DSN form will also add
using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD`` the ``Trusted_Connection=yes`` directive to the ODBC string.
information, plus the additional connection configuration option
``LANGUAGE``. The connection string that is created will appear
like::
dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english Hostname Connections
^^^^^^^^^^^^^^^^^^^^
* ``mssql+pyodbc://user:pass@host/db`` - connects using a connection string Hostname-based connections are **not preferred**, however are supported.
dynamically created that would appear like:: The ODBC driver name must be explicitly specified::
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=SQL+Server+Native+Client+10.0")
* ``mssql+pyodbc://user:pass@host:123/db`` - connects using a connection .. versionchanged:: 1.0.0 Hostname-based PyODBC connections now require the
string that is dynamically created, which also includes the port SQL Server driver name specified explicitly. SQLAlchemy cannot
information using the comma syntax. If your connection string choose an optimal default here as it varies based on platform
requires the port information to be passed as a ``port`` keyword and installed drivers.
see the next example. This will create the following connection
string::
DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass Other keywords interpreted by the Pyodbc dialect to be passed to
``pyodbc.connect()`` in both the DSN and hostname cases include:
``odbc_autotranslate``, ``ansi``, ``unicode_results``, ``autocommit``.
* ``mssql+pyodbc://user:pass@host/db?port=123`` - connects using a connection Pass through exact Pyodbc string
string that is dynamically created that includes the port ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
information as a separate ``port`` keyword. This will create the
following connection string::
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123 A PyODBC connection string can also be sent exactly as specified in
`ConnectionStrings <https://code.google.com/p/pyodbc/wiki/ConnectionStrings>`_
into the driver using the parameter ``odbc_connect``. The delimeters must be URL escaped, however,
as illustrated below using ``urllib.quote_plus``::
If you require a connection string that is outside the options import urllib
presented above, use the ``odbc_connect`` keyword to pass in a params = urllib.quote_plus("DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password")
urlencoded connection string. What gets passed in will be urldecoded
and passed directly.
For example:: engine = create_engine("mssql+pyodbc:///?odbc_connect=%s" % params)
mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
would create the following connection string:: Unicode Binds
-------------
dsn=mydsn;Database=db The current state of PyODBC on a unix backend with FreeTDS and/or
EasySoft is poor regarding unicode; different OS platforms and versions of
UnixODBC versus IODBC versus FreeTDS/EasySoft versus PyODBC itself
dramatically alter how strings are received. The PyODBC dialect attempts to
use all the information it knows to determine whether or not a Python unicode
literal can be passed directly to the PyODBC driver or not; while SQLAlchemy
can encode these to bytestrings first, some users have reported that PyODBC
mis-handles bytestrings for certain encodings and requires a Python unicode
object, while the author has observed widespread cases where a Python unicode
is completely misinterpreted by PyODBC, particularly when dealing with
the information schema tables used in table reflection, and the value
must first be encoded to a bytestring.
Encoding your connection string can be easily accomplished through It is for this reason that whether or not unicode literals for bound
the python shell. For example:: parameters be sent to PyODBC can be controlled using the
``supports_unicode_binds`` parameter to ``create_engine()``. When
left at its default of ``None``, the PyODBC dialect will use its
best guess as to whether or not the driver deals with unicode literals
well. When ``False``, unicode literals will be encoded first, and when
``True`` unicode literals will be passed straight through. This is an interim
flag that hopefully should not be needed when the unicode situation stabilizes
for unix + PyODBC.
>>> import urllib .. versionadded:: 0.7.7
>>> urllib.quote_plus('dsn=mydsn;Database=db') ``supports_unicode_binds`` parameter to ``create_engine()``\ .
'dsn%3Dmydsn%3BDatabase%3Ddb'
Rowcount Support
----------------
Pyodbc only has partial support for rowcount. See the notes at
:ref:`mssql_rowcount_versioning` for important notes when using ORM
versioning.
""" """
from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect from .base import MSExecutionContext, MSDialect, VARBINARY
from sqlalchemy.connectors.pyodbc import PyODBCConnector from ...connectors.pyodbc import PyODBCConnector
from sqlalchemy import types as sqltypes, util from ... import types as sqltypes, util, exc
import decimal import decimal
import re
class _ms_numeric_pyodbc(object):
class _MSNumeric_pyodbc(sqltypes.Numeric):
"""Turns Decimals with adjusted() < 0 or > 7 into strings. """Turns Decimals with adjusted() < 0 or > 7 into strings.
This is the only method that is proven to work with Pyodbc+MSSQL The routines here are needed for older pyodbc versions
without crashing (floats can be used but seem to cause sporadic as well as current mxODBC versions.
crashes).
""" """
def bind_processor(self, dialect): def bind_processor(self, dialect):
super_process = super(_MSNumeric_pyodbc, self).bind_processor(dialect)
super_process = super(_ms_numeric_pyodbc, self).\
bind_processor(dialect)
if not dialect._need_decimal_fix:
return super_process
def process(value): def process(value):
if self.asdecimal and \ if self.asdecimal and \
@ -106,36 +142,68 @@ class _MSNumeric_pyodbc(sqltypes.Numeric):
return value return value
return process return process
# these routines needed for older versions of pyodbc.
# as of 2.1.8 this logic is integrated.
def _small_dec_to_string(self, value): def _small_dec_to_string(self, value):
return "%s0.%s%s" % ( return "%s0.%s%s" % (
(value < 0 and '-' or ''), (value < 0 and '-' or ''),
'0' * (abs(value.adjusted()) - 1), '0' * (abs(value.adjusted()) - 1),
"".join([str(nint) for nint in value._int])) "".join([str(nint) for nint in value.as_tuple()[1]]))
def _large_dec_to_string(self, value): def _large_dec_to_string(self, value):
_int = value.as_tuple()[1]
if 'E' in str(value): if 'E' in str(value):
result = "%s%s%s" % ( result = "%s%s%s" % (
(value < 0 and '-' or ''), (value < 0 and '-' or ''),
"".join([str(s) for s in value._int]), "".join([str(s) for s in _int]),
"0" * (value.adjusted() - (len(value._int)-1))) "0" * (value.adjusted() - (len(_int) - 1)))
else: else:
if (len(value._int) - 1) > value.adjusted(): if (len(_int) - 1) > value.adjusted():
result = "%s%s.%s" % ( result = "%s%s.%s" % (
(value < 0 and '-' or ''), (value < 0 and '-' or ''),
"".join([str(s) for s in value._int][0:value.adjusted() + 1]), "".join(
"".join([str(s) for s in value._int][value.adjusted() + 1:])) [str(s) for s in _int][0:value.adjusted() + 1]),
"".join(
[str(s) for s in _int][value.adjusted() + 1:]))
else: else:
result = "%s%s" % ( result = "%s%s" % (
(value < 0 and '-' or ''), (value < 0 and '-' or ''),
"".join([str(s) for s in value._int][0:value.adjusted() + 1])) "".join(
[str(s) for s in _int][0:value.adjusted() + 1]))
return result return result
class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric):
pass
class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float):
pass
class _VARBINARY_pyodbc(VARBINARY):
def bind_processor(self, dialect):
if dialect.dbapi is None:
return None
DBAPIBinary = dialect.dbapi.Binary
def process(value):
if value is not None:
return DBAPIBinary(value)
else:
# pyodbc-specific
return dialect.dbapi.BinaryNull
return process
class MSExecutionContext_pyodbc(MSExecutionContext): class MSExecutionContext_pyodbc(MSExecutionContext):
_embedded_scope_identity = False _embedded_scope_identity = False
def pre_exec(self): def pre_exec(self):
"""where appropriate, issue "select scope_identity()" in the same statement. """where appropriate, issue "select scope_identity()" in the same
statement.
Background on why "scope_identity()" is preferable to "@@identity": Background on why "scope_identity()" is preferable to "@@identity":
http://msdn.microsoft.com/en-us/library/ms190315.aspx http://msdn.microsoft.com/en-us/library/ms190315.aspx
@ -148,7 +216,8 @@ class MSExecutionContext_pyodbc(MSExecutionContext):
super(MSExecutionContext_pyodbc, self).pre_exec() super(MSExecutionContext_pyodbc, self).pre_exec()
# don't embed the scope_identity select into an "INSERT .. DEFAULT VALUES" # don't embed the scope_identity select into an
# "INSERT .. DEFAULT VALUES"
if self._select_lastrowid and \ if self._select_lastrowid and \
self.dialect.use_scope_identity and \ self.dialect.use_scope_identity and \
len(self.parameters[0]): len(self.parameters[0]):
@ -159,14 +228,15 @@ class MSExecutionContext_pyodbc(MSExecutionContext):
def post_exec(self): def post_exec(self):
if self._embedded_scope_identity: if self._embedded_scope_identity:
# Fetch the last inserted id from the manipulated statement # Fetch the last inserted id from the manipulated statement
# We may have to skip over a number of result sets with no data (due to triggers, etc.) # We may have to skip over a number of result sets with
# no data (due to triggers, etc.)
while True: while True:
try: try:
# fetchall() ensures the cursor is consumed # fetchall() ensures the cursor is consumed
# without closing it (FreeTDS particularly) # without closing it (FreeTDS particularly)
row = self.cursor.fetchall()[0] row = self.cursor.fetchall()[0]
break break
except self.dialect.dbapi.Error, e: except self.dialect.dbapi.Error as e:
# no way around this - nextset() consumes the previous set # no way around this - nextset() consumes the previous set
# so we need to just keep flipping # so we need to just keep flipping
self.cursor.nextset() self.cursor.nextset()
@ -180,18 +250,43 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
execution_ctx_cls = MSExecutionContext_pyodbc execution_ctx_cls = MSExecutionContext_pyodbc
pyodbc_driver_name = 'SQL Server'
colspecs = util.update_copy( colspecs = util.update_copy(
MSDialect.colspecs, MSDialect.colspecs,
{ {
sqltypes.Numeric:_MSNumeric_pyodbc sqltypes.Numeric: _MSNumeric_pyodbc,
sqltypes.Float: _MSFloat_pyodbc,
VARBINARY: _VARBINARY_pyodbc,
sqltypes.LargeBinary: _VARBINARY_pyodbc,
} }
) )
def __init__(self, description_encoding='latin-1', **params): def __init__(self, description_encoding=None, **params):
if 'description_encoding' in params:
self.description_encoding = params.pop('description_encoding')
super(MSDialect_pyodbc, self).__init__(**params) super(MSDialect_pyodbc, self).__init__(**params)
self.description_encoding = description_encoding self.use_scope_identity = self.use_scope_identity and \
self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset') self.dbapi and \
hasattr(self.dbapi.Cursor, 'nextset')
self._need_decimal_fix = self.dbapi and \
self._dbapi_version() < (2, 1, 8)
def _get_server_version_info(self, connection):
try:
raw = connection.scalar("SELECT SERVERPROPERTY('ProductVersion')")
except exc.DBAPIError:
# SQL Server docs indicate this function isn't present prior to
# 2008; additionally, unknown combinations of pyodbc aren't
# able to run this query.
return super(MSDialect_pyodbc, self).\
_get_server_version_info(connection)
else:
version = []
r = re.compile(r'[.\-]')
for n in r.split(raw):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
dialect = MSDialect_pyodbc dialect = MSDialect_pyodbc

View File

@ -1,26 +1,26 @@
"""Support for the Microsoft SQL Server database via the zxjdbc JDBC # mssql/zxjdbc.py
connector. # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
JDBC Driver #
----------- # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
Requires the jTDS driver, available from: http://jtds.sourceforge.net/
Connecting
----------
URLs are of the standard form of
``mssql+zxjdbc://user:pass@host:port/dbname[?key=value&key=value...]``.
Additional arguments which may be specified either as query string
arguments on the URL, or as keyword arguments to
:func:`~sqlalchemy.create_engine()` will be passed as Connection
properties to the underlying JDBC driver.
""" """
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector .. dialect:: mssql+zxjdbc
from sqlalchemy.dialects.mssql.base import MSDialect, MSExecutionContext :name: zxJDBC for Jython
from sqlalchemy.engine import base :dbapi: zxjdbc
:connectstring: mssql+zxjdbc://user:pass@host:port/dbname\
[?key=value&key=value...]
:driverurl: http://jtds.sourceforge.net/
.. note:: Jython is not supported by current versions of SQLAlchemy. The
zxjdbc dialect should be considered as experimental.
"""
from ...connectors.zxJDBC import ZxJDBCConnector
from .base import MSDialect, MSExecutionContext
from ... import engine
class MSExecutionContext_zxjdbc(MSExecutionContext): class MSExecutionContext_zxjdbc(MSExecutionContext):
@ -40,15 +40,17 @@ class MSExecutionContext_zxjdbc(MSExecutionContext):
try: try:
row = self.cursor.fetchall()[0] row = self.cursor.fetchall()[0]
break break
except self.dialect.dbapi.Error, e: except self.dialect.dbapi.Error:
self.cursor.nextset() self.cursor.nextset()
self._lastrowid = int(row[0]) self._lastrowid = int(row[0])
if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning: if (self.isinsert or self.isupdate or self.isdelete) and \
self._result_proxy = base.FullyBufferedResultProxy(self) self.compiled.returning:
self._result_proxy = engine.FullyBufferedResultProxy(self)
if self._enable_identity_insert: if self._enable_identity_insert:
table = self.dialect.identifier_preparer.format_table(self.compiled.statement.table) table = self.dialect.identifier_preparer.format_table(
self.compiled.statement.table)
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table) self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table)
@ -59,6 +61,9 @@ class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect):
execution_ctx_cls = MSExecutionContext_zxjdbc execution_ctx_cls = MSExecutionContext_zxjdbc
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
return tuple(int(x) for x in connection.connection.dbversion.split('.')) return tuple(
int(x)
for x in connection.connection.dbversion.split('.')
)
dialect = MSDialect_zxjdbc dialect = MSDialect_zxjdbc

View File

@ -1,17 +1,31 @@
from sqlalchemy.dialects.mysql import base, mysqldb, oursql, pyodbc, zxjdbc, mysqlconnector # mysql/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from . import base, mysqldb, oursql, \
pyodbc, zxjdbc, mysqlconnector, pymysql,\
gaerdbms, cymysql
# default dialect # default dialect
base.dialect = mysqldb.dialect base.dialect = mysqldb.dialect
from sqlalchemy.dialects.mysql.base import \ from .base import \
BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, DOUBLE, ENUM, DECIMAL,\ BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, \
FLOAT, INTEGER, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, MEDIUMTEXT, NCHAR, \ DECIMAL, DOUBLE, ENUM, DECIMAL,\
NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT,\ FLOAT, INTEGER, INTEGER, JSON, LONGBLOB, LONGTEXT, MEDIUMBLOB, \
MEDIUMINT, MEDIUMTEXT, NCHAR, \
NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, \
TINYBLOB, TINYINT, TINYTEXT,\
VARBINARY, VARCHAR, YEAR, dialect VARBINARY, VARCHAR, YEAR, dialect
__all__ = ( __all__ = (
'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE', 'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME',
'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT', 'DECIMAL', 'DOUBLE', 'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER',
'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP', 'JSON', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT', 'MEDIUMTEXT',
'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect' 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME',
'TIMESTAMP', 'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR',
'YEAR', 'dialect'
) )

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,87 @@
# mysql/cymysql.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+cymysql
:name: CyMySQL
:dbapi: cymysql
:connectstring: mysql+cymysql://<username>:<password>@<host>/<dbname>\
[?<options>]
:url: https://github.com/nakagami/CyMySQL
"""
import re
from .mysqldb import MySQLDialect_mysqldb
from .base import (BIT, MySQLDialect)
from ... import util
class _cymysqlBIT(BIT):
def result_processor(self, dialect, coltype):
"""Convert a MySQL's 64 bit, variable length binary string to a long.
"""
def process(value):
if value is not None:
v = 0
for i in util.iterbytes(value):
v = v << 8 | i
return v
return value
return process
class MySQLDialect_cymysql(MySQLDialect_mysqldb):
driver = 'cymysql'
description_encoding = None
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
supports_unicode_statements = True
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
BIT: _cymysqlBIT,
}
)
@classmethod
def dbapi(cls):
return __import__('cymysql')
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile(r'[.\-]')
for n in r.split(dbapi_con.server_version):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.OperationalError):
return self._extract_error_code(e) in \
(2006, 2013, 2014, 2045, 2055)
elif isinstance(e, self.dbapi.InterfaceError):
# if underlying connection is closed,
# this is the error you get
return True
else:
return False
dialect = MySQLDialect_cymysql

View File

@ -0,0 +1,311 @@
# mysql/enumerated.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import re
from .types import _StringType
from ... import exc, sql, util
from ... import types as sqltypes
class _EnumeratedValues(_StringType):
def _init_values(self, values, kw):
self.quoting = kw.pop('quoting', 'auto')
if self.quoting == 'auto' and len(values):
# What quoting character are we using?
q = None
for e in values:
if len(e) == 0:
self.quoting = 'unquoted'
break
elif q is None:
q = e[0]
if len(e) == 1 or e[0] != q or e[-1] != q:
self.quoting = 'unquoted'
break
else:
self.quoting = 'quoted'
if self.quoting == 'quoted':
util.warn_deprecated(
'Manually quoting %s value literals is deprecated. Supply '
'unquoted values and use the quoting= option in cases of '
'ambiguity.' % self.__class__.__name__)
values = self._strip_values(values)
self._enumerated_values = values
length = max([len(v) for v in values] + [0])
return values, length
@classmethod
def _strip_values(cls, values):
strip_values = []
for a in values:
if a[0:1] == '"' or a[0:1] == "'":
# strip enclosing quotes and unquote interior
a = a[1:-1].replace(a[0] * 2, a[0])
strip_values.append(a)
return strip_values
class ENUM(sqltypes.Enum, _EnumeratedValues):
"""MySQL ENUM type."""
__visit_name__ = 'ENUM'
def __init__(self, *enums, **kw):
"""Construct an ENUM.
E.g.::
Column('myenum', ENUM("foo", "bar", "baz"))
:param enums: The range of valid values for this ENUM. Values will be
quoted when generating the schema according to the quoting flag (see
below). This object may also be a PEP-435-compliant enumerated
type.
.. versionadded: 1.1 added support for PEP-435-compliant enumerated
types.
:param strict: This flag has no effect.
.. versionchanged:: The MySQL ENUM type as well as the base Enum
type now validates all Python data values.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
:param quoting: Defaults to 'auto': automatically determine enum value
quoting. If all enum values are surrounded by the same quoting
character, then use 'quoted' mode. Otherwise, use 'unquoted' mode.
'quoted': values in enums are already quoted, they will be used
directly when generating the schema - this usage is deprecated.
'unquoted': values in enums are not quoted, they will be escaped and
surrounded by single quotes when generating the schema.
Previous versions of this type always required manually quoted
values to be supplied; future versions will always quote the string
literals for you. This is a transitional option.
"""
kw.pop('strict', None)
validate_strings = kw.pop("validate_strings", False)
sqltypes.Enum.__init__(
self, validate_strings=validate_strings, *enums)
kw.pop('metadata', None)
kw.pop('schema', None)
kw.pop('name', None)
kw.pop('quote', None)
kw.pop('native_enum', None)
kw.pop('inherit_schema', None)
kw.pop('_create_events', None)
_StringType.__init__(self, length=self.length, **kw)
def _setup_for_values(self, values, objects, kw):
values, length = self._init_values(values, kw)
return sqltypes.Enum._setup_for_values(self, values, objects, kw)
def _object_value_for_elem(self, elem):
# mysql sends back a blank string for any value that
# was persisted that was not in the enums; that is, it does no
# validation on the incoming data, it "truncates" it to be
# the blank string. Return it straight.
if elem == "":
return elem
else:
return super(ENUM, self)._object_value_for_elem(elem)
def __repr__(self):
return util.generic_repr(
self, to_inspect=[ENUM, _StringType, sqltypes.Enum])
def adapt(self, cls, **kw):
return sqltypes.Enum.adapt(self, cls, **kw)
class SET(_EnumeratedValues):
"""MySQL SET type."""
__visit_name__ = 'SET'
def __init__(self, *values, **kw):
"""Construct a SET.
E.g.::
Column('myset', SET("foo", "bar", "baz"))
The list of potential values is required in the case that this
set will be used to generate DDL for a table, or if the
:paramref:`.SET.retrieve_as_bitwise` flag is set to True.
:param values: The range of valid values for this SET.
:param convert_unicode: Same flag as that of
:paramref:`.String.convert_unicode`.
:param collation: same as that of :paramref:`.String.collation`
:param charset: same as that of :paramref:`.VARCHAR.charset`.
:param ascii: same as that of :paramref:`.VARCHAR.ascii`.
:param unicode: same as that of :paramref:`.VARCHAR.unicode`.
:param binary: same as that of :paramref:`.VARCHAR.binary`.
:param quoting: Defaults to 'auto': automatically determine set value
quoting. If all values are surrounded by the same quoting
character, then use 'quoted' mode. Otherwise, use 'unquoted' mode.
'quoted': values in enums are already quoted, they will be used
directly when generating the schema - this usage is deprecated.
'unquoted': values in enums are not quoted, they will be escaped and
surrounded by single quotes when generating the schema.
Previous versions of this type always required manually quoted
values to be supplied; future versions will always quote the string
literals for you. This is a transitional option.
.. versionadded:: 0.9.0
:param retrieve_as_bitwise: if True, the data for the set type will be
persisted and selected using an integer value, where a set is coerced
into a bitwise mask for persistence. MySQL allows this mode which
has the advantage of being able to store values unambiguously,
such as the blank string ``''``. The datatype will appear
as the expression ``col + 0`` in a SELECT statement, so that the
value is coerced into an integer value in result sets.
This flag is required if one wishes
to persist a set that can store the blank string ``''`` as a value.
.. warning::
When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is
essential that the list of set values is expressed in the
**exact same order** as exists on the MySQL database.
.. versionadded:: 1.0.0
"""
self.retrieve_as_bitwise = kw.pop('retrieve_as_bitwise', False)
values, length = self._init_values(values, kw)
self.values = tuple(values)
if not self.retrieve_as_bitwise and '' in values:
raise exc.ArgumentError(
"Can't use the blank value '' in a SET without "
"setting retrieve_as_bitwise=True")
if self.retrieve_as_bitwise:
self._bitmap = dict(
(value, 2 ** idx)
for idx, value in enumerate(self.values)
)
self._bitmap.update(
(2 ** idx, value)
for idx, value in enumerate(self.values)
)
kw.setdefault('length', length)
super(SET, self).__init__(**kw)
def column_expression(self, colexpr):
if self.retrieve_as_bitwise:
return sql.type_coerce(
sql.type_coerce(colexpr, sqltypes.Integer) + 0,
self
)
else:
return colexpr
def result_processor(self, dialect, coltype):
if self.retrieve_as_bitwise:
def process(value):
if value is not None:
value = int(value)
return set(
util.map_bits(self._bitmap.__getitem__, value)
)
else:
return None
else:
super_convert = super(SET, self).result_processor(dialect, coltype)
def process(value):
if isinstance(value, util.string_types):
# MySQLdb returns a string, let's parse
if super_convert:
value = super_convert(value)
return set(re.findall(r'[^,]+', value))
else:
# mysql-connector-python does a naive
# split(",") which throws in an empty string
if value is not None:
value.discard('')
return value
return process
def bind_processor(self, dialect):
super_convert = super(SET, self).bind_processor(dialect)
if self.retrieve_as_bitwise:
def process(value):
if value is None:
return None
elif isinstance(value, util.int_types + util.string_types):
if super_convert:
return super_convert(value)
else:
return value
else:
int_value = 0
for v in value:
int_value |= self._bitmap[v]
return int_value
else:
def process(value):
# accept strings and int (actually bitflag) values directly
if value is not None and not isinstance(
value, util.int_types + util.string_types):
value = ",".join(value)
if super_convert:
return super_convert(value)
else:
return value
return process
def adapt(self, impltype, **kw):
kw['retrieve_as_bitwise'] = self.retrieve_as_bitwise
return util.constructor_copy(
self, impltype,
*self.values,
**kw
)

View File

@ -0,0 +1,102 @@
# mysql/gaerdbms.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+gaerdbms
:name: Google Cloud SQL
:dbapi: rdbms
:connectstring: mysql+gaerdbms:///<dbname>?instance=<instancename>
:url: https://developers.google.com/appengine/docs/python/cloud-sql/\
developers-guide
This dialect is based primarily on the :mod:`.mysql.mysqldb` dialect with
minimal changes.
.. versionadded:: 0.7.8
.. deprecated:: 1.0 This dialect is **no longer necessary** for
Google Cloud SQL; the MySQLdb dialect can be used directly.
Cloud SQL now recommends creating connections via the
mysql dialect using the URL format
``mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>``
Pooling
-------
Google App Engine connections appear to be randomly recycled,
so the dialect does not pool connections. The :class:`.NullPool`
implementation is installed within the :class:`.Engine` by
default.
"""
import os
from .mysqldb import MySQLDialect_mysqldb
from ...pool import NullPool
import re
from sqlalchemy.util import warn_deprecated
def _is_dev_environment():
return os.environ.get('SERVER_SOFTWARE', '').startswith('Development/')
class MySQLDialect_gaerdbms(MySQLDialect_mysqldb):
@classmethod
def dbapi(cls):
warn_deprecated(
"Google Cloud SQL now recommends creating connections via the "
"MySQLdb dialect directly, using the URL format "
"mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/"
"<projectid>:<instancename>"
)
# from django:
# http://code.google.com/p/googleappengine/source/
# browse/trunk/python/google/storage/speckle/
# python/django/backend/base.py#118
# see also [ticket:2649]
# see also http://stackoverflow.com/q/14224679/34549
from google.appengine.api import apiproxy_stub_map
if _is_dev_environment():
from google.appengine.api import rdbms_mysqldb
return rdbms_mysqldb
elif apiproxy_stub_map.apiproxy.GetStub('rdbms'):
from google.storage.speckle.python.api import rdbms_apiproxy
return rdbms_apiproxy
else:
from google.storage.speckle.python.api import rdbms_googleapi
return rdbms_googleapi
@classmethod
def get_pool_class(cls, url):
# Cloud SQL connections die at any moment
return NullPool
def create_connect_args(self, url):
opts = url.translate_connect_args()
if not _is_dev_environment():
# 'dsn' and 'instance' are because we are skipping
# the traditional google.api.rdbms wrapper
opts['dsn'] = ''
opts['instance'] = url.query['instance']
return [], opts
def _extract_error_code(self, exception):
match = re.compile(r"^(\d+)L?:|^\((\d+)L?,").match(str(exception))
# The rdbms api will wrap then re-raise some types of errors
# making this regex return no matches.
code = match.group(1) or match.group(2) if match else None
if code:
return int(code)
dialect = MySQLDialect_gaerdbms

View File

@ -0,0 +1,79 @@
# mysql/json.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from __future__ import absolute_import
import json
from ...sql import elements
from ... import types as sqltypes
from ... import util
class JSON(sqltypes.JSON):
"""MySQL JSON type.
MySQL supports JSON as of version 5.7. Note that MariaDB does **not**
support JSON at the time of this writing.
The :class:`.mysql.JSON` type supports persistence of JSON values
as well as the core index operations provided by :class:`.types.JSON`
datatype, by adapting the operations to render the ``JSON_EXTRACT``
function at the database level.
.. versionadded:: 1.1
"""
pass
class _FormatTypeMixin(object):
def _format_value(self, value):
raise NotImplementedError()
def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
def _format_value(self, value):
if isinstance(value, int):
value = "$[%s]" % value
else:
value = '$."%s"' % value
return value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value):
return "$%s" % (
"".join([
"[%s]" % elem if isinstance(elem, int)
else '."%s"' % elem for elem in value
])
)

View File

@ -1,28 +1,34 @@
"""Support for the MySQL database via the MySQL Connector/Python adapter. # mysql/mysqlconnector.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
MySQL Connector/Python is available at: """
.. dialect:: mysql+mysqlconnector
:name: MySQL Connector/Python
:dbapi: myconnpy
:connectstring: mysql+mysqlconnector://<user>:<password>@\
<host>[:<port>]/<dbname>
:url: http://dev.mysql.com/downloads/connector/python/
https://launchpad.net/myconnpy
Connecting Unicode
----------- -------
Connect string format:: Please see :ref:`mysql_unicode` for current recommendations on unicode
handling.
mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
""" """
import re from .base import (MySQLDialect, MySQLExecutionContext,
MySQLCompiler, MySQLIdentifierPreparer,
from sqlalchemy.dialects.mysql.base import (MySQLDialect,
MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer,
BIT) BIT)
from sqlalchemy.engine import base as engine_base, default from ... import util
from sqlalchemy.sql import operators as sql_operators import re
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
from sqlalchemy import processors
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
@ -31,17 +37,36 @@ class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
class MySQLCompiler_mysqlconnector(MySQLCompiler): class MySQLCompiler_mysqlconnector(MySQLCompiler):
def visit_mod(self, binary, **kw): def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left) + " %% " + self.process(binary.right) if self.dialect._mysqlconnector_double_percents:
return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
else:
return self.process(binary.left, **kw) + " % " + \
self.process(binary.right, **kw)
def post_process_text(self, text): def post_process_text(self, text):
if self.dialect._mysqlconnector_double_percents:
return text.replace('%', '%%') return text.replace('%', '%%')
else:
return text
def escape_literal_column(self, text):
if self.dialect._mysqlconnector_double_percents:
return text.replace('%', '%%')
else:
return text
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer): class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
def _escape_identifier(self, value): def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote) value = value.replace(self.escape_quote, self.escape_to_quote)
if self.dialect._mysqlconnector_double_percents:
return value.replace("%", "%%") return value.replace("%", "%%")
else:
return value
class _myconnpyBIT(BIT): class _myconnpyBIT(BIT):
def result_processor(self, dialect, coltype): def result_processor(self, dialect, coltype):
@ -49,10 +74,12 @@ class _myconnpyBIT(BIT):
return None return None
class MySQLDialect_mysqlconnector(MySQLDialect): class MySQLDialect_mysqlconnector(MySQLDialect):
driver = 'mysqlconnector' driver = 'mysqlconnector'
supports_unicode_statements = True
supports_unicode_binds = True supports_unicode_binds = True
supports_sane_rowcount = True supports_sane_rowcount = True
supports_sane_multi_rowcount = True supports_sane_multi_rowcount = True
@ -71,6 +98,10 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
} }
) )
@util.memoized_property
def supports_unicode_statements(self):
return util.py3k or self._mysqlconnector_version_info > (2, 0)
@classmethod @classmethod
def dbapi(cls): def dbapi(cls):
from mysql import connector from mysql import connector
@ -78,48 +109,75 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
def create_connect_args(self, url): def create_connect_args(self, url):
opts = url.translate_connect_args(username='user') opts = url.translate_connect_args(username='user')
opts.update(url.query) opts.update(url.query)
util.coerce_kw_type(opts, 'allow_local_infile', bool)
util.coerce_kw_type(opts, 'autocommit', bool)
util.coerce_kw_type(opts, 'buffered', bool) util.coerce_kw_type(opts, 'buffered', bool)
util.coerce_kw_type(opts, 'compress', bool)
util.coerce_kw_type(opts, 'connection_timeout', int)
util.coerce_kw_type(opts, 'connect_timeout', int)
util.coerce_kw_type(opts, 'consume_results', bool)
util.coerce_kw_type(opts, 'force_ipv6', bool)
util.coerce_kw_type(opts, 'get_warnings', bool)
util.coerce_kw_type(opts, 'pool_reset_session', bool)
util.coerce_kw_type(opts, 'pool_size', int)
util.coerce_kw_type(opts, 'raise_on_warnings', bool) util.coerce_kw_type(opts, 'raise_on_warnings', bool)
opts['buffered'] = True util.coerce_kw_type(opts, 'raw', bool)
opts['raise_on_warnings'] = True util.coerce_kw_type(opts, 'ssl_verify_cert', bool)
util.coerce_kw_type(opts, 'use_pure', bool)
util.coerce_kw_type(opts, 'use_unicode', bool)
# unfortunately, MySQL/connector python refuses to release a
# cursor without reading fully, so non-buffered isn't an option
opts.setdefault('buffered', True)
# FOUND_ROWS must be set in ClientFlag to enable # FOUND_ROWS must be set in ClientFlag to enable
# supports_sane_rowcount. # supports_sane_rowcount.
if self.dbapi is not None: if self.dbapi is not None:
try: try:
from mysql.connector.constants import ClientFlag from mysql.connector.constants import ClientFlag
client_flags = opts.get('client_flags', ClientFlag.get_default()) client_flags = opts.get(
'client_flags', ClientFlag.get_default())
client_flags |= ClientFlag.FOUND_ROWS client_flags |= ClientFlag.FOUND_ROWS
opts['client_flags'] = client_flags opts['client_flags'] = client_flags
except: except Exception:
pass pass
return [[], opts] return [[], opts]
@util.memoized_property
def _mysqlconnector_version_info(self):
if self.dbapi and hasattr(self.dbapi, '__version__'):
m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?',
self.dbapi.__version__)
if m:
return tuple(
int(x)
for x in m.group(1, 2, 3)
if x is not None)
@util.memoized_property
def _mysqlconnector_double_percents(self):
return not util.py3k and self._mysqlconnector_version_info < (2, 0)
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
dbapi_con = connection.connection dbapi_con = connection.connection
from mysql.connector.constants import ClientFlag
dbapi_con.set_client_flag(ClientFlag.FOUND_ROWS)
version = dbapi_con.get_server_version() version = dbapi_con.get_server_version()
return tuple(version) return tuple(version)
def _detect_charset(self, connection): def _detect_charset(self, connection):
return connection.connection.get_characterset_info() return connection.connection.charset
def _extract_error_code(self, exception): def _extract_error_code(self, exception):
try: return exception.errno
return exception.orig.errno
except AttributeError:
return None
def is_disconnect(self, e): def is_disconnect(self, e, connection, cursor):
errnos = (2006, 2013, 2014, 2045, 2055, 2048) errnos = (2006, 2013, 2014, 2045, 2055, 2048)
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError) exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
if isinstance(e, exceptions): if isinstance(e, exceptions):
return e.errno in errnos return e.errno in errnos or \
"MySQL Connection not available." in str(e)
else: else:
return False return False
@ -129,4 +187,17 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
def _compat_fetchone(self, rp, charset=None): def _compat_fetchone(self, rp, charset=None):
return rp.fetchone() return rp.fetchone()
_isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED',
'READ COMMITTED', 'REPEATABLE READ',
'AUTOCOMMIT'])
def _set_isolation_level(self, connection, level):
if level == 'AUTOCOMMIT':
connection.autocommit = True
else:
connection.autocommit = False
super(MySQLDialect_mysqlconnector, self)._set_isolation_level(
connection, level)
dialect = MySQLDialect_mysqlconnector dialect = MySQLDialect_mysqlconnector

View File

@ -1,57 +1,57 @@
"""Support for the MySQL database via the MySQL-python adapter. # mysql/mysqldb.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
MySQL-Python is available at: # <see AUTHORS file>
#
http://sourceforge.net/projects/mysql-python # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
At least version 1.2.1 or 1.2.2 should be used.
Connecting
-----------
Connect string format::
mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
Character Sets
--------------
Many MySQL server installations default to a ``latin1`` encoding for client
connections. All data sent through the connection will be converted into
``latin1``, even if you have ``utf8`` or another character set on your tables
and columns. With versions 4.1 and higher, you can change the connection
character set either through server configuration or by including the
``charset`` parameter in the URL used for ``create_engine``. The ``charset``
option is passed through to MySQL-Python and has the side-effect of also
enabling ``use_unicode`` in the driver by default. For regular encoded
strings, also pass ``use_unicode=0`` in the connection arguments::
# set client encoding to utf8; all strings come back as unicode
create_engine('mysql+mysqldb:///mydb?charset=utf8')
# set client encoding to utf8; all strings come back as utf8 str
create_engine('mysql+mysqldb:///mydb?charset=utf8&use_unicode=0')
Known Issues
-------------
MySQL-python at least as of version 1.2.2 has a serious memory leak related
to unicode conversion, a feature which is disabled via ``use_unicode=0``.
The recommended connection form with SQLAlchemy is::
engine = create_engine('mysql://scott:tiger@localhost/test?charset=utf8&use_unicode=0', pool_recycle=3600)
""" """
.. dialect:: mysql+mysqldb
:name: MySQL-Python
:dbapi: mysqldb
:connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
:url: http://sourceforge.net/projects/mysql-python
.. _mysqldb_unicode:
Unicode
-------
Please see :ref:`mysql_unicode` for current recommendations on unicode
handling.
Py3K Support
------------
Currently, MySQLdb only runs on Python 2 and development has been stopped.
`mysqlclient`_ is fork of MySQLdb and provides Python 3 support as well
as some bugfixes.
.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python
Using MySQLdb with Google Cloud SQL
-----------------------------------
Google Cloud SQL now recommends use of the MySQLdb dialect. Connect
using a URL like the following::
mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
Server Side Cursors
-------------------
The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
"""
from .base import (MySQLDialect, MySQLExecutionContext,
MySQLCompiler, MySQLIdentifierPreparer)
from .base import TEXT
from ... import sql
from ... import util
import re import re
from sqlalchemy.dialects.mysql.base import (MySQLDialect, MySQLExecutionContext,
MySQLCompiler, MySQLIdentifierPreparer)
from sqlalchemy.engine import base as engine_base, default
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
from sqlalchemy import processors
class MySQLExecutionContext_mysqldb(MySQLExecutionContext): class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
@ -64,8 +64,9 @@ class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
class MySQLCompiler_mysqldb(MySQLCompiler): class MySQLCompiler_mysqldb(MySQLCompiler):
def visit_mod(self, binary, **kw): def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left) + " %% " + self.process(binary.right) return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
def post_process_text(self, text): def post_process_text(self, text):
return text.replace('%', '%%') return text.replace('%', '%%')
@ -77,9 +78,10 @@ class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer):
value = value.replace(self.escape_quote, self.escape_to_quote) value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace("%", "%%") return value.replace("%", "%%")
class MySQLDialect_mysqldb(MySQLDialect): class MySQLDialect_mysqldb(MySQLDialect):
driver = 'mysqldb' driver = 'mysqldb'
supports_unicode_statements = False supports_unicode_statements = True
supports_sane_rowcount = True supports_sane_rowcount = True
supports_sane_multi_rowcount = True supports_sane_multi_rowcount = True
@ -90,11 +92,18 @@ class MySQLDialect_mysqldb(MySQLDialect):
statement_compiler = MySQLCompiler_mysqldb statement_compiler = MySQLCompiler_mysqldb
preparer = MySQLIdentifierPreparer_mysqldb preparer = MySQLIdentifierPreparer_mysqldb
colspecs = util.update_copy( def __init__(self, server_side_cursors=False, **kwargs):
MySQLDialect.colspecs, super(MySQLDialect_mysqldb, self).__init__(**kwargs)
{ self.server_side_cursors = server_side_cursors
}
) @util.langhelpers.memoized_property
def supports_server_side_cursors(self):
try:
cursors = __import__('MySQLdb.cursors').cursors
self._sscursor = cursors.SSCursor
return True
except (ImportError, AttributeError):
return False
@classmethod @classmethod
def dbapi(cls): def dbapi(cls):
@ -105,6 +114,30 @@ class MySQLDialect_mysqldb(MySQLDialect):
if context is not None: if context is not None:
context._rowcount = rowcount context._rowcount = rowcount
def _check_unicode_returns(self, connection):
# work around issue fixed in
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
# specific issue w/ the utf8_bin collation and unicode returns
has_utf8_bin = self.server_version_info > (5, ) and \
connection.scalar(
"show collation where %s = 'utf8' and %s = 'utf8_bin'"
% (
self.identifier_preparer.quote("Charset"),
self.identifier_preparer.quote("Collation")
))
if has_utf8_bin:
additional_tests = [
sql.collate(sql.cast(
sql.literal_column(
"'test collated returns'"),
TEXT(charset='utf8')), "utf8_bin")
]
else:
additional_tests = []
return super(MySQLDialect_mysqldb, self)._check_unicode_returns(
connection, additional_tests)
def create_connect_args(self, url): def create_connect_args(self, url):
opts = url.translate_connect_args(database='db', username='user', opts = url.translate_connect_args(database='db', username='user',
password='passwd') password='passwd')
@ -112,11 +145,12 @@ class MySQLDialect_mysqldb(MySQLDialect):
util.coerce_kw_type(opts, 'compress', bool) util.coerce_kw_type(opts, 'compress', bool)
util.coerce_kw_type(opts, 'connect_timeout', int) util.coerce_kw_type(opts, 'connect_timeout', int)
util.coerce_kw_type(opts, 'read_timeout', int)
util.coerce_kw_type(opts, 'client_flag', int) util.coerce_kw_type(opts, 'client_flag', int)
util.coerce_kw_type(opts, 'local_infile', int) util.coerce_kw_type(opts, 'local_infile', int)
# Note: using either of the below will cause all strings to be returned # Note: using either of the below will cause all strings to be
# as Unicode, both in raw SQL operations and with column types like # returned as Unicode, both in raw SQL operations and with column
# String and MSString. # types like String and MSString.
util.coerce_kw_type(opts, 'use_unicode', bool) util.coerce_kw_type(opts, 'use_unicode', bool)
util.coerce_kw_type(opts, 'charset', str) util.coerce_kw_type(opts, 'charset', str)
@ -124,7 +158,8 @@ class MySQLDialect_mysqldb(MySQLDialect):
# query string. # query string.
ssl = {} ssl = {}
for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']: keys = ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']
for key in keys:
if key in opts: if key in opts:
ssl[key[4:]] = opts[key] ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], str) util.coerce_kw_type(ssl, key[4:], str)
@ -137,17 +172,19 @@ class MySQLDialect_mysqldb(MySQLDialect):
client_flag = opts.get('client_flag', 0) client_flag = opts.get('client_flag', 0)
if self.dbapi is not None: if self.dbapi is not None:
try: try:
from MySQLdb.constants import CLIENT as CLIENT_FLAGS CLIENT_FLAGS = __import__(
self.dbapi.__name__ + '.constants.CLIENT'
).constants.CLIENT
client_flag |= CLIENT_FLAGS.FOUND_ROWS client_flag |= CLIENT_FLAGS.FOUND_ROWS
except: except (AttributeError, ImportError):
pass self.supports_sane_rowcount = False
opts['client_flag'] = client_flag opts['client_flag'] = client_flag
return [[], opts] return [[], opts]
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
dbapi_con = connection.connection dbapi_con = connection.connection
version = [] version = []
r = re.compile('[.\-]') r = re.compile(r'[.\-]')
for n in r.split(dbapi_con.get_server_info()): for n in r.split(dbapi_con.get_server_info()):
try: try:
version.append(int(n)) version.append(int(n))
@ -156,47 +193,36 @@ class MySQLDialect_mysqldb(MySQLDialect):
return tuple(version) return tuple(version)
def _extract_error_code(self, exception): def _extract_error_code(self, exception):
try: return exception.args[0]
return exception.orig.args[0]
except AttributeError:
return None
def _detect_charset(self, connection): def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results.""" """Sniff out the character set in use for connection results."""
# Note: MySQL-python 1.2.1c7 seems to ignore changes made
# on a connection via set_character_set()
if self.server_version_info < (4, 1, 0):
try: try:
return connection.connection.character_set_name() # note: the SQL here would be
# "SHOW VARIABLES LIKE 'character_set%%'"
cset_name = connection.connection.character_set_name
except AttributeError: except AttributeError:
# < 1.2.1 final MySQL-python drivers have no charset support.
# a query is needed.
pass
# Prefer 'character_set_results' for the current connection over the
# value in the driver. SET NAMES or individual variable SETs will
# change the charset without updating the driver's view of the world.
#
# If it's decided that issuing that sort of SQL leaves you SOL, then
# this can prefer the driver value.
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
if 'character_set_results' in opts:
return opts['character_set_results']
try:
return connection.connection.character_set_name()
except AttributeError:
# Still no charset on < 1.2.1 final...
if 'character_set' in opts:
return opts['character_set']
else:
util.warn( util.warn(
"Could not detect the connection character set with this " "No 'character_set_name' can be detected with "
"combination of MySQL server and MySQL-python. " "this MySQL-Python version; "
"MySQL-python >= 1.2.2 is recommended. Assuming latin1.") "please upgrade to a recent version of MySQL-Python. "
"Assuming latin1.")
return 'latin1' return 'latin1'
else:
return cset_name()
_isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED',
'READ COMMITTED', 'REPEATABLE READ',
'AUTOCOMMIT'])
def _set_isolation_level(self, connection, level):
if level == 'AUTOCOMMIT':
connection.autocommit(True)
else:
connection.autocommit(False)
super(MySQLDialect_mysqldb, self)._set_isolation_level(connection,
level)
dialect = MySQLDialect_mysqldb dialect = MySQLDialect_mysqldb

View File

@ -1,46 +1,31 @@
"""Support for the MySQL database via the oursql adapter. # mysql/oursql.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
OurSQL is available at: """
http://packages.python.org/oursql/ .. dialect:: mysql+oursql
:name: OurSQL
:dbapi: oursql
:connectstring: mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
:url: http://packages.python.org/oursql/
Connecting Unicode
----------- -------
Connect string format:: Please see :ref:`mysql_unicode` for current recommendations on unicode
handling.
mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
Character Sets
--------------
oursql defaults to using ``utf8`` as the connection charset, but other
encodings may be used instead. Like the MySQL-Python driver, unicode support
can be completely disabled::
# oursql sets the connection charset to utf8 automatically; all strings come
# back as utf8 str
create_engine('mysql+oursql:///mydb?use_unicode=0')
To not automatically use ``utf8`` and instead use whatever the connection
defaults to, there is a separate parameter::
# use the default connection charset; all strings come back as unicode
create_engine('mysql+oursql:///mydb?default_charset=1')
# use latin1 as the connection charset; all strings come back as unicode
create_engine('mysql+oursql:///mydb?charset=latin1')
""" """
import re import re
from sqlalchemy.dialects.mysql.base import (BIT, MySQLDialect, MySQLExecutionContext, from .base import (BIT, MySQLDialect, MySQLExecutionContext)
MySQLCompiler, MySQLIdentifierPreparer) from ... import types as sqltypes, util
from sqlalchemy.engine import base as engine_base, default
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
from sqlalchemy import processors
class _oursqlBIT(BIT): class _oursqlBIT(BIT):
@ -56,14 +41,13 @@ class MySQLExecutionContext_oursql(MySQLExecutionContext):
def plain_query(self): def plain_query(self):
return self.execution_options.get('_oursql_plain_query', False) return self.execution_options.get('_oursql_plain_query', False)
class MySQLDialect_oursql(MySQLDialect): class MySQLDialect_oursql(MySQLDialect):
driver = 'oursql' driver = 'oursql'
# Py3K
# description_encoding = None if util.py2k:
# Py2K
supports_unicode_binds = True supports_unicode_binds = True
supports_unicode_statements = True supports_unicode_statements = True
# end Py2K
supports_native_decimal = True supports_native_decimal = True
@ -84,7 +68,8 @@ class MySQLDialect_oursql(MySQLDialect):
return __import__('oursql') return __import__('oursql')
def do_execute(self, cursor, statement, parameters, context=None): def do_execute(self, cursor, statement, parameters, context=None):
"""Provide an implementation of *cursor.execute(statement, parameters)*.""" """Provide an implementation of
*cursor.execute(statement, parameters)*."""
if context and context.plain_query: if context and context.plain_query:
cursor.execute(statement, plain_query=True) cursor.execute(statement, plain_query=True)
@ -95,13 +80,15 @@ class MySQLDialect_oursql(MySQLDialect):
connection.cursor().execute('BEGIN', plain_query=True) connection.cursor().execute('BEGIN', plain_query=True)
def _xa_query(self, connection, query, xid): def _xa_query(self, connection, query, xid):
# Py2K if util.py2k:
arg = connection.connection._escape_string(xid) arg = connection.connection._escape_string(xid)
# end Py2K else:
# Py3K charset = self._connection_charset
# charset = self._connection_charset arg = connection.connection._escape_string(
# arg = connection.connection._escape_string(xid.encode(charset)).decode(charset) xid.encode(charset)).decode(charset)
connection.execution_options(_oursql_plain_query=True).execute(query % arg) arg = "'%s'" % arg
connection.execution_options(
_oursql_plain_query=True).execute(query % arg)
# Because mysql is bad, these methods have to be # Because mysql is bad, these methods have to be
# reimplemented to use _PlainQuery. Basically, some queries # reimplemented to use _PlainQuery. Basically, some queries
@ -109,70 +96,71 @@ class MySQLDialect_oursql(MySQLDialect):
# the parameterized query API, or refuse to be parameterized # the parameterized query API, or refuse to be parameterized
# in the first place. # in the first place.
def do_begin_twophase(self, connection, xid): def do_begin_twophase(self, connection, xid):
self._xa_query(connection, 'XA BEGIN "%s"', xid) self._xa_query(connection, 'XA BEGIN %s', xid)
def do_prepare_twophase(self, connection, xid): def do_prepare_twophase(self, connection, xid):
self._xa_query(connection, 'XA END "%s"', xid) self._xa_query(connection, 'XA END %s', xid)
self._xa_query(connection, 'XA PREPARE "%s"', xid) self._xa_query(connection, 'XA PREPARE %s', xid)
def do_rollback_twophase(self, connection, xid, is_prepared=True, def do_rollback_twophase(self, connection, xid, is_prepared=True,
recover=False): recover=False):
if not is_prepared: if not is_prepared:
self._xa_query(connection, 'XA END "%s"', xid) self._xa_query(connection, 'XA END %s', xid)
self._xa_query(connection, 'XA ROLLBACK "%s"', xid) self._xa_query(connection, 'XA ROLLBACK %s', xid)
def do_commit_twophase(self, connection, xid, is_prepared=True, def do_commit_twophase(self, connection, xid, is_prepared=True,
recover=False): recover=False):
if not is_prepared: if not is_prepared:
self.do_prepare_twophase(connection, xid) self.do_prepare_twophase(connection, xid)
self._xa_query(connection, 'XA COMMIT "%s"', xid) self._xa_query(connection, 'XA COMMIT %s', xid)
# Q: why didn't we need all these "plain_query" overrides earlier ? # Q: why didn't we need all these "plain_query" overrides earlier ?
# am i on a newer/older version of OurSQL ? # am i on a newer/older version of OurSQL ?
def has_table(self, connection, table_name, schema=None): def has_table(self, connection, table_name, schema=None):
return MySQLDialect.has_table(self, return MySQLDialect.has_table(
connection.connect().\ self,
execution_options(_oursql_plain_query=True), connection.connect().execution_options(_oursql_plain_query=True),
table_name, schema) table_name,
schema
)
def get_table_options(self, connection, table_name, schema=None, **kw): def get_table_options(self, connection, table_name, schema=None, **kw):
return MySQLDialect.get_table_options(self, return MySQLDialect.get_table_options(
connection.connect().\ self,
execution_options(_oursql_plain_query=True), connection.connect().execution_options(_oursql_plain_query=True),
table_name, table_name,
schema=schema, schema=schema,
**kw **kw
) )
def get_columns(self, connection, table_name, schema=None, **kw): def get_columns(self, connection, table_name, schema=None, **kw):
return MySQLDialect.get_columns(self, return MySQLDialect.get_columns(
connection.connect().\ self,
execution_options(_oursql_plain_query=True), connection.connect().execution_options(_oursql_plain_query=True),
table_name, table_name,
schema=schema, schema=schema,
**kw **kw
) )
def get_view_names(self, connection, schema=None, **kw): def get_view_names(self, connection, schema=None, **kw):
return MySQLDialect.get_view_names(self, return MySQLDialect.get_view_names(
connection.connect().\ self,
execution_options(_oursql_plain_query=True), connection.connect().execution_options(_oursql_plain_query=True),
schema=schema, schema=schema,
**kw **kw
) )
def get_table_names(self, connection, schema=None, **kw): def get_table_names(self, connection, schema=None, **kw):
return MySQLDialect.get_table_names(self, return MySQLDialect.get_table_names(
connection.connect().\ self,
execution_options(_oursql_plain_query=True), connection.connect().execution_options(_oursql_plain_query=True),
schema schema
) )
def get_schema_names(self, connection, **kw): def get_schema_names(self, connection, **kw):
return MySQLDialect.get_schema_names(self, return MySQLDialect.get_schema_names(
connection.connect().\ self,
execution_options(_oursql_plain_query=True), connection.connect().execution_options(_oursql_plain_query=True),
**kw **kw
) )
@ -184,14 +172,17 @@ class MySQLDialect_oursql(MySQLDialect):
def _show_create_table(self, connection, table, charset=None, def _show_create_table(self, connection, table, charset=None,
full_name=None): full_name=None):
return MySQLDialect._show_create_table(self, return MySQLDialect._show_create_table(
self,
connection.contextual_connect(close_with_result=True). connection.contextual_connect(close_with_result=True).
execution_options(_oursql_plain_query=True), execution_options(_oursql_plain_query=True),
table, charset, full_name) table, charset, full_name
)
def is_disconnect(self, e): def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.ProgrammingError): if isinstance(e, self.dbapi.ProgrammingError):
return e.errno is None and 'cursor' not in e.args[1] and e.args[1].endswith('closed') return e.errno is None and 'cursor' not in e.args[1] \
and e.args[1].endswith('closed')
else: else:
return e.errno in (2006, 2013, 2014, 2045, 2055) return e.errno in (2006, 2013, 2014, 2045, 2055)
@ -203,6 +194,7 @@ class MySQLDialect_oursql(MySQLDialect):
util.coerce_kw_type(opts, 'port', int) util.coerce_kw_type(opts, 'port', int)
util.coerce_kw_type(opts, 'compress', bool) util.coerce_kw_type(opts, 'compress', bool)
util.coerce_kw_type(opts, 'autoping', bool) util.coerce_kw_type(opts, 'autoping', bool)
util.coerce_kw_type(opts, 'raise_on_warnings', bool)
util.coerce_kw_type(opts, 'default_charset', bool) util.coerce_kw_type(opts, 'default_charset', bool)
if opts.pop('default_charset', False): if opts.pop('default_charset', False):
@ -216,12 +208,22 @@ class MySQLDialect_oursql(MySQLDialect):
# supports_sane_rowcount. # supports_sane_rowcount.
opts.setdefault('found_rows', True) opts.setdefault('found_rows', True)
ssl = {}
for key in ['ssl_ca', 'ssl_key', 'ssl_cert',
'ssl_capath', 'ssl_cipher']:
if key in opts:
ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], str)
del opts[key]
if ssl:
opts['ssl'] = ssl
return [[], opts] return [[], opts]
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
dbapi_con = connection.connection dbapi_con = connection.connection
version = [] version = []
r = re.compile('[.\-]') r = re.compile(r'[.\-]')
for n in r.split(dbapi_con.server_info): for n in r.split(dbapi_con.server_info):
try: try:
version.append(int(n)) version.append(int(n))
@ -230,10 +232,7 @@ class MySQLDialect_oursql(MySQLDialect):
return tuple(version) return tuple(version)
def _extract_error_code(self, exception): def _extract_error_code(self, exception):
try: return exception.errno
return exception.orig.errno
except AttributeError:
return None
def _detect_charset(self, connection): def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results.""" """Sniff out the character set in use for connection results."""

View File

@ -0,0 +1,70 @@
# mysql/pymysql.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+pymysql
:name: PyMySQL
:dbapi: pymysql
:connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>\
[?<options>]
:url: http://www.pymysql.org/
Unicode
-------
Please see :ref:`mysql_unicode` for current recommendations on unicode
handling.
MySQL-Python Compatibility
--------------------------
The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver,
and targets 100% compatibility. Most behavioral notes for MySQL-python apply
to the pymysql driver as well.
"""
from .mysqldb import MySQLDialect_mysqldb
from ...util import langhelpers, py3k
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
driver = 'pymysql'
description_encoding = None
# generally, these two values should be both True
# or both False. PyMySQL unicode tests pass all the way back
# to 0.4 either way. See [ticket:3337]
supports_unicode_statements = True
supports_unicode_binds = True
def __init__(self, server_side_cursors=False, **kwargs):
super(MySQLDialect_pymysql, self).__init__(**kwargs)
self.server_side_cursors = server_side_cursors
@langhelpers.memoized_property
def supports_server_side_cursors(self):
try:
cursors = __import__('pymysql.cursors').cursors
self._sscursor = cursors.SSCursor
return True
except (ImportError, AttributeError):
return False
@classmethod
def dbapi(cls):
return __import__('pymysql')
if py3k:
def _extract_error_code(self, exception):
if isinstance(exception.args[0], Exception):
exception = exception.args[0]
return exception.args[0]
dialect = MySQLDialect_pymysql

View File

@ -1,32 +1,33 @@
"""Support for the MySQL database via the pyodbc adapter. # mysql/pyodbc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
pyodbc is available at: # <see AUTHORS file>
#
http://pypi.python.org/pypi/pyodbc/ # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
Connecting
----------
Connect string::
mysql+pyodbc://<username>:<password>@<dsnname>
Limitations
-----------
The mysql-pyodbc dialect is subject to unresolved character encoding issues
which exist within the current ODBC drivers available.
(see http://code.google.com/p/pyodbc/issues/detail?id=25). Consider usage
of OurSQL, MySQLdb, or MySQL-connector/Python.
""" """
from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
from sqlalchemy.connectors.pyodbc import PyODBCConnector .. dialect:: mysql+pyodbc
from sqlalchemy.engine import base as engine_base :name: PyODBC
from sqlalchemy import util :dbapi: pyodbc
:connectstring: mysql+pyodbc://<username>:<password>@<dsnname>
:url: http://pypi.python.org/pypi/pyodbc/
.. note:: The PyODBC for MySQL dialect is not well supported, and
is subject to unresolved character encoding issues
which exist within the current ODBC drivers available.
(see http://code.google.com/p/pyodbc/issues/detail?id=25).
Other dialects for MySQL are recommended.
"""
from .base import MySQLDialect, MySQLExecutionContext
from ...connectors.pyodbc import PyODBCConnector
from ... import util
import re import re
class MySQLExecutionContext_pyodbc(MySQLExecutionContext): class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
def get_lastrowid(self): def get_lastrowid(self):
@ -36,6 +37,7 @@ class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
cursor.close() cursor.close()
return lastrowid return lastrowid
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
supports_unicode_statements = False supports_unicode_statements = False
execution_ctx_cls = MySQLExecutionContext_pyodbc execution_ctx_cls = MySQLExecutionContext_pyodbc
@ -62,11 +64,12 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
if opts.get(key, None): if opts.get(key, None):
return opts[key] return opts[key]
util.warn("Could not detect the connection character set. Assuming latin1.") util.warn("Could not detect the connection character set. "
"Assuming latin1.")
return 'latin1' return 'latin1'
def _extract_error_code(self, exception): def _extract_error_code(self, exception):
m = re.compile(r"\((\d+)\)").search(str(exception.orig.args)) m = re.compile(r"\((\d+)\)").search(str(exception.args))
c = m.group(1) c = m.group(1)
if c: if c:
return int(c) return int(c)

View File

@ -0,0 +1,450 @@
# mysql/reflection.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import re
from ... import log, util
from ... import types as sqltypes
from .enumerated import _EnumeratedValues, SET
from .types import DATETIME, TIME, TIMESTAMP
class ReflectedState(object):
"""Stores raw information about a SHOW CREATE TABLE statement."""
def __init__(self):
self.columns = []
self.table_options = {}
self.table_name = None
self.keys = []
self.constraints = []
@log.class_logger
class MySQLTableDefinitionParser(object):
"""Parses the results of a SHOW CREATE TABLE statement."""
def __init__(self, dialect, preparer):
self.dialect = dialect
self.preparer = preparer
self._prep_regexes()
def parse(self, show_create, charset):
state = ReflectedState()
state.charset = charset
for line in re.split(r'\r?\n', show_create):
if line.startswith(' ' + self.preparer.initial_quote):
self._parse_column(line, state)
# a regular table options line
elif line.startswith(') '):
self._parse_table_options(line, state)
# an ANSI-mode table options line
elif line == ')':
pass
elif line.startswith('CREATE '):
self._parse_table_name(line, state)
# Not present in real reflection, but may be if
# loading from a file.
elif not line:
pass
else:
type_, spec = self._parse_constraints(line)
if type_ is None:
util.warn("Unknown schema content: %r" % line)
elif type_ == 'key':
state.keys.append(spec)
elif type_ == 'constraint':
state.constraints.append(spec)
else:
pass
return state
def _parse_constraints(self, line):
"""Parse a KEY or CONSTRAINT line.
:param line: A line of SHOW CREATE TABLE output
"""
# KEY
m = self._re_key.match(line)
if m:
spec = m.groupdict()
# convert columns into name, length pairs
spec['columns'] = self._parse_keyexprs(spec['columns'])
return 'key', spec
# CONSTRAINT
m = self._re_constraint.match(line)
if m:
spec = m.groupdict()
spec['table'] = \
self.preparer.unformat_identifiers(spec['table'])
spec['local'] = [c[0]
for c in self._parse_keyexprs(spec['local'])]
spec['foreign'] = [c[0]
for c in self._parse_keyexprs(spec['foreign'])]
return 'constraint', spec
# PARTITION and SUBPARTITION
m = self._re_partition.match(line)
if m:
# Punt!
return 'partition', line
# No match.
return (None, line)
def _parse_table_name(self, line, state):
"""Extract the table name.
:param line: The first line of SHOW CREATE TABLE
"""
regex, cleanup = self._pr_name
m = regex.match(line)
if m:
state.table_name = cleanup(m.group('name'))
def _parse_table_options(self, line, state):
"""Build a dictionary of all reflected table-level options.
:param line: The final line of SHOW CREATE TABLE output.
"""
options = {}
if not line or line == ')':
pass
else:
rest_of_line = line[:]
for regex, cleanup in self._pr_options:
m = regex.search(rest_of_line)
if not m:
continue
directive, value = m.group('directive'), m.group('val')
if cleanup:
value = cleanup(value)
options[directive.lower()] = value
rest_of_line = regex.sub('', rest_of_line)
for nope in ('auto_increment', 'data directory', 'index directory'):
options.pop(nope, None)
for opt, val in options.items():
state.table_options['%s_%s' % (self.dialect.name, opt)] = val
def _parse_column(self, line, state):
"""Extract column details.
Falls back to a 'minimal support' variant if full parse fails.
:param line: Any column-bearing line from SHOW CREATE TABLE
"""
spec = None
m = self._re_column.match(line)
if m:
spec = m.groupdict()
spec['full'] = True
else:
m = self._re_column_loose.match(line)
if m:
spec = m.groupdict()
spec['full'] = False
if not spec:
util.warn("Unknown column definition %r" % line)
return
if not spec['full']:
util.warn("Incomplete reflection of column definition %r" % line)
name, type_, args = spec['name'], spec['coltype'], spec['arg']
try:
col_type = self.dialect.ischema_names[type_]
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" %
(type_, name))
col_type = sqltypes.NullType
# Column type positional arguments eg. varchar(32)
if args is None or args == '':
type_args = []
elif args[0] == "'" and args[-1] == "'":
type_args = self._re_csv_str.findall(args)
else:
type_args = [int(v) for v in self._re_csv_int.findall(args)]
# Column type keyword options
type_kw = {}
if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)):
if type_args:
type_kw['fsp'] = type_args.pop(0)
for kw in ('unsigned', 'zerofill'):
if spec.get(kw, False):
type_kw[kw] = True
for kw in ('charset', 'collate'):
if spec.get(kw, False):
type_kw[kw] = spec[kw]
if issubclass(col_type, _EnumeratedValues):
type_args = _EnumeratedValues._strip_values(type_args)
if issubclass(col_type, SET) and '' in type_args:
type_kw['retrieve_as_bitwise'] = True
type_instance = col_type(*type_args, **type_kw)
col_kw = {}
# NOT NULL
col_kw['nullable'] = True
# this can be "NULL" in the case of TIMESTAMP
if spec.get('notnull', False) == 'NOT NULL':
col_kw['nullable'] = False
# AUTO_INCREMENT
if spec.get('autoincr', False):
col_kw['autoincrement'] = True
elif issubclass(col_type, sqltypes.Integer):
col_kw['autoincrement'] = False
# DEFAULT
default = spec.get('default', None)
if default == 'NULL':
# eliminates the need to deal with this later.
default = None
col_d = dict(name=name, type=type_instance, default=default)
col_d.update(col_kw)
state.columns.append(col_d)
def _describe_to_create(self, table_name, columns):
"""Re-format DESCRIBE output as a SHOW CREATE TABLE string.
DESCRIBE is a much simpler reflection and is sufficient for
reflecting views for runtime use. This method formats DDL
for columns only- keys are omitted.
:param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
SHOW FULL COLUMNS FROM rows must be rearranged for use with
this function.
"""
buffer = []
for row in columns:
(name, col_type, nullable, default, extra) = \
[row[i] for i in (0, 1, 2, 4, 5)]
line = [' ']
line.append(self.preparer.quote_identifier(name))
line.append(col_type)
if not nullable:
line.append('NOT NULL')
if default:
if 'auto_increment' in default:
pass
elif (col_type.startswith('timestamp') and
default.startswith('C')):
line.append('DEFAULT')
line.append(default)
elif default == 'NULL':
line.append('DEFAULT')
line.append(default)
else:
line.append('DEFAULT')
line.append("'%s'" % default.replace("'", "''"))
if extra:
line.append(extra)
buffer.append(' '.join(line))
return ''.join([('CREATE TABLE %s (\n' %
self.preparer.quote_identifier(table_name)),
',\n'.join(buffer),
'\n) '])
def _parse_keyexprs(self, identifiers):
"""Unpack '"col"(2),"col" ASC'-ish strings into components."""
return self._re_keyexprs.findall(identifiers)
def _prep_regexes(self):
"""Pre-compile regular expressions."""
self._re_columns = []
self._pr_options = []
_final = self.preparer.final_quote
quotes = dict(zip(('iq', 'fq', 'esc_fq'),
[re.escape(s) for s in
(self.preparer.initial_quote,
_final,
self.preparer._escape_identifier(_final))]))
self._pr_name = _pr_compile(
r'^CREATE (?:\w+ +)?TABLE +'
r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($' % quotes,
self.preparer._unescape_identifier)
# `col`,`col2`(32),`col3`(15) DESC
#
# Note: ASC and DESC aren't reflected, so we'll punt...
self._re_keyexprs = _re_compile(
r'(?:'
r'(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)'
r'(?:\((\d+)\))?(?=\,|$))+' % quotes)
# 'foo' or 'foo','bar' or 'fo,o','ba''a''r'
self._re_csv_str = _re_compile(r'\x27(?:\x27\x27|[^\x27])*\x27')
# 123 or 123,456
self._re_csv_int = _re_compile(r'\d+')
# `colname` <type> [type opts]
# (NOT NULL | NULL)
# DEFAULT ('value' | CURRENT_TIMESTAMP...)
# COMMENT 'comment'
# COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT)
# STORAGE (DISK|MEMORY)
self._re_column = _re_compile(
r' '
r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +'
r'(?P<coltype>\w+)'
r'(?:\((?P<arg>(?:\d+|\d+,\d+|'
r'(?:\x27(?:\x27\x27|[^\x27])*\x27,?)+))\))?'
r'(?: +(?P<unsigned>UNSIGNED))?'
r'(?: +(?P<zerofill>ZEROFILL))?'
r'(?: +CHARACTER SET +(?P<charset>[\w_]+))?'
r'(?: +COLLATE +(?P<collate>[\w_]+))?'
r'(?: +(?P<notnull>(?:NOT )?NULL))?'
r'(?: +DEFAULT +(?P<default>'
r'(?:NULL|\x27(?:\x27\x27|[^\x27])*\x27|\w+'
r'(?: +ON UPDATE \w+)?)'
r'))?'
r'(?: +(?P<autoincr>AUTO_INCREMENT))?'
r'(?: +COMMENT +(P<comment>(?:\x27\x27|[^\x27])+))?'
r'(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?'
r'(?: +STORAGE +(?P<storage>\w+))?'
r'(?: +(?P<extra>.*))?'
r',?$'
% quotes
)
# Fallback, try to parse as little as possible
self._re_column_loose = _re_compile(
r' '
r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +'
r'(?P<coltype>\w+)'
r'(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?'
r'.*?(?P<notnull>(?:NOT )NULL)?'
% quotes
)
# (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))?
# (`col` (ASC|DESC)?, `col` (ASC|DESC)?)
# KEY_BLOCK_SIZE size | WITH PARSER name
self._re_key = _re_compile(
r' '
r'(?:(?P<type>\S+) )?KEY'
r'(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?'
r'(?: +USING +(?P<using_pre>\S+))?'
r' +\((?P<columns>.+?)\)'
r'(?: +USING +(?P<using_post>\S+))?'
r'(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?'
r'(?: +WITH PARSER +(?P<parser>\S+))?'
r'(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?'
r',?$'
% quotes
)
# CONSTRAINT `name` FOREIGN KEY (`local_col`)
# REFERENCES `remote` (`remote_col`)
# MATCH FULL | MATCH PARTIAL | MATCH SIMPLE
# ON DELETE CASCADE ON UPDATE RESTRICT
#
# unique constraints come back as KEYs
kw = quotes.copy()
kw['on'] = 'RESTRICT|CASCADE|SET NULL|NOACTION'
self._re_constraint = _re_compile(
r' '
r'CONSTRAINT +'
r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +'
r'FOREIGN KEY +'
r'\((?P<local>[^\)]+?)\) REFERENCES +'
r'(?P<table>%(iq)s[^%(fq)s]+%(fq)s'
r'(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +'
r'\((?P<foreign>[^\)]+?)\)'
r'(?: +(?P<match>MATCH \w+))?'
r'(?: +ON DELETE (?P<ondelete>%(on)s))?'
r'(?: +ON UPDATE (?P<onupdate>%(on)s))?'
% kw
)
# PARTITION
#
# punt!
self._re_partition = _re_compile(r'(?:.*)(?:SUB)?PARTITION(?:.*)')
# Table-level options (COLLATE, ENGINE, etc.)
# Do the string options first, since they have quoted
# strings we need to get rid of.
for option in _options_of_type_string:
self._add_option_string(option)
for option in ('ENGINE', 'TYPE', 'AUTO_INCREMENT',
'AVG_ROW_LENGTH', 'CHARACTER SET',
'DEFAULT CHARSET', 'CHECKSUM',
'COLLATE', 'DELAY_KEY_WRITE', 'INSERT_METHOD',
'MAX_ROWS', 'MIN_ROWS', 'PACK_KEYS', 'ROW_FORMAT',
'KEY_BLOCK_SIZE'):
self._add_option_word(option)
self._add_option_regex('UNION', r'\([^\)]+\)')
self._add_option_regex('TABLESPACE', r'.*? STORAGE DISK')
self._add_option_regex(
'RAID_TYPE',
r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+')
_optional_equals = r'(?:\s*(?:=\s*)|\s+)'
def _add_option_string(self, directive):
regex = (r'(?P<directive>%s)%s'
r"'(?P<val>(?:[^']|'')*?)'(?!')" %
(re.escape(directive), self._optional_equals))
self._pr_options.append(_pr_compile(
regex, lambda v: v.replace("\\\\", "\\").replace("''", "'")
))
def _add_option_word(self, directive):
regex = (r'(?P<directive>%s)%s'
r'(?P<val>\w+)' %
(re.escape(directive), self._optional_equals))
self._pr_options.append(_pr_compile(regex))
def _add_option_regex(self, directive, regex):
regex = (r'(?P<directive>%s)%s'
r'(?P<val>%s)' %
(re.escape(directive), self._optional_equals, regex))
self._pr_options.append(_pr_compile(regex))
_options_of_type_string = ('COMMENT', 'DATA DIRECTORY', 'INDEX DIRECTORY',
'PASSWORD', 'CONNECTION')
def _pr_compile(regex, cleanup=None):
"""Prepare a 2-tuple of compiled regex and callable."""
return (_re_compile(regex), cleanup)
def _re_compile(regex):
"""Compile a string to regex, I and UNICODE."""
return re.compile(regex, re.I | re.UNICODE)

View File

@ -0,0 +1,766 @@
# mysql/types.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import datetime
from ... import exc, util
from ... import types as sqltypes
class _NumericType(object):
"""Base for MySQL numeric types.
This is the base both for NUMERIC as well as INTEGER, hence
it's a mixin.
"""
def __init__(self, unsigned=False, zerofill=False, **kw):
self.unsigned = unsigned
self.zerofill = zerofill
super(_NumericType, self).__init__(**kw)
def __repr__(self):
return util.generic_repr(self,
to_inspect=[_NumericType, sqltypes.Numeric])
class _FloatType(_NumericType, sqltypes.Float):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
if isinstance(self, (REAL, DOUBLE)) and \
(
(precision is None and scale is not None) or
(precision is not None and scale is None)
):
raise exc.ArgumentError(
"You must specify both precision and scale or omit "
"both altogether.")
super(_FloatType, self).__init__(
precision=precision, asdecimal=asdecimal, **kw)
self.scale = scale
def __repr__(self):
return util.generic_repr(self, to_inspect=[_FloatType,
_NumericType,
sqltypes.Float])
class _IntegerType(_NumericType, sqltypes.Integer):
def __init__(self, display_width=None, **kw):
self.display_width = display_width
super(_IntegerType, self).__init__(**kw)
def __repr__(self):
return util.generic_repr(self, to_inspect=[_IntegerType,
_NumericType,
sqltypes.Integer])
class _StringType(sqltypes.String):
"""Base for MySQL string types."""
def __init__(self, charset=None, collation=None,
ascii=False, binary=False, unicode=False,
national=False, **kw):
self.charset = charset
# allow collate= or collation=
kw.setdefault('collation', kw.pop('collate', collation))
self.ascii = ascii
self.unicode = unicode
self.binary = binary
self.national = national
super(_StringType, self).__init__(**kw)
def __repr__(self):
return util.generic_repr(self,
to_inspect=[_StringType, sqltypes.String])
class _MatchType(sqltypes.Float, sqltypes.MatchType):
def __init__(self, **kw):
# TODO: float arguments?
sqltypes.Float.__init__(self)
sqltypes.MatchType.__init__(self)
class NUMERIC(_NumericType, sqltypes.NUMERIC):
"""MySQL NUMERIC type."""
__visit_name__ = 'NUMERIC'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a NUMERIC.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(NUMERIC, self).__init__(precision=precision,
scale=scale, asdecimal=asdecimal, **kw)
class DECIMAL(_NumericType, sqltypes.DECIMAL):
"""MySQL DECIMAL type."""
__visit_name__ = 'DECIMAL'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DECIMAL.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(DECIMAL, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
class DOUBLE(_FloatType):
"""MySQL DOUBLE type."""
__visit_name__ = 'DOUBLE'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DOUBLE.
.. note::
The :class:`.DOUBLE` type by default converts from float
to Decimal, using a truncation that defaults to 10 digits.
Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
to change this scale, or ``asdecimal=False`` to return values
directly as Python floating points.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(DOUBLE, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
class REAL(_FloatType, sqltypes.REAL):
"""MySQL REAL type."""
__visit_name__ = 'REAL'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a REAL.
.. note::
The :class:`.REAL` type by default converts from float
to Decimal, using a truncation that defaults to 10 digits.
Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
to change this scale, or ``asdecimal=False`` to return values
directly as Python floating points.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(REAL, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
class FLOAT(_FloatType, sqltypes.FLOAT):
"""MySQL FLOAT type."""
__visit_name__ = 'FLOAT'
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
"""Construct a FLOAT.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(FLOAT, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
def bind_processor(self, dialect):
return None
class INTEGER(_IntegerType, sqltypes.INTEGER):
"""MySQL INTEGER type."""
__visit_name__ = 'INTEGER'
def __init__(self, display_width=None, **kw):
"""Construct an INTEGER.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(INTEGER, self).__init__(display_width=display_width, **kw)
class BIGINT(_IntegerType, sqltypes.BIGINT):
"""MySQL BIGINTEGER type."""
__visit_name__ = 'BIGINT'
def __init__(self, display_width=None, **kw):
"""Construct a BIGINTEGER.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(BIGINT, self).__init__(display_width=display_width, **kw)
class MEDIUMINT(_IntegerType):
"""MySQL MEDIUMINTEGER type."""
__visit_name__ = 'MEDIUMINT'
def __init__(self, display_width=None, **kw):
"""Construct a MEDIUMINTEGER
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(MEDIUMINT, self).__init__(display_width=display_width, **kw)
class TINYINT(_IntegerType):
"""MySQL TINYINT type."""
__visit_name__ = 'TINYINT'
def __init__(self, display_width=None, **kw):
"""Construct a TINYINT.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(TINYINT, self).__init__(display_width=display_width, **kw)
class SMALLINT(_IntegerType, sqltypes.SMALLINT):
"""MySQL SMALLINTEGER type."""
__visit_name__ = 'SMALLINT'
def __init__(self, display_width=None, **kw):
"""Construct a SMALLINTEGER.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(SMALLINT, self).__init__(display_width=display_width, **kw)
class BIT(sqltypes.TypeEngine):
"""MySQL BIT type.
This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
for MyISAM, MEMORY, InnoDB and BDB. For older versions, use a
MSTinyInteger() type.
"""
__visit_name__ = 'BIT'
def __init__(self, length=None):
"""Construct a BIT.
:param length: Optional, number of bits.
"""
self.length = length
def result_processor(self, dialect, coltype):
"""Convert a MySQL's 64 bit, variable length binary string to a long.
TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector
already do this, so this logic should be moved to those dialects.
"""
def process(value):
if value is not None:
v = 0
for i in value:
if not isinstance(i, int):
i = ord(i) # convert byte to int on Python 2
v = v << 8 | i
return v
return value
return process
class TIME(sqltypes.TIME):
"""MySQL TIME type. """
__visit_name__ = 'TIME'
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIME type.
:param timezone: not used by the MySQL dialect.
:param fsp: fractional seconds precision value.
MySQL 5.6 supports storage of fractional seconds;
this parameter will be used when emitting DDL
for the TIME type.
.. note::
DBAPI driver support for fractional seconds may
be limited; current support includes
MySQL Connector/Python.
.. versionadded:: 0.8 The MySQL-specific TIME
type as well as fractional seconds support.
"""
super(TIME, self).__init__(timezone=timezone)
self.fsp = fsp
def result_processor(self, dialect, coltype):
time = datetime.time
def process(value):
# convert from a timedelta value
if value is not None:
microseconds = value.microseconds
seconds = value.seconds
minutes = seconds // 60
return time(minutes // 60,
minutes % 60,
seconds - minutes * 60,
microsecond=microseconds)
else:
return None
return process
class TIMESTAMP(sqltypes.TIMESTAMP):
"""MySQL TIMESTAMP type.
"""
__visit_name__ = 'TIMESTAMP'
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIMESTAMP type.
:param timezone: not used by the MySQL dialect.
:param fsp: fractional seconds precision value.
MySQL 5.6.4 supports storage of fractional seconds;
this parameter will be used when emitting DDL
for the TIMESTAMP type.
.. note::
DBAPI driver support for fractional seconds may
be limited; current support includes
MySQL Connector/Python.
.. versionadded:: 0.8.5 Added MySQL-specific :class:`.mysql.TIMESTAMP`
with fractional seconds support.
"""
super(TIMESTAMP, self).__init__(timezone=timezone)
self.fsp = fsp
class DATETIME(sqltypes.DATETIME):
"""MySQL DATETIME type.
"""
__visit_name__ = 'DATETIME'
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL DATETIME type.
:param timezone: not used by the MySQL dialect.
:param fsp: fractional seconds precision value.
MySQL 5.6.4 supports storage of fractional seconds;
this parameter will be used when emitting DDL
for the DATETIME type.
.. note::
DBAPI driver support for fractional seconds may
be limited; current support includes
MySQL Connector/Python.
.. versionadded:: 0.8.5 Added MySQL-specific :class:`.mysql.DATETIME`
with fractional seconds support.
"""
super(DATETIME, self).__init__(timezone=timezone)
self.fsp = fsp
class YEAR(sqltypes.TypeEngine):
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
__visit_name__ = 'YEAR'
def __init__(self, display_width=None):
self.display_width = display_width
class TEXT(_StringType, sqltypes.TEXT):
"""MySQL TEXT type, for text up to 2^16 characters."""
__visit_name__ = 'TEXT'
def __init__(self, length=None, **kw):
"""Construct a TEXT.
:param length: Optional, if provided the server may optimize storage
by substituting the smallest TEXT type sufficient to store
``length`` characters.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(TEXT, self).__init__(length=length, **kw)
class TINYTEXT(_StringType):
"""MySQL TINYTEXT type, for text up to 2^8 characters."""
__visit_name__ = 'TINYTEXT'
def __init__(self, **kwargs):
"""Construct a TINYTEXT.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(TINYTEXT, self).__init__(**kwargs)
class MEDIUMTEXT(_StringType):
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
__visit_name__ = 'MEDIUMTEXT'
def __init__(self, **kwargs):
"""Construct a MEDIUMTEXT.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(MEDIUMTEXT, self).__init__(**kwargs)
class LONGTEXT(_StringType):
"""MySQL LONGTEXT type, for text up to 2^32 characters."""
__visit_name__ = 'LONGTEXT'
def __init__(self, **kwargs):
"""Construct a LONGTEXT.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(LONGTEXT, self).__init__(**kwargs)
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""MySQL VARCHAR type, for variable-length character data."""
__visit_name__ = 'VARCHAR'
def __init__(self, length=None, **kwargs):
"""Construct a VARCHAR.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(VARCHAR, self).__init__(length=length, **kwargs)
class CHAR(_StringType, sqltypes.CHAR):
"""MySQL CHAR type, for fixed-length character data."""
__visit_name__ = 'CHAR'
def __init__(self, length=None, **kwargs):
"""Construct a CHAR.
:param length: Maximum data length, in characters.
:param binary: Optional, use the default binary collation for the
national character set. This does not affect the type of data
stored, use a BINARY type for binary data.
:param collation: Optional, request a particular collation. Must be
compatible with the national character set.
"""
super(CHAR, self).__init__(length=length, **kwargs)
@classmethod
def _adapt_string_for_cast(self, type_):
# copy the given string type into a CHAR
# for the purposes of rendering a CAST expression
type_ = sqltypes.to_instance(type_)
if isinstance(type_, sqltypes.CHAR):
return type_
elif isinstance(type_, _StringType):
return CHAR(
length=type_.length,
charset=type_.charset,
collation=type_.collation,
ascii=type_.ascii,
binary=type_.binary,
unicode=type_.unicode,
national=False # not supported in CAST
)
else:
return CHAR(length=type_.length)
class NVARCHAR(_StringType, sqltypes.NVARCHAR):
"""MySQL NVARCHAR type.
For variable-length character data in the server's configured national
character set.
"""
__visit_name__ = 'NVARCHAR'
def __init__(self, length=None, **kwargs):
"""Construct an NVARCHAR.
:param length: Maximum data length, in characters.
:param binary: Optional, use the default binary collation for the
national character set. This does not affect the type of data
stored, use a BINARY type for binary data.
:param collation: Optional, request a particular collation. Must be
compatible with the national character set.
"""
kwargs['national'] = True
super(NVARCHAR, self).__init__(length=length, **kwargs)
class NCHAR(_StringType, sqltypes.NCHAR):
"""MySQL NCHAR type.
For fixed-length character data in the server's configured national
character set.
"""
__visit_name__ = 'NCHAR'
def __init__(self, length=None, **kwargs):
"""Construct an NCHAR.
:param length: Maximum data length, in characters.
:param binary: Optional, use the default binary collation for the
national character set. This does not affect the type of data
stored, use a BINARY type for binary data.
:param collation: Optional, request a particular collation. Must be
compatible with the national character set.
"""
kwargs['national'] = True
super(NCHAR, self).__init__(length=length, **kwargs)
class TINYBLOB(sqltypes._Binary):
"""MySQL TINYBLOB type, for binary data up to 2^8 bytes."""
__visit_name__ = 'TINYBLOB'
class MEDIUMBLOB(sqltypes._Binary):
"""MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes."""
__visit_name__ = 'MEDIUMBLOB'
class LONGBLOB(sqltypes._Binary):
"""MySQL LONGBLOB type, for binary data up to 2^32 bytes."""
__visit_name__ = 'LONGBLOB'

View File

@ -1,17 +1,21 @@
"""Support for the MySQL database via Jython's zxjdbc JDBC connector. # mysql/zxjdbc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
JDBC Driver """
-----------
The official MySQL JDBC driver is at .. dialect:: mysql+zxjdbc
http://dev.mysql.com/downloads/connector/j/. :name: zxjdbc for Jython
:dbapi: zxjdbc
:connectstring: mysql+zxjdbc://<user>:<password>@<hostname>[:<port>]/\
<database>
:driverurl: http://dev.mysql.com/downloads/connector/j/
Connecting .. note:: Jython is not supported by current versions of SQLAlchemy. The
---------- zxjdbc dialect should be considered as experimental.
Connect string format:
mysql+zxjdbc://<user>:<password>@<hostname>[:<port>]/<database>
Character Sets Character Sets
-------------- --------------
@ -20,14 +24,15 @@ SQLAlchemy zxjdbc dialects pass unicode straight through to the
zxjdbc/JDBC layer. To allow multiple character sets to be sent from the zxjdbc/JDBC layer. To allow multiple character sets to be sent from the
MySQL Connector/J JDBC driver, by default SQLAlchemy sets its MySQL Connector/J JDBC driver, by default SQLAlchemy sets its
``characterEncoding`` connection property to ``UTF-8``. It may be ``characterEncoding`` connection property to ``UTF-8``. It may be
overriden via a ``create_engine`` URL parameter. overridden via a ``create_engine`` URL parameter.
""" """
import re import re
from sqlalchemy import types as sqltypes, util from ... import types as sqltypes, util
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector from ...connectors.zxJDBC import ZxJDBCConnector
from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext from .base import BIT, MySQLDialect, MySQLExecutionContext
class _ZxJDBCBit(BIT): class _ZxJDBCBit(BIT):
def result_processor(self, dialect, coltype): def result_processor(self, dialect, coltype):
@ -37,7 +42,7 @@ class _ZxJDBCBit(BIT):
return value return value
if isinstance(value, bool): if isinstance(value, bool):
return int(value) return int(value)
v = 0L v = 0
for i in value: for i in value:
v = v << 8 | (i & 0xff) v = v << 8 | (i & 0xff)
value = v value = v
@ -82,7 +87,8 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
if opts.get(key, None): if opts.get(key, None):
return opts[key] return opts[key]
util.warn("Could not detect the connection character set. Assuming latin1.") util.warn("Could not detect the connection character set. "
"Assuming latin1.")
return 'latin1' return 'latin1'
def _driver_kwargs(self): def _driver_kwargs(self):
@ -92,7 +98,7 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
def _extract_error_code(self, exception): def _extract_error_code(self, exception):
# e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist
# [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' () # [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' ()
m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.orig.args)) m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.args))
c = m.group(1) c = m.group(1)
if c: if c:
return int(c) return int(c)
@ -100,7 +106,7 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
dbapi_con = connection.connection dbapi_con = connection.connection
version = [] version = []
r = re.compile('[.\-]') r = re.compile(r'[.\-]')
for n in r.split(dbapi_con.dbversion): for n in r.split(dbapi_con.dbversion):
try: try:
version.append(int(n)) version.append(int(n))

View File

@ -1,17 +1,24 @@
# oracle/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc
base.dialect = cx_oracle.dialect base.dialect = cx_oracle.dialect
from sqlalchemy.dialects.oracle.base import \ from sqlalchemy.dialects.oracle.base import \
VARCHAR, NVARCHAR, CHAR, DATE, DATETIME, NUMBER,\ VARCHAR, NVARCHAR, CHAR, DATE, NUMBER,\
BLOB, BFILE, CLOB, NCLOB, TIMESTAMP, RAW,\ BLOB, BFILE, CLOB, NCLOB, TIMESTAMP, RAW,\
FLOAT, DOUBLE_PRECISION, LONG, dialect, INTERVAL,\ FLOAT, DOUBLE_PRECISION, LONG, dialect, INTERVAL,\
VARCHAR2, NVARCHAR2 VARCHAR2, NVARCHAR2, ROWID, dialect
__all__ = ( __all__ = (
'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'DATETIME', 'NUMBER', 'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'NUMBER',
'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW', 'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW',
'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL', 'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL',
'VARCHAR2', 'NVARCHAR2' 'VARCHAR2', 'NVARCHAR2', 'ROWID'
) )

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,19 @@
"""Support for the Oracle database via the zxjdbc JDBC connector. # oracle/zxjdbc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
JDBC Driver """
----------- .. dialect:: oracle+zxjdbc
:name: zxJDBC for Jython
:dbapi: zxjdbc
:connectstring: oracle+zxjdbc://user:pass@host/dbname
:driverurl: http://www.oracle.com/technetwork/database/features/jdbc/index-091264.html
The official Oracle JDBC driver is at .. note:: Jython is not supported by current versions of SQLAlchemy. The
http://www.oracle.com/technology/software/tech/java/sqlj_jdbc/index.html. zxjdbc dialect should be considered as experimental.
""" """
import decimal import decimal
@ -12,12 +21,16 @@ import re
from sqlalchemy import sql, types as sqltypes, util from sqlalchemy import sql, types as sqltypes, util
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, OracleExecutionContext from sqlalchemy.dialects.oracle.base import (OracleCompiler,
from sqlalchemy.engine import base, default OracleDialect,
OracleExecutionContext)
from sqlalchemy.engine import result as _result
from sqlalchemy.sql import expression from sqlalchemy.sql import expression
import collections
SQLException = zxJDBC = None SQLException = zxJDBC = None
class _ZxJDBCDate(sqltypes.Date): class _ZxJDBCDate(sqltypes.Date):
def result_processor(self, dialect, coltype): def result_processor(self, dialect, coltype):
@ -53,10 +66,11 @@ class _ZxJDBCNumeric(sqltypes.Numeric):
class OracleCompiler_zxjdbc(OracleCompiler): class OracleCompiler_zxjdbc(OracleCompiler):
def returning_clause(self, stmt, returning_cols): def returning_clause(self, stmt, returning_cols):
self.returning_cols = list(expression._select_iterables(returning_cols)) self.returning_cols = list(
expression._select_iterables(returning_cols))
# within_columns_clause=False so that labels (foo AS bar) don't render # within_columns_clause=False so that labels (foo AS bar) don't render
columns = [self.process(c, within_columns_clause=False, result_map=self.result_map) columns = [self.process(c, within_columns_clause=False)
for c in self.returning_cols] for c in self.returning_cols]
if not hasattr(self, 'returning_parameters'): if not hasattr(self, 'returning_parameters'):
@ -64,12 +78,15 @@ class OracleCompiler_zxjdbc(OracleCompiler):
binds = [] binds = []
for i, col in enumerate(self.returning_cols): for i, col in enumerate(self.returning_cols):
dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) dbtype = col.type.dialect_impl(
self.dialect).get_dbapi_type(self.dialect.dbapi)
self.returning_parameters.append((i + 1, dbtype)) self.returning_parameters.append((i + 1, dbtype))
bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype)) bindparam = sql.bindparam(
"ret_%d" % i, value=ReturningParam(dbtype))
self.binds[bindparam.key] = bindparam self.binds[bindparam.key] = bindparam
binds.append(self.bindparam_string(self._truncate_bindparam(bindparam))) binds.append(
self.bindparam_string(self._truncate_bindparam(bindparam)))
return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
@ -88,15 +105,19 @@ class OracleExecutionContext_zxjdbc(OracleExecutionContext):
try: try:
try: try:
rrs = self.statement.__statement__.getReturnResultSet() rrs = self.statement.__statement__.getReturnResultSet()
rrs.next() next(rrs)
except SQLException, sqle: except SQLException as sqle:
msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode()) msg = '%s [SQLCode: %d]' % (
sqle.getMessage(), sqle.getErrorCode())
if sqle.getSQLState() is not None: if sqle.getSQLState() is not None:
msg += ' [SQLState: %s]' % sqle.getSQLState() msg += ' [SQLState: %s]' % sqle.getSQLState()
raise zxJDBC.Error(msg) raise zxJDBC.Error(msg)
else: else:
row = tuple(self.cursor.datahandler.getPyObject(rrs, index, dbtype) row = tuple(
for index, dbtype in self.compiled.returning_parameters) self.cursor.datahandler.getPyObject(
rrs, index, dbtype)
for index, dbtype in
self.compiled.returning_parameters)
return ReturningResultProxy(self, row) return ReturningResultProxy(self, row)
finally: finally:
if rrs is not None: if rrs is not None:
@ -106,15 +127,15 @@ class OracleExecutionContext_zxjdbc(OracleExecutionContext):
pass pass
self.statement.close() self.statement.close()
return base.ResultProxy(self) return _result.ResultProxy(self)
def create_cursor(self): def create_cursor(self):
cursor = self._connection.connection.cursor() cursor = self._dbapi_connection.cursor()
cursor.datahandler = self.dialect.DataHandler(cursor.datahandler) cursor.datahandler = self.dialect.DataHandler(cursor.datahandler)
return cursor return cursor
class ReturningResultProxy(base.FullyBufferedResultProxy): class ReturningResultProxy(_result.FullyBufferedResultProxy):
"""ResultProxy backed by the RETURNING ResultSet results.""" """ResultProxy backed by the RETURNING ResultSet results."""
@ -132,7 +153,7 @@ class ReturningResultProxy(base.FullyBufferedResultProxy):
return ret return ret
def _buffer_rows(self): def _buffer_rows(self):
return [self._returning_row] return collections.deque([self._returning_row])
class ReturningParam(object): class ReturningParam(object):
@ -157,8 +178,8 @@ class ReturningParam(object):
def __repr__(self): def __repr__(self):
kls = self.__class__ kls = self.__class__
return '<%s.%s object at 0x%x type=%s>' % (kls.__module__, kls.__name__, id(self), return '<%s.%s object at 0x%x type=%s>' % (
self.type) kls.__module__, kls.__name__, id(self), self.type)
class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect): class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
@ -182,28 +203,33 @@ class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
from java.sql import SQLException from java.sql import SQLException
from com.ziclix.python.sql import zxJDBC from com.ziclix.python.sql import zxJDBC
from com.ziclix.python.sql.handler import OracleDataHandler from com.ziclix.python.sql.handler import OracleDataHandler
class OracleReturningDataHandler(OracleDataHandler):
class OracleReturningDataHandler(OracleDataHandler):
"""zxJDBC DataHandler that specially handles ReturningParam.""" """zxJDBC DataHandler that specially handles ReturningParam."""
def setJDBCObject(self, statement, index, object, dbtype=None): def setJDBCObject(self, statement, index, object, dbtype=None):
if type(object) is ReturningParam: if type(object) is ReturningParam:
statement.registerReturnParameter(index, object.type) statement.registerReturnParameter(index, object.type)
elif dbtype is None: elif dbtype is None:
OracleDataHandler.setJDBCObject(self, statement, index, object) OracleDataHandler.setJDBCObject(
self, statement, index, object)
else: else:
OracleDataHandler.setJDBCObject(self, statement, index, object, dbtype) OracleDataHandler.setJDBCObject(
self, statement, index, object, dbtype)
self.DataHandler = OracleReturningDataHandler self.DataHandler = OracleReturningDataHandler
def initialize(self, connection): def initialize(self, connection):
super(OracleDialect_zxjdbc, self).initialize(connection) super(OracleDialect_zxjdbc, self).initialize(connection)
self.implicit_returning = connection.connection.driverversion >= '10.2' self.implicit_returning = \
connection.connection.driverversion >= '10.2'
def _create_jdbc_url(self, url): def _create_jdbc_url(self, url):
return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database) return 'jdbc:oracle:thin:@%s:%s:%s' % (
url.host, url.port or 1521, url.database)
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1) version = re.search(
r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
return tuple(int(x) for x in version.split('.')) return tuple(int(x) for x in version.split('.'))
dialect = OracleDialect_zxjdbc dialect = OracleDialect_zxjdbc

View File

@ -0,0 +1,314 @@
# postgresql/array.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .base import ischema_names
from ...sql import expression, operators
from ...sql.base import SchemaEventTarget
from ... import types as sqltypes
try:
from uuid import UUID as _python_UUID
except ImportError:
_python_UUID = None
def Any(other, arrexpr, operator=operators.eq):
"""A synonym for the :meth:`.ARRAY.Comparator.any` method.
This method is legacy and is here for backwards-compatibility.
.. seealso::
:func:`.expression.any_`
"""
return arrexpr.any(other, operator)
def All(other, arrexpr, operator=operators.eq):
"""A synonym for the :meth:`.ARRAY.Comparator.all` method.
This method is legacy and is here for backwards-compatibility.
.. seealso::
:func:`.expression.all_`
"""
return arrexpr.all(other, operator)
class array(expression.Tuple):
"""A PostgreSQL ARRAY literal.
This is used to produce ARRAY literals in SQL expressions, e.g.::
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.dialects import postgresql
from sqlalchemy import select, func
stmt = select([
array([1,2]) + array([3,4,5])
])
print stmt.compile(dialect=postgresql.dialect())
Produces the SQL::
SELECT ARRAY[%(param_1)s, %(param_2)s] ||
ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
An instance of :class:`.array` will always have the datatype
:class:`.ARRAY`. The "inner" type of the array is inferred from
the values present, unless the ``type_`` keyword argument is passed::
array(['foo', 'bar'], type_=CHAR)
.. versionadded:: 0.8 Added the :class:`~.postgresql.array` literal type.
See also:
:class:`.postgresql.ARRAY`
"""
__visit_name__ = 'array'
def __init__(self, clauses, **kw):
super(array, self).__init__(*clauses, **kw)
self.type = ARRAY(self.type)
def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
if _assume_scalar or operator is operators.getitem:
# if getitem->slice were called, Indexable produces
# a Slice object from that
assert isinstance(obj, int)
return expression.BindParameter(
None, obj, _compared_to_operator=operator,
type_=type_,
_compared_to_type=self.type, unique=True)
else:
return array([
self._bind_param(operator, o, _assume_scalar=True, type_=type_)
for o in obj])
def self_group(self, against=None):
if (against in (
operators.any_op, operators.all_op, operators.getitem)):
return expression.Grouping(self)
else:
return self
CONTAINS = operators.custom_op("@>", precedence=5)
CONTAINED_BY = operators.custom_op("<@", precedence=5)
OVERLAP = operators.custom_op("&&", precedence=5)
class ARRAY(SchemaEventTarget, sqltypes.ARRAY):
"""PostgreSQL ARRAY type.
.. versionchanged:: 1.1 The :class:`.postgresql.ARRAY` type is now
a subclass of the core :class:`.types.ARRAY` type.
The :class:`.postgresql.ARRAY` type is constructed in the same way
as the core :class:`.types.ARRAY` type; a member type is required, and a
number of dimensions is recommended if the type is to be used for more
than one dimension::
from sqlalchemy.dialects import postgresql
mytable = Table("mytable", metadata,
Column("data", postgresql.ARRAY(Integer, dimensions=2))
)
The :class:`.postgresql.ARRAY` type provides all operations defined on the
core :class:`.types.ARRAY` type, including support for "dimensions", indexed
access, and simple matching such as :meth:`.types.ARRAY.Comparator.any`
and :meth:`.types.ARRAY.Comparator.all`. :class:`.postgresql.ARRAY` class also
provides PostgreSQL-specific methods for containment operations, including
:meth:`.postgresql.ARRAY.Comparator.contains`
:meth:`.postgresql.ARRAY.Comparator.contained_by`,
and :meth:`.postgresql.ARRAY.Comparator.overlap`, e.g.::
mytable.c.data.contains([1, 2])
The :class:`.postgresql.ARRAY` type may not be supported on all
PostgreSQL DBAPIs; it is currently known to work on psycopg2 only.
Additionally, the :class:`.postgresql.ARRAY` type does not work directly in
conjunction with the :class:`.ENUM` type. For a workaround, see the
special type at :ref:`postgresql_array_of_enum`.
.. seealso::
:class:`.types.ARRAY` - base array type
:class:`.postgresql.array` - produces a literal array value.
"""
class Comparator(sqltypes.ARRAY.Comparator):
"""Define comparison operations for :class:`.ARRAY`.
Note that these operations are in addition to those provided
by the base :class:`.types.ARRAY.Comparator` class, including
:meth:`.types.ARRAY.Comparator.any` and
:meth:`.types.ARRAY.Comparator.all`.
"""
def contains(self, other, **kwargs):
"""Boolean expression. Test if elements are a superset of the
elements of the argument array expression.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
"""Boolean expression. Test if elements are a proper subset of the
elements of the argument array expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean)
def overlap(self, other):
"""Boolean expression. Test if array has elements in common with
an argument array expression.
"""
return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
comparator_factory = Comparator
def __init__(self, item_type, as_tuple=False, dimensions=None,
zero_indexes=False):
"""Construct an ARRAY.
E.g.::
Column('myarray', ARRAY(Integer))
Arguments are:
:param item_type: The data type of items of this array. Note that
dimensionality is irrelevant here, so multi-dimensional arrays like
``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as
``ARRAY(ARRAY(Integer))`` or such.
:param as_tuple=False: Specify whether return results
should be converted to tuples from lists. DBAPIs such
as psycopg2 return lists by default. When tuples are
returned, the results are hashable.
:param dimensions: if non-None, the ARRAY will assume a fixed
number of dimensions. This will cause the DDL emitted for this
ARRAY to include the exact number of bracket clauses ``[]``,
and will also optimize the performance of the type overall.
Note that PG arrays are always implicitly "non-dimensioned",
meaning they can store any number of dimensions no matter how
they were declared.
:param zero_indexes=False: when True, index values will be converted
between Python zero-based and PostgreSQL one-based indexes, e.g.
a value of one will be added to all index values before passing
to the database.
.. versionadded:: 0.9.5
"""
if isinstance(item_type, ARRAY):
raise ValueError("Do not nest ARRAY types; ARRAY(basetype) "
"handles multi-dimensional arrays of basetype")
if isinstance(item_type, type):
item_type = item_type()
self.item_type = item_type
self.as_tuple = as_tuple
self.dimensions = dimensions
self.zero_indexes = zero_indexes
@property
def hashable(self):
return self.as_tuple
@property
def python_type(self):
return list
def compare_values(self, x, y):
return x == y
def _set_parent(self, column):
"""Support SchemaEventTarget"""
if isinstance(self.item_type, SchemaEventTarget):
self.item_type._set_parent(column)
def _set_parent_with_dispatch(self, parent):
"""Support SchemaEventTarget"""
if isinstance(self.item_type, SchemaEventTarget):
self.item_type._set_parent_with_dispatch(parent)
def _proc_array(self, arr, itemproc, dim, collection):
if dim is None:
arr = list(arr)
if dim == 1 or dim is None and (
# this has to be (list, tuple), or at least
# not hasattr('__iter__'), since Py3K strings
# etc. have __iter__
not arr or not isinstance(arr[0], (list, tuple))):
if itemproc:
return collection(itemproc(x) for x in arr)
else:
return collection(arr)
else:
return collection(
self._proc_array(
x, itemproc,
dim - 1 if dim is not None else None,
collection)
for x in arr
)
def bind_processor(self, dialect):
item_proc = self.item_type.dialect_impl(dialect).\
bind_processor(dialect)
def process(value):
if value is None:
return value
else:
return self._proc_array(
value,
item_proc,
self.dimensions,
list)
return process
def result_processor(self, dialect, coltype):
item_proc = self.item_type.dialect_impl(dialect).\
result_processor(dialect, coltype)
def process(value):
if value is None:
return value
else:
return self._proc_array(
value,
item_proc,
self.dimensions,
tuple if self.as_tuple else list)
return process
ischema_names['_array'] = ARRAY

View File

@ -0,0 +1,213 @@
# postgresql/on_conflict.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from ...sql.elements import ClauseElement, _literal_as_binds
from ...sql.dml import Insert as StandardInsert
from ...sql.expression import alias
from ...sql import schema
from ...util.langhelpers import public_factory
from ...sql.base import _generative
from ... import util
from . import ext
__all__ = ('Insert', 'insert')
class Insert(StandardInsert):
"""PostgreSQL-specific implementation of INSERT.
Adds methods for PG-specific syntaxes such as ON CONFLICT.
.. versionadded:: 1.1
"""
@util.memoized_property
def excluded(self):
"""Provide the ``excluded`` namespace for an ON CONFLICT statement
PG's ON CONFLICT clause allows reference to the row that would
be inserted, known as ``excluded``. This attribute provides
all columns in this row to be referenaceable.
.. seealso::
:ref:`postgresql_insert_on_conflict` - example of how
to use :attr:`.Insert.excluded`
"""
return alias(self.table, name='excluded').columns
@_generative
def on_conflict_do_update(
self,
constraint=None, index_elements=None,
index_where=None, set_=None, where=None):
"""
Specifies a DO UPDATE SET action for ON CONFLICT clause.
Either the ``constraint`` or ``index_elements`` argument is
required, but only one of these can be specified.
:param constraint:
The name of a unique or exclusion constraint on the table,
or the constraint object itself if it has a .name attribute.
:param index_elements:
A sequence consisting of string column names, :class:`.Column`
objects, or other column expression objects that will be used
to infer a target index.
:param index_where:
Additional WHERE criterion that can be used to infer a
conditional target index.
:param set_:
Required argument. A dictionary or other mapping object
with column names as keys and expressions or literals as values,
specifying the ``SET`` actions to take.
If the target :class:`.Column` specifies a ".key" attribute distinct
from the column name, that key should be used.
.. warning:: This dictionary does **not** take into account
Python-specified default UPDATE values or generation functions,
e.g. those specified using :paramref:`.Column.onupdate`.
These values will not be exercised for an ON CONFLICT style of
UPDATE, unless they are manually specified in the
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
:param where:
Optional argument. If present, can be a literal SQL
string or an acceptable expression for a ``WHERE`` clause
that restricts the rows affected by ``DO UPDATE SET``. Rows
not meeting the ``WHERE`` condition will not be updated
(effectively a ``DO NOTHING`` for those rows).
.. versionadded:: 1.1
.. seealso::
:ref:`postgresql_insert_on_conflict`
"""
self._post_values_clause = OnConflictDoUpdate(
constraint, index_elements, index_where, set_, where)
return self
@_generative
def on_conflict_do_nothing(
self,
constraint=None, index_elements=None, index_where=None):
"""
Specifies a DO NOTHING action for ON CONFLICT clause.
The ``constraint`` and ``index_elements`` arguments
are optional, but only one of these can be specified.
:param constraint:
The name of a unique or exclusion constraint on the table,
or the constraint object itself if it has a .name attribute.
:param index_elements:
A sequence consisting of string column names, :class:`.Column`
objects, or other column expression objects that will be used
to infer a target index.
:param index_where:
Additional WHERE criterion that can be used to infer a
conditional target index.
.. versionadded:: 1.1
.. seealso::
:ref:`postgresql_insert_on_conflict`
"""
self._post_values_clause = OnConflictDoNothing(
constraint, index_elements, index_where)
return self
insert = public_factory(Insert, '.dialects.postgresql.insert')
class OnConflictClause(ClauseElement):
def __init__(
self,
constraint=None,
index_elements=None,
index_where=None):
if constraint is not None:
if not isinstance(constraint, util.string_types) and \
isinstance(constraint, (
schema.Index, schema.Constraint,
ext.ExcludeConstraint)):
constraint = getattr(constraint, 'name') or constraint
if constraint is not None:
if index_elements is not None:
raise ValueError(
"'constraint' and 'index_elements' are mutually exclusive")
if isinstance(constraint, util.string_types):
self.constraint_target = constraint
self.inferred_target_elements = None
self.inferred_target_whereclause = None
elif isinstance(constraint, schema.Index):
index_elements = constraint.expressions
index_where = \
constraint.dialect_options['postgresql'].get("where")
elif isinstance(constraint, ext.ExcludeConstraint):
index_elements = constraint.columns
index_where = constraint.where
else:
index_elements = constraint.columns
index_where = \
constraint.dialect_options['postgresql'].get("where")
if index_elements is not None:
self.constraint_target = None
self.inferred_target_elements = index_elements
self.inferred_target_whereclause = index_where
elif constraint is None:
self.constraint_target = self.inferred_target_elements = \
self.inferred_target_whereclause = None
class OnConflictDoNothing(OnConflictClause):
__visit_name__ = 'on_conflict_do_nothing'
class OnConflictDoUpdate(OnConflictClause):
__visit_name__ = 'on_conflict_do_update'
def __init__(
self,
constraint=None,
index_elements=None,
index_where=None,
set_=None,
where=None):
super(OnConflictDoUpdate, self).__init__(
constraint=constraint,
index_elements=index_elements,
index_where=index_where)
if self.inferred_target_elements is None and \
self.constraint_target is None:
raise ValueError(
"Either constraint or index_elements, "
"but not both, must be specified unless DO NOTHING")
if (not isinstance(set_, dict) or not set_):
raise ValueError("set parameter must be a non-empty dictionary")
self.update_values_to_set = [
(key, value)
for key, value in set_.items()
]
self.update_whereclause = where

View File

@ -0,0 +1,218 @@
# postgresql/ext.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from ...sql import expression
from ...sql import elements
from ...sql import functions
from ...sql.schema import ColumnCollectionConstraint
from .array import ARRAY
class aggregate_order_by(expression.ColumnElement):
"""Represent a PostgreSQL aggregate order by expression.
E.g.::
from sqlalchemy.dialects.postgresql import aggregate_order_by
expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc()))
stmt = select([expr])
would represent the expression::
SELECT array_agg(a ORDER BY b DESC) FROM table;
Similarly::
expr = func.string_agg(
table.c.a,
aggregate_order_by(literal_column("','"), table.c.a)
)
stmt = select([expr])
Would represent::
SELECT string_agg(a, ',' ORDER BY a) FROM table;
.. versionadded:: 1.1
.. seealso::
:class:`.array_agg`
"""
__visit_name__ = 'aggregate_order_by'
def __init__(self, target, order_by):
self.target = elements._literal_as_binds(target)
self.order_by = elements._literal_as_binds(order_by)
def self_group(self, against=None):
return self
def get_children(self, **kwargs):
return self.target, self.order_by
def _copy_internals(self, clone=elements._clone, **kw):
self.target = clone(self.target, **kw)
self.order_by = clone(self.order_by, **kw)
@property
def _from_objects(self):
return self.target._from_objects + self.order_by._from_objects
class ExcludeConstraint(ColumnCollectionConstraint):
"""A table-level EXCLUDE constraint.
Defines an EXCLUDE constraint as described in the `postgres
documentation`__.
__ http://www.postgresql.org/docs/9.0/\
static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
"""
__visit_name__ = 'exclude_constraint'
where = None
def __init__(self, *elements, **kw):
r"""
Create an :class:`.ExcludeConstraint` object.
E.g.::
const = ExcludeConstraint(
(Column('period'), '&&'),
(Column('group'), '='),
where=(Column('group') != 'some group')
)
The constraint is normally embedded into the :class:`.Table` construct
directly, or added later using :meth:`.append_constraint`::
some_table = Table(
'some_table', metadata,
Column('id', Integer, primary_key=True),
Column('period', TSRANGE()),
Column('group', String)
)
some_table.append_constraint(
ExcludeConstraint(
(some_table.c.period, '&&'),
(some_table.c.group, '='),
where=some_table.c.group != 'some group',
name='some_table_excl_const'
)
)
:param \*elements:
A sequence of two tuples of the form ``(column, operator)`` where
"column" is a SQL expression element or a raw SQL string, most
typically a :class:`.Column` object,
and "operator" is a string containing the operator to use.
.. note::
A plain string passed for the value of "column" is interpreted
as an arbitrary SQL expression; when passing a plain string,
any necessary quoting and escaping syntaxes must be applied
manually. In order to specify a column name when a
:class:`.Column` object is not available, while ensuring that
any necessary quoting rules take effect, an ad-hoc
:class:`.Column` or :func:`.sql.expression.column` object may
be used.
:param name:
Optional, the in-database name of this constraint.
:param deferrable:
Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
issuing DDL for this constraint.
:param initially:
Optional string. If set, emit INITIALLY <value> when issuing DDL
for this constraint.
:param using:
Optional string. If set, emit USING <index_method> when issuing DDL
for this constraint. Defaults to 'gist'.
:param where:
Optional SQL expression construct or literal SQL string.
If set, emit WHERE <predicate> when issuing DDL
for this constraint.
.. note::
A plain string passed here is interpreted as an arbitrary SQL
expression; when passing a plain string, any necessary quoting
and escaping syntaxes must be applied manually.
"""
columns = []
render_exprs = []
self.operators = {}
expressions, operators = zip(*elements)
for (expr, column, strname, add_element), operator in zip(
self._extract_col_expression_collection(expressions),
operators
):
if add_element is not None:
columns.append(add_element)
name = column.name if column is not None else strname
if name is not None:
# backwards compat
self.operators[name] = operator
expr = expression._literal_as_text(expr)
render_exprs.append(
(expr, name, operator)
)
self._render_exprs = render_exprs
ColumnCollectionConstraint.__init__(
self,
*columns,
name=kw.get('name'),
deferrable=kw.get('deferrable'),
initially=kw.get('initially')
)
self.using = kw.get('using', 'gist')
where = kw.get('where')
if where is not None:
self.where = expression._literal_as_text(where)
def copy(self, **kw):
elements = [(col, self.operators[col])
for col in self.columns.keys()]
c = self.__class__(*elements,
name=self.name,
deferrable=self.deferrable,
initially=self.initially,
where=self.where,
using=self.using)
c.dispatch._update(self.dispatch)
return c
def array_agg(*arg, **kw):
"""PostgreSQL-specific form of :class:`.array_agg`, ensures
return type is :class:`.postgresql.ARRAY` and not
the plain :class:`.types.ARRAY`.
.. versionadded:: 1.1
"""
kw['type_'] = ARRAY(functions._type_from_args(arg))
return functions.func.array_agg(*arg, **kw)

View File

@ -0,0 +1,420 @@
# postgresql/hstore.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import re
from .base import ischema_names
from .array import ARRAY
from ... import types as sqltypes
from ...sql import functions as sqlfunc
from ...sql import operators
from ... import util
__all__ = ('HSTORE', 'hstore')
idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
GETITEM = operators.custom_op(
"->", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
HAS_KEY = operators.custom_op(
"?", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
HAS_ALL = operators.custom_op(
"?&", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
HAS_ANY = operators.custom_op(
"?|", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
CONTAINS = operators.custom_op(
"@>", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
CONTAINED_BY = operators.custom_op(
"<@", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
"""Represent the PostgreSQL HSTORE type.
The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
data_table = Table('data_table', metadata,
Column('id', Integer, primary_key=True),
Column('data', HSTORE)
)
with engine.connect() as conn:
conn.execute(
data_table.insert(),
data = {"key1": "value1", "key2": "value2"}
)
:class:`.HSTORE` provides for a wide range of operations, including:
* Index operations::
data_table.c.data['some key'] == 'some value'
* Containment operations::
data_table.c.data.has_key('some key')
data_table.c.data.has_all(['one', 'two', 'three'])
* Concatenation::
data_table.c.data + {"k1": "v1"}
For a full list of special methods see
:class:`.HSTORE.comparator_factory`.
For usage with the SQLAlchemy ORM, it may be desirable to combine
the usage of :class:`.HSTORE` with :class:`.MutableDict` dictionary
now part of the :mod:`sqlalchemy.ext.mutable`
extension. This extension will allow "in-place" changes to the
dictionary, e.g. addition of new keys or replacement/removal of existing
keys to/from the current dictionary, to produce events which will be
detected by the unit of work::
from sqlalchemy.ext.mutable import MutableDict
class MyClass(Base):
__tablename__ = 'data_table'
id = Column(Integer, primary_key=True)
data = Column(MutableDict.as_mutable(HSTORE))
my_object = session.query(MyClass).one()
# in-place mutation, requires Mutable extension
# in order for the ORM to detect
my_object.data['some_key'] = 'some value'
session.commit()
When the :mod:`sqlalchemy.ext.mutable` extension is not used, the ORM
will not be alerted to any changes to the contents of an existing
dictionary, unless that dictionary value is re-assigned to the
HSTORE-attribute itself, thus generating a change event.
.. versionadded:: 0.8
.. seealso::
:class:`.hstore` - render the PostgreSQL ``hstore()`` function.
"""
__visit_name__ = 'HSTORE'
hashable = False
text_type = sqltypes.Text()
def __init__(self, text_type=None):
"""Construct a new :class:`.HSTORE`.
:param text_type: the type that should be used for indexed values.
Defaults to :class:`.types.Text`.
.. versionadded:: 1.1.0
"""
if text_type is not None:
self.text_type = text_type
class Comparator(
sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator):
"""Define comparison operations for :class:`.HSTORE`."""
def has_key(self, other):
"""Boolean expression. Test for presence of a key. Note that the
key may be a SQLA expression.
"""
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
def has_all(self, other):
"""Boolean expression. Test for presence of all keys in jsonb
"""
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
def has_any(self, other):
"""Boolean expression. Test for presence of any key in jsonb
"""
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
def contains(self, other, **kwargs):
"""Boolean expression. Test if keys (or array) are a superset
of/contained the keys of the argument jsonb expression.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
"""Boolean expression. Test if keys are a proper subset of the
keys of the argument jsonb expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean)
def _setup_getitem(self, index):
return GETITEM, index, self.type.text_type
def defined(self, key):
"""Boolean expression. Test for presence of a non-NULL value for
the key. Note that the key may be a SQLA expression.
"""
return _HStoreDefinedFunction(self.expr, key)
def delete(self, key):
"""HStore expression. Returns the contents of this hstore with the
given key deleted. Note that the key may be a SQLA expression.
"""
if isinstance(key, dict):
key = _serialize_hstore(key)
return _HStoreDeleteFunction(self.expr, key)
def slice(self, array):
"""HStore expression. Returns a subset of an hstore defined by
array of keys.
"""
return _HStoreSliceFunction(self.expr, array)
def keys(self):
"""Text array expression. Returns array of keys."""
return _HStoreKeysFunction(self.expr)
def vals(self):
"""Text array expression. Returns array of values."""
return _HStoreValsFunction(self.expr)
def array(self):
"""Text array expression. Returns array of alternating keys and
values.
"""
return _HStoreArrayFunction(self.expr)
def matrix(self):
"""Text array expression. Returns array of [key, value] pairs."""
return _HStoreMatrixFunction(self.expr)
comparator_factory = Comparator
def bind_processor(self, dialect):
if util.py2k:
encoding = dialect.encoding
def process(value):
if isinstance(value, dict):
return _serialize_hstore(value).encode(encoding)
else:
return value
else:
def process(value):
if isinstance(value, dict):
return _serialize_hstore(value)
else:
return value
return process
def result_processor(self, dialect, coltype):
if util.py2k:
encoding = dialect.encoding
def process(value):
if value is not None:
return _parse_hstore(value.decode(encoding))
else:
return value
else:
def process(value):
if value is not None:
return _parse_hstore(value)
else:
return value
return process
ischema_names['hstore'] = HSTORE
class hstore(sqlfunc.GenericFunction):
"""Construct an hstore value within a SQL expression using the
PostgreSQL ``hstore()`` function.
The :class:`.hstore` function accepts one or two arguments as described
in the PostgreSQL documentation.
E.g.::
from sqlalchemy.dialects.postgresql import array, hstore
select([hstore('key1', 'value1')])
select([
hstore(
array(['key1', 'key2', 'key3']),
array(['value1', 'value2', 'value3'])
)
])
.. versionadded:: 0.8
.. seealso::
:class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype.
"""
type = HSTORE
name = 'hstore'
class _HStoreDefinedFunction(sqlfunc.GenericFunction):
type = sqltypes.Boolean
name = 'defined'
class _HStoreDeleteFunction(sqlfunc.GenericFunction):
type = HSTORE
name = 'delete'
class _HStoreSliceFunction(sqlfunc.GenericFunction):
type = HSTORE
name = 'slice'
class _HStoreKeysFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = 'akeys'
class _HStoreValsFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = 'avals'
class _HStoreArrayFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = 'hstore_to_array'
class _HStoreMatrixFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = 'hstore_to_matrix'
#
# parsing. note that none of this is used with the psycopg2 backend,
# which provides its own native extensions.
#
# My best guess at the parsing rules of hstore literals, since no formal
# grammar is given. This is mostly reverse engineered from PG's input parser
# behavior.
HSTORE_PAIR_RE = re.compile(r"""
(
"(?P<key> (\\ . | [^"])* )" # Quoted key
)
[ ]* => [ ]* # Pair operator, optional adjoining whitespace
(
(?P<value_null> NULL ) # NULL value
| "(?P<value> (\\ . | [^"])* )" # Quoted value
)
""", re.VERBOSE)
HSTORE_DELIMITER_RE = re.compile(r"""
[ ]* , [ ]*
""", re.VERBOSE)
def _parse_error(hstore_str, pos):
"""format an unmarshalling error."""
ctx = 20
hslen = len(hstore_str)
parsed_tail = hstore_str[max(pos - ctx - 1, 0):min(pos, hslen)]
residual = hstore_str[min(pos, hslen):min(pos + ctx + 1, hslen)]
if len(parsed_tail) > ctx:
parsed_tail = '[...]' + parsed_tail[1:]
if len(residual) > ctx:
residual = residual[:-1] + '[...]'
return "After %r, could not parse residual at position %d: %r" % (
parsed_tail, pos, residual)
def _parse_hstore(hstore_str):
"""Parse an hstore from its literal string representation.
Attempts to approximate PG's hstore input parsing rules as closely as
possible. Although currently this is not strictly necessary, since the
current implementation of hstore's output syntax is stricter than what it
accepts as input, the documentation makes no guarantees that will always
be the case.
"""
result = {}
pos = 0
pair_match = HSTORE_PAIR_RE.match(hstore_str)
while pair_match is not None:
key = pair_match.group('key').replace(r'\"', '"').replace(
"\\\\", "\\")
if pair_match.group('value_null'):
value = None
else:
value = pair_match.group('value').replace(
r'\"', '"').replace("\\\\", "\\")
result[key] = value
pos += pair_match.end()
delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:])
if delim_match is not None:
pos += delim_match.end()
pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:])
if pos != len(hstore_str):
raise ValueError(_parse_error(hstore_str, pos))
return result
def _serialize_hstore(val):
"""Serialize a dictionary into an hstore literal. Keys and values must
both be strings (except None for values).
"""
def esc(s, position):
if position == 'value' and s is None:
return 'NULL'
elif isinstance(s, util.string_types):
return '"%s"' % s.replace("\\", "\\\\").replace('"', r'\"')
else:
raise ValueError("%r in %s position is not a string." %
(s, position))
return ', '.join('%s=>%s' % (esc(k, 'key'), esc(v, 'value'))
for k, v in val.items())

View File

@ -0,0 +1,301 @@
# postgresql/json.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from __future__ import absolute_import
import json
import collections
from .base import ischema_names, colspecs
from ... import types as sqltypes
from ...sql import operators
from ...sql import elements
from ... import util
__all__ = ('JSON', 'JSONB')
idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
ASTEXT = operators.custom_op(
"->>", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
JSONPATH_ASTEXT = operators.custom_op(
"#>>", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
HAS_KEY = operators.custom_op(
"?", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
HAS_ALL = operators.custom_op(
"?&", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
HAS_ANY = operators.custom_op(
"?|", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
CONTAINS = operators.custom_op(
"@>", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
CONTAINED_BY = operators.custom_op(
"<@", precedence=idx_precedence, natural_self_precedent=True,
eager_grouping=True
)
class JSONPathType(sqltypes.JSON.JSONPathType):
def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)
def process(value):
assert isinstance(value, collections.Sequence)
tokens = [util.text_type(elem)for elem in value]
value = "{%s}" % (", ".join(tokens))
if super_proc:
value = super_proc(value)
return value
return process
def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)
def process(value):
assert isinstance(value, collections.Sequence)
tokens = [util.text_type(elem)for elem in value]
value = "{%s}" % (", ".join(tokens))
if super_proc:
value = super_proc(value)
return value
return process
colspecs[sqltypes.JSON.JSONPathType] = JSONPathType
class JSON(sqltypes.JSON):
"""Represent the PostgreSQL JSON type.
This type is a specialization of the Core-level :class:`.types.JSON`
type. Be sure to read the documentation for :class:`.types.JSON` for
important tips regarding treatment of NULL values and ORM use.
.. versionchanged:: 1.1 :class:`.postgresql.JSON` is now a PostgreSQL-
specific specialization of the new :class:`.types.JSON` type.
The operators provided by the PostgreSQL version of :class:`.JSON`
include:
* Index operations (the ``->`` operator)::
data_table.c.data['some key']
data_table.c.data[5]
* Index operations returning text (the ``->>`` operator)::
data_table.c.data['some key'].astext == 'some value'
* Index operations with CAST
(equivalent to ``CAST(col ->> ['some key'] AS <type>)``)::
data_table.c.data['some key'].astext.cast(Integer) == 5
* Path index operations (the ``#>`` operator)::
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
* Path index operations returning text (the ``#>>`` operator)::
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == \
'some value'
.. versionchanged:: 1.1 The :meth:`.ColumnElement.cast` operator on
JSON objects now requires that the :attr:`.JSON.Comparator.astext`
modifier be called explicitly, if the cast works only from a textual
string.
Index operations return an expression object whose type defaults to
:class:`.JSON` by default, so that further JSON-oriented instructions
may be called upon the result type.
Custom serializers and deserializers are specified at the dialect level,
that is using :func:`.create_engine`. The reason for this is that when
using psycopg2, the DBAPI only allows serializers at the per-cursor
or per-connection level. E.g.::
engine = create_engine("postgresql://scott:tiger@localhost/test",
json_serializer=my_serialize_fn,
json_deserializer=my_deserialize_fn
)
When using the psycopg2 dialect, the json_deserializer is registered
against the database using ``psycopg2.extras.register_default_json``.
.. seealso::
:class:`.types.JSON` - Core level JSON type
:class:`.JSONB`
"""
astext_type = sqltypes.Text()
def __init__(self, none_as_null=False, astext_type=None):
"""Construct a :class:`.JSON` type.
:param none_as_null: if True, persist the value ``None`` as a
SQL NULL value, not the JSON encoding of ``null``. Note that
when this flag is False, the :func:`.null` construct can still
be used to persist a NULL value::
from sqlalchemy import null
conn.execute(table.insert(), data=null())
.. versionchanged:: 0.9.8 - Added ``none_as_null``, and :func:`.null`
is now supported in order to persist a NULL value.
.. seealso::
:attr:`.JSON.NULL`
:param astext_type: the type to use for the
:attr:`.JSON.Comparator.astext`
accessor on indexed attributes. Defaults to :class:`.types.Text`.
.. versionadded:: 1.1
"""
super(JSON, self).__init__(none_as_null=none_as_null)
if astext_type is not None:
self.astext_type = astext_type
class Comparator(sqltypes.JSON.Comparator):
"""Define comparison operations for :class:`.JSON`."""
@property
def astext(self):
"""On an indexed expression, use the "astext" (e.g. "->>")
conversion when rendered in SQL.
E.g.::
select([data_table.c.data['some key'].astext])
.. seealso::
:meth:`.ColumnElement.cast`
"""
if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
return self.expr.left.operate(
JSONPATH_ASTEXT,
self.expr.right, result_type=self.type.astext_type)
else:
return self.expr.left.operate(
ASTEXT, self.expr.right, result_type=self.type.astext_type)
comparator_factory = Comparator
colspecs[sqltypes.JSON] = JSON
ischema_names['json'] = JSON
class JSONB(JSON):
"""Represent the PostgreSQL JSONB type.
The :class:`.JSONB` type stores arbitrary JSONB format data, e.g.::
data_table = Table('data_table', metadata,
Column('id', Integer, primary_key=True),
Column('data', JSONB)
)
with engine.connect() as conn:
conn.execute(
data_table.insert(),
data = {"key1": "value1", "key2": "value2"}
)
The :class:`.JSONB` type includes all operations provided by
:class:`.JSON`, including the same behaviors for indexing operations.
It also adds additional operators specific to JSONB, including
:meth:`.JSONB.Comparator.has_key`, :meth:`.JSONB.Comparator.has_all`,
:meth:`.JSONB.Comparator.has_any`, :meth:`.JSONB.Comparator.contains`,
and :meth:`.JSONB.Comparator.contained_by`.
Like the :class:`.JSON` type, the :class:`.JSONB` type does not detect
in-place changes when used with the ORM, unless the
:mod:`sqlalchemy.ext.mutable` extension is used.
Custom serializers and deserializers
are shared with the :class:`.JSON` class, using the ``json_serializer``
and ``json_deserializer`` keyword arguments. These must be specified
at the dialect level using :func:`.create_engine`. When using
psycopg2, the serializers are associated with the jsonb type using
``psycopg2.extras.register_default_jsonb`` on a per-connection basis,
in the same way that ``psycopg2.extras.register_default_json`` is used
to register these handlers with the json type.
.. versionadded:: 0.9.7
.. seealso::
:class:`.JSON`
"""
__visit_name__ = 'JSONB'
class Comparator(JSON.Comparator):
"""Define comparison operations for :class:`.JSON`."""
def has_key(self, other):
"""Boolean expression. Test for presence of a key. Note that the
key may be a SQLA expression.
"""
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
def has_all(self, other):
"""Boolean expression. Test for presence of all keys in jsonb
"""
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
def has_any(self, other):
"""Boolean expression. Test for presence of any key in jsonb
"""
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
def contains(self, other, **kwargs):
"""Boolean expression. Test if keys (or array) are a superset
of/contained the keys of the argument jsonb expression.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
"""Boolean expression. Test if keys are a proper subset of the
keys of the argument jsonb expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean)
comparator_factory = Comparator
ischema_names['jsonb'] = JSONB

View File

@ -0,0 +1,61 @@
# testing/engines.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: postgresql+psycopg2cffi
:name: psycopg2cffi
:dbapi: psycopg2cffi
:connectstring: \
postgresql+psycopg2cffi://user:password@host:port/dbname\
[?key=value&key=value...]
:url: http://pypi.python.org/pypi/psycopg2cffi/
``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C
layer. This makes it suitable for use in e.g. PyPy. Documentation
is as per ``psycopg2``.
.. versionadded:: 1.0.0
.. seealso::
:mod:`sqlalchemy.dialects.postgresql.psycopg2`
"""
from .psycopg2 import PGDialect_psycopg2
class PGDialect_psycopg2cffi(PGDialect_psycopg2):
driver = 'psycopg2cffi'
supports_unicode_statements = True
# psycopg2cffi's first release is 2.5.0, but reports
# __version__ as 2.4.4. Subsequent releases seem to have
# fixed this.
FEATURE_VERSION_MAP = dict(
native_json=(2, 4, 4),
native_jsonb=(2, 7, 1),
sane_multi_rowcount=(2, 4, 4),
array_oid=(2, 4, 4),
hstore_adapter=(2, 4, 4)
)
@classmethod
def dbapi(cls):
return __import__('psycopg2cffi')
@classmethod
def _psycopg2_extensions(cls):
root = __import__('psycopg2cffi', fromlist=['extensions'])
return root.extensions
@classmethod
def _psycopg2_extras(cls):
root = __import__('psycopg2cffi', fromlist=['extras'])
return root.extras
dialect = PGDialect_psycopg2cffi

View File

@ -0,0 +1,243 @@
# postgresql/pygresql.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: postgresql+pygresql
:name: pygresql
:dbapi: pgdb
:connectstring: postgresql+pygresql://user:password@host:port/dbname\
[?key=value&key=value...]
:url: http://www.pygresql.org/
"""
import decimal
import re
from ... import exc, processors, util
from ...types import Numeric, JSON as Json
from ...sql.elements import Null
from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \
_DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID
from .hstore import HSTORE
from .json import JSON, JSONB
class _PGNumeric(Numeric):
def bind_processor(self, dialect):
return None
def result_processor(self, dialect, coltype):
if not isinstance(coltype, int):
coltype = coltype.oid
if self.asdecimal:
if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
decimal.Decimal,
self._effective_decimal_return_scale)
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# PyGreSQL returns Decimal natively for 1700 (numeric)
return None
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
else:
if coltype in _FLOAT_TYPES:
# PyGreSQL returns float natively for 701 (float8)
return None
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
class _PGHStore(HSTORE):
def bind_processor(self, dialect):
if not dialect.has_native_hstore:
return super(_PGHStore, self).bind_processor(dialect)
hstore = dialect.dbapi.Hstore
def process(value):
if isinstance(value, dict):
return hstore(value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_hstore:
return super(_PGHStore, self).result_processor(dialect, coltype)
class _PGJSON(JSON):
def bind_processor(self, dialect):
if not dialect.has_native_json:
return super(_PGJSON, self).bind_processor(dialect)
json = dialect.dbapi.Json
def process(value):
if value is self.NULL:
value = None
elif isinstance(value, Null) or (
value is None and self.none_as_null):
return None
if value is None or isinstance(value, (dict, list)):
return json(value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_json:
return super(_PGJSON, self).result_processor(dialect, coltype)
class _PGJSONB(JSONB):
def bind_processor(self, dialect):
if not dialect.has_native_json:
return super(_PGJSONB, self).bind_processor(dialect)
json = dialect.dbapi.Json
def process(value):
if value is self.NULL:
value = None
elif isinstance(value, Null) or (
value is None and self.none_as_null):
return None
if value is None or isinstance(value, (dict, list)):
return json(value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_json:
return super(_PGJSONB, self).result_processor(dialect, coltype)
class _PGUUID(UUID):
def bind_processor(self, dialect):
if not dialect.has_native_uuid:
return super(_PGUUID, self).bind_processor(dialect)
uuid = dialect.dbapi.Uuid
def process(value):
if value is None:
return None
if isinstance(value, (str, bytes)):
if len(value) == 16:
return uuid(bytes=value)
return uuid(value)
if isinstance(value, int):
return uuid(int=value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_uuid:
return super(_PGUUID, self).result_processor(dialect, coltype)
if not self.as_uuid:
def process(value):
if value is not None:
return str(value)
return process
class _PGCompiler(PGCompiler):
def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
def post_process_text(self, text):
return text.replace('%', '%%')
class _PGIdentifierPreparer(PGIdentifierPreparer):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace('%', '%%')
class PGDialect_pygresql(PGDialect):
driver = 'pygresql'
statement_compiler = _PGCompiler
preparer = _PGIdentifierPreparer
@classmethod
def dbapi(cls):
import pgdb
return pgdb
colspecs = util.update_copy(
PGDialect.colspecs,
{
Numeric: _PGNumeric,
HSTORE: _PGHStore,
Json: _PGJSON,
JSON: _PGJSON,
JSONB: _PGJSONB,
UUID: _PGUUID,
}
)
def __init__(self, **kwargs):
super(PGDialect_pygresql, self).__init__(**kwargs)
try:
version = self.dbapi.version
m = re.match(r'(\d+)\.(\d+)', version)
version = (int(m.group(1)), int(m.group(2)))
except (AttributeError, ValueError, TypeError):
version = (0, 0)
self.dbapi_version = version
if version < (5, 0):
has_native_hstore = has_native_json = has_native_uuid = False
if version != (0, 0):
util.warn("PyGreSQL is only fully supported by SQLAlchemy"
" since version 5.0.")
else:
self.supports_unicode_statements = True
self.supports_unicode_binds = True
has_native_hstore = has_native_json = has_native_uuid = True
self.has_native_hstore = has_native_hstore
self.has_native_json = has_native_json
self.has_native_uuid = has_native_uuid
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if 'port' in opts:
opts['host'] = '%s:%s' % (
opts.get('host', '').rsplit(':', 1)[0], opts.pop('port'))
opts.update(url.query)
return [], opts
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.Error):
if not connection:
return False
try:
connection = connection.connection
except AttributeError:
pass
else:
if not connection:
return False
try:
return connection.closed
except AttributeError: # PyGreSQL < 5.0
return connection._cnx is None
return False
dialect = PGDialect_pygresql

View File

@ -0,0 +1,168 @@
# Copyright (C) 2013-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .base import ischema_names
from ... import types as sqltypes
__all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE')
class RangeOperators(object):
"""
This mixin provides functionality for the Range Operators
listed in Table 9-44 of the `postgres documentation`__ for Range
Functions and Operators. It is used by all the range types
provided in the ``postgres`` dialect and can likely be used for
any range types you create yourself.
__ http://www.postgresql.org/docs/devel/static/functions-range.html
No extra support is provided for the Range Functions listed in
Table 9-45 of the postgres documentation. For these, the normal
:func:`~sqlalchemy.sql.expression.func` object should be used.
.. versionadded:: 0.8.2 Support for PostgreSQL RANGE operations.
"""
class comparator_factory(sqltypes.Concatenable.Comparator):
"""Define comparison operations for range types."""
def __ne__(self, other):
"Boolean expression. Returns true if two ranges are not equal"
return self.expr.op('<>')(other)
def contains(self, other, **kw):
"""Boolean expression. Returns true if the right hand operand,
which can be an element or a range, is contained within the
column.
"""
return self.expr.op('@>')(other)
def contained_by(self, other):
"""Boolean expression. Returns true if the column is contained
within the right hand operand.
"""
return self.expr.op('<@')(other)
def overlaps(self, other):
"""Boolean expression. Returns true if the column overlaps
(has points in common with) the right hand operand.
"""
return self.expr.op('&&')(other)
def strictly_left_of(self, other):
"""Boolean expression. Returns true if the column is strictly
left of the right hand operand.
"""
return self.expr.op('<<')(other)
__lshift__ = strictly_left_of
def strictly_right_of(self, other):
"""Boolean expression. Returns true if the column is strictly
right of the right hand operand.
"""
return self.expr.op('>>')(other)
__rshift__ = strictly_right_of
def not_extend_right_of(self, other):
"""Boolean expression. Returns true if the range in the column
does not extend right of the range in the operand.
"""
return self.expr.op('&<')(other)
def not_extend_left_of(self, other):
"""Boolean expression. Returns true if the range in the column
does not extend left of the range in the operand.
"""
return self.expr.op('&>')(other)
def adjacent_to(self, other):
"""Boolean expression. Returns true if the range in the column
is adjacent to the range in the operand.
"""
return self.expr.op('-|-')(other)
def __add__(self, other):
"""Range expression. Returns the union of the two ranges.
Will raise an exception if the resulting range is not
contigous.
"""
return self.expr.op('+')(other)
class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL INT4RANGE type.
.. versionadded:: 0.8.2
"""
__visit_name__ = 'INT4RANGE'
ischema_names['int4range'] = INT4RANGE
class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL INT8RANGE type.
.. versionadded:: 0.8.2
"""
__visit_name__ = 'INT8RANGE'
ischema_names['int8range'] = INT8RANGE
class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL NUMRANGE type.
.. versionadded:: 0.8.2
"""
__visit_name__ = 'NUMRANGE'
ischema_names['numrange'] = NUMRANGE
class DATERANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL DATERANGE type.
.. versionadded:: 0.8.2
"""
__visit_name__ = 'DATERANGE'
ischema_names['daterange'] = DATERANGE
class TSRANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL TSRANGE type.
.. versionadded:: 0.8.2
"""
__visit_name__ = 'TSRANGE'
ischema_names['tsrange'] = TSRANGE
class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL TSTZRANGE type.
.. versionadded:: 0.8.2
"""
__visit_name__ = 'TSTZRANGE'
ischema_names['tstzrange'] = TSTZRANGE

View File

@ -1,14 +1,20 @@
from sqlalchemy.dialects.sqlite import base, pysqlite # sqlite/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.sqlite import base, pysqlite, pysqlcipher
# default dialect # default dialect
base.dialect = pysqlite.dialect base.dialect = pysqlite.dialect
from sqlalchemy.dialects.sqlite.base import (
from sqlalchemy.dialects.sqlite.base import \ BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, FLOAT, INTEGER, REAL,
BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, FLOAT, INTEGER,\ NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, dialect,
NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, dialect
__all__ = (
'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'FLOAT', 'INTEGER',
'NUMERIC', 'SMALLINT', 'TEXT', 'TIME', 'TIMESTAMP', 'VARCHAR', 'dialect'
) )
__all__ = ('BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL',
'FLOAT', 'INTEGER', 'NUMERIC', 'SMALLINT', 'TEXT', 'TIME',
'TIMESTAMP', 'VARCHAR', 'REAL', 'dialect')

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,130 @@
# sqlite/pysqlcipher.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: sqlite+pysqlcipher
:name: pysqlcipher
:dbapi: pysqlcipher
:connectstring: sqlite+pysqlcipher://:passphrase/file_path[?kdf_iter=<iter>]
:url: https://pypi.python.org/pypi/pysqlcipher
``pysqlcipher`` is a fork of the standard ``pysqlite`` driver to make
use of the `SQLCipher <https://www.zetetic.net/sqlcipher>`_ backend.
``pysqlcipher3`` is a fork of ``pysqlcipher`` for Python 3. This dialect
will attempt to import it if ``pysqlcipher`` is non-present.
.. versionadded:: 1.1.4 - added fallback import for pysqlcipher3
.. versionadded:: 0.9.9 - added pysqlcipher dialect
Driver
------
The driver here is the `pysqlcipher <https://pypi.python.org/pypi/pysqlcipher>`_
driver, which makes use of the SQLCipher engine. This system essentially
introduces new PRAGMA commands to SQLite which allows the setting of a
passphrase and other encryption parameters, allowing the database
file to be encrypted.
`pysqlcipher3` is a fork of `pysqlcipher` with support for Python 3,
the driver is the same.
Connect Strings
---------------
The format of the connect string is in every way the same as that
of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the
"password" field is now accepted, which should contain a passphrase::
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db')
For an absolute file path, two leading slashes should be used for the
database name::
e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db')
A selection of additional encryption-related pragmas supported by SQLCipher
as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed
in the query string, and will result in that PRAGMA being called for each
new connection. Currently, ``cipher``, ``kdf_iter``
``cipher_page_size`` and ``cipher_use_hmac`` are supported::
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000')
Pooling Behavior
----------------
The driver makes a change to the default pool behavior of pysqlite
as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver
has been observed to be significantly slower on connection than the
pysqlite driver, most likely due to the encryption overhead, so the
dialect here defaults to using the :class:`.SingletonThreadPool`
implementation,
instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool
implementation is entirely configurable using the
:paramref:`.create_engine.poolclass` parameter; the :class:`.StaticPool` may
be more feasible for single-threaded use, or :class:`.NullPool` may be used
to prevent unencrypted connections from being held open for long periods of
time, at the expense of slower startup time for new connections.
"""
from __future__ import absolute_import
from .pysqlite import SQLiteDialect_pysqlite
from ...engine import url as _url
from ... import pool
class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
driver = 'pysqlcipher'
pragmas = ('kdf_iter', 'cipher', 'cipher_page_size', 'cipher_use_hmac')
@classmethod
def dbapi(cls):
try:
from pysqlcipher import dbapi2 as sqlcipher
except ImportError as e:
try:
from pysqlcipher3 import dbapi2 as sqlcipher
except ImportError:
raise e
return sqlcipher
@classmethod
def get_pool_class(cls, url):
return pool.SingletonThreadPool
def connect(self, *cargs, **cparams):
passphrase = cparams.pop('passphrase', '')
pragmas = dict(
(key, cparams.pop(key, None)) for key in
self.pragmas
)
conn = super(SQLiteDialect_pysqlcipher, self).\
connect(*cargs, **cparams)
conn.execute('pragma key="%s"' % passphrase)
for prag, value in pragmas.items():
if value is not None:
conn.execute('pragma %s="%s"' % (prag, value))
return conn
def create_connect_args(self, url):
super_url = _url.URL(
url.drivername, username=url.username,
host=url.host, database=url.database, query=url.query)
c_args, opts = super(SQLiteDialect_pysqlcipher, self).\
create_connect_args(super_url)
opts['passphrase'] = url.password
return c_args, opts
dialect = SQLiteDialect_pysqlcipher

View File

@ -1,6 +1,18 @@
"""Support for the SQLite database via pysqlite. # sqlite/pysqlite.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
Note that pysqlite is the same driver as the ``sqlite3`` r"""
.. dialect:: sqlite+pysqlite
:name: pysqlite
:dbapi: sqlite3
:connectstring: sqlite+pysqlite:///file_path
:url: http://docs.python.org/library/sqlite3.html
Note that ``pysqlite`` is the same driver as the ``sqlite3``
module included with the Python distribution. module included with the Python distribution.
Driver Driver
@ -20,37 +32,36 @@ this explicitly::
from sqlite3 import dbapi2 as sqlite from sqlite3 import dbapi2 as sqlite
e = create_engine('sqlite+pysqlite:///file.db', module=sqlite) e = create_engine('sqlite+pysqlite:///file.db', module=sqlite)
Full documentation on pysqlite is available at:
`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
Connect Strings Connect Strings
--------------- ---------------
The file specification for the SQLite database is taken as the "database" portion of The file specification for the SQLite database is taken as the "database"
the URL. Note that the format of a url is:: portion of the URL. Note that the format of a SQLAlchemy url is::
driver://user:pass@host/database driver://user:pass@host/database
This means that the actual filename to be used starts with the characters to the This means that the actual filename to be used starts with the characters to
**right** of the third slash. So connecting to a relative filepath looks like:: the **right** of the third slash. So connecting to a relative filepath
looks like::
# relative path # relative path
e = create_engine('sqlite:///path/to/database.db') e = create_engine('sqlite:///path/to/database.db')
An absolute path, which is denoted by starting with a slash, means you need **four** An absolute path, which is denoted by starting with a slash, means you
slashes:: need **four** slashes::
# absolute path # absolute path
e = create_engine('sqlite:////path/to/database.db') e = create_engine('sqlite:////path/to/database.db')
To use a Windows path, regular drive specifications and backslashes can be used. To use a Windows path, regular drive specifications and backslashes can be
Double backslashes are probably needed:: used. Double backslashes are probably needed::
# absolute path on Windows # absolute path on Windows
e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db') e = create_engine('sqlite:///C:\\path\\to\\database.db')
The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify The sqlite ``:memory:`` identifier is the default if no filepath is
``sqlite://`` and nothing else:: present. Specify ``sqlite://`` and nothing else::
# in-memory database # in-memory database
e = create_engine('sqlite://') e = create_engine('sqlite://')
@ -68,79 +79,198 @@ pysqlite's driver does not. Additionally, SQLAlchemy does not at
this time automatically render the "cast" syntax required for the this time automatically render the "cast" syntax required for the
freestanding functions "current_timestamp" and "current_date" to return freestanding functions "current_timestamp" and "current_date" to return
datetime/date types natively. Unfortunately, pysqlite datetime/date types natively. Unfortunately, pysqlite
does not provide the standard DBAPI types in `cursor.description`, does not provide the standard DBAPI types in ``cursor.description``,
leaving SQLAlchemy with no way to detect these types on the fly leaving SQLAlchemy with no way to detect these types on the fly
without expensive per-row type checks. without expensive per-row type checks.
Usage of PARSE_DECLTYPES can be forced if one configures Keeping in mind that pysqlite's parsing option is not recommended,
"native_datetime=True" on create_engine():: nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES
can be forced if one configures "native_datetime=True" on create_engine()::
engine = create_engine('sqlite://', engine = create_engine('sqlite://',
connect_args={'detect_types': sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES}, connect_args={'detect_types':
sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
native_datetime=True native_datetime=True
) )
With this flag enabled, the DATE and TIMESTAMP types (but note - not the DATETIME With this flag enabled, the DATE and TIMESTAMP types (but note - not the
or TIME types...confused yet ?) will not perform any bind parameter or result DATETIME or TIME types...confused yet ?) will not perform any bind parameter
processing. Execution of "func.current_date()" will return a string. or result processing. Execution of "func.current_date()" will return a string.
"func.current_timestamp()" is registered as returning a DATETIME type in "func.current_timestamp()" is registered as returning a DATETIME type in
SQLAlchemy, so this function still receives SQLAlchemy-level result processing. SQLAlchemy, so this function still receives SQLAlchemy-level result
processing.
Threading Behavior .. _pysqlite_threading_pooling:
------------------
Pysqlite connections do not support being moved between threads, unless Threading/Pooling Behavior
the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition, ---------------------------
when using an in-memory SQLite database, the full database exists only within
the scope of a single connection. It is reported that an in-memory
database does not support being shared between threads regardless of the
``check_same_thread`` flag - which means that a multithreaded
application **cannot** share data from a ``:memory:`` database across threads
unless access to the connection is limited to a single worker thread which communicates
through a queueing mechanism to concurrent threads.
To provide a default which accomodates SQLite's default threading capabilities Pysqlite's default behavior is to prohibit the usage of a single connection
somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool` in more than one thread. This is originally intended to work with older
be used by default. This pool maintains a single SQLite connection per thread versions of SQLite that did not support multithreaded operation under
that is held open up to a count of five concurrent threads. When more than five threads various circumstances. In particular, older SQLite versions
are used, a cleanup mechanism will dispose of excess unused connections. did not allow a ``:memory:`` database to be used in multiple threads
under any circumstances.
Two optional pool implementations that may be appropriate for particular SQLite usage scenarios: Pysqlite does include a now-undocumented flag known as
``check_same_thread`` which will disable this check, however note that
pysqlite connections are still not safe to use in concurrently in multiple
threads. In particular, any statement execution calls would need to be
externally mutexed, as Pysqlite does not provide for thread-safe propagation
of error messages among other things. So while even ``:memory:`` databases
can be shared among threads in modern SQLite, Pysqlite doesn't provide enough
thread-safety to make this usage worth it.
* the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded SQLAlchemy sets up pooling to work with Pysqlite's default behavior:
application using an in-memory database, assuming the threading issues inherent in
pysqlite are somehow accomodated for. This pool holds persistently onto a single connection
which is never closed, and is returned for all requests.
* the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that * When a ``:memory:`` SQLite database is specified, the dialect by default
makes use of a file-based sqlite database. This pool disables any actual "pooling" will use :class:`.SingletonThreadPool`. This pool maintains a single
behavior, and simply opens and closes real connections corresonding to the :func:`connect()` connection per thread, so that all access to the engine within the current
and :func:`close()` methods. SQLite can "connect" to a particular file with very high thread use the same ``:memory:`` database - other threads would access a
efficiency, so this option may actually perform better without the extra overhead different ``:memory:`` database.
of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection * When a file-based database is specified, the dialect will use
useless since the database would be lost as soon as the connection is "returned" to the pool. :class:`.NullPool` as the source of connections. This pool closes and
discards connections which are returned to the pool immediately. SQLite
file-based connections have extremely low overhead, so pooling is not
necessary. The scheme also prevents a connection from being used again in
a different thread and works best with SQLite's coarse-grained file locking.
.. versionchanged:: 0.7
Default selection of :class:`.NullPool` for SQLite file-based databases.
Previous versions select :class:`.SingletonThreadPool` by
default for all SQLite databases.
Using a Memory Database in Multiple Threads
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
To use a ``:memory:`` database in a multithreaded scenario, the same
connection object must be shared among threads, since the database exists
only within the scope of that connection. The
:class:`.StaticPool` implementation will maintain a single connection
globally, and the ``check_same_thread`` flag can be passed to Pysqlite
as ``False``::
from sqlalchemy.pool import StaticPool
engine = create_engine('sqlite://',
connect_args={'check_same_thread':False},
poolclass=StaticPool)
Note that using a ``:memory:`` database in multiple threads requires a recent
version of SQLite.
Using Temporary Tables with SQLite
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Due to the way SQLite deals with temporary tables, if you wish to use a
temporary table in a file-based SQLite database across multiple checkouts
from the connection pool, such as when using an ORM :class:`.Session` where
the temporary table should continue to remain after :meth:`.Session.commit` or
:meth:`.Session.rollback` is called, a pool which maintains a single
connection must be used. Use :class:`.SingletonThreadPool` if the scope is
only needed within the current thread, or :class:`.StaticPool` is scope is
needed within multiple threads for this case::
# maintain the same connection per thread
from sqlalchemy.pool import SingletonThreadPool
engine = create_engine('sqlite:///mydb.db',
poolclass=SingletonThreadPool)
# maintain the same connection across all threads
from sqlalchemy.pool import StaticPool
engine = create_engine('sqlite:///mydb.db',
poolclass=StaticPool)
Note that :class:`.SingletonThreadPool` should be configured for the number
of threads that are to be used; beyond that number, connections will be
closed out in a non deterministic way.
Unicode Unicode
------- -------
In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's The pysqlite driver only returns Python ``unicode`` objects in result sets,
default behavior regarding Unicode is that all strings are returned as Python unicode objects never plain strings, and accommodates ``unicode`` objects within bound
in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is parameter values in all cases. Regardless of the SQLAlchemy string type in
*not* used, you will still always receive unicode data back from a result set. It is use, string-based result values will by Python ``unicode`` in Python 2.
**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type The :class:`.Unicode` type should still be used to indicate those columns that
to represent strings, since it will raise a warning if a non-unicode Python string is require unicode, however, so that non-``unicode`` values passed inadvertently
passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can will emit a warning. Pysqlite will emit an error if a non-``unicode`` string
quickly create confusion, particularly when using the ORM as internal data is not is passed containing non-ASCII characters.
always represented by an actual database result string.
.. _pysqlite_serializable:
Serializable isolation / Savepoints / Transactional DDL
-------------------------------------------------------
In the section :ref:`sqlite_concurrency`, we refer to the pysqlite
driver's assortment of issues that prevent several features of SQLite
from working correctly. The pysqlite DBAPI driver has several
long-standing bugs which impact the correctness of its transactional
behavior. In its default mode of operation, SQLite features such as
SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are
non-functional, and in order to use these features, workarounds must
be taken.
The issue is essentially that the driver attempts to second-guess the user's
intent, failing to start transactions and sometimes ending them prematurely, in
an effort to minimize the SQLite databases's file locking behavior, even
though SQLite itself uses "shared" locks for read-only activities.
SQLAlchemy chooses to not alter this behavior by default, as it is the
long-expected behavior of the pysqlite driver; if and when the pysqlite
driver attempts to repair these issues, that will be more of a driver towards
defaults for SQLAlchemy.
The good news is that with a few events, we can implement transactional
support fully, by disabling pysqlite's feature entirely and emitting BEGIN
ourselves. This is achieved using two event listeners::
from sqlalchemy import create_engine, event
engine = create_engine("sqlite:///myfile.db")
@event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
# disable pysqlite's emitting of the BEGIN statement entirely.
# also stops it from emitting COMMIT before any DDL.
dbapi_connection.isolation_level = None
@event.listens_for(engine, "begin")
def do_begin(conn):
# emit our own BEGIN
conn.execute("BEGIN")
Above, we intercept a new pysqlite connection and disable any transactional
integration. Then, at the point at which SQLAlchemy knows that transaction
scope is to begin, we emit ``"BEGIN"`` ourselves.
When we take control of ``"BEGIN"``, we can also control directly SQLite's
locking modes, introduced at `BEGIN TRANSACTION <http://sqlite.org/lang_transaction.html>`_,
by adding the desired locking mode to our ``"BEGIN"``::
@event.listens_for(engine, "begin")
def do_begin(conn):
conn.execute("BEGIN EXCLUSIVE")
.. seealso::
`BEGIN TRANSACTION <http://sqlite.org/lang_transaction.html>`_ - on the SQLite site
`sqlite3 SELECT does not BEGIN a transaction <http://bugs.python.org/issue9924>`_ - on the Python bug tracker
`sqlite3 module breaks transactions and potentially corrupts data <http://bugs.python.org/issue10740>`_ - on the Python bug tracker
""" """
from sqlalchemy.dialects.sqlite.base import SQLiteDialect, DATETIME, DATE from sqlalchemy.dialects.sqlite.base import SQLiteDialect, DATETIME, DATE
from sqlalchemy import schema, exc, pool from sqlalchemy import exc, pool
from sqlalchemy.engine import default
from sqlalchemy import types as sqltypes from sqlalchemy import types as sqltypes
from sqlalchemy import util from sqlalchemy import util
import os
class _SQLite_pysqliteTimeStamp(DATETIME): class _SQLite_pysqliteTimeStamp(DATETIME):
def bind_processor(self, dialect): def bind_processor(self, dialect):
@ -155,6 +285,7 @@ class _SQLite_pysqliteTimeStamp(DATETIME):
else: else:
return DATETIME.result_processor(self, dialect, coltype) return DATETIME.result_processor(self, dialect, coltype)
class _SQLite_pysqliteDate(DATE): class _SQLite_pysqliteDate(DATE):
def bind_processor(self, dialect): def bind_processor(self, dialect):
if dialect.native_datetime: if dialect.native_datetime:
@ -168,9 +299,9 @@ class _SQLite_pysqliteDate(DATE):
else: else:
return DATE.result_processor(self, dialect, coltype) return DATE.result_processor(self, dialect, coltype)
class SQLiteDialect_pysqlite(SQLiteDialect): class SQLiteDialect_pysqlite(SQLiteDialect):
default_paramstyle = 'qmark' default_paramstyle = 'qmark'
poolclass = pool.SingletonThreadPool
colspecs = util.update_copy( colspecs = util.update_copy(
SQLiteDialect.colspecs, SQLiteDialect.colspecs,
@ -180,8 +311,8 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
} }
) )
# Py3K if not util.py2k:
#description_encoding = None description_encoding = None
driver = 'pysqlite' driver = 'pysqlite'
@ -201,13 +332,20 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
def dbapi(cls): def dbapi(cls):
try: try:
from pysqlite2 import dbapi2 as sqlite from pysqlite2 import dbapi2 as sqlite
except ImportError, e: except ImportError as e:
try: try:
from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. from sqlite3 import dbapi2 as sqlite # try 2.5+ stdlib name.
except ImportError: except ImportError:
raise e raise e
return sqlite return sqlite
@classmethod
def get_pool_class(cls, url):
if url.database and url.database != ':memory:':
return pool.NullPool
else:
return pool.SingletonThreadPool
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
return self.dbapi.sqlite_version_info return self.dbapi.sqlite_version_info
@ -220,6 +358,8 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
" sqlite:///relative/path/to/file.db\n" " sqlite:///relative/path/to/file.db\n"
" sqlite:////absolute/path/to/file.db" % (url,)) " sqlite:////absolute/path/to/file.db" % (url,))
filename = url.database or ':memory:' filename = url.database or ':memory:'
if filename != ':memory:':
filename = os.path.abspath(filename)
opts = url.query.copy() opts = url.query.copy()
util.coerce_kw_type(opts, 'timeout', float) util.coerce_kw_type(opts, 'timeout', float)
@ -230,7 +370,8 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
return ([filename], opts) return ([filename], opts)
def is_disconnect(self, e): def is_disconnect(self, e, connection, cursor):
return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e) return isinstance(e, self.dbapi.ProgrammingError) and \
"Cannot operate on a closed database." in str(e)
dialect = SQLiteDialect_pysqlite dialect = SQLiteDialect_pysqlite

View File

@ -1,15 +1,23 @@
# sybase/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.sybase import base, pysybase, pyodbc from sqlalchemy.dialects.sybase import base, pysybase, pyodbc
from base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
TEXT,DATE,DATETIME, FLOAT, NUMERIC,\
BIGINT,INT, INTEGER, SMALLINT, BINARY,\
VARBINARY,UNITEXT,UNICHAR,UNIVARCHAR,\
IMAGE,BIT,MONEY,SMALLMONEY,TINYINT
# default dialect # default dialect
base.dialect = pyodbc.dialect base.dialect = pyodbc.dialect
from .base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
TEXT, DATE, DATETIME, FLOAT, NUMERIC,\
BIGINT, INT, INTEGER, SMALLINT, BINARY,\
VARBINARY, UNITEXT, UNICHAR, UNIVARCHAR,\
IMAGE, BIT, MONEY, SMALLMONEY, TINYINT,\
dialect
__all__ = ( __all__ = (
'CHAR', 'VARCHAR', 'TIME', 'NCHAR', 'NVARCHAR', 'CHAR', 'VARCHAR', 'TIME', 'NCHAR', 'NVARCHAR',
'TEXT', 'DATE', 'DATETIME', 'FLOAT', 'NUMERIC', 'TEXT', 'DATE', 'DATETIME', 'FLOAT', 'NUMERIC',

View File

@ -1,5 +1,6 @@
# sybase/base.py # sybase/base.py
# Copyright (C) 2010 Michael Bayer mike_mp@zzzcomputing.com # Copyright (C) 2010-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# get_select_precolumns(), limit_clause() implementation # get_select_precolumns(), limit_clause() implementation
# copyright (C) 2007 Fisch Asset Management # copyright (C) 2007 Fisch Asset Management
# AG http://www.fam.ch, with coding by Alexander Houben # AG http://www.fam.ch, with coding by Alexander Houben
@ -8,14 +9,21 @@
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Support for Sybase Adaptive Server Enterprise (ASE).
Note that this dialect is no longer specific to Sybase iAnywhere.
ASE is the primary support platform.
""" """
.. dialect:: sybase
:name: Sybase
.. note::
The Sybase dialect functions on current SQLAlchemy versions
but is not regularly tested, and may have many issues and
caveats not currently handled.
"""
import operator import operator
import re
from sqlalchemy.sql import compiler, expression, text, bindparam from sqlalchemy.sql import compiler, expression, text, bindparam
from sqlalchemy.engine import default, base, reflection from sqlalchemy.engine import default, base, reflection
from sqlalchemy import types as sqltypes from sqlalchemy import types as sqltypes
@ -27,7 +35,7 @@ from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ TEXT, DATE, DATETIME, FLOAT, NUMERIC,\
BIGINT, INT, INTEGER, SMALLINT, BINARY,\ BIGINT, INT, INTEGER, SMALLINT, BINARY,\
VARBINARY, DECIMAL, TIMESTAMP, Unicode,\ VARBINARY, DECIMAL, TIMESTAMP, Unicode,\
UnicodeText UnicodeText, REAL
RESERVED_WORDS = set([ RESERVED_WORDS = set([
"add", "all", "alter", "and", "add", "all", "alter", "and",
@ -95,103 +103,148 @@ class _SybaseUnitypeMixin(object):
def result_processor(self, dialect, coltype): def result_processor(self, dialect, coltype):
def process(value): def process(value):
if value is not None: if value is not None:
return str(value) #.decode("ucs-2") return str(value) # decode("ucs-2")
else: else:
return None return None
return process return process
class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode): class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
__visit_name__ = 'UNICHAR' __visit_name__ = 'UNICHAR'
class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode): class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
__visit_name__ = 'UNIVARCHAR' __visit_name__ = 'UNIVARCHAR'
class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText): class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText):
__visit_name__ = 'UNITEXT' __visit_name__ = 'UNITEXT'
class TINYINT(sqltypes.Integer): class TINYINT(sqltypes.Integer):
__visit_name__ = 'TINYINT' __visit_name__ = 'TINYINT'
class BIT(sqltypes.TypeEngine): class BIT(sqltypes.TypeEngine):
__visit_name__ = 'BIT' __visit_name__ = 'BIT'
class MONEY(sqltypes.TypeEngine): class MONEY(sqltypes.TypeEngine):
__visit_name__ = "MONEY" __visit_name__ = "MONEY"
class SMALLMONEY(sqltypes.TypeEngine): class SMALLMONEY(sqltypes.TypeEngine):
__visit_name__ = "SMALLMONEY" __visit_name__ = "SMALLMONEY"
class UNIQUEIDENTIFIER(sqltypes.TypeEngine): class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
__visit_name__ = "UNIQUEIDENTIFIER" __visit_name__ = "UNIQUEIDENTIFIER"
class IMAGE(sqltypes.LargeBinary): class IMAGE(sqltypes.LargeBinary):
__visit_name__ = 'IMAGE' __visit_name__ = 'IMAGE'
class SybaseTypeCompiler(compiler.GenericTypeCompiler): class SybaseTypeCompiler(compiler.GenericTypeCompiler):
def visit_large_binary(self, type_): def visit_large_binary(self, type_, **kw):
return self.visit_IMAGE(type_) return self.visit_IMAGE(type_)
def visit_boolean(self, type_): def visit_boolean(self, type_, **kw):
return self.visit_BIT(type_) return self.visit_BIT(type_)
def visit_unicode(self, type_): def visit_unicode(self, type_, **kw):
return self.visit_NVARCHAR(type_) return self.visit_NVARCHAR(type_)
def visit_UNICHAR(self, type_): def visit_UNICHAR(self, type_, **kw):
return "UNICHAR(%d)" % type_.length return "UNICHAR(%d)" % type_.length
def visit_UNIVARCHAR(self, type_): def visit_UNIVARCHAR(self, type_, **kw):
return "UNIVARCHAR(%d)" % type_.length return "UNIVARCHAR(%d)" % type_.length
def visit_UNITEXT(self, type_): def visit_UNITEXT(self, type_, **kw):
return "UNITEXT" return "UNITEXT"
def visit_TINYINT(self, type_): def visit_TINYINT(self, type_, **kw):
return "TINYINT" return "TINYINT"
def visit_IMAGE(self, type_): def visit_IMAGE(self, type_, **kw):
return "IMAGE" return "IMAGE"
def visit_BIT(self, type_): def visit_BIT(self, type_, **kw):
return "BIT" return "BIT"
def visit_MONEY(self, type_): def visit_MONEY(self, type_, **kw):
return "MONEY" return "MONEY"
def visit_SMALLMONEY(self, type_): def visit_SMALLMONEY(self, type_, **kw):
return "SMALLMONEY" return "SMALLMONEY"
def visit_UNIQUEIDENTIFIER(self, type_): def visit_UNIQUEIDENTIFIER(self, type_, **kw):
return "UNIQUEIDENTIFIER" return "UNIQUEIDENTIFIER"
ischema_names = { ischema_names = {
'integer' : INTEGER,
'unsigned int' : INTEGER, # TODO: unsigned flags
'unsigned smallint' : SMALLINT, # TODO: unsigned flags
'unsigned bigint' : BIGINT, # TODO: unsigned flags
'bigint': BIGINT, 'bigint': BIGINT,
'int': INTEGER,
'integer': INTEGER,
'smallint': SMALLINT, 'smallint': SMALLINT,
'tinyint': TINYINT, 'tinyint': TINYINT,
'varchar' : VARCHAR, 'unsigned bigint': BIGINT, # TODO: unsigned flags
'long varchar' : TEXT, # TODO 'unsigned int': INTEGER, # TODO: unsigned flags
'char' : CHAR, 'unsigned smallint': SMALLINT, # TODO: unsigned flags
'decimal' : DECIMAL,
'numeric': NUMERIC, 'numeric': NUMERIC,
'decimal': DECIMAL,
'dec': DECIMAL,
'float': FLOAT, 'float': FLOAT,
'double': NUMERIC, # TODO 'double': NUMERIC, # TODO
'double precision': NUMERIC, # TODO
'real': REAL,
'smallmoney': SMALLMONEY,
'money': MONEY,
'smalldatetime': DATETIME,
'datetime': DATETIME,
'date': DATE,
'time': TIME,
'char': CHAR,
'character': CHAR,
'varchar': VARCHAR,
'character varying': VARCHAR,
'char varying': VARCHAR,
'unichar': UNICHAR,
'unicode character': UNIVARCHAR,
'nchar': NCHAR,
'national char': NCHAR,
'national character': NCHAR,
'nvarchar': NVARCHAR,
'nchar varying': NVARCHAR,
'national char varying': NVARCHAR,
'national character varying': NVARCHAR,
'text': TEXT,
'unitext': UNITEXT,
'binary': BINARY, 'binary': BINARY,
'varbinary': VARBINARY, 'varbinary': VARBINARY,
'bit': BIT,
'image': IMAGE, 'image': IMAGE,
'bit': BIT,
# not in documentation for ASE 15.7
'long varchar': TEXT, # TODO
'timestamp': TIMESTAMP, 'timestamp': TIMESTAMP,
'money': MONEY,
'smallmoney': MONEY,
'uniqueidentifier': UNIQUEIDENTIFIER, 'uniqueidentifier': UNIQUEIDENTIFIER,
} }
class SybaseInspector(reflection.Inspector):
def __init__(self, conn):
reflection.Inspector.__init__(self, conn)
def get_table_id(self, table_name, schema=None):
"""Return the table id from `table_name` and `schema`."""
return self.dialect.get_table_id(self.bind, table_name, schema,
info_cache=self.info_cache)
class SybaseExecutionContext(default.DefaultExecutionContext): class SybaseExecutionContext(default.DefaultExecutionContext):
_enable_identity_insert = False _enable_identity_insert = False
@ -214,12 +267,14 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
insert_has_sequence = seq_column is not None insert_has_sequence = seq_column is not None
if insert_has_sequence: if insert_has_sequence:
self._enable_identity_insert = seq_column.key in self.compiled_parameters[0] self._enable_identity_insert = \
seq_column.key in self.compiled_parameters[0]
else: else:
self._enable_identity_insert = False self._enable_identity_insert = False
if self._enable_identity_insert: if self._enable_identity_insert:
self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.cursor.execute(
"SET IDENTITY_INSERT %s ON" %
self.dialect.identifier_preparer.format_table(tbl)) self.dialect.identifier_preparer.format_table(tbl))
if self.isddl: if self.isddl:
@ -227,13 +282,16 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
# database settings. this error message should be improved to # database settings. this error message should be improved to
# include a note about that. # include a note about that.
if not self.should_autocommit: if not self.should_autocommit:
raise exc.InvalidRequestError("The Sybase dialect only supports " raise exc.InvalidRequestError(
"The Sybase dialect only supports "
"DDL in 'autocommit' mode at this time.") "DDL in 'autocommit' mode at this time.")
self.root_connection.engine.logger.info("AUTOCOMMIT (Assuming no Sybase 'ddl in tran')") self.root_connection.engine.logger.info(
"AUTOCOMMIT (Assuming no Sybase 'ddl in tran')")
self.set_ddl_autocommit(self.root_connection.connection.connection, True)
self.set_ddl_autocommit(
self.root_connection.connection.connection,
True)
def post_exec(self): def post_exec(self):
if self.isddl: if self.isddl:
@ -253,6 +311,7 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
cursor.close() cursor.close()
return lastrowid return lastrowid
class SybaseSQLCompiler(compiler.SQLCompiler): class SybaseSQLCompiler(compiler.SQLCompiler):
ansi_bind_rules = True ansi_bind_rules = True
@ -264,35 +323,40 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
'milliseconds': 'millisecond' 'milliseconds': 'millisecond'
}) })
def get_select_precolumns(self, select): def get_select_precolumns(self, select, **kw):
s = select._distinct and "DISTINCT " or "" s = select._distinct and "DISTINCT " or ""
if select._limit: # TODO: don't think Sybase supports
# bind params for FIRST / TOP
limit = select._limit
if limit:
# if select._limit == 1: # if select._limit == 1:
# s += "FIRST " # s += "FIRST "
# else: # else:
# s += "TOP %s " % (select._limit,) # s += "TOP %s " % (select._limit,)
s += "TOP %s " % (select._limit,) s += "TOP %s " % (limit,)
if select._offset: offset = select._offset
if not select._limit: if offset:
# FIXME: sybase doesn't allow an offset without a limit raise NotImplementedError("Sybase ASE does not support OFFSET")
# so use a huge value for TOP here
s += "TOP 1000000 "
s += "START AT %s " % (select._offset+1,)
return s return s
def get_from_hint_text(self, table, text): def get_from_hint_text(self, table, text):
return text return text
def limit_clause(self, select): def limit_clause(self, select, **kw):
# Limit in sybase is after the select keyword # Limit in sybase is after the select keyword
return "" return ""
def visit_extract(self, extract, **kw): def visit_extract(self, extract, **kw):
field = self.extract_map.get(extract.field, extract.field) field = self.extract_map.get(extract.field, extract.field)
return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) return 'DATEPART("%s", %s)' % (
field, self.process(extract.expr, **kw))
def visit_now_func(self, fn, **kw):
return "GETDATE()"
def for_update_clause(self, select): def for_update_clause(self, select):
# "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use # "FOR UPDATE" is only allowed on "DECLARE CURSOR"
# which SQLAlchemy doesn't use
return '' return ''
def order_by_clause(self, select, **kw): def order_by_clause(self, select, **kw):
@ -309,18 +373,22 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
class SybaseDDLCompiler(compiler.DDLCompiler): class SybaseDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs): def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + \ colspec = self.preparer.format_column(column) + " " + \
self.dialect.type_compiler.process(column.type) self.dialect.type_compiler.process(
column.type, type_expression=column)
if column.table is None: if column.table is None:
raise exc.InvalidRequestError("The Sybase dialect requires Table-bound "\ raise exc.CompileError(
"The Sybase dialect requires Table-bound "
"columns in order to generate DDL") "columns in order to generate DDL")
seq_col = column.table._autoincrement_column seq_col = column.table._autoincrement_column
# install a IDENTITY Sequence if we have an implicit IDENTITY column # install a IDENTITY Sequence if we have an implicit IDENTITY column
if seq_col is column: if seq_col is column:
sequence = isinstance(column.default, sa_schema.Sequence) and column.default sequence = isinstance(column.default, sa_schema.Sequence) \
and column.default
if sequence: if sequence:
start, increment = sequence.start or 1, sequence.increment or 1 start, increment = sequence.start or 1, \
sequence.increment or 1
else: else:
start, increment = 1, 1 start, increment = 1, 1
if (start, increment) == (1, 1): if (start, increment) == (1, 1):
@ -329,28 +397,31 @@ class SybaseDDLCompiler(compiler.DDLCompiler):
# TODO: need correct syntax for this # TODO: need correct syntax for this
colspec += " IDENTITY(%s,%s)" % (start, increment) colspec += " IDENTITY(%s,%s)" % (start, increment)
else: else:
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
if column.nullable is not None: if column.nullable is not None:
if not column.nullable or column.primary_key: if not column.nullable or column.primary_key:
colspec += " NOT NULL" colspec += " NOT NULL"
else: else:
colspec += " NULL" colspec += " NULL"
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
return colspec return colspec
def visit_drop_index(self, drop): def visit_drop_index(self, drop):
index = drop.element index = drop.element
return "\nDROP INDEX %s.%s" % ( return "\nDROP INDEX %s.%s" % (
self.preparer.quote_identifier(index.table.name), self.preparer.quote_identifier(index.table.name),
self.preparer.quote(self._validate_identifier(index.name, False), index.quote) self._prepared_index_name(drop.element,
include_schema=False)
) )
class SybaseIdentifierPreparer(compiler.IdentifierPreparer): class SybaseIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS reserved_words = RESERVED_WORDS
class SybaseDialect(default.DefaultDialect): class SybaseDialect(default.DefaultDialect):
name = 'sybase' name = 'sybase'
supports_unicode_statements = False supports_unicode_statements = False
@ -368,10 +439,14 @@ class SybaseDialect(default.DefaultDialect):
statement_compiler = SybaseSQLCompiler statement_compiler = SybaseSQLCompiler
ddl_compiler = SybaseDDLCompiler ddl_compiler = SybaseDDLCompiler
preparer = SybaseIdentifierPreparer preparer = SybaseIdentifierPreparer
inspector = SybaseInspector
construct_arguments = []
def _get_default_schema_name(self, connection): def _get_default_schema_name(self, connection):
return connection.scalar( return connection.scalar(
text("SELECT user_name() as user_name", typemap={'user_name':Unicode}) text("SELECT user_name() as user_name",
typemap={'user_name': Unicode})
) )
def initialize(self, connection): def initialize(self, connection):
@ -382,39 +457,365 @@ class SybaseDialect(default.DefaultDialect):
else: else:
self.max_identifier_length = 255 self.max_identifier_length = 255
def get_table_id(self, connection, table_name, schema=None, **kw):
"""Fetch the id for schema.table_name.
Several reflection methods require the table id. The idea for using
this method is that it can be fetched one time and cached for
subsequent calls.
"""
table_id = None
if schema is None:
schema = self.default_schema_name
TABLEID_SQL = text("""
SELECT o.id AS id
FROM sysobjects o JOIN sysusers u ON o.uid=u.uid
WHERE u.name = :schema_name
AND o.name = :table_name
AND o.type in ('U', 'V')
""")
if util.py2k:
if isinstance(schema, unicode):
schema = schema.encode("ascii")
if isinstance(table_name, unicode):
table_name = table_name.encode("ascii")
result = connection.execute(TABLEID_SQL,
schema_name=schema,
table_name=table_name)
table_id = result.scalar()
if table_id is None:
raise exc.NoSuchTableError(table_name)
return table_id
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
table_id = self.get_table_id(connection, table_name, schema,
info_cache=kw.get("info_cache"))
COLUMN_SQL = text("""
SELECT col.name AS name,
t.name AS type,
(col.status & 8) AS nullable,
(col.status & 128) AS autoincrement,
com.text AS 'default',
col.prec AS precision,
col.scale AS scale,
col.length AS length
FROM systypes t, syscolumns col LEFT OUTER JOIN syscomments com ON
col.cdefault = com.id
WHERE col.usertype = t.usertype
AND col.id = :table_id
ORDER BY col.colid
""")
results = connection.execute(COLUMN_SQL, table_id=table_id)
columns = []
for (name, type_, nullable, autoincrement, default, precision, scale,
length) in results:
col_info = self._get_column_info(name, type_, bool(nullable),
bool(autoincrement),
default, precision, scale,
length)
columns.append(col_info)
return columns
def _get_column_info(self, name, type_, nullable, autoincrement, default,
precision, scale, length):
coltype = self.ischema_names.get(type_, None)
kwargs = {}
if coltype in (NUMERIC, DECIMAL):
args = (precision, scale)
elif coltype == FLOAT:
args = (precision,)
elif coltype in (CHAR, VARCHAR, UNICHAR, UNIVARCHAR, NCHAR, NVARCHAR):
args = (length,)
else:
args = ()
if coltype:
coltype = coltype(*args, **kwargs)
# is this necessary
# if is_array:
# coltype = ARRAY(coltype)
else:
util.warn("Did not recognize type '%s' of column '%s'" %
(type_, name))
coltype = sqltypes.NULLTYPE
if default:
default = default.replace("DEFAULT", "").strip()
default = re.sub("^'(.*)'$", lambda m: m.group(1), default)
else:
default = None
column_info = dict(name=name, type=coltype, nullable=nullable,
default=default, autoincrement=autoincrement)
return column_info
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
table_id = self.get_table_id(connection, table_name, schema,
info_cache=kw.get("info_cache"))
table_cache = {}
column_cache = {}
foreign_keys = []
table_cache[table_id] = {"name": table_name, "schema": schema}
COLUMN_SQL = text("""
SELECT c.colid AS id, c.name AS name
FROM syscolumns c
WHERE c.id = :table_id
""")
results = connection.execute(COLUMN_SQL, table_id=table_id)
columns = {}
for col in results:
columns[col["id"]] = col["name"]
column_cache[table_id] = columns
REFCONSTRAINT_SQL = text("""
SELECT o.name AS name, r.reftabid AS reftable_id,
r.keycnt AS 'count',
r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3,
r.fokey4 AS fokey4, r.fokey5 AS fokey5, r.fokey6 AS fokey6,
r.fokey7 AS fokey7, r.fokey1 AS fokey8, r.fokey9 AS fokey9,
r.fokey10 AS fokey10, r.fokey11 AS fokey11, r.fokey12 AS fokey12,
r.fokey13 AS fokey13, r.fokey14 AS fokey14, r.fokey15 AS fokey15,
r.fokey16 AS fokey16,
r.refkey1 AS refkey1, r.refkey2 AS refkey2, r.refkey3 AS refkey3,
r.refkey4 AS refkey4, r.refkey5 AS refkey5, r.refkey6 AS refkey6,
r.refkey7 AS refkey7, r.refkey1 AS refkey8, r.refkey9 AS refkey9,
r.refkey10 AS refkey10, r.refkey11 AS refkey11,
r.refkey12 AS refkey12, r.refkey13 AS refkey13,
r.refkey14 AS refkey14, r.refkey15 AS refkey15,
r.refkey16 AS refkey16
FROM sysreferences r JOIN sysobjects o on r.tableid = o.id
WHERE r.tableid = :table_id
""")
referential_constraints = connection.execute(
REFCONSTRAINT_SQL, table_id=table_id).fetchall()
REFTABLE_SQL = text("""
SELECT o.name AS name, u.name AS 'schema'
FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
WHERE o.id = :table_id
""")
for r in referential_constraints:
reftable_id = r["reftable_id"]
if reftable_id not in table_cache:
c = connection.execute(REFTABLE_SQL, table_id=reftable_id)
reftable = c.fetchone()
c.close()
table_info = {"name": reftable["name"], "schema": None}
if (schema is not None or
reftable["schema"] != self.default_schema_name):
table_info["schema"] = reftable["schema"]
table_cache[reftable_id] = table_info
results = connection.execute(COLUMN_SQL, table_id=reftable_id)
reftable_columns = {}
for col in results:
reftable_columns[col["id"]] = col["name"]
column_cache[reftable_id] = reftable_columns
reftable = table_cache[reftable_id]
reftable_columns = column_cache[reftable_id]
constrained_columns = []
referred_columns = []
for i in range(1, r["count"] + 1):
constrained_columns.append(columns[r["fokey%i" % i]])
referred_columns.append(reftable_columns[r["refkey%i" % i]])
fk_info = {
"constrained_columns": constrained_columns,
"referred_schema": reftable["schema"],
"referred_table": reftable["name"],
"referred_columns": referred_columns,
"name": r["name"]
}
foreign_keys.append(fk_info)
return foreign_keys
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
table_id = self.get_table_id(connection, table_name, schema,
info_cache=kw.get("info_cache"))
INDEX_SQL = text("""
SELECT object_name(i.id) AS table_name,
i.keycnt AS 'count',
i.name AS name,
(i.status & 0x2) AS 'unique',
index_col(object_name(i.id), i.indid, 1) AS col_1,
index_col(object_name(i.id), i.indid, 2) AS col_2,
index_col(object_name(i.id), i.indid, 3) AS col_3,
index_col(object_name(i.id), i.indid, 4) AS col_4,
index_col(object_name(i.id), i.indid, 5) AS col_5,
index_col(object_name(i.id), i.indid, 6) AS col_6,
index_col(object_name(i.id), i.indid, 7) AS col_7,
index_col(object_name(i.id), i.indid, 8) AS col_8,
index_col(object_name(i.id), i.indid, 9) AS col_9,
index_col(object_name(i.id), i.indid, 10) AS col_10,
index_col(object_name(i.id), i.indid, 11) AS col_11,
index_col(object_name(i.id), i.indid, 12) AS col_12,
index_col(object_name(i.id), i.indid, 13) AS col_13,
index_col(object_name(i.id), i.indid, 14) AS col_14,
index_col(object_name(i.id), i.indid, 15) AS col_15,
index_col(object_name(i.id), i.indid, 16) AS col_16
FROM sysindexes i, sysobjects o
WHERE o.id = i.id
AND o.id = :table_id
AND (i.status & 2048) = 0
AND i.indid BETWEEN 1 AND 254
""")
results = connection.execute(INDEX_SQL, table_id=table_id)
indexes = []
for r in results:
column_names = []
for i in range(1, r["count"]):
column_names.append(r["col_%i" % (i,)])
index_info = {"name": r["name"],
"unique": bool(r["unique"]),
"column_names": column_names}
indexes.append(index_info)
return indexes
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
table_id = self.get_table_id(connection, table_name, schema,
info_cache=kw.get("info_cache"))
PK_SQL = text("""
SELECT object_name(i.id) AS table_name,
i.keycnt AS 'count',
i.name AS name,
index_col(object_name(i.id), i.indid, 1) AS pk_1,
index_col(object_name(i.id), i.indid, 2) AS pk_2,
index_col(object_name(i.id), i.indid, 3) AS pk_3,
index_col(object_name(i.id), i.indid, 4) AS pk_4,
index_col(object_name(i.id), i.indid, 5) AS pk_5,
index_col(object_name(i.id), i.indid, 6) AS pk_6,
index_col(object_name(i.id), i.indid, 7) AS pk_7,
index_col(object_name(i.id), i.indid, 8) AS pk_8,
index_col(object_name(i.id), i.indid, 9) AS pk_9,
index_col(object_name(i.id), i.indid, 10) AS pk_10,
index_col(object_name(i.id), i.indid, 11) AS pk_11,
index_col(object_name(i.id), i.indid, 12) AS pk_12,
index_col(object_name(i.id), i.indid, 13) AS pk_13,
index_col(object_name(i.id), i.indid, 14) AS pk_14,
index_col(object_name(i.id), i.indid, 15) AS pk_15,
index_col(object_name(i.id), i.indid, 16) AS pk_16
FROM sysindexes i, sysobjects o
WHERE o.id = i.id
AND o.id = :table_id
AND (i.status & 2048) = 2048
AND i.indid BETWEEN 1 AND 254
""")
results = connection.execute(PK_SQL, table_id=table_id)
pks = results.fetchone()
results.close()
constrained_columns = []
if pks:
for i in range(1, pks["count"] + 1):
constrained_columns.append(pks["pk_%i" % (i,)])
return {"constrained_columns": constrained_columns,
"name": pks["name"]}
else:
return {"constrained_columns": [], "name": None}
@reflection.cache
def get_schema_names(self, connection, **kw):
SCHEMA_SQL = text("SELECT u.name AS name FROM sysusers u")
schemas = connection.execute(SCHEMA_SQL)
return [s["name"] for s in schemas]
@reflection.cache @reflection.cache
def get_table_names(self, connection, schema=None, **kw): def get_table_names(self, connection, schema=None, **kw):
if schema is None: if schema is None:
schema = self.default_schema_name schema = self.default_schema_name
result = connection.execute( TABLE_SQL = text("""
text("select sysobjects.name from sysobjects, sysusers " SELECT o.name AS name
"where sysobjects.uid=sysusers.uid and " FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
"sysusers.name=:schemaname and " WHERE u.name = :schema_name
"sysobjects.type='U'", AND o.type = 'U'
bindparams=[ """)
bindparam('schemaname', schema)
])
)
return [r[0] for r in result]
def has_table(self, connection, tablename, schema=None): if util.py2k:
if isinstance(schema, unicode):
schema = schema.encode("ascii")
tables = connection.execute(TABLE_SQL, schema_name=schema)
return [t["name"] for t in tables]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
if schema is None: if schema is None:
schema = self.default_schema_name schema = self.default_schema_name
result = connection.execute( VIEW_DEF_SQL = text("""
text("select sysobjects.name from sysobjects, sysusers " SELECT c.text
"where sysobjects.uid=sysusers.uid and " FROM syscomments c JOIN sysobjects o ON c.id = o.id
"sysobjects.name=:tablename and " WHERE o.name = :view_name
"sysusers.name=:schemaname and " AND o.type = 'V'
"sysobjects.type='U'", """)
bindparams=[
bindparam('tablename', tablename),
bindparam('schemaname', schema)
])
)
return result.scalar() is not None
def reflecttable(self, connection, table, include_columns): if util.py2k:
raise NotImplementedError() if isinstance(view_name, unicode):
view_name = view_name.encode("ascii")
view = connection.execute(VIEW_DEF_SQL, view_name=view_name)
return view.scalar()
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
if schema is None:
schema = self.default_schema_name
VIEW_SQL = text("""
SELECT o.name AS name
FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
WHERE u.name = :schema_name
AND o.type = 'V'
""")
if util.py2k:
if isinstance(schema, unicode):
schema = schema.encode("ascii")
views = connection.execute(VIEW_SQL, schema_name=schema)
return [v["name"] for v in views]
def has_table(self, connection, table_name, schema=None):
try:
self.get_table_id(connection, table_name, schema)
except exc.NoSuchTableError:
return False
else:
return True

View File

@ -1,16 +1,32 @@
# sybase/mxodbc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
""" """
Support for Sybase via mxodbc.
.. dialect:: sybase+mxodbc
:name: mxODBC
:dbapi: mxodbc
:connectstring: sybase+mxodbc://<username>:<password>@<dsnname>
:url: http://www.egenix.com/
.. note::
This dialect is a stub only and is likely non functional at this time. This dialect is a stub only and is likely non functional at this time.
""" """
from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext from sqlalchemy.dialects.sybase.base import SybaseDialect
from sqlalchemy.dialects.sybase.base import SybaseExecutionContext
from sqlalchemy.connectors.mxodbc import MxODBCConnector from sqlalchemy.connectors.mxodbc import MxODBCConnector
class SybaseExecutionContext_mxodbc(SybaseExecutionContext): class SybaseExecutionContext_mxodbc(SybaseExecutionContext):
pass pass
class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect): class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect):
execution_ctx_cls = SybaseExecutionContext_mxodbc execution_ctx_cls = SybaseExecutionContext_mxodbc

View File

@ -1,12 +1,18 @@
# sybase/pyodbc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
""" """
Support for Sybase via pyodbc. .. dialect:: sybase+pyodbc
:name: PyODBC
:dbapi: pyodbc
:connectstring: sybase+pyodbc://<username>:<password>@<dsnname>\
[/<database>]
:url: http://pypi.python.org/pypi/pyodbc/
http://pypi.python.org/pypi/pyodbc/
Connect strings are of the form::
sybase+pyodbc://<username>:<password>@<dsn>/
sybase+pyodbc://<username>:<password>@<host>/<database>
Unicode Support Unicode Support
--------------- ---------------
@ -28,10 +34,12 @@ Currently *not* supported are::
""" """
from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext from sqlalchemy.dialects.sybase.base import SybaseDialect,\
SybaseExecutionContext
from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy import types as sqltypes, processors
import decimal import decimal
from sqlalchemy import types as sqltypes, util, processors
class _SybNumeric_pyodbc(sqltypes.Numeric): class _SybNumeric_pyodbc(sqltypes.Numeric):
"""Turns Decimals with adjusted() < -6 into floats. """Turns Decimals with adjusted() < -6 into floats.
@ -43,7 +51,8 @@ class _SybNumeric_pyodbc(sqltypes.Numeric):
""" """
def bind_processor(self, dialect): def bind_processor(self, dialect):
super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect) super_process = super(_SybNumeric_pyodbc, self).\
bind_processor(dialect)
def process(value): def process(value):
if self.asdecimal and \ if self.asdecimal and \
@ -58,6 +67,7 @@ class _SybNumeric_pyodbc(sqltypes.Numeric):
return value return value
return process return process
class SybaseExecutionContext_pyodbc(SybaseExecutionContext): class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
def set_ddl_autocommit(self, connection, value): def set_ddl_autocommit(self, connection, value):
if value: if value:
@ -65,6 +75,7 @@ class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
else: else:
connection.autocommit = False connection.autocommit = False
class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect): class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect):
execution_ctx_cls = SybaseExecutionContext_pyodbc execution_ctx_cls = SybaseExecutionContext_pyodbc

View File

@ -1,17 +1,17 @@
# pysybase.py # sybase/pysybase.py
# Copyright (C) 2010 Michael Bayer mike_mp@zzzcomputing.com # Copyright (C) 2010-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
""" """
Support for Sybase via the python-sybase driver. .. dialect:: sybase+pysybase
:name: Python-Sybase
http://python-sybase.sourceforge.net/ :dbapi: Sybase
:connectstring: sybase+pysybase://<username>:<password>@<dsn>/\
Connect strings are of the form:: [database name]
:url: http://python-sybase.sourceforge.net/
sybase+pysybase://<username>:<password>@<dsn>/[database name]
Unicode Support Unicode Support
--------------- ---------------
@ -33,6 +33,7 @@ class _SybNumeric(sqltypes.Numeric):
else: else:
return sqltypes.Numeric.result_processor(self, dialect, type_) return sqltypes.Numeric.result_processor(self, dialect, type_)
class SybaseExecutionContext_pysybase(SybaseExecutionContext): class SybaseExecutionContext_pysybase(SybaseExecutionContext):
def set_ddl_autocommit(self, dbapi_connection, value): def set_ddl_autocommit(self, dbapi_connection, value):
@ -52,9 +53,10 @@ class SybaseExecutionContext_pysybase(SybaseExecutionContext):
class SybaseSQLCompiler_pysybase(SybaseSQLCompiler): class SybaseSQLCompiler_pysybase(SybaseSQLCompiler):
def bindparam_string(self, name): def bindparam_string(self, name, **kw):
return "@" + name return "@" + name
class SybaseDialect_pysybase(SybaseDialect): class SybaseDialect_pysybase(SybaseDialect):
driver = 'pysybase' driver = 'pysybase'
execution_ctx_cls = SybaseExecutionContext_pysybase execution_ctx_cls = SybaseExecutionContext_pysybase
@ -83,11 +85,13 @@ class SybaseDialect_pysybase(SybaseDialect):
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
vers = connection.scalar("select @@version_number") vers = connection.scalar("select @@version_number")
# i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0), (12, 5, 0, 0) # i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0),
# (12, 5, 0, 0)
return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10) return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10)
def is_disconnect(self, e): def is_disconnect(self, e, connection, cursor):
if isinstance(e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)): if isinstance(e, (self.dbapi.OperationalError,
self.dbapi.ProgrammingError)):
msg = str(e) msg = str(e)
return ('Unable to complete network request to host' in msg or return ('Unable to complete network request to host' in msg or
'Invalid connection state' in msg or 'Invalid connection state' in msg or

File diff suppressed because it is too large Load Diff

1435
sqlalchemy/engine/result.py Normal file

File diff suppressed because it is too large Load Diff

74
sqlalchemy/engine/util.py Normal file
View File

@ -0,0 +1,74 @@
# engine/util.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .. import util
def connection_memoize(key):
"""Decorator, memoize a function in a connection.info stash.
Only applicable to functions which take no arguments other than a
connection. The memo will be stored in ``connection.info[key]``.
"""
@util.decorator
def decorated(fn, self, connection):
connection = connection.connect()
try:
return connection.info[key]
except KeyError:
connection.info[key] = val = fn(self, connection)
return val
return decorated
def py_fallback():
def _distill_params(multiparams, params):
"""Given arguments from the calling form *multiparams, **params,
return a list of bind parameter structures, usually a list of
dictionaries.
In the case of 'raw' execution which accepts positional parameters,
it may be a list of tuples or lists.
"""
if not multiparams:
if params:
return [params]
else:
return []
elif len(multiparams) == 1:
zero = multiparams[0]
if isinstance(zero, (list, tuple)):
if not zero or hasattr(zero[0], '__iter__') and \
not hasattr(zero[0], 'strip'):
# execute(stmt, [{}, {}, {}, ...])
# execute(stmt, [(), (), (), ...])
return zero
else:
# execute(stmt, ("value", "value"))
return [zero]
elif hasattr(zero, 'keys'):
# execute(stmt, {"key":"value"})
return [zero]
else:
# execute(stmt, "value")
return [[zero]]
else:
if hasattr(multiparams[0], '__iter__') and \
not hasattr(multiparams[0], 'strip'):
return multiparams
else:
return [multiparams]
return locals()
try:
from sqlalchemy.cutils import _distill_params
except ImportError:
globals().update(py_fallback())

View File

@ -0,0 +1,11 @@
# event/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .api import CANCEL, NO_RETVAL, listen, listens_for, remove, contains
from .base import Events, dispatcher
from .attr import RefCollection
from .legacy import _legacy_signature

188
sqlalchemy/event/api.py Normal file
View File

@ -0,0 +1,188 @@
# event/api.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Public API functions for the event system.
"""
from __future__ import absolute_import
from .. import util, exc
from .base import _registrars
from .registry import _EventKey
CANCEL = util.symbol('CANCEL')
NO_RETVAL = util.symbol('NO_RETVAL')
def _event_key(target, identifier, fn):
for evt_cls in _registrars[identifier]:
tgt = evt_cls._accept_with(target)
if tgt is not None:
return _EventKey(target, identifier, fn, tgt)
else:
raise exc.InvalidRequestError("No such event '%s' for target '%s'" %
(identifier, target))
def listen(target, identifier, fn, *args, **kw):
"""Register a listener function for the given target.
e.g.::
from sqlalchemy import event
from sqlalchemy.schema import UniqueConstraint
def unique_constraint_name(const, table):
const.name = "uq_%s_%s" % (
table.name,
list(const.columns)[0].name
)
event.listen(
UniqueConstraint,
"after_parent_attach",
unique_constraint_name)
A given function can also be invoked for only the first invocation
of the event using the ``once`` argument::
def on_config():
do_config()
event.listen(Mapper, "before_configure", on_config, once=True)
.. versionadded:: 0.9.4 Added ``once=True`` to :func:`.event.listen`
and :func:`.event.listens_for`.
.. note::
The :func:`.listen` function cannot be called at the same time
that the target event is being run. This has implications
for thread safety, and also means an event cannot be added
from inside the listener function for itself. The list of
events to be run are present inside of a mutable collection
that can't be changed during iteration.
Event registration and removal is not intended to be a "high
velocity" operation; it is a configurational operation. For
systems that need to quickly associate and deassociate with
events at high scale, use a mutable structure that is handled
from inside of a single listener.
.. versionchanged:: 1.0.0 - a ``collections.deque()`` object is now
used as the container for the list of events, which explicitly
disallows collection mutation while the collection is being
iterated.
.. seealso::
:func:`.listens_for`
:func:`.remove`
"""
_event_key(target, identifier, fn).listen(*args, **kw)
def listens_for(target, identifier, *args, **kw):
"""Decorate a function as a listener for the given target + identifier.
e.g.::
from sqlalchemy import event
from sqlalchemy.schema import UniqueConstraint
@event.listens_for(UniqueConstraint, "after_parent_attach")
def unique_constraint_name(const, table):
const.name = "uq_%s_%s" % (
table.name,
list(const.columns)[0].name
)
A given function can also be invoked for only the first invocation
of the event using the ``once`` argument::
@event.listens_for(Mapper, "before_configure", once=True)
def on_config():
do_config()
.. versionadded:: 0.9.4 Added ``once=True`` to :func:`.event.listen`
and :func:`.event.listens_for`.
.. seealso::
:func:`.listen` - general description of event listening
"""
def decorate(fn):
listen(target, identifier, fn, *args, **kw)
return fn
return decorate
def remove(target, identifier, fn):
"""Remove an event listener.
The arguments here should match exactly those which were sent to
:func:`.listen`; all the event registration which proceeded as a result
of this call will be reverted by calling :func:`.remove` with the same
arguments.
e.g.::
# if a function was registered like this...
@event.listens_for(SomeMappedClass, "before_insert", propagate=True)
def my_listener_function(*arg):
pass
# ... it's removed like this
event.remove(SomeMappedClass, "before_insert", my_listener_function)
Above, the listener function associated with ``SomeMappedClass`` was also
propagated to subclasses of ``SomeMappedClass``; the :func:`.remove`
function will revert all of these operations.
.. versionadded:: 0.9.0
.. note::
The :func:`.remove` function cannot be called at the same time
that the target event is being run. This has implications
for thread safety, and also means an event cannot be removed
from inside the listener function for itself. The list of
events to be run are present inside of a mutable collection
that can't be changed during iteration.
Event registration and removal is not intended to be a "high
velocity" operation; it is a configurational operation. For
systems that need to quickly associate and deassociate with
events at high scale, use a mutable structure that is handled
from inside of a single listener.
.. versionchanged:: 1.0.0 - a ``collections.deque()`` object is now
used as the container for the list of events, which explicitly
disallows collection mutation while the collection is being
iterated.
.. seealso::
:func:`.listen`
"""
_event_key(target, identifier, fn).remove()
def contains(target, identifier, fn):
"""Return True if the given target/ident/fn is set up to listen.
.. versionadded:: 0.9.0
"""
return _event_key(target, identifier, fn).contains()

373
sqlalchemy/event/attr.py Normal file
View File

@ -0,0 +1,373 @@
# event/attr.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Attribute implementation for _Dispatch classes.
The various listener targets for a particular event class are represented
as attributes, which refer to collections of listeners to be fired off.
These collections can exist at the class level as well as at the instance
level. An event is fired off using code like this::
some_object.dispatch.first_connect(arg1, arg2)
Above, ``some_object.dispatch`` would be an instance of ``_Dispatch`` and
``first_connect`` is typically an instance of ``_ListenerCollection``
if event listeners are present, or ``_EmptyListener`` if none are present.
The attribute mechanics here spend effort trying to ensure listener functions
are available with a minimum of function call overhead, that unnecessary
objects aren't created (i.e. many empty per-instance listener collections),
as well as that everything is garbage collectable when owning references are
lost. Other features such as "propagation" of listener functions across
many ``_Dispatch`` instances, "joining" of multiple ``_Dispatch`` instances,
as well as support for subclass propagation (e.g. events assigned to
``Pool`` vs. ``QueuePool``) are all implemented here.
"""
from __future__ import absolute_import, with_statement
from .. import util
from ..util import threading
from . import registry
from . import legacy
from itertools import chain
import weakref
import collections
class RefCollection(util.MemoizedSlots):
__slots__ = 'ref',
def _memoized_attr_ref(self):
return weakref.ref(self, registry._collection_gced)
class _ClsLevelDispatch(RefCollection):
"""Class-level events on :class:`._Dispatch` classes."""
__slots__ = ('name', 'arg_names', 'has_kw',
'legacy_signatures', '_clslevel', '__weakref__')
def __init__(self, parent_dispatch_cls, fn):
self.name = fn.__name__
argspec = util.inspect_getargspec(fn)
self.arg_names = argspec.args[1:]
self.has_kw = bool(argspec.keywords)
self.legacy_signatures = list(reversed(
sorted(
getattr(fn, '_legacy_signatures', []),
key=lambda s: s[0]
)
))
fn.__doc__ = legacy._augment_fn_docs(self, parent_dispatch_cls, fn)
self._clslevel = weakref.WeakKeyDictionary()
def _adjust_fn_spec(self, fn, named):
if named:
fn = self._wrap_fn_for_kw(fn)
if self.legacy_signatures:
try:
argspec = util.get_callable_argspec(fn, no_self=True)
except TypeError:
pass
else:
fn = legacy._wrap_fn_for_legacy(self, fn, argspec)
return fn
def _wrap_fn_for_kw(self, fn):
def wrap_kw(*args, **kw):
argdict = dict(zip(self.arg_names, args))
argdict.update(kw)
return fn(**argdict)
return wrap_kw
def insert(self, event_key, propagate):
target = event_key.dispatch_target
assert isinstance(target, type), \
"Class-level Event targets must be classes."
stack = [target]
while stack:
cls = stack.pop(0)
stack.extend(cls.__subclasses__())
if cls is not target and cls not in self._clslevel:
self.update_subclass(cls)
else:
if cls not in self._clslevel:
self._clslevel[cls] = collections.deque()
self._clslevel[cls].appendleft(event_key._listen_fn)
registry._stored_in_collection(event_key, self)
def append(self, event_key, propagate):
target = event_key.dispatch_target
assert isinstance(target, type), \
"Class-level Event targets must be classes."
stack = [target]
while stack:
cls = stack.pop(0)
stack.extend(cls.__subclasses__())
if cls is not target and cls not in self._clslevel:
self.update_subclass(cls)
else:
if cls not in self._clslevel:
self._clslevel[cls] = collections.deque()
self._clslevel[cls].append(event_key._listen_fn)
registry._stored_in_collection(event_key, self)
def update_subclass(self, target):
if target not in self._clslevel:
self._clslevel[target] = collections.deque()
clslevel = self._clslevel[target]
for cls in target.__mro__[1:]:
if cls in self._clslevel:
clslevel.extend([
fn for fn
in self._clslevel[cls]
if fn not in clslevel
])
def remove(self, event_key):
target = event_key.dispatch_target
stack = [target]
while stack:
cls = stack.pop(0)
stack.extend(cls.__subclasses__())
if cls in self._clslevel:
self._clslevel[cls].remove(event_key._listen_fn)
registry._removed_from_collection(event_key, self)
def clear(self):
"""Clear all class level listeners"""
to_clear = set()
for dispatcher in self._clslevel.values():
to_clear.update(dispatcher)
dispatcher.clear()
registry._clear(self, to_clear)
def for_modify(self, obj):
"""Return an event collection which can be modified.
For _ClsLevelDispatch at the class level of
a dispatcher, this returns self.
"""
return self
class _InstanceLevelDispatch(RefCollection):
__slots__ = ()
def _adjust_fn_spec(self, fn, named):
return self.parent._adjust_fn_spec(fn, named)
class _EmptyListener(_InstanceLevelDispatch):
"""Serves as a proxy interface to the events
served by a _ClsLevelDispatch, when there are no
instance-level events present.
Is replaced by _ListenerCollection when instance-level
events are added.
"""
propagate = frozenset()
listeners = ()
__slots__ = 'parent', 'parent_listeners', 'name'
def __init__(self, parent, target_cls):
if target_cls not in parent._clslevel:
parent.update_subclass(target_cls)
self.parent = parent # _ClsLevelDispatch
self.parent_listeners = parent._clslevel[target_cls]
self.name = parent.name
def for_modify(self, obj):
"""Return an event collection which can be modified.
For _EmptyListener at the instance level of
a dispatcher, this generates a new
_ListenerCollection, applies it to the instance,
and returns it.
"""
result = _ListenerCollection(self.parent, obj._instance_cls)
if getattr(obj, self.name) is self:
setattr(obj, self.name, result)
else:
assert isinstance(getattr(obj, self.name), _JoinedListener)
return result
def _needs_modify(self, *args, **kw):
raise NotImplementedError("need to call for_modify()")
exec_once = insert = append = remove = clear = _needs_modify
def __call__(self, *args, **kw):
"""Execute this event."""
for fn in self.parent_listeners:
fn(*args, **kw)
def __len__(self):
return len(self.parent_listeners)
def __iter__(self):
return iter(self.parent_listeners)
def __bool__(self):
return bool(self.parent_listeners)
__nonzero__ = __bool__
class _CompoundListener(_InstanceLevelDispatch):
__slots__ = '_exec_once_mutex', '_exec_once'
def _memoized_attr__exec_once_mutex(self):
return threading.Lock()
def exec_once(self, *args, **kw):
"""Execute this event, but only if it has not been
executed already for this collection."""
if not self._exec_once:
with self._exec_once_mutex:
if not self._exec_once:
try:
self(*args, **kw)
finally:
self._exec_once = True
def __call__(self, *args, **kw):
"""Execute this event."""
for fn in self.parent_listeners:
fn(*args, **kw)
for fn in self.listeners:
fn(*args, **kw)
def __len__(self):
return len(self.parent_listeners) + len(self.listeners)
def __iter__(self):
return chain(self.parent_listeners, self.listeners)
def __bool__(self):
return bool(self.listeners or self.parent_listeners)
__nonzero__ = __bool__
class _ListenerCollection(_CompoundListener):
"""Instance-level attributes on instances of :class:`._Dispatch`.
Represents a collection of listeners.
As of 0.7.9, _ListenerCollection is only first
created via the _EmptyListener.for_modify() method.
"""
__slots__ = (
'parent_listeners', 'parent', 'name', 'listeners',
'propagate', '__weakref__')
def __init__(self, parent, target_cls):
if target_cls not in parent._clslevel:
parent.update_subclass(target_cls)
self._exec_once = False
self.parent_listeners = parent._clslevel[target_cls]
self.parent = parent
self.name = parent.name
self.listeners = collections.deque()
self.propagate = set()
def for_modify(self, obj):
"""Return an event collection which can be modified.
For _ListenerCollection at the instance level of
a dispatcher, this returns self.
"""
return self
def _update(self, other, only_propagate=True):
"""Populate from the listeners in another :class:`_Dispatch`
object."""
existing_listeners = self.listeners
existing_listener_set = set(existing_listeners)
self.propagate.update(other.propagate)
other_listeners = [l for l
in other.listeners
if l not in existing_listener_set
and not only_propagate or l in self.propagate
]
existing_listeners.extend(other_listeners)
to_associate = other.propagate.union(other_listeners)
registry._stored_in_collection_multi(self, other, to_associate)
def insert(self, event_key, propagate):
if event_key.prepend_to_list(self, self.listeners):
if propagate:
self.propagate.add(event_key._listen_fn)
def append(self, event_key, propagate):
if event_key.append_to_list(self, self.listeners):
if propagate:
self.propagate.add(event_key._listen_fn)
def remove(self, event_key):
self.listeners.remove(event_key._listen_fn)
self.propagate.discard(event_key._listen_fn)
registry._removed_from_collection(event_key, self)
def clear(self):
registry._clear(self, self.listeners)
self.propagate.clear()
self.listeners.clear()
class _JoinedListener(_CompoundListener):
__slots__ = 'parent', 'name', 'local', 'parent_listeners'
def __init__(self, parent, name, local):
self._exec_once = False
self.parent = parent
self.name = name
self.local = local
self.parent_listeners = self.local
@property
def listeners(self):
return getattr(self.parent, self.name)
def _adjust_fn_spec(self, fn, named):
return self.local._adjust_fn_spec(fn, named)
def for_modify(self, obj):
self.local = self.parent_listeners = self.local.for_modify(obj)
return self
def insert(self, event_key, propagate):
self.local.insert(event_key, propagate)
def append(self, event_key, propagate):
self.local.append(event_key, propagate)
def remove(self, event_key):
self.local.remove(event_key)
def clear(self):
raise NotImplementedError()

289
sqlalchemy/event/base.py Normal file
View File

@ -0,0 +1,289 @@
# event/base.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Base implementation classes.
The public-facing ``Events`` serves as the base class for an event interface;
its public attributes represent different kinds of events. These attributes
are mirrored onto a ``_Dispatch`` class, which serves as a container for
collections of listener functions. These collections are represented both
at the class level of a particular ``_Dispatch`` class as well as within
instances of ``_Dispatch``.
"""
from __future__ import absolute_import
import weakref
from .. import util
from .attr import _JoinedListener, \
_EmptyListener, _ClsLevelDispatch
_registrars = util.defaultdict(list)
def _is_event_name(name):
return not name.startswith('_') and name != 'dispatch'
class _UnpickleDispatch(object):
"""Serializable callable that re-generates an instance of
:class:`_Dispatch` given a particular :class:`.Events` subclass.
"""
def __call__(self, _instance_cls):
for cls in _instance_cls.__mro__:
if 'dispatch' in cls.__dict__:
return cls.__dict__['dispatch'].\
dispatch_cls._for_class(_instance_cls)
else:
raise AttributeError("No class with a 'dispatch' member present.")
class _Dispatch(object):
"""Mirror the event listening definitions of an Events class with
listener collections.
Classes which define a "dispatch" member will return a
non-instantiated :class:`._Dispatch` subclass when the member
is accessed at the class level. When the "dispatch" member is
accessed at the instance level of its owner, an instance
of the :class:`._Dispatch` class is returned.
A :class:`._Dispatch` class is generated for each :class:`.Events`
class defined, by the :func:`._create_dispatcher_class` function.
The original :class:`.Events` classes remain untouched.
This decouples the construction of :class:`.Events` subclasses from
the implementation used by the event internals, and allows
inspecting tools like Sphinx to work in an unsurprising
way against the public API.
"""
# in one ORM edge case, an attribute is added to _Dispatch,
# so __dict__ is used in just that case and potentially others.
__slots__ = '_parent', '_instance_cls', '__dict__', '_empty_listeners'
_empty_listener_reg = weakref.WeakKeyDictionary()
def __init__(self, parent, instance_cls=None):
self._parent = parent
self._instance_cls = instance_cls
if instance_cls:
try:
self._empty_listeners = self._empty_listener_reg[instance_cls]
except KeyError:
self._empty_listeners = \
self._empty_listener_reg[instance_cls] = dict(
(ls.name, _EmptyListener(ls, instance_cls))
for ls in parent._event_descriptors
)
else:
self._empty_listeners = {}
def __getattr__(self, name):
# assign EmptyListeners as attributes on demand
# to reduce startup time for new dispatch objects
try:
ls = self._empty_listeners[name]
except KeyError:
raise AttributeError(name)
else:
setattr(self, ls.name, ls)
return ls
@property
def _event_descriptors(self):
for k in self._event_names:
yield getattr(self, k)
def _for_class(self, instance_cls):
return self.__class__(self, instance_cls)
def _for_instance(self, instance):
instance_cls = instance.__class__
return self._for_class(instance_cls)
@property
def _listen(self):
return self._events._listen
def _join(self, other):
"""Create a 'join' of this :class:`._Dispatch` and another.
This new dispatcher will dispatch events to both
:class:`._Dispatch` objects.
"""
if '_joined_dispatch_cls' not in self.__class__.__dict__:
cls = type(
"Joined%s" % self.__class__.__name__,
(_JoinedDispatcher, ), {'__slots__': self._event_names}
)
self.__class__._joined_dispatch_cls = cls
return self._joined_dispatch_cls(self, other)
def __reduce__(self):
return _UnpickleDispatch(), (self._instance_cls, )
def _update(self, other, only_propagate=True):
"""Populate from the listeners in another :class:`_Dispatch`
object."""
for ls in other._event_descriptors:
if isinstance(ls, _EmptyListener):
continue
getattr(self, ls.name).\
for_modify(self)._update(ls, only_propagate=only_propagate)
def _clear(self):
for ls in self._event_descriptors:
ls.for_modify(self).clear()
class _EventMeta(type):
"""Intercept new Event subclasses and create
associated _Dispatch classes."""
def __init__(cls, classname, bases, dict_):
_create_dispatcher_class(cls, classname, bases, dict_)
return type.__init__(cls, classname, bases, dict_)
def _create_dispatcher_class(cls, classname, bases, dict_):
"""Create a :class:`._Dispatch` class corresponding to an
:class:`.Events` class."""
# there's all kinds of ways to do this,
# i.e. make a Dispatch class that shares the '_listen' method
# of the Event class, this is the straight monkeypatch.
if hasattr(cls, 'dispatch'):
dispatch_base = cls.dispatch.__class__
else:
dispatch_base = _Dispatch
event_names = [k for k in dict_ if _is_event_name(k)]
dispatch_cls = type("%sDispatch" % classname,
(dispatch_base, ), {'__slots__': event_names})
dispatch_cls._event_names = event_names
dispatch_inst = cls._set_dispatch(cls, dispatch_cls)
for k in dispatch_cls._event_names:
setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k]))
_registrars[k].append(cls)
for super_ in dispatch_cls.__bases__:
if issubclass(super_, _Dispatch) and super_ is not _Dispatch:
for ls in super_._events.dispatch._event_descriptors:
setattr(dispatch_inst, ls.name, ls)
dispatch_cls._event_names.append(ls.name)
if getattr(cls, '_dispatch_target', None):
cls._dispatch_target.dispatch = dispatcher(cls)
def _remove_dispatcher(cls):
for k in cls.dispatch._event_names:
_registrars[k].remove(cls)
if not _registrars[k]:
del _registrars[k]
class Events(util.with_metaclass(_EventMeta, object)):
"""Define event listening functions for a particular target type."""
@staticmethod
def _set_dispatch(cls, dispatch_cls):
# this allows an Events subclass to define additional utility
# methods made available to the target via
# "self.dispatch._events.<utilitymethod>"
# @staticemethod to allow easy "super" calls while in a metaclass
# constructor.
cls.dispatch = dispatch_cls(None)
dispatch_cls._events = cls
return cls.dispatch
@classmethod
def _accept_with(cls, target):
# Mapper, ClassManager, Session override this to
# also accept classes, scoped_sessions, sessionmakers, etc.
if hasattr(target, 'dispatch') and (
isinstance(target.dispatch, cls.dispatch.__class__) or
(
isinstance(target.dispatch, type) and
isinstance(target.dispatch, cls.dispatch.__class__)
) or
(
isinstance(target.dispatch, _JoinedDispatcher) and
isinstance(target.dispatch.parent, cls.dispatch.__class__)
)
):
return target
else:
return None
@classmethod
def _listen(cls, event_key, propagate=False, insert=False, named=False):
event_key.base_listen(propagate=propagate, insert=insert, named=named)
@classmethod
def _remove(cls, event_key):
event_key.remove()
@classmethod
def _clear(cls):
cls.dispatch._clear()
class _JoinedDispatcher(object):
"""Represent a connection between two _Dispatch objects."""
__slots__ = 'local', 'parent', '_instance_cls'
def __init__(self, local, parent):
self.local = local
self.parent = parent
self._instance_cls = self.local._instance_cls
def __getattr__(self, name):
# assign _JoinedListeners as attributes on demand
# to reduce startup time for new dispatch objects
ls = getattr(self.local, name)
jl = _JoinedListener(self.parent, ls.name, ls)
setattr(self, ls.name, jl)
return jl
@property
def _listen(self):
return self.parent._listen
class dispatcher(object):
"""Descriptor used by target classes to
deliver the _Dispatch class at the class level
and produce new _Dispatch instances for target
instances.
"""
def __init__(self, events):
self.dispatch_cls = events.dispatch
self.events = events
def __get__(self, obj, cls):
if obj is None:
return self.dispatch_cls
obj.__dict__['dispatch'] = disp = self.dispatch_cls._for_instance(obj)
return disp

169
sqlalchemy/event/legacy.py Normal file
View File

@ -0,0 +1,169 @@
# event/legacy.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Routines to handle adaption of legacy call signatures,
generation of deprecation notes and docstrings.
"""
from .. import util
def _legacy_signature(since, argnames, converter=None):
def leg(fn):
if not hasattr(fn, '_legacy_signatures'):
fn._legacy_signatures = []
fn._legacy_signatures.append((since, argnames, converter))
return fn
return leg
def _wrap_fn_for_legacy(dispatch_collection, fn, argspec):
for since, argnames, conv in dispatch_collection.legacy_signatures:
if argnames[-1] == "**kw":
has_kw = True
argnames = argnames[0:-1]
else:
has_kw = False
if len(argnames) == len(argspec.args) \
and has_kw is bool(argspec.keywords):
if conv:
assert not has_kw
def wrap_leg(*args):
return fn(*conv(*args))
else:
def wrap_leg(*args, **kw):
argdict = dict(zip(dispatch_collection.arg_names, args))
args = [argdict[name] for name in argnames]
if has_kw:
return fn(*args, **kw)
else:
return fn(*args)
return wrap_leg
else:
return fn
def _indent(text, indent):
return "\n".join(
indent + line
for line in text.split("\n")
)
def _standard_listen_example(dispatch_collection, sample_target, fn):
example_kw_arg = _indent(
"\n".join(
"%(arg)s = kw['%(arg)s']" % {"arg": arg}
for arg in dispatch_collection.arg_names[0:2]
),
" ")
if dispatch_collection.legacy_signatures:
current_since = max(since for since, args, conv
in dispatch_collection.legacy_signatures)
else:
current_since = None
text = (
"from sqlalchemy import event\n\n"
"# standard decorator style%(current_since)s\n"
"@event.listens_for(%(sample_target)s, '%(event_name)s')\n"
"def receive_%(event_name)s("
"%(named_event_arguments)s%(has_kw_arguments)s):\n"
" \"listen for the '%(event_name)s' event\"\n"
"\n # ... (event handling logic) ...\n"
)
if len(dispatch_collection.arg_names) > 3:
text += (
"\n# named argument style (new in 0.9)\n"
"@event.listens_for("
"%(sample_target)s, '%(event_name)s', named=True)\n"
"def receive_%(event_name)s(**kw):\n"
" \"listen for the '%(event_name)s' event\"\n"
"%(example_kw_arg)s\n"
"\n # ... (event handling logic) ...\n"
)
text %= {
"current_since": " (arguments as of %s)" %
current_since if current_since else "",
"event_name": fn.__name__,
"has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "",
"named_event_arguments": ", ".join(dispatch_collection.arg_names),
"example_kw_arg": example_kw_arg,
"sample_target": sample_target
}
return text
def _legacy_listen_examples(dispatch_collection, sample_target, fn):
text = ""
for since, args, conv in dispatch_collection.legacy_signatures:
text += (
"\n# legacy calling style (pre-%(since)s)\n"
"@event.listens_for(%(sample_target)s, '%(event_name)s')\n"
"def receive_%(event_name)s("
"%(named_event_arguments)s%(has_kw_arguments)s):\n"
" \"listen for the '%(event_name)s' event\"\n"
"\n # ... (event handling logic) ...\n" % {
"since": since,
"event_name": fn.__name__,
"has_kw_arguments": " **kw"
if dispatch_collection.has_kw else "",
"named_event_arguments": ", ".join(args),
"sample_target": sample_target
}
)
return text
def _version_signature_changes(dispatch_collection):
since, args, conv = dispatch_collection.legacy_signatures[0]
return (
"\n.. versionchanged:: %(since)s\n"
" The ``%(event_name)s`` event now accepts the \n"
" arguments ``%(named_event_arguments)s%(has_kw_arguments)s``.\n"
" Listener functions which accept the previous argument \n"
" signature(s) listed above will be automatically \n"
" adapted to the new signature." % {
"since": since,
"event_name": dispatch_collection.name,
"named_event_arguments": ", ".join(dispatch_collection.arg_names),
"has_kw_arguments": ", **kw" if dispatch_collection.has_kw else ""
}
)
def _augment_fn_docs(dispatch_collection, parent_dispatch_cls, fn):
header = ".. container:: event_signatures\n\n"\
" Example argument forms::\n"\
"\n"
sample_target = getattr(parent_dispatch_cls, "_target_class_doc", "obj")
text = (
header +
_indent(
_standard_listen_example(
dispatch_collection, sample_target, fn),
" " * 8)
)
if dispatch_collection.legacy_signatures:
text += _indent(
_legacy_listen_examples(
dispatch_collection, sample_target, fn),
" " * 8)
text += _version_signature_changes(dispatch_collection)
return util.inject_docstring_text(fn.__doc__,
text,
1
)

View File

@ -0,0 +1,262 @@
# event/registry.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Provides managed registration services on behalf of :func:`.listen`
arguments.
By "managed registration", we mean that event listening functions and
other objects can be added to various collections in such a way that their
membership in all those collections can be revoked at once, based on
an equivalent :class:`._EventKey`.
"""
from __future__ import absolute_import
import weakref
import collections
import types
from .. import exc, util
_key_to_collection = collections.defaultdict(dict)
"""
Given an original listen() argument, can locate all
listener collections and the listener fn contained
(target, identifier, fn) -> {
ref(listenercollection) -> ref(listener_fn)
ref(listenercollection) -> ref(listener_fn)
ref(listenercollection) -> ref(listener_fn)
}
"""
_collection_to_key = collections.defaultdict(dict)
"""
Given a _ListenerCollection or _ClsLevelListener, can locate
all the original listen() arguments and the listener fn contained
ref(listenercollection) -> {
ref(listener_fn) -> (target, identifier, fn),
ref(listener_fn) -> (target, identifier, fn),
ref(listener_fn) -> (target, identifier, fn),
}
"""
def _collection_gced(ref):
# defaultdict, so can't get a KeyError
if not _collection_to_key or ref not in _collection_to_key:
return
listener_to_key = _collection_to_key.pop(ref)
for key in listener_to_key.values():
if key in _key_to_collection:
# defaultdict, so can't get a KeyError
dispatch_reg = _key_to_collection[key]
dispatch_reg.pop(ref)
if not dispatch_reg:
_key_to_collection.pop(key)
def _stored_in_collection(event_key, owner):
key = event_key._key
dispatch_reg = _key_to_collection[key]
owner_ref = owner.ref
listen_ref = weakref.ref(event_key._listen_fn)
if owner_ref in dispatch_reg:
return False
dispatch_reg[owner_ref] = listen_ref
listener_to_key = _collection_to_key[owner_ref]
listener_to_key[listen_ref] = key
return True
def _removed_from_collection(event_key, owner):
key = event_key._key
dispatch_reg = _key_to_collection[key]
listen_ref = weakref.ref(event_key._listen_fn)
owner_ref = owner.ref
dispatch_reg.pop(owner_ref, None)
if not dispatch_reg:
del _key_to_collection[key]
if owner_ref in _collection_to_key:
listener_to_key = _collection_to_key[owner_ref]
listener_to_key.pop(listen_ref)
def _stored_in_collection_multi(newowner, oldowner, elements):
if not elements:
return
oldowner = oldowner.ref
newowner = newowner.ref
old_listener_to_key = _collection_to_key[oldowner]
new_listener_to_key = _collection_to_key[newowner]
for listen_fn in elements:
listen_ref = weakref.ref(listen_fn)
key = old_listener_to_key[listen_ref]
dispatch_reg = _key_to_collection[key]
if newowner in dispatch_reg:
assert dispatch_reg[newowner] == listen_ref
else:
dispatch_reg[newowner] = listen_ref
new_listener_to_key[listen_ref] = key
def _clear(owner, elements):
if not elements:
return
owner = owner.ref
listener_to_key = _collection_to_key[owner]
for listen_fn in elements:
listen_ref = weakref.ref(listen_fn)
key = listener_to_key[listen_ref]
dispatch_reg = _key_to_collection[key]
dispatch_reg.pop(owner, None)
if not dispatch_reg:
del _key_to_collection[key]
class _EventKey(object):
"""Represent :func:`.listen` arguments.
"""
__slots__ = (
'target', 'identifier', 'fn', 'fn_key', 'fn_wrap', 'dispatch_target'
)
def __init__(self, target, identifier,
fn, dispatch_target, _fn_wrap=None):
self.target = target
self.identifier = identifier
self.fn = fn
if isinstance(fn, types.MethodType):
self.fn_key = id(fn.__func__), id(fn.__self__)
else:
self.fn_key = id(fn)
self.fn_wrap = _fn_wrap
self.dispatch_target = dispatch_target
@property
def _key(self):
return (id(self.target), self.identifier, self.fn_key)
def with_wrapper(self, fn_wrap):
if fn_wrap is self._listen_fn:
return self
else:
return _EventKey(
self.target,
self.identifier,
self.fn,
self.dispatch_target,
_fn_wrap=fn_wrap
)
def with_dispatch_target(self, dispatch_target):
if dispatch_target is self.dispatch_target:
return self
else:
return _EventKey(
self.target,
self.identifier,
self.fn,
dispatch_target,
_fn_wrap=self.fn_wrap
)
def listen(self, *args, **kw):
once = kw.pop("once", False)
named = kw.pop("named", False)
target, identifier, fn = \
self.dispatch_target, self.identifier, self._listen_fn
dispatch_collection = getattr(target.dispatch, identifier)
adjusted_fn = dispatch_collection._adjust_fn_spec(fn, named)
self = self.with_wrapper(adjusted_fn)
if once:
self.with_wrapper(
util.only_once(self._listen_fn)).listen(*args, **kw)
else:
self.dispatch_target.dispatch._listen(self, *args, **kw)
def remove(self):
key = self._key
if key not in _key_to_collection:
raise exc.InvalidRequestError(
"No listeners found for event %s / %r / %s " %
(self.target, self.identifier, self.fn)
)
dispatch_reg = _key_to_collection.pop(key)
for collection_ref, listener_ref in dispatch_reg.items():
collection = collection_ref()
listener_fn = listener_ref()
if collection is not None and listener_fn is not None:
collection.remove(self.with_wrapper(listener_fn))
def contains(self):
"""Return True if this event key is registered to listen.
"""
return self._key in _key_to_collection
def base_listen(self, propagate=False, insert=False,
named=False):
target, identifier, fn = \
self.dispatch_target, self.identifier, self._listen_fn
dispatch_collection = getattr(target.dispatch, identifier)
if insert:
dispatch_collection.\
for_modify(target.dispatch).insert(self, propagate)
else:
dispatch_collection.\
for_modify(target.dispatch).append(self, propagate)
@property
def _listen_fn(self):
return self.fn_wrap or self.fn
def append_to_list(self, owner, list_):
if _stored_in_collection(self, owner):
list_.append(self._listen_fn)
return True
else:
return False
def remove_from_list(self, owner, list_):
_removed_from_collection(self, owner)
list_.remove(self._listen_fn)
def prepend_to_list(self, owner, list_):
if _stored_in_collection(self, owner):
list_.appendleft(self._listen_fn)
return True
else:
return False

1173
sqlalchemy/events.py Normal file

File diff suppressed because it is too large Load Diff

1048
sqlalchemy/ext/automap.py Normal file

File diff suppressed because it is too large Load Diff

559
sqlalchemy/ext/baked.py Normal file
View File

@ -0,0 +1,559 @@
# sqlalchemy/ext/baked.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Baked query extension.
Provides a creational pattern for the :class:`.query.Query` object which
allows the fully constructed object, Core select statement, and string
compiled result to be fully cached.
"""
from ..orm.query import Query
from ..orm import strategies, attributes, properties, \
strategy_options, util as orm_util, interfaces
from .. import log as sqla_log
from ..sql import util as sql_util, func, literal_column
from ..orm import exc as orm_exc
from .. import exc as sa_exc
from .. import util
import copy
import logging
log = logging.getLogger(__name__)
class BakedQuery(object):
"""A builder object for :class:`.query.Query` objects."""
__slots__ = 'steps', '_bakery', '_cache_key', '_spoiled'
def __init__(self, bakery, initial_fn, args=()):
self._cache_key = ()
self._update_cache_key(initial_fn, args)
self.steps = [initial_fn]
self._spoiled = False
self._bakery = bakery
@classmethod
def bakery(cls, size=200):
"""Construct a new bakery."""
_bakery = util.LRUCache(size)
def call(initial_fn, *args):
return cls(_bakery, initial_fn, args)
return call
def _clone(self):
b1 = BakedQuery.__new__(BakedQuery)
b1._cache_key = self._cache_key
b1.steps = list(self.steps)
b1._bakery = self._bakery
b1._spoiled = self._spoiled
return b1
def _update_cache_key(self, fn, args=()):
self._cache_key += (fn.__code__,) + args
def __iadd__(self, other):
if isinstance(other, tuple):
self.add_criteria(*other)
else:
self.add_criteria(other)
return self
def __add__(self, other):
if isinstance(other, tuple):
return self.with_criteria(*other)
else:
return self.with_criteria(other)
def add_criteria(self, fn, *args):
"""Add a criteria function to this :class:`.BakedQuery`.
This is equivalent to using the ``+=`` operator to
modify a :class:`.BakedQuery` in-place.
"""
self._update_cache_key(fn, args)
self.steps.append(fn)
return self
def with_criteria(self, fn, *args):
"""Add a criteria function to a :class:`.BakedQuery` cloned from this one.
This is equivalent to using the ``+`` operator to
produce a new :class:`.BakedQuery` with modifications.
"""
return self._clone().add_criteria(fn, *args)
def for_session(self, session):
"""Return a :class:`.Result` object for this :class:`.BakedQuery`.
This is equivalent to calling the :class:`.BakedQuery` as a
Python callable, e.g. ``result = my_baked_query(session)``.
"""
return Result(self, session)
def __call__(self, session):
return self.for_session(session)
def spoil(self, full=False):
"""Cancel any query caching that will occur on this BakedQuery object.
The BakedQuery can continue to be used normally, however additional
creational functions will not be cached; they will be called
on every invocation.
This is to support the case where a particular step in constructing
a baked query disqualifies the query from being cacheable, such
as a variant that relies upon some uncacheable value.
:param full: if False, only functions added to this
:class:`.BakedQuery` object subsequent to the spoil step will be
non-cached; the state of the :class:`.BakedQuery` up until
this point will be pulled from the cache. If True, then the
entire :class:`.Query` object is built from scratch each
time, with all creational functions being called on each
invocation.
"""
if not full:
_spoil_point = self._clone()
_spoil_point._cache_key += ('_query_only', )
self.steps = [_spoil_point._retrieve_baked_query]
self._spoiled = True
return self
def _retrieve_baked_query(self, session):
query = self._bakery.get(self._cache_key, None)
if query is None:
query = self._as_query(session)
self._bakery[self._cache_key] = query.with_session(None)
return query.with_session(session)
def _bake(self, session):
query = self._as_query(session)
context = query._compile_context()
self._bake_subquery_loaders(session, context)
context.session = None
context.query = query = context.query.with_session(None)
query._execution_options = query._execution_options.union(
{"compiled_cache": self._bakery}
)
# we'll be holding onto the query for some of its state,
# so delete some compilation-use-only attributes that can take up
# space
for attr in (
'_correlate', '_from_obj', '_mapper_adapter_map',
'_joinpath', '_joinpoint'):
query.__dict__.pop(attr, None)
self._bakery[self._cache_key] = context
return context
def _as_query(self, session):
query = self.steps[0](session)
for step in self.steps[1:]:
query = step(query)
return query
def _bake_subquery_loaders(self, session, context):
"""convert subquery eager loaders in the cache into baked queries.
For subquery eager loading to work, all we need here is that the
Query point to the correct session when it is run. However, since
we are "baking" anyway, we may as well also turn the query into
a "baked" query so that we save on performance too.
"""
context.attributes['baked_queries'] = baked_queries = []
for k, v in list(context.attributes.items()):
if isinstance(v, Query):
if 'subquery' in k:
bk = BakedQuery(self._bakery, lambda *args: v)
bk._cache_key = self._cache_key + k
bk._bake(session)
baked_queries.append((k, bk._cache_key, v))
del context.attributes[k]
def _unbake_subquery_loaders(self, session, context, params):
"""Retrieve subquery eager loaders stored by _bake_subquery_loaders
and turn them back into Result objects that will iterate just
like a Query object.
"""
for k, cache_key, query in context.attributes["baked_queries"]:
bk = BakedQuery(self._bakery,
lambda sess, q=query: q.with_session(sess))
bk._cache_key = cache_key
context.attributes[k] = bk.for_session(session).params(**params)
class Result(object):
"""Invokes a :class:`.BakedQuery` against a :class:`.Session`.
The :class:`.Result` object is where the actual :class:`.query.Query`
object gets created, or retrieved from the cache,
against a target :class:`.Session`, and is then invoked for results.
"""
__slots__ = 'bq', 'session', '_params'
def __init__(self, bq, session):
self.bq = bq
self.session = session
self._params = {}
def params(self, *args, **kw):
"""Specify parameters to be replaced into the string SQL statement."""
if len(args) == 1:
kw.update(args[0])
elif len(args) > 0:
raise sa_exc.ArgumentError(
"params() takes zero or one positional argument, "
"which is a dictionary.")
self._params.update(kw)
return self
def _as_query(self):
return self.bq._as_query(self.session).params(self._params)
def __str__(self):
return str(self._as_query())
def __iter__(self):
bq = self.bq
if bq._spoiled:
return iter(self._as_query())
baked_context = bq._bakery.get(bq._cache_key, None)
if baked_context is None:
baked_context = bq._bake(self.session)
context = copy.copy(baked_context)
context.session = self.session
context.attributes = context.attributes.copy()
bq._unbake_subquery_loaders(self.session, context, self._params)
context.statement.use_labels = True
if context.autoflush and not context.populate_existing:
self.session._autoflush()
return context.query.params(self._params).\
with_session(self.session)._execute_and_instances(context)
def count(self):
"""return the 'count'.
Equivalent to :meth:`.Query.count`.
Note this uses a subquery to ensure an accurate count regardless
of the structure of the original statement.
.. versionadded:: 1.1.6
"""
col = func.count(literal_column('*'))
bq = self.bq.with_criteria(lambda q: q.from_self(col))
return bq.for_session(self.session).params(self._params).scalar()
def scalar(self):
"""Return the first element of the first result or None
if no rows present. If multiple rows are returned,
raises MultipleResultsFound.
Equivalent to :meth:`.Query.scalar`.
.. versionadded:: 1.1.6
"""
try:
ret = self.one()
if not isinstance(ret, tuple):
return ret
return ret[0]
except orm_exc.NoResultFound:
return None
def first(self):
"""Return the first row.
Equivalent to :meth:`.Query.first`.
"""
bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
ret = list(bq.for_session(self.session).params(self._params))
if len(ret) > 0:
return ret[0]
else:
return None
def one(self):
"""Return exactly one result or raise an exception.
Equivalent to :meth:`.Query.one`.
"""
try:
ret = self.one_or_none()
except orm_exc.MultipleResultsFound:
raise orm_exc.MultipleResultsFound(
"Multiple rows were found for one()")
else:
if ret is None:
raise orm_exc.NoResultFound("No row was found for one()")
return ret
def one_or_none(self):
"""Return one or zero results, or raise an exception for multiple
rows.
Equivalent to :meth:`.Query.one_or_none`.
.. versionadded:: 1.0.9
"""
ret = list(self)
l = len(ret)
if l == 1:
return ret[0]
elif l == 0:
return None
else:
raise orm_exc.MultipleResultsFound(
"Multiple rows were found for one_or_none()")
def all(self):
"""Return all rows.
Equivalent to :meth:`.Query.all`.
"""
return list(self)
def get(self, ident):
"""Retrieve an object based on identity.
Equivalent to :meth:`.Query.get`.
"""
query = self.bq.steps[0](self.session)
return query._get_impl(ident, self._load_on_ident)
def _load_on_ident(self, query, key):
"""Load the given identity key from the database."""
ident = key[1]
mapper = query._mapper_zero()
_get_clause, _get_params = mapper._get_clause
def setup(query):
_lcl_get_clause = _get_clause
q = query._clone()
q._get_condition()
q._order_by = None
# None present in ident - turn those comparisons
# into "IS NULL"
if None in ident:
nones = set([
_get_params[col].key for col, value in
zip(mapper.primary_key, ident) if value is None
])
_lcl_get_clause = sql_util.adapt_criterion_to_null(
_lcl_get_clause, nones)
_lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False)
q._criterion = _lcl_get_clause
return q
# cache the query against a key that includes
# which positions in the primary key are NULL
# (remember, we can map to an OUTER JOIN)
bq = self.bq
# add the clause we got from mapper._get_clause to the cache
# key so that if a race causes multiple calls to _get_clause,
# we've cached on ours
bq = bq._clone()
bq._cache_key += (_get_clause, )
bq = bq.with_criteria(setup, tuple(elem is None for elem in ident))
params = dict([
(_get_params[primary_key].key, id_val)
for id_val, primary_key in zip(ident, mapper.primary_key)
])
result = list(bq.for_session(self.session).params(**params))
l = len(result)
if l > 1:
raise orm_exc.MultipleResultsFound()
elif l:
return result[0]
else:
return None
def bake_lazy_loaders():
"""Enable the use of baked queries for all lazyloaders systemwide.
This operation should be safe for all lazy loaders, and will reduce
Python overhead for these operations.
"""
BakedLazyLoader._strategy_keys[:] = []
properties.RelationshipProperty.strategy_for(
lazy="select")(BakedLazyLoader)
properties.RelationshipProperty.strategy_for(
lazy=True)(BakedLazyLoader)
properties.RelationshipProperty.strategy_for(
lazy="baked_select")(BakedLazyLoader)
strategies.LazyLoader._strategy_keys[:] = BakedLazyLoader._strategy_keys[:]
def unbake_lazy_loaders():
"""Disable the use of baked queries for all lazyloaders systemwide.
This operation reverts the changes produced by :func:`.bake_lazy_loaders`.
"""
strategies.LazyLoader._strategy_keys[:] = []
BakedLazyLoader._strategy_keys[:] = []
properties.RelationshipProperty.strategy_for(
lazy="select")(strategies.LazyLoader)
properties.RelationshipProperty.strategy_for(
lazy=True)(strategies.LazyLoader)
properties.RelationshipProperty.strategy_for(
lazy="baked_select")(BakedLazyLoader)
assert strategies.LazyLoader._strategy_keys
@sqla_log.class_logger
@properties.RelationshipProperty.strategy_for(lazy="baked_select")
class BakedLazyLoader(strategies.LazyLoader):
def _emit_lazyload(self, session, state, ident_key, passive):
q = BakedQuery(
self.mapper._compiled_cache,
lambda session: session.query(self.mapper))
q.add_criteria(
lambda q: q._adapt_all_clauses()._with_invoke_all_eagers(False),
self.parent_property)
if not self.parent_property.bake_queries:
q.spoil(full=True)
if self.parent_property.secondary is not None:
q.add_criteria(
lambda q:
q.select_from(self.mapper, self.parent_property.secondary))
pending = not state.key
# don't autoflush on pending
if pending or passive & attributes.NO_AUTOFLUSH:
q.add_criteria(lambda q: q.autoflush(False))
if state.load_options:
q.spoil()
args = state.load_path[self.parent_property]
q.add_criteria(
lambda q:
q._with_current_path(args), args)
q.add_criteria(
lambda q: q._conditional_options(*state.load_options))
if self.use_get:
return q(session)._load_on_ident(
session.query(self.mapper), ident_key)
if self.parent_property.order_by:
q.add_criteria(
lambda q:
q.order_by(*util.to_list(self.parent_property.order_by)))
for rev in self.parent_property._reverse_property:
# reverse props that are MANYTOONE are loading *this*
# object from get(), so don't need to eager out to those.
if rev.direction is interfaces.MANYTOONE and \
rev._use_get and \
not isinstance(rev.strategy, strategies.LazyLoader):
q.add_criteria(
lambda q:
q.options(
strategy_options.Load.for_existing_path(
q._current_path[rev.parent]
).baked_lazyload(rev.key)
)
)
lazy_clause, params = self._generate_lazy_clause(state, passive)
if pending:
if orm_util._none_set.intersection(params.values()):
return None
q.add_criteria(lambda q: q.filter(lazy_clause))
result = q(session).params(**params).all()
if self.uselist:
return result
else:
l = len(result)
if l:
if l > 1:
util.warn(
"Multiple rows returned with "
"uselist=False for lazily-loaded attribute '%s' "
% self.parent_property)
return result[0]
else:
return None
@strategy_options.loader_option()
def baked_lazyload(loadopt, attr):
"""Indicate that the given attribute should be loaded using "lazy"
loading with a "baked" query used in the load.
"""
return loadopt.set_relationship_strategy(attr, {"lazy": "baked_select"})
@baked_lazyload._add_unbound_fn
def baked_lazyload(*keys):
return strategy_options._UnboundLoad._from_keys(
strategy_options._UnboundLoad.baked_lazyload, keys, False, {})
@baked_lazyload._add_unbound_all_fn
def baked_lazyload_all(*keys):
return strategy_options._UnboundLoad._from_keys(
strategy_options._UnboundLoad.baked_lazyload, keys, True, {})
baked_lazyload = baked_lazyload._unbound_fn
baked_lazyload_all = baked_lazyload_all._unbound_all_fn
bakery = BakedQuery.bakery

View File

@ -0,0 +1,18 @@
# ext/declarative/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .api import declarative_base, synonym_for, comparable_using, \
instrument_declarative, ConcreteBase, AbstractConcreteBase, \
DeclarativeMeta, DeferredReflection, has_inherited_table,\
declared_attr, as_declarative
__all__ = ['declarative_base', 'synonym_for', 'has_inherited_table',
'comparable_using', 'instrument_declarative', 'declared_attr',
'as_declarative',
'ConcreteBase', 'AbstractConcreteBase', 'DeclarativeMeta',
'DeferredReflection']

View File

@ -0,0 +1,696 @@
# ext/declarative/api.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Public API functions and helpers for declarative."""
from ...schema import Table, MetaData, Column
from ...orm import synonym as _orm_synonym, \
comparable_property,\
interfaces, properties, attributes
from ...orm.util import polymorphic_union
from ...orm.base import _mapper_or_none
from ...util import OrderedDict, hybridmethod, hybridproperty
from ... import util
from ... import exc
import weakref
from .base import _as_declarative, \
_declarative_constructor,\
_DeferredMapperConfig, _add_attribute
from .clsregistry import _class_resolver
def instrument_declarative(cls, registry, metadata):
"""Given a class, configure the class declaratively,
using the given registry, which can be any dictionary, and
MetaData object.
"""
if '_decl_class_registry' in cls.__dict__:
raise exc.InvalidRequestError(
"Class %r already has been "
"instrumented declaratively" % cls)
cls._decl_class_registry = registry
cls.metadata = metadata
_as_declarative(cls, cls.__name__, cls.__dict__)
def has_inherited_table(cls):
"""Given a class, return True if any of the classes it inherits from has a
mapped table, otherwise return False.
This is used in declarative mixins to build attributes that behave
differently for the base class vs. a subclass in an inheritance
hierarchy.
.. seealso::
:ref:`decl_mixin_inheritance`
"""
for class_ in cls.__mro__[1:]:
if getattr(class_, '__table__', None) is not None:
return True
return False
class DeclarativeMeta(type):
def __init__(cls, classname, bases, dict_):
if '_decl_class_registry' not in cls.__dict__:
_as_declarative(cls, classname, cls.__dict__)
type.__init__(cls, classname, bases, dict_)
def __setattr__(cls, key, value):
_add_attribute(cls, key, value)
def synonym_for(name, map_column=False):
"""Decorator, make a Python @property a query synonym for a column.
A decorator version of :func:`~sqlalchemy.orm.synonym`. The function being
decorated is the 'descriptor', otherwise passes its arguments through to
synonym()::
@synonym_for('col')
@property
def prop(self):
return 'special sauce'
The regular ``synonym()`` is also usable directly in a declarative setting
and may be convenient for read/write properties::
prop = synonym('col', descriptor=property(_read_prop, _write_prop))
"""
def decorate(fn):
return _orm_synonym(name, map_column=map_column, descriptor=fn)
return decorate
def comparable_using(comparator_factory):
"""Decorator, allow a Python @property to be used in query criteria.
This is a decorator front end to
:func:`~sqlalchemy.orm.comparable_property` that passes
through the comparator_factory and the function being decorated::
@comparable_using(MyComparatorType)
@property
def prop(self):
return 'special sauce'
The regular ``comparable_property()`` is also usable directly in a
declarative setting and may be convenient for read/write properties::
prop = comparable_property(MyComparatorType)
"""
def decorate(fn):
return comparable_property(comparator_factory, fn)
return decorate
class declared_attr(interfaces._MappedAttribute, property):
"""Mark a class-level method as representing the definition of
a mapped property or special declarative member name.
@declared_attr turns the attribute into a scalar-like
property that can be invoked from the uninstantiated class.
Declarative treats attributes specifically marked with
@declared_attr as returning a construct that is specific
to mapping or declarative table configuration. The name
of the attribute is that of what the non-dynamic version
of the attribute would be.
@declared_attr is more often than not applicable to mixins,
to define relationships that are to be applied to different
implementors of the class::
class ProvidesUser(object):
"A mixin that adds a 'user' relationship to classes."
@declared_attr
def user(self):
return relationship("User")
It also can be applied to mapped classes, such as to provide
a "polymorphic" scheme for inheritance::
class Employee(Base):
id = Column(Integer, primary_key=True)
type = Column(String(50), nullable=False)
@declared_attr
def __tablename__(cls):
return cls.__name__.lower()
@declared_attr
def __mapper_args__(cls):
if cls.__name__ == 'Employee':
return {
"polymorphic_on":cls.type,
"polymorphic_identity":"Employee"
}
else:
return {"polymorphic_identity":cls.__name__}
.. versionchanged:: 0.8 :class:`.declared_attr` can be used with
non-ORM or extension attributes, such as user-defined attributes
or :func:`.association_proxy` objects, which will be assigned
to the class at class construction time.
"""
def __init__(self, fget, cascading=False):
super(declared_attr, self).__init__(fget)
self.__doc__ = fget.__doc__
self._cascading = cascading
def __get__(desc, self, cls):
reg = cls.__dict__.get('_sa_declared_attr_reg', None)
if reg is None:
manager = attributes.manager_of_class(cls)
if manager is None:
util.warn(
"Unmanaged access of declarative attribute %s from "
"non-mapped class %s" %
(desc.fget.__name__, cls.__name__))
return desc.fget(cls)
elif desc in reg:
return reg[desc]
else:
reg[desc] = obj = desc.fget(cls)
return obj
@hybridmethod
def _stateful(cls, **kw):
return _stateful_declared_attr(**kw)
@hybridproperty
def cascading(cls):
"""Mark a :class:`.declared_attr` as cascading.
This is a special-use modifier which indicates that a column
or MapperProperty-based declared attribute should be configured
distinctly per mapped subclass, within a mapped-inheritance scenario.
Below, both MyClass as well as MySubClass will have a distinct
``id`` Column object established::
class HasIdMixin(object):
@declared_attr.cascading
def id(cls):
if has_inherited_table(cls):
return Column(ForeignKey('myclass.id'), primary_key=True)
else:
return Column(Integer, primary_key=True)
class MyClass(HasIdMixin, Base):
__tablename__ = 'myclass'
# ...
class MySubClass(MyClass):
""
# ...
The behavior of the above configuration is that ``MySubClass``
will refer to both its own ``id`` column as well as that of
``MyClass`` underneath the attribute named ``some_id``.
.. seealso::
:ref:`declarative_inheritance`
:ref:`mixin_inheritance_columns`
"""
return cls._stateful(cascading=True)
class _stateful_declared_attr(declared_attr):
def __init__(self, **kw):
self.kw = kw
def _stateful(self, **kw):
new_kw = self.kw.copy()
new_kw.update(kw)
return _stateful_declared_attr(**new_kw)
def __call__(self, fn):
return declared_attr(fn, **self.kw)
def declarative_base(bind=None, metadata=None, mapper=None, cls=object,
name='Base', constructor=_declarative_constructor,
class_registry=None,
metaclass=DeclarativeMeta):
r"""Construct a base class for declarative class definitions.
The new base class will be given a metaclass that produces
appropriate :class:`~sqlalchemy.schema.Table` objects and makes
the appropriate :func:`~sqlalchemy.orm.mapper` calls based on the
information provided declaratively in the class and any subclasses
of the class.
:param bind: An optional
:class:`~sqlalchemy.engine.Connectable`, will be assigned
the ``bind`` attribute on the :class:`~sqlalchemy.schema.MetaData`
instance.
:param metadata:
An optional :class:`~sqlalchemy.schema.MetaData` instance. All
:class:`~sqlalchemy.schema.Table` objects implicitly declared by
subclasses of the base will share this MetaData. A MetaData instance
will be created if none is provided. The
:class:`~sqlalchemy.schema.MetaData` instance will be available via the
`metadata` attribute of the generated declarative base class.
:param mapper:
An optional callable, defaults to :func:`~sqlalchemy.orm.mapper`. Will
be used to map subclasses to their Tables.
:param cls:
Defaults to :class:`object`. A type to use as the base for the generated
declarative base class. May be a class or tuple of classes.
:param name:
Defaults to ``Base``. The display name for the generated
class. Customizing this is not required, but can improve clarity in
tracebacks and debugging.
:param constructor:
Defaults to
:func:`~sqlalchemy.ext.declarative.base._declarative_constructor`, an
__init__ implementation that assigns \**kwargs for declared
fields and relationships to an instance. If ``None`` is supplied,
no __init__ will be provided and construction will fall back to
cls.__init__ by way of the normal Python semantics.
:param class_registry: optional dictionary that will serve as the
registry of class names-> mapped classes when string names
are used to identify classes inside of :func:`.relationship`
and others. Allows two or more declarative base classes
to share the same registry of class names for simplified
inter-base relationships.
:param metaclass:
Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__
compatible callable to use as the meta type of the generated
declarative base class.
.. versionchanged:: 1.1 if :paramref:`.declarative_base.cls` is a single class (rather
than a tuple), the constructed base class will inherit its docstring.
.. seealso::
:func:`.as_declarative`
"""
lcl_metadata = metadata or MetaData()
if bind:
lcl_metadata.bind = bind
if class_registry is None:
class_registry = weakref.WeakValueDictionary()
bases = not isinstance(cls, tuple) and (cls,) or cls
class_dict = dict(_decl_class_registry=class_registry,
metadata=lcl_metadata)
if isinstance(cls, type):
class_dict['__doc__'] = cls.__doc__
if constructor:
class_dict['__init__'] = constructor
if mapper:
class_dict['__mapper_cls__'] = mapper
return metaclass(name, bases, class_dict)
def as_declarative(**kw):
"""
Class decorator for :func:`.declarative_base`.
Provides a syntactical shortcut to the ``cls`` argument
sent to :func:`.declarative_base`, allowing the base class
to be converted in-place to a "declarative" base::
from sqlalchemy.ext.declarative import as_declarative
@as_declarative()
class Base(object):
@declared_attr
def __tablename__(cls):
return cls.__name__.lower()
id = Column(Integer, primary_key=True)
class MyMappedClass(Base):
# ...
All keyword arguments passed to :func:`.as_declarative` are passed
along to :func:`.declarative_base`.
.. versionadded:: 0.8.3
.. seealso::
:func:`.declarative_base`
"""
def decorate(cls):
kw['cls'] = cls
kw['name'] = cls.__name__
return declarative_base(**kw)
return decorate
class ConcreteBase(object):
"""A helper class for 'concrete' declarative mappings.
:class:`.ConcreteBase` will use the :func:`.polymorphic_union`
function automatically, against all tables mapped as a subclass
to this class. The function is called via the
``__declare_last__()`` function, which is essentially
a hook for the :meth:`.after_configured` event.
:class:`.ConcreteBase` produces a mapped
table for the class itself. Compare to :class:`.AbstractConcreteBase`,
which does not.
Example::
from sqlalchemy.ext.declarative import ConcreteBase
class Employee(ConcreteBase, Base):
__tablename__ = 'employee'
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
__mapper_args__ = {
'polymorphic_identity':'employee',
'concrete':True}
class Manager(Employee):
__tablename__ = 'manager'
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
'polymorphic_identity':'manager',
'concrete':True}
.. seealso::
:class:`.AbstractConcreteBase`
:ref:`concrete_inheritance`
:ref:`inheritance_concrete_helpers`
"""
@classmethod
def _create_polymorphic_union(cls, mappers):
return polymorphic_union(OrderedDict(
(mp.polymorphic_identity, mp.local_table)
for mp in mappers
), 'type', 'pjoin')
@classmethod
def __declare_first__(cls):
m = cls.__mapper__
if m.with_polymorphic:
return
mappers = list(m.self_and_descendants)
pjoin = cls._create_polymorphic_union(mappers)
m._set_with_polymorphic(("*", pjoin))
m._set_polymorphic_on(pjoin.c.type)
class AbstractConcreteBase(ConcreteBase):
"""A helper class for 'concrete' declarative mappings.
:class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union`
function automatically, against all tables mapped as a subclass
to this class. The function is called via the
``__declare_last__()`` function, which is essentially
a hook for the :meth:`.after_configured` event.
:class:`.AbstractConcreteBase` does produce a mapped class
for the base class, however it is not persisted to any table; it
is instead mapped directly to the "polymorphic" selectable directly
and is only used for selecting. Compare to :class:`.ConcreteBase`,
which does create a persisted table for the base class.
Example::
from sqlalchemy.ext.declarative import AbstractConcreteBase
class Employee(AbstractConcreteBase, Base):
pass
class Manager(Employee):
__tablename__ = 'manager'
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
'polymorphic_identity':'manager',
'concrete':True}
The abstract base class is handled by declarative in a special way;
at class configuration time, it behaves like a declarative mixin
or an ``__abstract__`` base class. Once classes are configured
and mappings are produced, it then gets mapped itself, but
after all of its decscendants. This is a very unique system of mapping
not found in any other SQLAlchemy system.
Using this approach, we can specify columns and properties
that will take place on mapped subclasses, in the way that
we normally do as in :ref:`declarative_mixins`::
class Company(Base):
__tablename__ = 'company'
id = Column(Integer, primary_key=True)
class Employee(AbstractConcreteBase, Base):
employee_id = Column(Integer, primary_key=True)
@declared_attr
def company_id(cls):
return Column(ForeignKey('company.id'))
@declared_attr
def company(cls):
return relationship("Company")
class Manager(Employee):
__tablename__ = 'manager'
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
'polymorphic_identity':'manager',
'concrete':True}
When we make use of our mappings however, both ``Manager`` and
``Employee`` will have an independently usable ``.company`` attribute::
session.query(Employee).filter(Employee.company.has(id=5))
.. versionchanged:: 1.0.0 - The mechanics of :class:`.AbstractConcreteBase`
have been reworked to support relationships established directly
on the abstract base, without any special configurational steps.
.. seealso::
:class:`.ConcreteBase`
:ref:`concrete_inheritance`
:ref:`inheritance_concrete_helpers`
"""
__no_table__ = True
@classmethod
def __declare_first__(cls):
cls._sa_decl_prepare_nocascade()
@classmethod
def _sa_decl_prepare_nocascade(cls):
if getattr(cls, '__mapper__', None):
return
to_map = _DeferredMapperConfig.config_for_cls(cls)
# can't rely on 'self_and_descendants' here
# since technically an immediate subclass
# might not be mapped, but a subclass
# may be.
mappers = []
stack = list(cls.__subclasses__())
while stack:
klass = stack.pop()
stack.extend(klass.__subclasses__())
mn = _mapper_or_none(klass)
if mn is not None:
mappers.append(mn)
pjoin = cls._create_polymorphic_union(mappers)
# For columns that were declared on the class, these
# are normally ignored with the "__no_table__" mapping,
# unless they have a different attribute key vs. col name
# and are in the properties argument.
# In that case, ensure we update the properties entry
# to the correct column from the pjoin target table.
declared_cols = set(to_map.declared_columns)
for k, v in list(to_map.properties.items()):
if v in declared_cols:
to_map.properties[k] = pjoin.c[v.key]
to_map.local_table = pjoin
m_args = to_map.mapper_args_fn or dict
def mapper_args():
args = m_args()
args['polymorphic_on'] = pjoin.c.type
return args
to_map.mapper_args_fn = mapper_args
m = to_map.map()
for scls in cls.__subclasses__():
sm = _mapper_or_none(scls)
if sm and sm.concrete and cls in scls.__bases__:
sm._set_concrete_base(m)
class DeferredReflection(object):
"""A helper class for construction of mappings based on
a deferred reflection step.
Normally, declarative can be used with reflection by
setting a :class:`.Table` object using autoload=True
as the ``__table__`` attribute on a declarative class.
The caveat is that the :class:`.Table` must be fully
reflected, or at the very least have a primary key column,
at the point at which a normal declarative mapping is
constructed, meaning the :class:`.Engine` must be available
at class declaration time.
The :class:`.DeferredReflection` mixin moves the construction
of mappers to be at a later point, after a specific
method is called which first reflects all :class:`.Table`
objects created so far. Classes can define it as such::
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import DeferredReflection
Base = declarative_base()
class MyClass(DeferredReflection, Base):
__tablename__ = 'mytable'
Above, ``MyClass`` is not yet mapped. After a series of
classes have been defined in the above fashion, all tables
can be reflected and mappings created using
:meth:`.prepare`::
engine = create_engine("someengine://...")
DeferredReflection.prepare(engine)
The :class:`.DeferredReflection` mixin can be applied to individual
classes, used as the base for the declarative base itself,
or used in a custom abstract class. Using an abstract base
allows that only a subset of classes to be prepared for a
particular prepare step, which is necessary for applications
that use more than one engine. For example, if an application
has two engines, you might use two bases, and prepare each
separately, e.g.::
class ReflectedOne(DeferredReflection, Base):
__abstract__ = True
class ReflectedTwo(DeferredReflection, Base):
__abstract__ = True
class MyClass(ReflectedOne):
__tablename__ = 'mytable'
class MyOtherClass(ReflectedOne):
__tablename__ = 'myothertable'
class YetAnotherClass(ReflectedTwo):
__tablename__ = 'yetanothertable'
# ... etc.
Above, the class hierarchies for ``ReflectedOne`` and
``ReflectedTwo`` can be configured separately::
ReflectedOne.prepare(engine_one)
ReflectedTwo.prepare(engine_two)
.. versionadded:: 0.8
"""
@classmethod
def prepare(cls, engine):
"""Reflect all :class:`.Table` objects for all current
:class:`.DeferredReflection` subclasses"""
to_map = _DeferredMapperConfig.classes_for_base(cls)
for thingy in to_map:
cls._sa_decl_prepare(thingy.local_table, engine)
thingy.map()
mapper = thingy.cls.__mapper__
metadata = mapper.class_.metadata
for rel in mapper._props.values():
if isinstance(rel, properties.RelationshipProperty) and \
rel.secondary is not None:
if isinstance(rel.secondary, Table):
cls._reflect_table(rel.secondary, engine)
elif isinstance(rel.secondary, _class_resolver):
rel.secondary._resolvers += (
cls._sa_deferred_table_resolver(engine, metadata),
)
@classmethod
def _sa_deferred_table_resolver(cls, engine, metadata):
def _resolve(key):
t1 = Table(key, metadata)
cls._reflect_table(t1, engine)
return t1
return _resolve
@classmethod
def _sa_decl_prepare(cls, local_table, engine):
# autoload Table, which is already
# present in the metadata. This
# will fill in db-loaded columns
# into the existing Table object.
if local_table is not None:
cls._reflect_table(local_table, engine)
@classmethod
def _reflect_table(cls, table, engine):
Table(table.name,
table.metadata,
extend_existing=True,
autoload_replace=False,
autoload=True,
autoload_with=engine,
schema=table.schema)

View File

@ -0,0 +1,662 @@
# ext/declarative/base.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Internal implementation for declarative."""
from ...schema import Table, Column
from ...orm import mapper, class_mapper, synonym
from ...orm.interfaces import MapperProperty
from ...orm.properties import ColumnProperty, CompositeProperty
from ...orm.attributes import QueryableAttribute
from ...orm.base import _is_mapped_class
from ... import util, exc
from ...util import topological
from ...sql import expression
from ... import event
from . import clsregistry
import collections
import weakref
from sqlalchemy.orm import instrumentation
declared_attr = declarative_props = None
def _declared_mapping_info(cls):
# deferred mapping
if _DeferredMapperConfig.has_cls(cls):
return _DeferredMapperConfig.config_for_cls(cls)
# regular mapping
elif _is_mapped_class(cls):
return class_mapper(cls, configure=False)
else:
return None
def _resolve_for_abstract(cls):
if cls is object:
return None
if _get_immediate_cls_attr(cls, '__abstract__', strict=True):
for sup in cls.__bases__:
sup = _resolve_for_abstract(sup)
if sup is not None:
return sup
else:
return None
else:
return cls
def _get_immediate_cls_attr(cls, attrname, strict=False):
"""return an attribute of the class that is either present directly
on the class, e.g. not on a superclass, or is from a superclass but
this superclass is a mixin, that is, not a descendant of
the declarative base.
This is used to detect attributes that indicate something about
a mapped class independently from any mapped classes that it may
inherit from.
"""
if not issubclass(cls, object):
return None
for base in cls.__mro__:
_is_declarative_inherits = hasattr(base, '_decl_class_registry')
if attrname in base.__dict__ and (
base is cls or
((base in cls.__bases__ if strict else True)
and not _is_declarative_inherits)
):
return getattr(base, attrname)
else:
return None
def _as_declarative(cls, classname, dict_):
global declared_attr, declarative_props
if declared_attr is None:
from .api import declared_attr
declarative_props = (declared_attr, util.classproperty)
if _get_immediate_cls_attr(cls, '__abstract__', strict=True):
return
_MapperConfig.setup_mapping(cls, classname, dict_)
class _MapperConfig(object):
@classmethod
def setup_mapping(cls, cls_, classname, dict_):
defer_map = _get_immediate_cls_attr(
cls_, '_sa_decl_prepare_nocascade', strict=True) or \
hasattr(cls_, '_sa_decl_prepare')
if defer_map:
cfg_cls = _DeferredMapperConfig
else:
cfg_cls = _MapperConfig
cfg_cls(cls_, classname, dict_)
def __init__(self, cls_, classname, dict_):
self.cls = cls_
# dict_ will be a dictproxy, which we can't write to, and we need to!
self.dict_ = dict(dict_)
self.classname = classname
self.mapped_table = None
self.properties = util.OrderedDict()
self.declared_columns = set()
self.column_copies = {}
self._setup_declared_events()
# temporary registry. While early 1.0 versions
# set up the ClassManager here, by API contract
# we can't do that until there's a mapper.
self.cls._sa_declared_attr_reg = {}
self._scan_attributes()
clsregistry.add_class(self.classname, self.cls)
self._extract_mappable_attributes()
self._extract_declared_columns()
self._setup_table()
self._setup_inheritance()
self._early_mapping()
def _early_mapping(self):
self.map()
def _setup_declared_events(self):
if _get_immediate_cls_attr(self.cls, '__declare_last__'):
@event.listens_for(mapper, "after_configured")
def after_configured():
self.cls.__declare_last__()
if _get_immediate_cls_attr(self.cls, '__declare_first__'):
@event.listens_for(mapper, "before_configured")
def before_configured():
self.cls.__declare_first__()
def _scan_attributes(self):
cls = self.cls
dict_ = self.dict_
column_copies = self.column_copies
mapper_args_fn = None
table_args = inherited_table_args = None
tablename = None
for base in cls.__mro__:
class_mapped = base is not cls and \
_declared_mapping_info(base) is not None and \
not _get_immediate_cls_attr(
base, '_sa_decl_prepare_nocascade', strict=True)
if not class_mapped and base is not cls:
self._produce_column_copies(base)
for name, obj in vars(base).items():
if name == '__mapper_args__':
if not mapper_args_fn and (
not class_mapped or
isinstance(obj, declarative_props)
):
# don't even invoke __mapper_args__ until
# after we've determined everything about the
# mapped table.
# make a copy of it so a class-level dictionary
# is not overwritten when we update column-based
# arguments.
mapper_args_fn = lambda: dict(cls.__mapper_args__)
elif name == '__tablename__':
if not tablename and (
not class_mapped or
isinstance(obj, declarative_props)
):
tablename = cls.__tablename__
elif name == '__table_args__':
if not table_args and (
not class_mapped or
isinstance(obj, declarative_props)
):
table_args = cls.__table_args__
if not isinstance(
table_args, (tuple, dict, type(None))):
raise exc.ArgumentError(
"__table_args__ value must be a tuple, "
"dict, or None")
if base is not cls:
inherited_table_args = True
elif class_mapped:
if isinstance(obj, declarative_props):
util.warn("Regular (i.e. not __special__) "
"attribute '%s.%s' uses @declared_attr, "
"but owning class %s is mapped - "
"not applying to subclass %s."
% (base.__name__, name, base, cls))
continue
elif base is not cls:
# we're a mixin, abstract base, or something that is
# acting like that for now.
if isinstance(obj, Column):
# already copied columns to the mapped class.
continue
elif isinstance(obj, MapperProperty):
raise exc.InvalidRequestError(
"Mapper properties (i.e. deferred,"
"column_property(), relationship(), etc.) must "
"be declared as @declared_attr callables "
"on declarative mixin classes.")
elif isinstance(obj, declarative_props):
oldclassprop = isinstance(obj, util.classproperty)
if not oldclassprop and obj._cascading:
dict_[name] = column_copies[obj] = \
ret = obj.__get__(obj, cls)
setattr(cls, name, ret)
else:
if oldclassprop:
util.warn_deprecated(
"Use of sqlalchemy.util.classproperty on "
"declarative classes is deprecated.")
dict_[name] = column_copies[obj] = \
ret = getattr(cls, name)
if isinstance(ret, (Column, MapperProperty)) and \
ret.doc is None:
ret.doc = obj.__doc__
if inherited_table_args and not tablename:
table_args = None
self.table_args = table_args
self.tablename = tablename
self.mapper_args_fn = mapper_args_fn
def _produce_column_copies(self, base):
cls = self.cls
dict_ = self.dict_
column_copies = self.column_copies
# copy mixin columns to the mapped class
for name, obj in vars(base).items():
if isinstance(obj, Column):
if getattr(cls, name) is not obj:
# if column has been overridden
# (like by the InstrumentedAttribute of the
# superclass), skip
continue
elif obj.foreign_keys:
raise exc.InvalidRequestError(
"Columns with foreign keys to other columns "
"must be declared as @declared_attr callables "
"on declarative mixin classes. ")
elif name not in dict_ and not (
'__table__' in dict_ and
(obj.name or name) in dict_['__table__'].c
):
column_copies[obj] = copy_ = obj.copy()
copy_._creation_order = obj._creation_order
setattr(cls, name, copy_)
dict_[name] = copy_
def _extract_mappable_attributes(self):
cls = self.cls
dict_ = self.dict_
our_stuff = self.properties
for k in list(dict_):
if k in ('__table__', '__tablename__', '__mapper_args__'):
continue
value = dict_[k]
if isinstance(value, declarative_props):
value = getattr(cls, k)
elif isinstance(value, QueryableAttribute) and \
value.class_ is not cls and \
value.key != k:
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
value = synonym(value.key)
setattr(cls, k, value)
if (isinstance(value, tuple) and len(value) == 1 and
isinstance(value[0], (Column, MapperProperty))):
util.warn("Ignoring declarative-like tuple value of attribute "
"%s: possibly a copy-and-paste error with a comma "
"left at the end of the line?" % k)
continue
elif not isinstance(value, (Column, MapperProperty)):
# using @declared_attr for some object that
# isn't Column/MapperProperty; remove from the dict_
# and place the evaluated value onto the class.
if not k.startswith('__'):
dict_.pop(k)
setattr(cls, k, value)
continue
# we expect to see the name 'metadata' in some valid cases;
# however at this point we see it's assigned to something trying
# to be mapped, so raise for that.
elif k == 'metadata':
raise exc.InvalidRequestError(
"Attribute name 'metadata' is reserved "
"for the MetaData instance when using a "
"declarative base class."
)
prop = clsregistry._deferred_relationship(cls, value)
our_stuff[k] = prop
def _extract_declared_columns(self):
our_stuff = self.properties
# set up attributes in the order they were created
our_stuff.sort(key=lambda key: our_stuff[key]._creation_order)
# extract columns from the class dict
declared_columns = self.declared_columns
name_to_prop_key = collections.defaultdict(set)
for key, c in list(our_stuff.items()):
if isinstance(c, (ColumnProperty, CompositeProperty)):
for col in c.columns:
if isinstance(col, Column) and \
col.table is None:
_undefer_column_name(key, col)
if not isinstance(c, CompositeProperty):
name_to_prop_key[col.name].add(key)
declared_columns.add(col)
elif isinstance(c, Column):
_undefer_column_name(key, c)
name_to_prop_key[c.name].add(key)
declared_columns.add(c)
# if the column is the same name as the key,
# remove it from the explicit properties dict.
# the normal rules for assigning column-based properties
# will take over, including precedence of columns
# in multi-column ColumnProperties.
if key == c.key:
del our_stuff[key]
for name, keys in name_to_prop_key.items():
if len(keys) > 1:
util.warn(
"On class %r, Column object %r named "
"directly multiple times, "
"only one will be used: %s. "
"Consider using orm.synonym instead" %
(self.classname, name, (", ".join(sorted(keys))))
)
def _setup_table(self):
cls = self.cls
tablename = self.tablename
table_args = self.table_args
dict_ = self.dict_
declared_columns = self.declared_columns
declared_columns = self.declared_columns = sorted(
declared_columns, key=lambda c: c._creation_order)
table = None
if hasattr(cls, '__table_cls__'):
table_cls = util.unbound_method_to_callable(cls.__table_cls__)
else:
table_cls = Table
if '__table__' not in dict_:
if tablename is not None:
args, table_kw = (), {}
if table_args:
if isinstance(table_args, dict):
table_kw = table_args
elif isinstance(table_args, tuple):
if isinstance(table_args[-1], dict):
args, table_kw = table_args[0:-1], table_args[-1]
else:
args = table_args
autoload = dict_.get('__autoload__')
if autoload:
table_kw['autoload'] = True
cls.__table__ = table = table_cls(
tablename, cls.metadata,
*(tuple(declared_columns) + tuple(args)),
**table_kw)
else:
table = cls.__table__
if declared_columns:
for c in declared_columns:
if not table.c.contains_column(c):
raise exc.ArgumentError(
"Can't add additional column %r when "
"specifying __table__" % c.key
)
self.local_table = table
def _setup_inheritance(self):
table = self.local_table
cls = self.cls
table_args = self.table_args
declared_columns = self.declared_columns
for c in cls.__bases__:
c = _resolve_for_abstract(c)
if c is None:
continue
if _declared_mapping_info(c) is not None and \
not _get_immediate_cls_attr(
c, '_sa_decl_prepare_nocascade', strict=True):
self.inherits = c
break
else:
self.inherits = None
if table is None and self.inherits is None and \
not _get_immediate_cls_attr(cls, '__no_table__'):
raise exc.InvalidRequestError(
"Class %r does not have a __table__ or __tablename__ "
"specified and does not inherit from an existing "
"table-mapped class." % cls
)
elif self.inherits:
inherited_mapper = _declared_mapping_info(self.inherits)
inherited_table = inherited_mapper.local_table
inherited_mapped_table = inherited_mapper.mapped_table
if table is None:
# single table inheritance.
# ensure no table args
if table_args:
raise exc.ArgumentError(
"Can't place __table_args__ on an inherited class "
"with no table."
)
# add any columns declared here to the inherited table.
for c in declared_columns:
if c.primary_key:
raise exc.ArgumentError(
"Can't place primary key columns on an inherited "
"class with no table."
)
if c.name in inherited_table.c:
if inherited_table.c[c.name] is c:
continue
raise exc.ArgumentError(
"Column '%s' on class %s conflicts with "
"existing column '%s'" %
(c, cls, inherited_table.c[c.name])
)
inherited_table.append_column(c)
if inherited_mapped_table is not None and \
inherited_mapped_table is not inherited_table:
inherited_mapped_table._refresh_for_new_column(c)
def _prepare_mapper_arguments(self):
properties = self.properties
if self.mapper_args_fn:
mapper_args = self.mapper_args_fn()
else:
mapper_args = {}
# make sure that column copies are used rather
# than the original columns from any mixins
for k in ('version_id_col', 'polymorphic_on',):
if k in mapper_args:
v = mapper_args[k]
mapper_args[k] = self.column_copies.get(v, v)
assert 'inherits' not in mapper_args, \
"Can't specify 'inherits' explicitly with declarative mappings"
if self.inherits:
mapper_args['inherits'] = self.inherits
if self.inherits and not mapper_args.get('concrete', False):
# single or joined inheritance
# exclude any cols on the inherited table which are
# not mapped on the parent class, to avoid
# mapping columns specific to sibling/nephew classes
inherited_mapper = _declared_mapping_info(self.inherits)
inherited_table = inherited_mapper.local_table
if 'exclude_properties' not in mapper_args:
mapper_args['exclude_properties'] = exclude_properties = \
set(
[c.key for c in inherited_table.c
if c not in inherited_mapper._columntoproperty]
).union(
inherited_mapper.exclude_properties or ()
)
exclude_properties.difference_update(
[c.key for c in self.declared_columns])
# look through columns in the current mapper that
# are keyed to a propname different than the colname
# (if names were the same, we'd have popped it out above,
# in which case the mapper makes this combination).
# See if the superclass has a similar column property.
# If so, join them together.
for k, col in list(properties.items()):
if not isinstance(col, expression.ColumnElement):
continue
if k in inherited_mapper._props:
p = inherited_mapper._props[k]
if isinstance(p, ColumnProperty):
# note here we place the subclass column
# first. See [ticket:1892] for background.
properties[k] = [col] + p.columns
result_mapper_args = mapper_args.copy()
result_mapper_args['properties'] = properties
self.mapper_args = result_mapper_args
def map(self):
self._prepare_mapper_arguments()
if hasattr(self.cls, '__mapper_cls__'):
mapper_cls = util.unbound_method_to_callable(
self.cls.__mapper_cls__)
else:
mapper_cls = mapper
self.cls.__mapper__ = mp_ = mapper_cls(
self.cls,
self.local_table,
**self.mapper_args
)
del self.cls._sa_declared_attr_reg
return mp_
class _DeferredMapperConfig(_MapperConfig):
_configs = util.OrderedDict()
def _early_mapping(self):
pass
@property
def cls(self):
return self._cls()
@cls.setter
def cls(self, class_):
self._cls = weakref.ref(class_, self._remove_config_cls)
self._configs[self._cls] = self
@classmethod
def _remove_config_cls(cls, ref):
cls._configs.pop(ref, None)
@classmethod
def has_cls(cls, class_):
# 2.6 fails on weakref if class_ is an old style class
return isinstance(class_, type) and \
weakref.ref(class_) in cls._configs
@classmethod
def config_for_cls(cls, class_):
return cls._configs[weakref.ref(class_)]
@classmethod
def classes_for_base(cls, base_cls, sort=True):
classes_for_base = [m for m in cls._configs.values()
if issubclass(m.cls, base_cls)]
if not sort:
return classes_for_base
all_m_by_cls = dict(
(m.cls, m)
for m in classes_for_base
)
tuples = []
for m_cls in all_m_by_cls:
tuples.extend(
(all_m_by_cls[base_cls], all_m_by_cls[m_cls])
for base_cls in m_cls.__bases__
if base_cls in all_m_by_cls
)
return list(
topological.sort(
tuples,
classes_for_base
)
)
def map(self):
self._configs.pop(self._cls, None)
return super(_DeferredMapperConfig, self).map()
def _add_attribute(cls, key, value):
"""add an attribute to an existing declarative class.
This runs through the logic to determine MapperProperty,
adds it to the Mapper, adds a column to the mapped Table, etc.
"""
if '__mapper__' in cls.__dict__:
if isinstance(value, Column):
_undefer_column_name(key, value)
cls.__table__.append_column(value)
cls.__mapper__.add_property(key, value)
elif isinstance(value, ColumnProperty):
for col in value.columns:
if isinstance(col, Column) and col.table is None:
_undefer_column_name(key, col)
cls.__table__.append_column(col)
cls.__mapper__.add_property(key, value)
elif isinstance(value, MapperProperty):
cls.__mapper__.add_property(
key,
clsregistry._deferred_relationship(cls, value)
)
elif isinstance(value, QueryableAttribute) and value.key != key:
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
value = synonym(value.key)
cls.__mapper__.add_property(
key,
clsregistry._deferred_relationship(cls, value)
)
else:
type.__setattr__(cls, key, value)
else:
type.__setattr__(cls, key, value)
def _declarative_constructor(self, **kwargs):
"""A simple constructor that allows initialization from kwargs.
Sets attributes on the constructed instance using the names and
values in ``kwargs``.
Only keys that are present as
attributes of the instance's class are allowed. These could be,
for example, any mapped columns or relationships.
"""
cls_ = type(self)
for k in kwargs:
if not hasattr(cls_, k):
raise TypeError(
"%r is an invalid keyword argument for %s" %
(k, cls_.__name__))
setattr(self, k, kwargs[k])
_declarative_constructor.__name__ = '__init__'
def _undefer_column_name(key, column):
if column.key is None:
column.key = key
if column.name is None:
column.name = key

View File

@ -0,0 +1,328 @@
# ext/declarative/clsregistry.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Routines to handle the string class registry used by declarative.
This system allows specification of classes and expressions used in
:func:`.relationship` using strings.
"""
from ...orm.properties import ColumnProperty, RelationshipProperty, \
SynonymProperty
from ...schema import _get_table_key
from ...orm import class_mapper, interfaces
from ... import util
from ... import inspection
from ... import exc
import weakref
# strong references to registries which we place in
# the _decl_class_registry, which is usually weak referencing.
# the internal registries here link to classes with weakrefs and remove
# themselves when all references to contained classes are removed.
_registries = set()
def add_class(classname, cls):
"""Add a class to the _decl_class_registry associated with the
given declarative class.
"""
if classname in cls._decl_class_registry:
# class already exists.
existing = cls._decl_class_registry[classname]
if not isinstance(existing, _MultipleClassMarker):
existing = \
cls._decl_class_registry[classname] = \
_MultipleClassMarker([cls, existing])
else:
cls._decl_class_registry[classname] = cls
try:
root_module = cls._decl_class_registry['_sa_module_registry']
except KeyError:
cls._decl_class_registry['_sa_module_registry'] = \
root_module = _ModuleMarker('_sa_module_registry', None)
tokens = cls.__module__.split(".")
# build up a tree like this:
# modulename: myapp.snacks.nuts
#
# myapp->snack->nuts->(classes)
# snack->nuts->(classes)
# nuts->(classes)
#
# this allows partial token paths to be used.
while tokens:
token = tokens.pop(0)
module = root_module.get_module(token)
for token in tokens:
module = module.get_module(token)
module.add_class(classname, cls)
class _MultipleClassMarker(object):
"""refers to multiple classes of the same name
within _decl_class_registry.
"""
__slots__ = 'on_remove', 'contents', '__weakref__'
def __init__(self, classes, on_remove=None):
self.on_remove = on_remove
self.contents = set([
weakref.ref(item, self._remove_item) for item in classes])
_registries.add(self)
def __iter__(self):
return (ref() for ref in self.contents)
def attempt_get(self, path, key):
if len(self.contents) > 1:
raise exc.InvalidRequestError(
"Multiple classes found for path \"%s\" "
"in the registry of this declarative "
"base. Please use a fully module-qualified path." %
(".".join(path + [key]))
)
else:
ref = list(self.contents)[0]
cls = ref()
if cls is None:
raise NameError(key)
return cls
def _remove_item(self, ref):
self.contents.remove(ref)
if not self.contents:
_registries.discard(self)
if self.on_remove:
self.on_remove()
def add_item(self, item):
# protect against class registration race condition against
# asynchronous garbage collection calling _remove_item,
# [ticket:3208]
modules = set([
cls.__module__ for cls in
[ref() for ref in self.contents] if cls is not None])
if item.__module__ in modules:
util.warn(
"This declarative base already contains a class with the "
"same class name and module name as %s.%s, and will "
"be replaced in the string-lookup table." % (
item.__module__,
item.__name__
)
)
self.contents.add(weakref.ref(item, self._remove_item))
class _ModuleMarker(object):
""""refers to a module name within
_decl_class_registry.
"""
__slots__ = 'parent', 'name', 'contents', 'mod_ns', 'path', '__weakref__'
def __init__(self, name, parent):
self.parent = parent
self.name = name
self.contents = {}
self.mod_ns = _ModNS(self)
if self.parent:
self.path = self.parent.path + [self.name]
else:
self.path = []
_registries.add(self)
def __contains__(self, name):
return name in self.contents
def __getitem__(self, name):
return self.contents[name]
def _remove_item(self, name):
self.contents.pop(name, None)
if not self.contents and self.parent is not None:
self.parent._remove_item(self.name)
_registries.discard(self)
def resolve_attr(self, key):
return getattr(self.mod_ns, key)
def get_module(self, name):
if name not in self.contents:
marker = _ModuleMarker(name, self)
self.contents[name] = marker
else:
marker = self.contents[name]
return marker
def add_class(self, name, cls):
if name in self.contents:
existing = self.contents[name]
existing.add_item(cls)
else:
existing = self.contents[name] = \
_MultipleClassMarker([cls],
on_remove=lambda: self._remove_item(name))
class _ModNS(object):
__slots__ = '__parent',
def __init__(self, parent):
self.__parent = parent
def __getattr__(self, key):
try:
value = self.__parent.contents[key]
except KeyError:
pass
else:
if value is not None:
if isinstance(value, _ModuleMarker):
return value.mod_ns
else:
assert isinstance(value, _MultipleClassMarker)
return value.attempt_get(self.__parent.path, key)
raise AttributeError("Module %r has no mapped classes "
"registered under the name %r" % (
self.__parent.name, key))
class _GetColumns(object):
__slots__ = 'cls',
def __init__(self, cls):
self.cls = cls
def __getattr__(self, key):
mp = class_mapper(self.cls, configure=False)
if mp:
if key not in mp.all_orm_descriptors:
raise exc.InvalidRequestError(
"Class %r does not have a mapped column named %r"
% (self.cls, key))
desc = mp.all_orm_descriptors[key]
if desc.extension_type is interfaces.NOT_EXTENSION:
prop = desc.property
if isinstance(prop, SynonymProperty):
key = prop.name
elif not isinstance(prop, ColumnProperty):
raise exc.InvalidRequestError(
"Property %r is not an instance of"
" ColumnProperty (i.e. does not correspond"
" directly to a Column)." % key)
return getattr(self.cls, key)
inspection._inspects(_GetColumns)(
lambda target: inspection.inspect(target.cls))
class _GetTable(object):
__slots__ = 'key', 'metadata'
def __init__(self, key, metadata):
self.key = key
self.metadata = metadata
def __getattr__(self, key):
return self.metadata.tables[
_get_table_key(key, self.key)
]
def _determine_container(key, value):
if isinstance(value, _MultipleClassMarker):
value = value.attempt_get([], key)
return _GetColumns(value)
class _class_resolver(object):
def __init__(self, cls, prop, fallback, arg):
self.cls = cls
self.prop = prop
self.arg = self._declarative_arg = arg
self.fallback = fallback
self._dict = util.PopulateDict(self._access_cls)
self._resolvers = ()
def _access_cls(self, key):
cls = self.cls
if key in cls._decl_class_registry:
return _determine_container(key, cls._decl_class_registry[key])
elif key in cls.metadata.tables:
return cls.metadata.tables[key]
elif key in cls.metadata._schemas:
return _GetTable(key, cls.metadata)
elif '_sa_module_registry' in cls._decl_class_registry and \
key in cls._decl_class_registry['_sa_module_registry']:
registry = cls._decl_class_registry['_sa_module_registry']
return registry.resolve_attr(key)
elif self._resolvers:
for resolv in self._resolvers:
value = resolv(key)
if value is not None:
return value
return self.fallback[key]
def __call__(self):
try:
x = eval(self.arg, globals(), self._dict)
if isinstance(x, _GetColumns):
return x.cls
else:
return x
except NameError as n:
raise exc.InvalidRequestError(
"When initializing mapper %s, expression %r failed to "
"locate a name (%r). If this is a class name, consider "
"adding this relationship() to the %r class after "
"both dependent classes have been defined." %
(self.prop.parent, self.arg, n.args[0], self.cls)
)
def _resolver(cls, prop):
import sqlalchemy
from sqlalchemy.orm import foreign, remote
fallback = sqlalchemy.__dict__.copy()
fallback.update({'foreign': foreign, 'remote': remote})
def resolve_arg(arg):
return _class_resolver(cls, prop, fallback, arg)
return resolve_arg
def _deferred_relationship(cls, prop):
if isinstance(prop, RelationshipProperty):
resolve_arg = _resolver(cls, prop)
for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin',
'secondary', '_user_defined_foreign_keys', 'remote_side'):
v = getattr(prop, attr)
if isinstance(v, util.string_types):
setattr(prop, attr, resolve_arg(v))
if prop.backref and isinstance(prop.backref, tuple):
key, kwargs = prop.backref
for attr in ('primaryjoin', 'secondaryjoin', 'secondary',
'foreign_keys', 'remote_side', 'order_by'):
if attr in kwargs and isinstance(kwargs[attr],
util.string_types):
kwargs[attr] = resolve_arg(kwargs[attr])
return prop

841
sqlalchemy/ext/hybrid.py Normal file
View File

@ -0,0 +1,841 @@
# ext/hybrid.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
r"""Define attributes on ORM-mapped classes that have "hybrid" behavior.
"hybrid" means the attribute has distinct behaviors defined at the
class level and at the instance level.
The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of
method decorator, is around 50 lines of code and has almost no
dependencies on the rest of SQLAlchemy. It can, in theory, work with
any descriptor-based expression system.
Consider a mapping ``Interval``, representing integer ``start`` and ``end``
values. We can define higher level functions on mapped classes that produce
SQL expressions at the class level, and Python expression evaluation at the
instance level. Below, each function decorated with :class:`.hybrid_method` or
:class:`.hybrid_property` may receive ``self`` as an instance of the class, or
as the class itself::
from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, aliased
from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
Base = declarative_base()
class Interval(Base):
__tablename__ = 'interval'
id = Column(Integer, primary_key=True)
start = Column(Integer, nullable=False)
end = Column(Integer, nullable=False)
def __init__(self, start, end):
self.start = start
self.end = end
@hybrid_property
def length(self):
return self.end - self.start
@hybrid_method
def contains(self, point):
return (self.start <= point) & (point <= self.end)
@hybrid_method
def intersects(self, other):
return self.contains(other.start) | self.contains(other.end)
Above, the ``length`` property returns the difference between the
``end`` and ``start`` attributes. With an instance of ``Interval``,
this subtraction occurs in Python, using normal Python descriptor
mechanics::
>>> i1 = Interval(5, 10)
>>> i1.length
5
When dealing with the ``Interval`` class itself, the :class:`.hybrid_property`
descriptor evaluates the function body given the ``Interval`` class as
the argument, which when evaluated with SQLAlchemy expression mechanics
returns a new SQL expression::
>>> print Interval.length
interval."end" - interval.start
>>> print Session().query(Interval).filter(Interval.length > 10)
SELECT interval.id AS interval_id, interval.start AS interval_start,
interval."end" AS interval_end
FROM interval
WHERE interval."end" - interval.start > :param_1
ORM methods such as :meth:`~.Query.filter_by` generally use ``getattr()`` to
locate attributes, so can also be used with hybrid attributes::
>>> print Session().query(Interval).filter_by(length=5)
SELECT interval.id AS interval_id, interval.start AS interval_start,
interval."end" AS interval_end
FROM interval
WHERE interval."end" - interval.start = :param_1
The ``Interval`` class example also illustrates two methods,
``contains()`` and ``intersects()``, decorated with
:class:`.hybrid_method`. This decorator applies the same idea to
methods that :class:`.hybrid_property` applies to attributes. The
methods return boolean values, and take advantage of the Python ``|``
and ``&`` bitwise operators to produce equivalent instance-level and
SQL expression-level boolean behavior::
>>> i1.contains(6)
True
>>> i1.contains(15)
False
>>> i1.intersects(Interval(7, 18))
True
>>> i1.intersects(Interval(25, 29))
False
>>> print Session().query(Interval).filter(Interval.contains(15))
SELECT interval.id AS interval_id, interval.start AS interval_start,
interval."end" AS interval_end
FROM interval
WHERE interval.start <= :start_1 AND interval."end" > :end_1
>>> ia = aliased(Interval)
>>> print Session().query(Interval, ia).filter(Interval.intersects(ia))
SELECT interval.id AS interval_id, interval.start AS interval_start,
interval."end" AS interval_end, interval_1.id AS interval_1_id,
interval_1.start AS interval_1_start, interval_1."end" AS interval_1_end
FROM interval, interval AS interval_1
WHERE interval.start <= interval_1.start
AND interval."end" > interval_1.start
OR interval.start <= interval_1."end"
AND interval."end" > interval_1."end"
Defining Expression Behavior Distinct from Attribute Behavior
--------------------------------------------------------------
Our usage of the ``&`` and ``|`` bitwise operators above was
fortunate, considering our functions operated on two boolean values to
return a new one. In many cases, the construction of an in-Python
function and a SQLAlchemy SQL expression have enough differences that
two separate Python expressions should be defined. The
:mod:`~sqlalchemy.ext.hybrid` decorators define the
:meth:`.hybrid_property.expression` modifier for this purpose. As an
example we'll define the radius of the interval, which requires the
usage of the absolute value function::
from sqlalchemy import func
class Interval(object):
# ...
@hybrid_property
def radius(self):
return abs(self.length) / 2
@radius.expression
def radius(cls):
return func.abs(cls.length) / 2
Above the Python function ``abs()`` is used for instance-level
operations, the SQL function ``ABS()`` is used via the :data:`.func`
object for class-level expressions::
>>> i1.radius
2
>>> print Session().query(Interval).filter(Interval.radius > 5)
SELECT interval.id AS interval_id, interval.start AS interval_start,
interval."end" AS interval_end
FROM interval
WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1
Defining Setters
----------------
Hybrid properties can also define setter methods. If we wanted
``length`` above, when set, to modify the endpoint value::
class Interval(object):
# ...
@hybrid_property
def length(self):
return self.end - self.start
@length.setter
def length(self, value):
self.end = self.start + value
The ``length(self, value)`` method is now called upon set::
>>> i1 = Interval(5, 10)
>>> i1.length
5
>>> i1.length = 12
>>> i1.end
17
Working with Relationships
--------------------------
There's no essential difference when creating hybrids that work with
related objects as opposed to column-based data. The need for distinct
expressions tends to be greater. The two variants we'll illustrate
are the "join-dependent" hybrid, and the "correlated subquery" hybrid.
Join-Dependent Relationship Hybrid
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Consider the following declarative
mapping which relates a ``User`` to a ``SavingsAccount``::
from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
Base = declarative_base()
class SavingsAccount(Base):
__tablename__ = 'account'
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
balance = Column(Numeric(15, 5))
class User(Base):
__tablename__ = 'user'
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
accounts = relationship("SavingsAccount", backref="owner")
@hybrid_property
def balance(self):
if self.accounts:
return self.accounts[0].balance
else:
return None
@balance.setter
def balance(self, value):
if not self.accounts:
account = Account(owner=self)
else:
account = self.accounts[0]
account.balance = value
@balance.expression
def balance(cls):
return SavingsAccount.balance
The above hybrid property ``balance`` works with the first
``SavingsAccount`` entry in the list of accounts for this user. The
in-Python getter/setter methods can treat ``accounts`` as a Python
list available on ``self``.
However, at the expression level, it's expected that the ``User`` class will
be used in an appropriate context such that an appropriate join to
``SavingsAccount`` will be present::
>>> print Session().query(User, User.balance).\
... join(User.accounts).filter(User.balance > 5000)
SELECT "user".id AS user_id, "user".name AS user_name,
account.balance AS account_balance
FROM "user" JOIN account ON "user".id = account.user_id
WHERE account.balance > :balance_1
Note however, that while the instance level accessors need to worry
about whether ``self.accounts`` is even present, this issue expresses
itself differently at the SQL expression level, where we basically
would use an outer join::
>>> from sqlalchemy import or_
>>> print (Session().query(User, User.balance).outerjoin(User.accounts).
... filter(or_(User.balance < 5000, User.balance == None)))
SELECT "user".id AS user_id, "user".name AS user_name,
account.balance AS account_balance
FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id
WHERE account.balance < :balance_1 OR account.balance IS NULL
Correlated Subquery Relationship Hybrid
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
We can, of course, forego being dependent on the enclosing query's usage
of joins in favor of the correlated subquery, which can portably be packed
into a single column expression. A correlated subquery is more portable, but
often performs more poorly at the SQL level. Using the same technique
illustrated at :ref:`mapper_column_property_sql_expressions`,
we can adjust our ``SavingsAccount`` example to aggregate the balances for
*all* accounts, and use a correlated subquery for the column expression::
from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy import select, func
Base = declarative_base()
class SavingsAccount(Base):
__tablename__ = 'account'
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
balance = Column(Numeric(15, 5))
class User(Base):
__tablename__ = 'user'
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
accounts = relationship("SavingsAccount", backref="owner")
@hybrid_property
def balance(self):
return sum(acc.balance for acc in self.accounts)
@balance.expression
def balance(cls):
return select([func.sum(SavingsAccount.balance)]).\
where(SavingsAccount.user_id==cls.id).\
label('total_balance')
The above recipe will give us the ``balance`` column which renders
a correlated SELECT::
>>> print s.query(User).filter(User.balance > 400)
SELECT "user".id AS user_id, "user".name AS user_name
FROM "user"
WHERE (SELECT sum(account.balance) AS sum_1
FROM account
WHERE account.user_id = "user".id) > :param_1
.. _hybrid_custom_comparators:
Building Custom Comparators
---------------------------
The hybrid property also includes a helper that allows construction of
custom comparators. A comparator object allows one to customize the
behavior of each SQLAlchemy expression operator individually. They
are useful when creating custom types that have some highly
idiosyncratic behavior on the SQL side.
The example class below allows case-insensitive comparisons on the attribute
named ``word_insensitive``::
from sqlalchemy.ext.hybrid import Comparator, hybrid_property
from sqlalchemy import func, Column, Integer, String
from sqlalchemy.orm import Session
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class CaseInsensitiveComparator(Comparator):
def __eq__(self, other):
return func.lower(self.__clause_element__()) == func.lower(other)
class SearchWord(Base):
__tablename__ = 'searchword'
id = Column(Integer, primary_key=True)
word = Column(String(255), nullable=False)
@hybrid_property
def word_insensitive(self):
return self.word.lower()
@word_insensitive.comparator
def word_insensitive(cls):
return CaseInsensitiveComparator(cls.word)
Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()``
SQL function to both sides::
>>> print Session().query(SearchWord).filter_by(word_insensitive="Trucks")
SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
FROM searchword
WHERE lower(searchword.word) = lower(:lower_1)
The ``CaseInsensitiveComparator`` above implements part of the
:class:`.ColumnOperators` interface. A "coercion" operation like
lowercasing can be applied to all comparison operations (i.e. ``eq``,
``lt``, ``gt``, etc.) using :meth:`.Operators.operate`::
class CaseInsensitiveComparator(Comparator):
def operate(self, op, other):
return op(func.lower(self.__clause_element__()), func.lower(other))
Hybrid Value Objects
--------------------
Note in our previous example, if we were to compare the
``word_insensitive`` attribute of a ``SearchWord`` instance to a plain
Python string, the plain Python string would not be coerced to lower
case - the ``CaseInsensitiveComparator`` we built, being returned by
``@word_insensitive.comparator``, only applies to the SQL side.
A more comprehensive form of the custom comparator is to construct a
*Hybrid Value Object*. This technique applies the target value or
expression to a value object which is then returned by the accessor in
all cases. The value object allows control of all operations upon
the value as well as how compared values are treated, both on the SQL
expression side as well as the Python value side. Replacing the
previous ``CaseInsensitiveComparator`` class with a new
``CaseInsensitiveWord`` class::
class CaseInsensitiveWord(Comparator):
"Hybrid value representing a lower case representation of a word."
def __init__(self, word):
if isinstance(word, basestring):
self.word = word.lower()
elif isinstance(word, CaseInsensitiveWord):
self.word = word.word
else:
self.word = func.lower(word)
def operate(self, op, other):
if not isinstance(other, CaseInsensitiveWord):
other = CaseInsensitiveWord(other)
return op(self.word, other.word)
def __clause_element__(self):
return self.word
def __str__(self):
return self.word
key = 'word'
"Label to apply to Query tuple results"
Above, the ``CaseInsensitiveWord`` object represents ``self.word``,
which may be a SQL function, or may be a Python native. By
overriding ``operate()`` and ``__clause_element__()`` to work in terms
of ``self.word``, all comparison operations will work against the
"converted" form of ``word``, whether it be SQL side or Python side.
Our ``SearchWord`` class can now deliver the ``CaseInsensitiveWord``
object unconditionally from a single hybrid call::
class SearchWord(Base):
__tablename__ = 'searchword'
id = Column(Integer, primary_key=True)
word = Column(String(255), nullable=False)
@hybrid_property
def word_insensitive(self):
return CaseInsensitiveWord(self.word)
The ``word_insensitive`` attribute now has case-insensitive comparison
behavior universally, including SQL expression vs. Python expression
(note the Python value is converted to lower case on the Python side
here)::
>>> print Session().query(SearchWord).filter_by(word_insensitive="Trucks")
SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
FROM searchword
WHERE lower(searchword.word) = :lower_1
SQL expression versus SQL expression::
>>> sw1 = aliased(SearchWord)
>>> sw2 = aliased(SearchWord)
>>> print Session().query(
... sw1.word_insensitive,
... sw2.word_insensitive).\
... filter(
... sw1.word_insensitive > sw2.word_insensitive
... )
SELECT lower(searchword_1.word) AS lower_1,
lower(searchword_2.word) AS lower_2
FROM searchword AS searchword_1, searchword AS searchword_2
WHERE lower(searchword_1.word) > lower(searchword_2.word)
Python only expression::
>>> ws1 = SearchWord(word="SomeWord")
>>> ws1.word_insensitive == "sOmEwOrD"
True
>>> ws1.word_insensitive == "XOmEwOrX"
False
>>> print ws1.word_insensitive
someword
The Hybrid Value pattern is very useful for any kind of value that may
have multiple representations, such as timestamps, time deltas, units
of measurement, currencies and encrypted passwords.
.. seealso::
`Hybrids and Value Agnostic Types
<http://techspot.zzzeek.org/2011/10/21/hybrids-and-value-agnostic-types/>`_
- on the techspot.zzzeek.org blog
`Value Agnostic Types, Part II
<http://techspot.zzzeek.org/2011/10/29/value-agnostic-types-part-ii/>`_ -
on the techspot.zzzeek.org blog
.. _hybrid_transformers:
Building Transformers
----------------------
A *transformer* is an object which can receive a :class:`.Query`
object and return a new one. The :class:`.Query` object includes a
method :meth:`.with_transformation` that returns a new :class:`.Query`
transformed by the given function.
We can combine this with the :class:`.Comparator` class to produce one type
of recipe which can both set up the FROM clause of a query as well as assign
filtering criterion.
Consider a mapped class ``Node``, which assembles using adjacency list
into a hierarchical tree pattern::
from sqlalchemy import Column, Integer, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class Node(Base):
__tablename__ = 'node'
id = Column(Integer, primary_key=True)
parent_id = Column(Integer, ForeignKey('node.id'))
parent = relationship("Node", remote_side=id)
Suppose we wanted to add an accessor ``grandparent``. This would
return the ``parent`` of ``Node.parent``. When we have an instance of
``Node``, this is simple::
from sqlalchemy.ext.hybrid import hybrid_property
class Node(Base):
# ...
@hybrid_property
def grandparent(self):
return self.parent.parent
For the expression, things are not so clear. We'd need to construct
a :class:`.Query` where we :meth:`~.Query.join` twice along
``Node.parent`` to get to the ``grandparent``. We can instead return
a transforming callable that we'll combine with the
:class:`.Comparator` class to receive any :class:`.Query` object, and
return a new one that's joined to the ``Node.parent`` attribute and
filtered based on the given criterion::
from sqlalchemy.ext.hybrid import Comparator
class GrandparentTransformer(Comparator):
def operate(self, op, other):
def transform(q):
cls = self.__clause_element__()
parent_alias = aliased(cls)
return q.join(parent_alias, cls.parent).\
filter(op(parent_alias.parent, other))
return transform
Base = declarative_base()
class Node(Base):
__tablename__ = 'node'
id =Column(Integer, primary_key=True)
parent_id = Column(Integer, ForeignKey('node.id'))
parent = relationship("Node", remote_side=id)
@hybrid_property
def grandparent(self):
return self.parent.parent
@grandparent.comparator
def grandparent(cls):
return GrandparentTransformer(cls)
The ``GrandparentTransformer`` overrides the core
:meth:`.Operators.operate` method at the base of the
:class:`.Comparator` hierarchy to return a query-transforming
callable, which then runs the given comparison operation in a
particular context. Such as, in the example above, the ``operate``
method is called, given the :attr:`.Operators.eq` callable as well as
the right side of the comparison ``Node(id=5)``. A function
``transform`` is then returned which will transform a :class:`.Query`
first to join to ``Node.parent``, then to compare ``parent_alias``
using :attr:`.Operators.eq` against the left and right sides, passing
into :class:`.Query.filter`:
.. sourcecode:: pycon+sql
>>> from sqlalchemy.orm import Session
>>> session = Session()
{sql}>>> session.query(Node).\
... with_transformation(Node.grandparent==Node(id=5)).\
... all()
SELECT node.id AS node_id, node.parent_id AS node_parent_id
FROM node JOIN node AS node_1 ON node_1.id = node.parent_id
WHERE :param_1 = node_1.parent_id
{stop}
We can modify the pattern to be more verbose but flexible by separating
the "join" step from the "filter" step. The tricky part here is ensuring
that successive instances of ``GrandparentTransformer`` use the same
:class:`.AliasedClass` object against ``Node``. Below we use a simple
memoizing approach that associates a ``GrandparentTransformer``
with each class::
class Node(Base):
# ...
@grandparent.comparator
def grandparent(cls):
# memoize a GrandparentTransformer
# per class
if '_gp' not in cls.__dict__:
cls._gp = GrandparentTransformer(cls)
return cls._gp
class GrandparentTransformer(Comparator):
def __init__(self, cls):
self.parent_alias = aliased(cls)
@property
def join(self):
def go(q):
return q.join(self.parent_alias, Node.parent)
return go
def operate(self, op, other):
return op(self.parent_alias.parent, other)
.. sourcecode:: pycon+sql
{sql}>>> session.query(Node).\
... with_transformation(Node.grandparent.join).\
... filter(Node.grandparent==Node(id=5))
SELECT node.id AS node_id, node.parent_id AS node_parent_id
FROM node JOIN node AS node_1 ON node_1.id = node.parent_id
WHERE :param_1 = node_1.parent_id
{stop}
The "transformer" pattern is an experimental pattern that starts
to make usage of some functional programming paradigms.
While it's only recommended for advanced and/or patient developers,
there's probably a whole lot of amazing things it can be used for.
"""
from .. import util
from ..orm import attributes, interfaces
HYBRID_METHOD = util.symbol('HYBRID_METHOD')
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.hybrid_method`.
Is assigned to the :attr:`.InspectionAttr.extension_type`
attibute.
.. seealso::
:attr:`.Mapper.all_orm_attributes`
"""
HYBRID_PROPERTY = util.symbol('HYBRID_PROPERTY')
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.hybrid_method`.
Is assigned to the :attr:`.InspectionAttr.extension_type`
attibute.
.. seealso::
:attr:`.Mapper.all_orm_attributes`
"""
class hybrid_method(interfaces.InspectionAttrInfo):
"""A decorator which allows definition of a Python object method with both
instance-level and class-level behavior.
"""
is_attribute = True
extension_type = HYBRID_METHOD
def __init__(self, func, expr=None):
"""Create a new :class:`.hybrid_method`.
Usage is typically via decorator::
from sqlalchemy.ext.hybrid import hybrid_method
class SomeClass(object):
@hybrid_method
def value(self, x, y):
return self._value + x + y
@value.expression
def value(self, x, y):
return func.some_function(self._value, x, y)
"""
self.func = func
self.expression(expr or func)
def __get__(self, instance, owner):
if instance is None:
return self.expr.__get__(owner, owner.__class__)
else:
return self.func.__get__(instance, owner)
def expression(self, expr):
"""Provide a modifying decorator that defines a
SQL-expression producing method."""
self.expr = expr
if not self.expr.__doc__:
self.expr.__doc__ = self.func.__doc__
return self
class hybrid_property(interfaces.InspectionAttrInfo):
"""A decorator which allows definition of a Python descriptor with both
instance-level and class-level behavior.
"""
is_attribute = True
extension_type = HYBRID_PROPERTY
def __init__(self, fget, fset=None, fdel=None, expr=None):
"""Create a new :class:`.hybrid_property`.
Usage is typically via decorator::
from sqlalchemy.ext.hybrid import hybrid_property
class SomeClass(object):
@hybrid_property
def value(self):
return self._value
@value.setter
def value(self, value):
self._value = value
"""
self.fget = fget
self.fset = fset
self.fdel = fdel
self.expression(expr or fget)
util.update_wrapper(self, fget)
def __get__(self, instance, owner):
if instance is None:
return self.expr(owner)
else:
return self.fget(instance)
def __set__(self, instance, value):
if self.fset is None:
raise AttributeError("can't set attribute")
self.fset(instance, value)
def __delete__(self, instance):
if self.fdel is None:
raise AttributeError("can't delete attribute")
self.fdel(instance)
def setter(self, fset):
"""Provide a modifying decorator that defines a value-setter method."""
self.fset = fset
return self
def deleter(self, fdel):
"""Provide a modifying decorator that defines a
value-deletion method."""
self.fdel = fdel
return self
def expression(self, expr):
"""Provide a modifying decorator that defines a SQL-expression
producing method."""
def _expr(cls):
return ExprComparator(expr(cls), self)
util.update_wrapper(_expr, expr)
self.expr = _expr
return self.comparator(_expr)
def comparator(self, comparator):
"""Provide a modifying decorator that defines a custom
comparator producing method.
The return value of the decorated method should be an instance of
:class:`~.hybrid.Comparator`.
"""
proxy_attr = attributes.\
create_proxied_attribute(self)
def expr(owner):
return proxy_attr(
owner, self.__name__, self, comparator(owner),
doc=comparator.__doc__ or self.__doc__)
self.expr = expr
return self
class Comparator(interfaces.PropComparator):
"""A helper class that allows easy construction of custom
:class:`~.orm.interfaces.PropComparator`
classes for usage with hybrids."""
property = None
def __init__(self, expression):
self.expression = expression
def __clause_element__(self):
expr = self.expression
if hasattr(expr, '__clause_element__'):
expr = expr.__clause_element__()
return expr
def adapt_to_entity(self, adapt_to_entity):
# interesting....
return self
class ExprComparator(Comparator):
def __init__(self, expression, hybrid):
self.expression = expression
self.hybrid = hybrid
def __getattr__(self, key):
return getattr(self.expression, key)
@property
def info(self):
return self.hybrid.info
@property
def property(self):
return self.expression.property
def operate(self, op, *other, **kwargs):
return op(self.expression, *other, **kwargs)
def reverse_operate(self, op, other, **kwargs):
return op(other, self.expression, **kwargs)

349
sqlalchemy/ext/indexable.py Normal file
View File

@ -0,0 +1,349 @@
# ext/index.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Define attributes on ORM-mapped classes that have "index" attributes for
columns with :class:`~.types.Indexable` types.
"index" means the attribute is associated with an element of an
:class:`~.types.Indexable` column with the predefined index to access it.
The :class:`~.types.Indexable` types include types such as
:class:`~.types.ARRAY`, :class:`~.types.JSON` and
:class:`~.postgresql.HSTORE`.
The :mod:`~sqlalchemy.ext.indexable` extension provides
:class:`~.schema.Column`-like interface for any element of an
:class:`~.types.Indexable` typed column. In simple cases, it can be
treated as a :class:`~.schema.Column` - mapped attribute.
.. versionadded:: 1.1
Synopsis
========
Given ``Person`` as a model with a primary key and JSON data field.
While this field may have any number of elements encoded within it,
we would like to refer to the element called ``name`` individually
as a dedicated attribute which behaves like a standalone column::
from sqlalchemy import Column, JSON, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.indexable import index_property
Base = declarative_base()
class Person(Base):
__tablename__ = 'person'
id = Column(Integer, primary_key=True)
data = Column(JSON)
name = index_property('data', 'name')
Above, the ``name`` attribute now behaves like a mapped column. We
can compose a new ``Person`` and set the value of ``name``::
>>> person = Person(name='Alchemist')
The value is now accessible::
>>> person.name
'Alchemist'
Behind the scenes, the JSON field was initialized to a new blank dictionary
and the field was set::
>>> person.data
{"name": "Alchemist'}
The field is mutable in place::
>>> person.name = 'Renamed'
>>> person.name
'Renamed'
>>> person.data
{'name': 'Renamed'}
When using :class:`.index_property`, the change that we make to the indexable
structure is also automatically tracked as history; we no longer need
to use :class:`~.mutable.MutableDict` in order to track this change
for the unit of work.
Deletions work normally as well::
>>> del person.name
>>> person.data
{}
Above, deletion of ``person.name`` deletes the value from the dictionary,
but not the dictionary itself.
A missing key will produce ``AttributeError``::
>>> person = Person()
>>> person.name
...
AttributeError: 'name'
Unless you set a default value::
>>> class Person(Base):
>>> __tablename__ = 'person'
>>>
>>> id = Column(Integer, primary_key=True)
>>> data = Column(JSON)
>>>
>>> name = index_property('data', 'name', default=None) # See default
>>> person = Person()
>>> print(person.name)
None
The attributes are also accessible at the class level.
Below, we illustrate ``Person.name`` used to generate
an indexed SQL criteria::
>>> from sqlalchemy.orm import Session
>>> session = Session()
>>> query = session.query(Person).filter(Person.name == 'Alchemist')
The above query is equivalent to::
>>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist')
Multiple :class:`.index_property` objects can be chained to produce
multiple levels of indexing::
from sqlalchemy import Column, JSON, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.indexable import index_property
Base = declarative_base()
class Person(Base):
__tablename__ = 'person'
id = Column(Integer, primary_key=True)
data = Column(JSON)
birthday = index_property('data', 'birthday')
year = index_property('birthday', 'year')
month = index_property('birthday', 'month')
day = index_property('birthday', 'day')
Above, a query such as::
q = session.query(Person).filter(Person.year == '1980')
On a PostgreSQL backend, the above query will render as::
SELECT person.id, person.data
FROM person
WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s
Default Values
==============
:class:`.index_property` includes special behaviors for when the indexed
data structure does not exist, and a set operation is called:
* For an :class:`.index_property` that is given an integer index value,
the default data structure will be a Python list of ``None`` values,
at least as long as the index value; the value is then set at its
place in the list. This means for an index value of zero, the list
will be initialized to ``[None]`` before setting the given value,
and for an index value of five, the list will be initialized to
``[None, None, None, None, None]`` before setting the fifth element
to the given value. Note that an existing list is **not** extended
in place to receive a value.
* for an :class:`.index_property` that is given any other kind of index
value (e.g. strings usually), a Python dictionary is used as the
default data structure.
* The default data structure can be set to any Python callable using the
:paramref:`.index_property.datatype` parameter, overriding the previous
rules.
Subclassing
===========
:class:`.index_property` can be subclassed, in particular for the common
use case of providing coercion of values or SQL expressions as they are
accessed. Below is a common recipe for use with a PostgreSQL JSON type,
where we want to also include automatic casting plus ``astext()``::
class pg_json_property(index_property):
def __init__(self, attr_name, index, cast_type):
super(pg_json_property, self).__init__(attr_name, index)
self.cast_type = cast_type
def expr(self, model):
expr = super(pg_json_property, self).expr(model)
return expr.astext.cast(self.cast_type)
The above subclass can be used with the PostgreSQL-specific
version of :class:`.postgresql.JSON`::
from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.dialects.postgresql import JSON
Base = declarative_base()
class Person(Base):
__tablename__ = 'person'
id = Column(Integer, primary_key=True)
data = Column(JSON)
age = pg_json_property('data', 'age', Integer)
The ``age`` attribute at the instance level works as before; however
when rendering SQL, PostgreSQL's ``->>`` operator will be used
for indexed access, instead of the usual index opearator of ``->``::
>>> query = session.query(Person).filter(Person.age < 20)
The above query will render::
SELECT person.id, person.data
FROM person
WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s
"""
from __future__ import absolute_import
from sqlalchemy import inspect
from ..orm.attributes import flag_modified
from ..ext.hybrid import hybrid_property
__all__ = ['index_property']
class index_property(hybrid_property): # noqa
"""A property generator. The generated property describes an object
attribute that corresponds to an :class:`~.types.Indexable`
column.
.. versionadded:: 1.1
.. seealso::
:mod:`sqlalchemy.ext.indexable`
"""
_NO_DEFAULT_ARGUMENT = object()
def __init__(
self, attr_name, index, default=_NO_DEFAULT_ARGUMENT,
datatype=None, mutable=True, onebased=True):
"""Create a new :class:`.index_property`.
:param attr_name:
An attribute name of an `Indexable` typed column, or other
attribute that returns an indexable structure.
:param index:
The index to be used for getting and setting this value. This
should be the Python-side index value for integers.
:param default:
A value which will be returned instead of `AttributeError`
when there is not a value at given index.
:param datatype: default datatype to use when the field is empty.
By default, this is derived from the type of index used; a
Python list for an integer index, or a Python dictionary for
any other style of index. For a list, the list will be
initialized to a list of None values that is at least
``index`` elements long.
:param mutable: if False, writes and deletes to the attribute will
be disallowed.
:param onebased: assume the SQL representation of this value is
one-based; that is, the first index in SQL is 1, not zero.
"""
if mutable:
super(index_property, self).__init__(
self.fget, self.fset, self.fdel, self.expr
)
else:
super(index_property, self).__init__(
self.fget, None, None, self.expr
)
self.attr_name = attr_name
self.index = index
self.default = default
is_numeric = isinstance(index, int)
onebased = is_numeric and onebased
if datatype is not None:
self.datatype = datatype
else:
if is_numeric:
self.datatype = lambda: [None for x in range(index + 1)]
else:
self.datatype = dict
self.onebased = onebased
def _fget_default(self):
if self.default == self._NO_DEFAULT_ARGUMENT:
raise AttributeError(self.attr_name)
else:
return self.default
def fget(self, instance):
attr_name = self.attr_name
column_value = getattr(instance, attr_name)
if column_value is None:
return self._fget_default()
try:
value = column_value[self.index]
except (KeyError, IndexError):
return self._fget_default()
else:
return value
def fset(self, instance, value):
attr_name = self.attr_name
column_value = getattr(instance, attr_name, None)
if column_value is None:
column_value = self.datatype()
setattr(instance, attr_name, column_value)
column_value[self.index] = value
setattr(instance, attr_name, column_value)
if attr_name in inspect(instance).mapper.attrs:
flag_modified(instance, attr_name)
def fdel(self, instance):
attr_name = self.attr_name
column_value = getattr(instance, attr_name)
if column_value is None:
raise AttributeError(self.attr_name)
try:
del column_value[self.index]
except KeyError:
raise AttributeError(self.attr_name)
else:
setattr(instance, attr_name, column_value)
flag_modified(instance, attr_name)
def expr(self, model):
column = getattr(model, self.attr_name)
index = self.index
if self.onebased:
index += 1
return column[index]

View File

@ -0,0 +1,414 @@
"""Extensible class instrumentation.
The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate
systems of class instrumentation within the ORM. Class instrumentation
refers to how the ORM places attributes on the class which maintain
data and track changes to that data, as well as event hooks installed
on the class.
.. note::
The extension package is provided for the benefit of integration
with other object management packages, which already perform
their own instrumentation. It is not intended for general use.
For examples of how the instrumentation extension is used,
see the example :ref:`examples_instrumentation`.
.. versionchanged:: 0.8
The :mod:`sqlalchemy.orm.instrumentation` was split out so
that all functionality having to do with non-standard
instrumentation was moved out to :mod:`sqlalchemy.ext.instrumentation`.
When imported, the module installs itself within
:mod:`sqlalchemy.orm.instrumentation` so that it
takes effect, including recognition of
``__sa_instrumentation_manager__`` on mapped classes, as
well :data:`.instrumentation_finders`
being used to determine class instrumentation resolution.
"""
from ..orm import instrumentation as orm_instrumentation
from ..orm.instrumentation import (
ClassManager, InstrumentationFactory, _default_state_getter,
_default_dict_getter, _default_manager_getter
)
from ..orm import attributes, collections, base as orm_base
from .. import util
from ..orm import exc as orm_exc
import weakref
INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__'
"""Attribute, elects custom instrumentation when present on a mapped class.
Allows a class to specify a slightly or wildly different technique for
tracking changes made to mapped attributes and collections.
Only one instrumentation implementation is allowed in a given object
inheritance hierarchy.
The value of this attribute must be a callable and will be passed a class
object. The callable must return one of:
- An instance of an InstrumentationManager or subclass
- An object implementing all or some of InstrumentationManager (TODO)
- A dictionary of callables, implementing all or some of the above (TODO)
- An instance of a ClassManager or subclass
This attribute is consulted by SQLAlchemy instrumentation
resolution, once the :mod:`sqlalchemy.ext.instrumentation` module
has been imported. If custom finders are installed in the global
instrumentation_finders list, they may or may not choose to honor this
attribute.
"""
def find_native_user_instrumentation_hook(cls):
"""Find user-specified instrumentation management for a class."""
return getattr(cls, INSTRUMENTATION_MANAGER, None)
instrumentation_finders = [find_native_user_instrumentation_hook]
"""An extensible sequence of callables which return instrumentation
implementations
When a class is registered, each callable will be passed a class object.
If None is returned, the
next finder in the sequence is consulted. Otherwise the return must be an
instrumentation factory that follows the same guidelines as
sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER.
By default, the only finder is find_native_user_instrumentation_hook, which
searches for INSTRUMENTATION_MANAGER. If all finders return None, standard
ClassManager instrumentation is used.
"""
class ExtendedInstrumentationRegistry(InstrumentationFactory):
"""Extends :class:`.InstrumentationFactory` with additional
bookkeeping, to accommodate multiple types of
class managers.
"""
_manager_finders = weakref.WeakKeyDictionary()
_state_finders = weakref.WeakKeyDictionary()
_dict_finders = weakref.WeakKeyDictionary()
_extended = False
def _locate_extended_factory(self, class_):
for finder in instrumentation_finders:
factory = finder(class_)
if factory is not None:
manager = self._extended_class_manager(class_, factory)
return manager, factory
else:
return None, None
def _check_conflicts(self, class_, factory):
existing_factories = self._collect_management_factories_for(class_).\
difference([factory])
if existing_factories:
raise TypeError(
"multiple instrumentation implementations specified "
"in %s inheritance hierarchy: %r" % (
class_.__name__, list(existing_factories)))
def _extended_class_manager(self, class_, factory):
manager = factory(class_)
if not isinstance(manager, ClassManager):
manager = _ClassInstrumentationAdapter(class_, manager)
if factory != ClassManager and not self._extended:
# somebody invoked a custom ClassManager.
# reinstall global "getter" functions with the more
# expensive ones.
self._extended = True
_install_instrumented_lookups()
self._manager_finders[class_] = manager.manager_getter()
self._state_finders[class_] = manager.state_getter()
self._dict_finders[class_] = manager.dict_getter()
return manager
def _collect_management_factories_for(self, cls):
"""Return a collection of factories in play or specified for a
hierarchy.
Traverses the entire inheritance graph of a cls and returns a
collection of instrumentation factories for those classes. Factories
are extracted from active ClassManagers, if available, otherwise
instrumentation_finders is consulted.
"""
hierarchy = util.class_hierarchy(cls)
factories = set()
for member in hierarchy:
manager = self.manager_of_class(member)
if manager is not None:
factories.add(manager.factory)
else:
for finder in instrumentation_finders:
factory = finder(member)
if factory is not None:
break
else:
factory = None
factories.add(factory)
factories.discard(None)
return factories
def unregister(self, class_):
if class_ in self._manager_finders:
del self._manager_finders[class_]
del self._state_finders[class_]
del self._dict_finders[class_]
super(ExtendedInstrumentationRegistry, self).unregister(class_)
def manager_of_class(self, cls):
if cls is None:
return None
try:
finder = self._manager_finders.get(cls, _default_manager_getter)
except TypeError:
# due to weakref lookup on invalid object
return None
else:
return finder(cls)
def state_of(self, instance):
if instance is None:
raise AttributeError("None has no persistent state.")
return self._state_finders.get(
instance.__class__, _default_state_getter)(instance)
def dict_of(self, instance):
if instance is None:
raise AttributeError("None has no persistent state.")
return self._dict_finders.get(
instance.__class__, _default_dict_getter)(instance)
orm_instrumentation._instrumentation_factory = \
_instrumentation_factory = ExtendedInstrumentationRegistry()
orm_instrumentation.instrumentation_finders = instrumentation_finders
class InstrumentationManager(object):
"""User-defined class instrumentation extension.
:class:`.InstrumentationManager` can be subclassed in order
to change
how class instrumentation proceeds. This class exists for
the purposes of integration with other object management
frameworks which would like to entirely modify the
instrumentation methodology of the ORM, and is not intended
for regular usage. For interception of class instrumentation
events, see :class:`.InstrumentationEvents`.
The API for this class should be considered as semi-stable,
and may change slightly with new releases.
.. versionchanged:: 0.8
:class:`.InstrumentationManager` was moved from
:mod:`sqlalchemy.orm.instrumentation` to
:mod:`sqlalchemy.ext.instrumentation`.
"""
# r4361 added a mandatory (cls) constructor to this interface.
# given that, perhaps class_ should be dropped from all of these
# signatures.
def __init__(self, class_):
pass
def manage(self, class_, manager):
setattr(class_, '_default_class_manager', manager)
def dispose(self, class_, manager):
delattr(class_, '_default_class_manager')
def manager_getter(self, class_):
def get(cls):
return cls._default_class_manager
return get
def instrument_attribute(self, class_, key, inst):
pass
def post_configure_attribute(self, class_, key, inst):
pass
def install_descriptor(self, class_, key, inst):
setattr(class_, key, inst)
def uninstall_descriptor(self, class_, key):
delattr(class_, key)
def install_member(self, class_, key, implementation):
setattr(class_, key, implementation)
def uninstall_member(self, class_, key):
delattr(class_, key)
def instrument_collection_class(self, class_, key, collection_class):
return collections.prepare_instrumentation(collection_class)
def get_instance_dict(self, class_, instance):
return instance.__dict__
def initialize_instance_dict(self, class_, instance):
pass
def install_state(self, class_, instance, state):
setattr(instance, '_default_state', state)
def remove_state(self, class_, instance):
delattr(instance, '_default_state')
def state_getter(self, class_):
return lambda instance: getattr(instance, '_default_state')
def dict_getter(self, class_):
return lambda inst: self.get_instance_dict(class_, inst)
class _ClassInstrumentationAdapter(ClassManager):
"""Adapts a user-defined InstrumentationManager to a ClassManager."""
def __init__(self, class_, override):
self._adapted = override
self._get_state = self._adapted.state_getter(class_)
self._get_dict = self._adapted.dict_getter(class_)
ClassManager.__init__(self, class_)
def manage(self):
self._adapted.manage(self.class_, self)
def dispose(self):
self._adapted.dispose(self.class_)
def manager_getter(self):
return self._adapted.manager_getter(self.class_)
def instrument_attribute(self, key, inst, propagated=False):
ClassManager.instrument_attribute(self, key, inst, propagated)
if not propagated:
self._adapted.instrument_attribute(self.class_, key, inst)
def post_configure_attribute(self, key):
super(_ClassInstrumentationAdapter, self).post_configure_attribute(key)
self._adapted.post_configure_attribute(self.class_, key, self[key])
def install_descriptor(self, key, inst):
self._adapted.install_descriptor(self.class_, key, inst)
def uninstall_descriptor(self, key):
self._adapted.uninstall_descriptor(self.class_, key)
def install_member(self, key, implementation):
self._adapted.install_member(self.class_, key, implementation)
def uninstall_member(self, key):
self._adapted.uninstall_member(self.class_, key)
def instrument_collection_class(self, key, collection_class):
return self._adapted.instrument_collection_class(
self.class_, key, collection_class)
def initialize_collection(self, key, state, factory):
delegate = getattr(self._adapted, 'initialize_collection', None)
if delegate:
return delegate(key, state, factory)
else:
return ClassManager.initialize_collection(self, key,
state, factory)
def new_instance(self, state=None):
instance = self.class_.__new__(self.class_)
self.setup_instance(instance, state)
return instance
def _new_state_if_none(self, instance):
"""Install a default InstanceState if none is present.
A private convenience method used by the __init__ decorator.
"""
if self.has_state(instance):
return False
else:
return self.setup_instance(instance)
def setup_instance(self, instance, state=None):
self._adapted.initialize_instance_dict(self.class_, instance)
if state is None:
state = self._state_constructor(instance, self)
# the given instance is assumed to have no state
self._adapted.install_state(self.class_, instance, state)
return state
def teardown_instance(self, instance):
self._adapted.remove_state(self.class_, instance)
def has_state(self, instance):
try:
self._get_state(instance)
except orm_exc.NO_STATE:
return False
else:
return True
def state_getter(self):
return self._get_state
def dict_getter(self):
return self._get_dict
def _install_instrumented_lookups():
"""Replace global class/object management functions
with ExtendedInstrumentationRegistry implementations, which
allow multiple types of class managers to be present,
at the cost of performance.
This function is called only by ExtendedInstrumentationRegistry
and unit tests specific to this behavior.
The _reinstall_default_lookups() function can be called
after this one to re-establish the default functions.
"""
_install_lookups(
dict(
instance_state=_instrumentation_factory.state_of,
instance_dict=_instrumentation_factory.dict_of,
manager_of_class=_instrumentation_factory.manager_of_class
)
)
def _reinstall_default_lookups():
"""Restore simplified lookups."""
_install_lookups(
dict(
instance_state=_default_state_getter,
instance_dict=_default_dict_getter,
manager_of_class=_default_manager_getter
)
)
_instrumentation_factory._extended = False
def _install_lookups(lookups):
global instance_state, instance_dict, manager_of_class
instance_state = lookups['instance_state']
instance_dict = lookups['instance_dict']
manager_of_class = lookups['manager_of_class']
orm_base.instance_state = attributes.instance_state = \
orm_instrumentation.instance_state = instance_state
orm_base.instance_dict = attributes.instance_dict = \
orm_instrumentation.instance_dict = instance_dict
orm_base.manager_of_class = attributes.manager_of_class = \
orm_instrumentation.manager_of_class = manager_of_class

904
sqlalchemy/ext/mutable.py Normal file
View File

@ -0,0 +1,904 @@
# ext/mutable.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
r"""Provide support for tracking of in-place changes to scalar values,
which are propagated into ORM change events on owning parent objects.
.. versionadded:: 0.7 :mod:`sqlalchemy.ext.mutable` replaces SQLAlchemy's
legacy approach to in-place mutations of scalar values; see
:ref:`07_migration_mutation_extension`.
.. _mutable_scalars:
Establishing Mutability on Scalar Column Values
===============================================
A typical example of a "mutable" structure is a Python dictionary.
Following the example introduced in :ref:`types_toplevel`, we
begin with a custom type that marshals Python dictionaries into
JSON strings before being persisted::
from sqlalchemy.types import TypeDecorator, VARCHAR
import json
class JSONEncodedDict(TypeDecorator):
"Represents an immutable structure as a json-encoded string."
impl = VARCHAR
def process_bind_param(self, value, dialect):
if value is not None:
value = json.dumps(value)
return value
def process_result_value(self, value, dialect):
if value is not None:
value = json.loads(value)
return value
The usage of ``json`` is only for the purposes of example. The
:mod:`sqlalchemy.ext.mutable` extension can be used
with any type whose target Python type may be mutable, including
:class:`.PickleType`, :class:`.postgresql.ARRAY`, etc.
When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself
tracks all parents which reference it. Below, we illustrate a simple
version of the :class:`.MutableDict` dictionary object, which applies
the :class:`.Mutable` mixin to a plain Python dictionary::
from sqlalchemy.ext.mutable import Mutable
class MutableDict(Mutable, dict):
@classmethod
def coerce(cls, key, value):
"Convert plain dictionaries to MutableDict."
if not isinstance(value, MutableDict):
if isinstance(value, dict):
return MutableDict(value)
# this call will raise ValueError
return Mutable.coerce(key, value)
else:
return value
def __setitem__(self, key, value):
"Detect dictionary set events and emit change events."
dict.__setitem__(self, key, value)
self.changed()
def __delitem__(self, key):
"Detect dictionary del events and emit change events."
dict.__delitem__(self, key)
self.changed()
The above dictionary class takes the approach of subclassing the Python
built-in ``dict`` to produce a dict
subclass which routes all mutation events through ``__setitem__``. There are
variants on this approach, such as subclassing ``UserDict.UserDict`` or
``collections.MutableMapping``; the part that's important to this example is
that the :meth:`.Mutable.changed` method is called whenever an in-place
change to the datastructure takes place.
We also redefine the :meth:`.Mutable.coerce` method which will be used to
convert any values that are not instances of ``MutableDict``, such
as the plain dictionaries returned by the ``json`` module, into the
appropriate type. Defining this method is optional; we could just as well
created our ``JSONEncodedDict`` such that it always returns an instance
of ``MutableDict``, and additionally ensured that all calling code
uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not
overridden, any values applied to a parent object which are not instances
of the mutable type will raise a ``ValueError``.
Our new ``MutableDict`` type offers a class method
:meth:`~.Mutable.as_mutable` which we can use within column metadata
to associate with types. This method grabs the given type object or
class and associates a listener that will detect all future mappings
of this type, applying event listening instrumentation to the mapped
attribute. Such as, with classical table metadata::
from sqlalchemy import Table, Column, Integer
my_data = Table('my_data', metadata,
Column('id', Integer, primary_key=True),
Column('data', MutableDict.as_mutable(JSONEncodedDict))
)
Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict``
(if the type object was not an instance already), which will intercept any
attributes which are mapped against this type. Below we establish a simple
mapping against the ``my_data`` table::
from sqlalchemy import mapper
class MyDataClass(object):
pass
# associates mutation listeners with MyDataClass.data
mapper(MyDataClass, my_data)
The ``MyDataClass.data`` member will now be notified of in place changes
to its value.
There's no difference in usage when using declarative::
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class MyDataClass(Base):
__tablename__ = 'my_data'
id = Column(Integer, primary_key=True)
data = Column(MutableDict.as_mutable(JSONEncodedDict))
Any in-place changes to the ``MyDataClass.data`` member
will flag the attribute as "dirty" on the parent object::
>>> from sqlalchemy.orm import Session
>>> sess = Session()
>>> m1 = MyDataClass(data={'value1':'foo'})
>>> sess.add(m1)
>>> sess.commit()
>>> m1.data['value1'] = 'bar'
>>> assert m1 in sess.dirty
True
The ``MutableDict`` can be associated with all future instances
of ``JSONEncodedDict`` in one step, using
:meth:`~.Mutable.associate_with`. This is similar to
:meth:`~.Mutable.as_mutable` except it will intercept all occurrences
of ``MutableDict`` in all mappings unconditionally, without
the need to declare it individually::
MutableDict.associate_with(JSONEncodedDict)
class MyDataClass(Base):
__tablename__ = 'my_data'
id = Column(Integer, primary_key=True)
data = Column(JSONEncodedDict)
Supporting Pickling
--------------------
The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the
placement of a ``weakref.WeakKeyDictionary`` upon the value object, which
stores a mapping of parent mapped objects keyed to the attribute name under
which they are associated with this value. ``WeakKeyDictionary`` objects are
not picklable, due to the fact that they contain weakrefs and function
callbacks. In our case, this is a good thing, since if this dictionary were
picklable, it could lead to an excessively large pickle size for our value
objects that are pickled by themselves outside of the context of the parent.
The developer responsibility here is only to provide a ``__getstate__`` method
that excludes the :meth:`~MutableBase._parents` collection from the pickle
stream::
class MyMutableType(Mutable):
def __getstate__(self):
d = self.__dict__.copy()
d.pop('_parents', None)
return d
With our dictionary example, we need to return the contents of the dict itself
(and also restore them on __setstate__)::
class MutableDict(Mutable, dict):
# ....
def __getstate__(self):
return dict(self)
def __setstate__(self, state):
self.update(state)
In the case that our mutable value object is pickled as it is attached to one
or more parent objects that are also part of the pickle, the :class:`.Mutable`
mixin will re-establish the :attr:`.Mutable._parents` collection on each value
object as the owning parents themselves are unpickled.
.. _mutable_composites:
Establishing Mutability on Composites
=====================================
Composites are a special ORM feature which allow a single scalar attribute to
be assigned an object value which represents information "composed" from one
or more columns from the underlying mapped table. The usual example is that of
a geometric "point", and is introduced in :ref:`mapper_composite`.
.. versionchanged:: 0.7
The internals of :func:`.orm.composite` have been
greatly simplified and in-place mutation detection is no longer enabled by
default; instead, the user-defined value must detect changes on its own and
propagate them to all owning parents. The :mod:`sqlalchemy.ext.mutable`
extension provides the helper class :class:`.MutableComposite`, which is a
slight variant on the :class:`.Mutable` class.
As is the case with :class:`.Mutable`, the user-defined composite class
subclasses :class:`.MutableComposite` as a mixin, and detects and delivers
change events to its parents via the :meth:`.MutableComposite.changed` method.
In the case of a composite class, the detection is usually via the usage of
Python descriptors (i.e. ``@property``), or alternatively via the special
Python method ``__setattr__()``. Below we expand upon the ``Point`` class
introduced in :ref:`mapper_composite` to subclass :class:`.MutableComposite`
and to also route attribute set events via ``__setattr__`` to the
:meth:`.MutableComposite.changed` method::
from sqlalchemy.ext.mutable import MutableComposite
class Point(MutableComposite):
def __init__(self, x, y):
self.x = x
self.y = y
def __setattr__(self, key, value):
"Intercept set events"
# set the attribute
object.__setattr__(self, key, value)
# alert all parents to the change
self.changed()
def __composite_values__(self):
return self.x, self.y
def __eq__(self, other):
return isinstance(other, Point) and \
other.x == self.x and \
other.y == self.y
def __ne__(self, other):
return not self.__eq__(other)
The :class:`.MutableComposite` class uses a Python metaclass to automatically
establish listeners for any usage of :func:`.orm.composite` that specifies our
``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class,
listeners are established which will route change events from ``Point``
objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes::
from sqlalchemy.orm import composite, mapper
from sqlalchemy import Table, Column
vertices = Table('vertices', metadata,
Column('id', Integer, primary_key=True),
Column('x1', Integer),
Column('y1', Integer),
Column('x2', Integer),
Column('y2', Integer),
)
class Vertex(object):
pass
mapper(Vertex, vertices, properties={
'start': composite(Point, vertices.c.x1, vertices.c.y1),
'end': composite(Point, vertices.c.x2, vertices.c.y2)
})
Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members
will flag the attribute as "dirty" on the parent object::
>>> from sqlalchemy.orm import Session
>>> sess = Session()
>>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15))
>>> sess.add(v1)
>>> sess.commit()
>>> v1.end.x = 8
>>> assert v1 in sess.dirty
True
Coercing Mutable Composites
---------------------------
The :meth:`.MutableBase.coerce` method is also supported on composite types.
In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce`
method is only called for attribute set operations, not load operations.
Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent
to using a :func:`.validates` validation routine for all attributes which
make use of the custom composite type::
class Point(MutableComposite):
# other Point methods
# ...
def coerce(cls, key, value):
if isinstance(value, tuple):
value = Point(*value)
elif not isinstance(value, Point):
raise ValueError("tuple or Point expected")
return value
.. versionadded:: 0.7.10,0.8.0b2
Support for the :meth:`.MutableBase.coerce` method in conjunction with
objects of type :class:`.MutableComposite`.
Supporting Pickling
--------------------
As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper
class uses a ``weakref.WeakKeyDictionary`` available via the
:meth:`MutableBase._parents` attribute which isn't picklable. If we need to
pickle instances of ``Point`` or its owning class ``Vertex``, we at least need
to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary.
Below we define both a ``__getstate__`` and a ``__setstate__`` that package up
the minimal form of our ``Point`` class::
class Point(MutableComposite):
# ...
def __getstate__(self):
return self.x, self.y
def __setstate__(self, state):
self.x, self.y = state
As with :class:`.Mutable`, the :class:`.MutableComposite` augments the
pickling process of the parent's object-relational state so that the
:meth:`MutableBase._parents` collection is restored to all ``Point`` objects.
"""
from ..orm.attributes import flag_modified
from .. import event, types
from ..orm import mapper, object_mapper, Mapper
from ..util import memoized_property
from ..sql.base import SchemaEventTarget
import weakref
class MutableBase(object):
"""Common base class to :class:`.Mutable`
and :class:`.MutableComposite`.
"""
@memoized_property
def _parents(self):
"""Dictionary of parent object->attribute name on the parent.
This attribute is a so-called "memoized" property. It initializes
itself with a new ``weakref.WeakKeyDictionary`` the first time
it is accessed, returning the same object upon subsequent access.
"""
return weakref.WeakKeyDictionary()
@classmethod
def coerce(cls, key, value):
"""Given a value, coerce it into the target type.
Can be overridden by custom subclasses to coerce incoming
data into a particular type.
By default, raises ``ValueError``.
This method is called in different scenarios depending on if
the parent class is of type :class:`.Mutable` or of type
:class:`.MutableComposite`. In the case of the former, it is called
for both attribute-set operations as well as during ORM loading
operations. For the latter, it is only called during attribute-set
operations; the mechanics of the :func:`.composite` construct
handle coercion during load operations.
:param key: string name of the ORM-mapped attribute being set.
:param value: the incoming value.
:return: the method should return the coerced value, or raise
``ValueError`` if the coercion cannot be completed.
"""
if value is None:
return None
msg = "Attribute '%s' does not accept objects of type %s"
raise ValueError(msg % (key, type(value)))
@classmethod
def _get_listen_keys(cls, attribute):
"""Given a descriptor attribute, return a ``set()`` of the attribute
keys which indicate a change in the state of this attribute.
This is normally just ``set([attribute.key])``, but can be overridden
to provide for additional keys. E.g. a :class:`.MutableComposite`
augments this set with the attribute keys associated with the columns
that comprise the composite value.
This collection is consulted in the case of intercepting the
:meth:`.InstanceEvents.refresh` and
:meth:`.InstanceEvents.refresh_flush` events, which pass along a list
of attribute names that have been refreshed; the list is compared
against this set to determine if action needs to be taken.
.. versionadded:: 1.0.5
"""
return set([attribute.key])
@classmethod
def _listen_on_attribute(cls, attribute, coerce, parent_cls):
"""Establish this type as a mutation listener for the given
mapped descriptor.
"""
key = attribute.key
if parent_cls is not attribute.class_:
return
# rely on "propagate" here
parent_cls = attribute.class_
listen_keys = cls._get_listen_keys(attribute)
def load(state, *args):
"""Listen for objects loaded or refreshed.
Wrap the target data member's value with
``Mutable``.
"""
val = state.dict.get(key, None)
if val is not None:
if coerce:
val = cls.coerce(key, val)
state.dict[key] = val
val._parents[state.obj()] = key
def load_attrs(state, ctx, attrs):
if not attrs or listen_keys.intersection(attrs):
load(state)
def set(target, value, oldvalue, initiator):
"""Listen for set/replace events on the target
data member.
Establish a weak reference to the parent object
on the incoming value, remove it for the one
outgoing.
"""
if value is oldvalue:
return value
if not isinstance(value, cls):
value = cls.coerce(key, value)
if value is not None:
value._parents[target.obj()] = key
if isinstance(oldvalue, cls):
oldvalue._parents.pop(target.obj(), None)
return value
def pickle(state, state_dict):
val = state.dict.get(key, None)
if val is not None:
if 'ext.mutable.values' not in state_dict:
state_dict['ext.mutable.values'] = []
state_dict['ext.mutable.values'].append(val)
def unpickle(state, state_dict):
if 'ext.mutable.values' in state_dict:
for val in state_dict['ext.mutable.values']:
val._parents[state.obj()] = key
event.listen(parent_cls, 'load', load,
raw=True, propagate=True)
event.listen(parent_cls, 'refresh', load_attrs,
raw=True, propagate=True)
event.listen(parent_cls, 'refresh_flush', load_attrs,
raw=True, propagate=True)
event.listen(attribute, 'set', set,
raw=True, retval=True, propagate=True)
event.listen(parent_cls, 'pickle', pickle,
raw=True, propagate=True)
event.listen(parent_cls, 'unpickle', unpickle,
raw=True, propagate=True)
class Mutable(MutableBase):
"""Mixin that defines transparent propagation of change
events to a parent object.
See the example in :ref:`mutable_scalars` for usage information.
"""
def changed(self):
"""Subclasses should call this method whenever change events occur."""
for parent, key in self._parents.items():
flag_modified(parent, key)
@classmethod
def associate_with_attribute(cls, attribute):
"""Establish this type as a mutation listener for the given
mapped descriptor.
"""
cls._listen_on_attribute(attribute, True, attribute.class_)
@classmethod
def associate_with(cls, sqltype):
"""Associate this wrapper with all future mapped columns
of the given type.
This is a convenience method that calls
``associate_with_attribute`` automatically.
.. warning::
The listeners established by this method are *global*
to all mappers, and are *not* garbage collected. Only use
:meth:`.associate_with` for types that are permanent to an
application, not with ad-hoc types else this will cause unbounded
growth in memory usage.
"""
def listen_for_type(mapper, class_):
for prop in mapper.column_attrs:
if isinstance(prop.columns[0].type, sqltype):
cls.associate_with_attribute(getattr(class_, prop.key))
event.listen(mapper, 'mapper_configured', listen_for_type)
@classmethod
def as_mutable(cls, sqltype):
"""Associate a SQL type with this mutable Python type.
This establishes listeners that will detect ORM mappings against
the given type, adding mutation event trackers to those mappings.
The type is returned, unconditionally as an instance, so that
:meth:`.as_mutable` can be used inline::
Table('mytable', metadata,
Column('id', Integer, primary_key=True),
Column('data', MyMutableType.as_mutable(PickleType))
)
Note that the returned type is always an instance, even if a class
is given, and that only columns which are declared specifically with
that type instance receive additional instrumentation.
To associate a particular mutable type with all occurrences of a
particular type, use the :meth:`.Mutable.associate_with` classmethod
of the particular :class:`.Mutable` subclass to establish a global
association.
.. warning::
The listeners established by this method are *global*
to all mappers, and are *not* garbage collected. Only use
:meth:`.as_mutable` for types that are permanent to an application,
not with ad-hoc types else this will cause unbounded growth
in memory usage.
"""
sqltype = types.to_instance(sqltype)
# a SchemaType will be copied when the Column is copied,
# and we'll lose our ability to link that type back to the original.
# so track our original type w/ columns
if isinstance(sqltype, SchemaEventTarget):
@event.listens_for(sqltype, "before_parent_attach")
def _add_column_memo(sqltyp, parent):
parent.info['_ext_mutable_orig_type'] = sqltyp
schema_event_check = True
else:
schema_event_check = False
def listen_for_type(mapper, class_):
for prop in mapper.column_attrs:
if (
schema_event_check and
hasattr(prop.expression, 'info') and
prop.expression.info.get('_ext_mutable_orig_type')
is sqltype
) or (
prop.columns[0].type is sqltype
):
cls.associate_with_attribute(getattr(class_, prop.key))
event.listen(mapper, 'mapper_configured', listen_for_type)
return sqltype
class MutableComposite(MutableBase):
"""Mixin that defines transparent propagation of change
events on a SQLAlchemy "composite" object to its
owning parent or parents.
See the example in :ref:`mutable_composites` for usage information.
"""
@classmethod
def _get_listen_keys(cls, attribute):
return set([attribute.key]).union(attribute.property._attribute_keys)
def changed(self):
"""Subclasses should call this method whenever change events occur."""
for parent, key in self._parents.items():
prop = object_mapper(parent).get_property(key)
for value, attr_name in zip(
self.__composite_values__(),
prop._attribute_keys):
setattr(parent, attr_name, value)
def _setup_composite_listener():
def _listen_for_type(mapper, class_):
for prop in mapper.iterate_properties:
if (hasattr(prop, 'composite_class') and
isinstance(prop.composite_class, type) and
issubclass(prop.composite_class, MutableComposite)):
prop.composite_class._listen_on_attribute(
getattr(class_, prop.key), False, class_)
if not event.contains(Mapper, "mapper_configured", _listen_for_type):
event.listen(Mapper, 'mapper_configured', _listen_for_type)
_setup_composite_listener()
class MutableDict(Mutable, dict):
"""A dictionary type that implements :class:`.Mutable`.
The :class:`.MutableDict` object implements a dictionary that will
emit change events to the underlying mapping when the contents of
the dictionary are altered, including when values are added or removed.
Note that :class:`.MutableDict` does **not** apply mutable tracking to the
*values themselves* inside the dictionary. Therefore it is not a sufficient
solution for the use case of tracking deep changes to a *recursive*
dictionary structure, such as a JSON structure. To support this use case,
build a subclass of :class:`.MutableDict` that provides appropriate
coersion to the values placed in the dictionary so that they too are
"mutable", and emit events up to their parent structure.
.. versionadded:: 0.8
.. seealso::
:class:`.MutableList`
:class:`.MutableSet`
"""
def __setitem__(self, key, value):
"""Detect dictionary set events and emit change events."""
dict.__setitem__(self, key, value)
self.changed()
def setdefault(self, key, value):
result = dict.setdefault(self, key, value)
self.changed()
return result
def __delitem__(self, key):
"""Detect dictionary del events and emit change events."""
dict.__delitem__(self, key)
self.changed()
def update(self, *a, **kw):
dict.update(self, *a, **kw)
self.changed()
def pop(self, *arg):
result = dict.pop(self, *arg)
self.changed()
return result
def popitem(self):
result = dict.popitem(self)
self.changed()
return result
def clear(self):
dict.clear(self)
self.changed()
@classmethod
def coerce(cls, key, value):
"""Convert plain dictionary to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, dict):
return cls(value)
return Mutable.coerce(key, value)
else:
return value
def __getstate__(self):
return dict(self)
def __setstate__(self, state):
self.update(state)
class MutableList(Mutable, list):
"""A list type that implements :class:`.Mutable`.
The :class:`.MutableList` object implements a list that will
emit change events to the underlying mapping when the contents of
the list are altered, including when values are added or removed.
Note that :class:`.MutableList` does **not** apply mutable tracking to the
*values themselves* inside the list. Therefore it is not a sufficient
solution for the use case of tracking deep changes to a *recursive*
mutable structure, such as a JSON structure. To support this use case,
build a subclass of :class:`.MutableList` that provides appropriate
coersion to the values placed in the dictionary so that they too are
"mutable", and emit events up to their parent structure.
.. versionadded:: 1.1
.. seealso::
:class:`.MutableDict`
:class:`.MutableSet`
"""
def __setitem__(self, index, value):
"""Detect list set events and emit change events."""
list.__setitem__(self, index, value)
self.changed()
def __setslice__(self, start, end, value):
"""Detect list set events and emit change events."""
list.__setslice__(self, start, end, value)
self.changed()
def __delitem__(self, index):
"""Detect list del events and emit change events."""
list.__delitem__(self, index)
self.changed()
def __delslice__(self, start, end):
"""Detect list del events and emit change events."""
list.__delslice__(self, start, end)
self.changed()
def pop(self, *arg):
result = list.pop(self, *arg)
self.changed()
return result
def append(self, x):
list.append(self, x)
self.changed()
def extend(self, x):
list.extend(self, x)
self.changed()
def insert(self, i, x):
list.insert(self, i, x)
self.changed()
def remove(self, i):
list.remove(self, i)
self.changed()
def clear(self):
list.clear(self)
self.changed()
def sort(self):
list.sort(self)
self.changed()
def reverse(self):
list.reverse(self)
self.changed()
@classmethod
def coerce(cls, index, value):
"""Convert plain list to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, list):
return cls(value)
return Mutable.coerce(index, value)
else:
return value
def __getstate__(self):
return list(self)
def __setstate__(self, state):
self[:] = state
class MutableSet(Mutable, set):
"""A set type that implements :class:`.Mutable`.
The :class:`.MutableSet` object implements a set that will
emit change events to the underlying mapping when the contents of
the set are altered, including when values are added or removed.
Note that :class:`.MutableSet` does **not** apply mutable tracking to the
*values themselves* inside the set. Therefore it is not a sufficient
solution for the use case of tracking deep changes to a *recursive*
mutable structure. To support this use case,
build a subclass of :class:`.MutableSet` that provides appropriate
coersion to the values placed in the dictionary so that they too are
"mutable", and emit events up to their parent structure.
.. versionadded:: 1.1
.. seealso::
:class:`.MutableDict`
:class:`.MutableList`
"""
def update(self, *arg):
set.update(self, *arg)
self.changed()
def intersection_update(self, *arg):
set.intersection_update(self, *arg)
self.changed()
def difference_update(self, *arg):
set.difference_update(self, *arg)
self.changed()
def symmetric_difference_update(self, *arg):
set.symmetric_difference_update(self, *arg)
self.changed()
def add(self, elem):
set.add(self, elem)
self.changed()
def remove(self, elem):
set.remove(self, elem)
self.changed()
def discard(self, elem):
set.discard(self, elem)
self.changed()
def pop(self, *arg):
result = set.pop(self, *arg)
self.changed()
return result
def clear(self):
set.clear(self)
self.changed()
@classmethod
def coerce(cls, index, value):
"""Convert plain set to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, set):
return cls(value)
return Mutable.coerce(index, value)
else:
return value
def __getstate__(self):
return set(self)
def __setstate__(self, state):
self.update(state)
def __reduce_ex__(self, proto):
return (self.__class__, (list(self), ))

93
sqlalchemy/inspection.py Normal file
View File

@ -0,0 +1,93 @@
# sqlalchemy/inspect.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""The inspection module provides the :func:`.inspect` function,
which delivers runtime information about a wide variety
of SQLAlchemy objects, both within the Core as well as the
ORM.
The :func:`.inspect` function is the entry point to SQLAlchemy's
public API for viewing the configuration and construction
of in-memory objects. Depending on the type of object
passed to :func:`.inspect`, the return value will either be
a related object which provides a known interface, or in many
cases it will return the object itself.
The rationale for :func:`.inspect` is twofold. One is that
it replaces the need to be aware of a large variety of "information
getting" functions in SQLAlchemy, such as :meth:`.Inspector.from_engine`,
:func:`.orm.attributes.instance_state`, :func:`.orm.class_mapper`,
and others. The other is that the return value of :func:`.inspect`
is guaranteed to obey a documented API, thus allowing third party
tools which build on top of SQLAlchemy configurations to be constructed
in a forwards-compatible way.
.. versionadded:: 0.8 The :func:`.inspect` system is introduced
as of version 0.8.
"""
from . import util, exc
_registrars = util.defaultdict(list)
def inspect(subject, raiseerr=True):
"""Produce an inspection object for the given target.
The returned value in some cases may be the
same object as the one given, such as if a
:class:`.Mapper` object is passed. In other
cases, it will be an instance of the registered
inspection type for the given object, such as
if an :class:`.engine.Engine` is passed, an
:class:`.Inspector` object is returned.
:param subject: the subject to be inspected.
:param raiseerr: When ``True``, if the given subject
does not
correspond to a known SQLAlchemy inspected type,
:class:`sqlalchemy.exc.NoInspectionAvailable`
is raised. If ``False``, ``None`` is returned.
"""
type_ = type(subject)
for cls in type_.__mro__:
if cls in _registrars:
reg = _registrars[cls]
if reg is True:
return subject
ret = reg(subject)
if ret is not None:
break
else:
reg = ret = None
if raiseerr and (
reg is None or ret is None
):
raise exc.NoInspectionAvailable(
"No inspection system is "
"available for object of type %s" %
type_)
return ret
def _inspects(*types):
def decorate(fn_or_cls):
for type_ in types:
if type_ in _registrars:
raise AssertionError(
"Type %s is already "
"registered" % type_)
_registrars[type_] = fn_or_cls
return fn_or_cls
return decorate
def _self_inspects(cls):
_inspects(cls)(True)
return cls

540
sqlalchemy/orm/base.py Normal file
View File

@ -0,0 +1,540 @@
# orm/base.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Constants and rudimental functions used throughout the ORM.
"""
from .. import util, inspection, exc as sa_exc
from ..sql import expression
from . import exc
import operator
PASSIVE_NO_RESULT = util.symbol(
'PASSIVE_NO_RESULT',
"""Symbol returned by a loader callable or other attribute/history
retrieval operation when a value could not be determined, based
on loader callable flags.
"""
)
ATTR_WAS_SET = util.symbol(
'ATTR_WAS_SET',
"""Symbol returned by a loader callable to indicate the
retrieved value, or values, were assigned to their attributes
on the target object.
"""
)
ATTR_EMPTY = util.symbol(
'ATTR_EMPTY',
"""Symbol used internally to indicate an attribute had no callable."""
)
NO_VALUE = util.symbol(
'NO_VALUE',
"""Symbol which may be placed as the 'previous' value of an attribute,
indicating no value was loaded for an attribute when it was modified,
and flags indicated we were not to load it.
"""
)
NEVER_SET = util.symbol(
'NEVER_SET',
"""Symbol which may be placed as the 'previous' value of an attribute
indicating that the attribute had not been assigned to previously.
"""
)
NO_CHANGE = util.symbol(
"NO_CHANGE",
"""No callables or SQL should be emitted on attribute access
and no state should change
""", canonical=0
)
CALLABLES_OK = util.symbol(
"CALLABLES_OK",
"""Loader callables can be fired off if a value
is not present.
""", canonical=1
)
SQL_OK = util.symbol(
"SQL_OK",
"""Loader callables can emit SQL at least on scalar value attributes.""",
canonical=2
)
RELATED_OBJECT_OK = util.symbol(
"RELATED_OBJECT_OK",
"""Callables can use SQL to load related objects as well
as scalar value attributes.
""", canonical=4
)
INIT_OK = util.symbol(
"INIT_OK",
"""Attributes should be initialized with a blank
value (None or an empty collection) upon get, if no other
value can be obtained.
""", canonical=8
)
NON_PERSISTENT_OK = util.symbol(
"NON_PERSISTENT_OK",
"""Callables can be emitted if the parent is not persistent.""",
canonical=16
)
LOAD_AGAINST_COMMITTED = util.symbol(
"LOAD_AGAINST_COMMITTED",
"""Callables should use committed values as primary/foreign keys during a
load.
""", canonical=32
)
NO_AUTOFLUSH = util.symbol(
"NO_AUTOFLUSH",
"""Loader callables should disable autoflush.""",
canonical=64
)
# pre-packaged sets of flags used as inputs
PASSIVE_OFF = util.symbol(
"PASSIVE_OFF",
"Callables can be emitted in all cases.",
canonical=(RELATED_OBJECT_OK | NON_PERSISTENT_OK |
INIT_OK | CALLABLES_OK | SQL_OK)
)
PASSIVE_RETURN_NEVER_SET = util.symbol(
"PASSIVE_RETURN_NEVER_SET",
"""PASSIVE_OFF ^ INIT_OK""",
canonical=PASSIVE_OFF ^ INIT_OK
)
PASSIVE_NO_INITIALIZE = util.symbol(
"PASSIVE_NO_INITIALIZE",
"PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK",
canonical=PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK
)
PASSIVE_NO_FETCH = util.symbol(
"PASSIVE_NO_FETCH",
"PASSIVE_OFF ^ SQL_OK",
canonical=PASSIVE_OFF ^ SQL_OK
)
PASSIVE_NO_FETCH_RELATED = util.symbol(
"PASSIVE_NO_FETCH_RELATED",
"PASSIVE_OFF ^ RELATED_OBJECT_OK",
canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK
)
PASSIVE_ONLY_PERSISTENT = util.symbol(
"PASSIVE_ONLY_PERSISTENT",
"PASSIVE_OFF ^ NON_PERSISTENT_OK",
canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK
)
DEFAULT_MANAGER_ATTR = '_sa_class_manager'
DEFAULT_STATE_ATTR = '_sa_instance_state'
_INSTRUMENTOR = ('mapper', 'instrumentor')
EXT_CONTINUE = util.symbol('EXT_CONTINUE')
EXT_STOP = util.symbol('EXT_STOP')
ONETOMANY = util.symbol(
'ONETOMANY',
"""Indicates the one-to-many direction for a :func:`.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
""")
MANYTOONE = util.symbol(
'MANYTOONE',
"""Indicates the many-to-one direction for a :func:`.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
""")
MANYTOMANY = util.symbol(
'MANYTOMANY',
"""Indicates the many-to-many direction for a :func:`.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
""")
NOT_EXTENSION = util.symbol(
'NOT_EXTENSION',
"""Symbol indicating an :class:`InspectionAttr` that's
not part of sqlalchemy.ext.
Is assigned to the :attr:`.InspectionAttr.extension_type`
attibute.
""")
_never_set = frozenset([NEVER_SET])
_none_set = frozenset([None, NEVER_SET, PASSIVE_NO_RESULT])
_SET_DEFERRED_EXPIRED = util.symbol("SET_DEFERRED_EXPIRED")
_DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE")
def _generative(*assertions):
"""Mark a method as generative, e.g. method-chained."""
@util.decorator
def generate(fn, *args, **kw):
self = args[0]._clone()
for assertion in assertions:
assertion(self, fn.__name__)
fn(self, *args[1:], **kw)
return self
return generate
# these can be replaced by sqlalchemy.ext.instrumentation
# if augmented class instrumentation is enabled.
def manager_of_class(cls):
return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None)
instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
instance_dict = operator.attrgetter('__dict__')
def instance_str(instance):
"""Return a string describing an instance."""
return state_str(instance_state(instance))
def state_str(state):
"""Return a string describing an instance via its InstanceState."""
if state is None:
return "None"
else:
return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj()))
def state_class_str(state):
"""Return a string describing an instance's class via its
InstanceState.
"""
if state is None:
return "None"
else:
return '<%s>' % (state.class_.__name__, )
def attribute_str(instance, attribute):
return instance_str(instance) + "." + attribute
def state_attribute_str(state, attribute):
return state_str(state) + "." + attribute
def object_mapper(instance):
"""Given an object, return the primary Mapper associated with the object
instance.
Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError`
if no mapping is configured.
This function is available via the inspection system as::
inspect(instance).mapper
Using the inspection system will raise
:class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is
not part of a mapping.
"""
return object_state(instance).mapper
def object_state(instance):
"""Given an object, return the :class:`.InstanceState`
associated with the object.
Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError`
if no mapping is configured.
Equivalent functionality is available via the :func:`.inspect`
function as::
inspect(instance)
Using the inspection system will raise
:class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is
not part of a mapping.
"""
state = _inspect_mapped_object(instance)
if state is None:
raise exc.UnmappedInstanceError(instance)
else:
return state
@inspection._inspects(object)
def _inspect_mapped_object(instance):
try:
return instance_state(instance)
# TODO: whats the py-2/3 syntax to catch two
# different kinds of exceptions at once ?
except exc.UnmappedClassError:
return None
except exc.NO_STATE:
return None
def _class_to_mapper(class_or_mapper):
insp = inspection.inspect(class_or_mapper, False)
if insp is not None:
return insp.mapper
else:
raise exc.UnmappedClassError(class_or_mapper)
def _mapper_or_none(entity):
"""Return the :class:`.Mapper` for the given class or None if the
class is not mapped.
"""
insp = inspection.inspect(entity, False)
if insp is not None:
return insp.mapper
else:
return None
def _is_mapped_class(entity):
"""Return True if the given object is a mapped class,
:class:`.Mapper`, or :class:`.AliasedClass`.
"""
insp = inspection.inspect(entity, False)
return insp is not None and \
not insp.is_clause_element and \
(
insp.is_mapper or insp.is_aliased_class
)
def _attr_as_key(attr):
if hasattr(attr, 'key'):
return attr.key
else:
return expression._column_as_key(attr)
def _orm_columns(entity):
insp = inspection.inspect(entity, False)
if hasattr(insp, 'selectable') and hasattr(insp.selectable, 'c'):
return [c for c in insp.selectable.c]
else:
return [entity]
def _is_aliased_class(entity):
insp = inspection.inspect(entity, False)
return insp is not None and \
getattr(insp, "is_aliased_class", False)
def _entity_descriptor(entity, key):
"""Return a class attribute given an entity and string name.
May return :class:`.InstrumentedAttribute` or user-defined
attribute.
"""
insp = inspection.inspect(entity)
if insp.is_selectable:
description = entity
entity = insp.c
elif insp.is_aliased_class:
entity = insp.entity
description = entity
elif hasattr(insp, "mapper"):
description = entity = insp.mapper.class_
else:
description = entity
try:
return getattr(entity, key)
except AttributeError:
raise sa_exc.InvalidRequestError(
"Entity '%s' has no property '%s'" %
(description, key)
)
_state_mapper = util.dottedgetter('manager.mapper')
@inspection._inspects(type)
def _inspect_mapped_class(class_, configure=False):
try:
class_manager = manager_of_class(class_)
if not class_manager.is_mapped:
return None
mapper = class_manager.mapper
except exc.NO_STATE:
return None
else:
if configure and mapper._new_mappers:
mapper._configure_all()
return mapper
def class_mapper(class_, configure=True):
"""Given a class, return the primary :class:`.Mapper` associated
with the key.
Raises :exc:`.UnmappedClassError` if no mapping is configured
on the given class, or :exc:`.ArgumentError` if a non-class
object is passed.
Equivalent functionality is available via the :func:`.inspect`
function as::
inspect(some_mapped_class)
Using the inspection system will raise
:class:`sqlalchemy.exc.NoInspectionAvailable` if the class is not mapped.
"""
mapper = _inspect_mapped_class(class_, configure=configure)
if mapper is None:
if not isinstance(class_, type):
raise sa_exc.ArgumentError(
"Class object expected, got '%r'." % (class_, ))
raise exc.UnmappedClassError(class_)
else:
return mapper
class InspectionAttr(object):
"""A base class applied to all ORM objects that can be returned
by the :func:`.inspect` function.
The attributes defined here allow the usage of simple boolean
checks to test basic facts about the object returned.
While the boolean checks here are basically the same as using
the Python isinstance() function, the flags here can be used without
the need to import all of these classes, and also such that
the SQLAlchemy class system can change while leaving the flags
here intact for forwards-compatibility.
"""
__slots__ = ()
is_selectable = False
"""Return True if this object is an instance of :class:`.Selectable`."""
is_aliased_class = False
"""True if this object is an instance of :class:`.AliasedClass`."""
is_instance = False
"""True if this object is an instance of :class:`.InstanceState`."""
is_mapper = False
"""True if this object is an instance of :class:`.Mapper`."""
is_property = False
"""True if this object is an instance of :class:`.MapperProperty`."""
is_attribute = False
"""True if this object is a Python :term:`descriptor`.
This can refer to one of many types. Usually a
:class:`.QueryableAttribute` which handles attributes events on behalf
of a :class:`.MapperProperty`. But can also be an extension type
such as :class:`.AssociationProxy` or :class:`.hybrid_property`.
The :attr:`.InspectionAttr.extension_type` will refer to a constant
identifying the specific subtype.
.. seealso::
:attr:`.Mapper.all_orm_descriptors`
"""
is_clause_element = False
"""True if this object is an instance of :class:`.ClauseElement`."""
extension_type = NOT_EXTENSION
"""The extension type, if any.
Defaults to :data:`.interfaces.NOT_EXTENSION`
.. versionadded:: 0.8.0
.. seealso::
:data:`.HYBRID_METHOD`
:data:`.HYBRID_PROPERTY`
:data:`.ASSOCIATION_PROXY`
"""
class InspectionAttrInfo(InspectionAttr):
"""Adds the ``.info`` attribute to :class:`.InspectionAttr`.
The rationale for :class:`.InspectionAttr` vs. :class:`.InspectionAttrInfo`
is that the former is compatible as a mixin for classes that specify
``__slots__``; this is essentially an implementation artifact.
"""
@util.memoized_property
def info(self):
"""Info dictionary associated with the object, allowing user-defined
data to be associated with this :class:`.InspectionAttr`.
The dictionary is generated when first accessed. Alternatively,
it can be specified as a constructor argument to the
:func:`.column_property`, :func:`.relationship`, or :func:`.composite`
functions.
.. versionadded:: 0.8 Added support for .info to all
:class:`.MapperProperty` subclasses.
.. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also
available on extension types via the
:attr:`.InspectionAttrInfo.info` attribute, so that it can apply
to a wider variety of ORM and extension constructs.
.. seealso::
:attr:`.QueryableAttribute.info`
:attr:`.SchemaItem.info`
"""
return {}
class _MappedAttribute(object):
"""Mixin for attributes which should be replaced by mapper-assigned
attributes.
"""
__slots__ = ()

View File

@ -0,0 +1,487 @@
# orm/deprecated_interfaces.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .. import event, util
from .interfaces import EXT_CONTINUE
@util.langhelpers.dependency_for("sqlalchemy.orm.interfaces")
class MapperExtension(object):
"""Base implementation for :class:`.Mapper` event hooks.
.. note::
:class:`.MapperExtension` is deprecated. Please
refer to :func:`.event.listen` as well as
:class:`.MapperEvents`.
New extension classes subclass :class:`.MapperExtension` and are specified
using the ``extension`` mapper() argument, which is a single
:class:`.MapperExtension` or a list of such::
from sqlalchemy.orm.interfaces import MapperExtension
class MyExtension(MapperExtension):
def before_insert(self, mapper, connection, instance):
print "instance %s before insert !" % instance
m = mapper(User, users_table, extension=MyExtension())
A single mapper can maintain a chain of ``MapperExtension``
objects. When a particular mapping event occurs, the
corresponding method on each ``MapperExtension`` is invoked
serially, and each method has the ability to halt the chain
from proceeding further::
m = mapper(User, users_table, extension=[ext1, ext2, ext3])
Each ``MapperExtension`` method returns the symbol
EXT_CONTINUE by default. This symbol generally means "move
to the next ``MapperExtension`` for processing". For methods
that return objects like translated rows or new object
instances, EXT_CONTINUE means the result of the method
should be ignored. In some cases it's required for a
default mapper activity to be performed, such as adding a
new instance to a result list.
The symbol EXT_STOP has significance within a chain
of ``MapperExtension`` objects that the chain will be stopped
when this symbol is returned. Like EXT_CONTINUE, it also
has additional significance in some cases that a default
mapper activity will not be performed.
"""
@classmethod
def _adapt_instrument_class(cls, self, listener):
cls._adapt_listener_methods(self, listener, ('instrument_class',))
@classmethod
def _adapt_listener(cls, self, listener):
cls._adapt_listener_methods(
self, listener,
(
'init_instance',
'init_failed',
'reconstruct_instance',
'before_insert',
'after_insert',
'before_update',
'after_update',
'before_delete',
'after_delete'
))
@classmethod
def _adapt_listener_methods(cls, self, listener, methods):
for meth in methods:
me_meth = getattr(MapperExtension, meth)
ls_meth = getattr(listener, meth)
if not util.methods_equivalent(me_meth, ls_meth):
if meth == 'reconstruct_instance':
def go(ls_meth):
def reconstruct(instance, ctx):
ls_meth(self, instance)
return reconstruct
event.listen(self.class_manager, 'load',
go(ls_meth), raw=False, propagate=True)
elif meth == 'init_instance':
def go(ls_meth):
def init_instance(instance, args, kwargs):
ls_meth(self, self.class_,
self.class_manager.original_init,
instance, args, kwargs)
return init_instance
event.listen(self.class_manager, 'init',
go(ls_meth), raw=False, propagate=True)
elif meth == 'init_failed':
def go(ls_meth):
def init_failed(instance, args, kwargs):
util.warn_exception(
ls_meth, self, self.class_,
self.class_manager.original_init,
instance, args, kwargs)
return init_failed
event.listen(self.class_manager, 'init_failure',
go(ls_meth), raw=False, propagate=True)
else:
event.listen(self, "%s" % meth, ls_meth,
raw=False, retval=True, propagate=True)
def instrument_class(self, mapper, class_):
"""Receive a class when the mapper is first constructed, and has
applied instrumentation to the mapped class.
The return value is only significant within the ``MapperExtension``
chain; the parent mapper's behavior isn't modified by this method.
"""
return EXT_CONTINUE
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
"""Receive an instance when its constructor is called.
This method is only called during a userland construction of
an object. It is not called when an object is loaded from the
database.
The return value is only significant within the ``MapperExtension``
chain; the parent mapper's behavior isn't modified by this method.
"""
return EXT_CONTINUE
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
"""Receive an instance when its constructor has been called,
and raised an exception.
This method is only called during a userland construction of
an object. It is not called when an object is loaded from the
database.
The return value is only significant within the ``MapperExtension``
chain; the parent mapper's behavior isn't modified by this method.
"""
return EXT_CONTINUE
def reconstruct_instance(self, mapper, instance):
"""Receive an object instance after it has been created via
``__new__``, and after initial attribute population has
occurred.
This typically occurs when the instance is created based on
incoming result rows, and is only called once for that
instance's lifetime.
Note that during a result-row load, this method is called upon
the first row received for this instance. Note that some
attributes and collections may or may not be loaded or even
initialized, depending on what's present in the result rows.
The return value is only significant within the ``MapperExtension``
chain; the parent mapper's behavior isn't modified by this method.
"""
return EXT_CONTINUE
def before_insert(self, mapper, connection, instance):
"""Receive an object instance before that instance is inserted
into its table.
This is a good place to set up primary key values and such
that aren't handled otherwise.
Column-based attributes can be modified within this method
which will result in the new value being inserted. However
*no* changes to the overall flush plan can be made, and
manipulation of the ``Session`` will not have the desired effect.
To manipulate the ``Session`` within an extension, use
``SessionExtension``.
The return value is only significant within the ``MapperExtension``
chain; the parent mapper's behavior isn't modified by this method.
"""
return EXT_CONTINUE
def after_insert(self, mapper, connection, instance):
"""Receive an object instance after that instance is inserted.
The return value is only significant within the ``MapperExtension``
chain; the parent mapper's behavior isn't modified by this method.
"""
return EXT_CONTINUE
def before_update(self, mapper, connection, instance):
"""Receive an object instance before that instance is updated.
Note that this method is called for all instances that are marked as
"dirty", even those which have no net changes to their column-based
attributes. An object is marked as dirty when any of its column-based
attributes have a "set attribute" operation called or when any of its
collections are modified. If, at update time, no column-based
attributes have any net changes, no UPDATE statement will be issued.
This means that an instance being sent to before_update is *not* a
guarantee that an UPDATE statement will be issued (although you can
affect the outcome here).
To detect if the column-based attributes on the object have net
changes, and will therefore generate an UPDATE statement, use
``object_session(instance).is_modified(instance,
include_collections=False)``.
Column-based attributes can be modified within this method
which will result in the new value being updated. However
*no* changes to the overall flush plan can be made, and
manipulation of the ``Session`` will not have the desired effect.
To manipulate the ``Session`` within an extension, use
``SessionExtension``.
The return value is only significant within the ``MapperExtension``
chain; the parent mapper's behavior isn't modified by this method.
"""
return EXT_CONTINUE
def after_update(self, mapper, connection, instance):
"""Receive an object instance after that instance is updated.
The return value is only significant within the ``MapperExtension``
chain; the parent mapper's behavior isn't modified by this method.
"""
return EXT_CONTINUE
def before_delete(self, mapper, connection, instance):
"""Receive an object instance before that instance is deleted.
Note that *no* changes to the overall flush plan can be made
here; and manipulation of the ``Session`` will not have the
desired effect. To manipulate the ``Session`` within an
extension, use ``SessionExtension``.
The return value is only significant within the ``MapperExtension``
chain; the parent mapper's behavior isn't modified by this method.
"""
return EXT_CONTINUE
def after_delete(self, mapper, connection, instance):
"""Receive an object instance after that instance is deleted.
The return value is only significant within the ``MapperExtension``
chain; the parent mapper's behavior isn't modified by this method.
"""
return EXT_CONTINUE
@util.langhelpers.dependency_for("sqlalchemy.orm.interfaces")
class SessionExtension(object):
"""Base implementation for :class:`.Session` event hooks.
.. note::
:class:`.SessionExtension` is deprecated. Please
refer to :func:`.event.listen` as well as
:class:`.SessionEvents`.
Subclasses may be installed into a :class:`.Session` (or
:class:`.sessionmaker`) using the ``extension`` keyword
argument::
from sqlalchemy.orm.interfaces import SessionExtension
class MySessionExtension(SessionExtension):
def before_commit(self, session):
print "before commit!"
Session = sessionmaker(extension=MySessionExtension())
The same :class:`.SessionExtension` instance can be used
with any number of sessions.
"""
@classmethod
def _adapt_listener(cls, self, listener):
for meth in [
'before_commit',
'after_commit',
'after_rollback',
'before_flush',
'after_flush',
'after_flush_postexec',
'after_begin',
'after_attach',
'after_bulk_update',
'after_bulk_delete',
]:
me_meth = getattr(SessionExtension, meth)
ls_meth = getattr(listener, meth)
if not util.methods_equivalent(me_meth, ls_meth):
event.listen(self, meth, getattr(listener, meth))
def before_commit(self, session):
"""Execute right before commit is called.
Note that this may not be per-flush if a longer running
transaction is ongoing."""
def after_commit(self, session):
"""Execute after a commit has occurred.
Note that this may not be per-flush if a longer running
transaction is ongoing."""
def after_rollback(self, session):
"""Execute after a rollback has occurred.
Note that this may not be per-flush if a longer running
transaction is ongoing."""
def before_flush(self, session, flush_context, instances):
"""Execute before flush process has started.
`instances` is an optional list of objects which were passed to
the ``flush()`` method. """
def after_flush(self, session, flush_context):
"""Execute after flush has completed, but before commit has been
called.
Note that the session's state is still in pre-flush, i.e. 'new',
'dirty', and 'deleted' lists still show pre-flush state as well
as the history settings on instance attributes."""
def after_flush_postexec(self, session, flush_context):
"""Execute after flush has completed, and after the post-exec
state occurs.
This will be when the 'new', 'dirty', and 'deleted' lists are in
their final state. An actual commit() may or may not have
occurred, depending on whether or not the flush started its own
transaction or participated in a larger transaction. """
def after_begin(self, session, transaction, connection):
"""Execute after a transaction is begun on a connection
`transaction` is the SessionTransaction. This method is called
after an engine level transaction is begun on a connection. """
def after_attach(self, session, instance):
"""Execute after an instance is attached to a session.
This is called after an add, delete or merge. """
def after_bulk_update(self, session, query, query_context, result):
"""Execute after a bulk update operation to the session.
This is called after a session.query(...).update()
`query` is the query object that this update operation was
called on. `query_context` was the query context object.
`result` is the result object returned from the bulk operation.
"""
def after_bulk_delete(self, session, query, query_context, result):
"""Execute after a bulk delete operation to the session.
This is called after a session.query(...).delete()
`query` is the query object that this delete operation was
called on. `query_context` was the query context object.
`result` is the result object returned from the bulk operation.
"""
@util.langhelpers.dependency_for("sqlalchemy.orm.interfaces")
class AttributeExtension(object):
"""Base implementation for :class:`.AttributeImpl` event hooks, events
that fire upon attribute mutations in user code.
.. note::
:class:`.AttributeExtension` is deprecated. Please
refer to :func:`.event.listen` as well as
:class:`.AttributeEvents`.
:class:`.AttributeExtension` is used to listen for set,
remove, and append events on individual mapped attributes.
It is established on an individual mapped attribute using
the `extension` argument, available on
:func:`.column_property`, :func:`.relationship`, and
others::
from sqlalchemy.orm.interfaces import AttributeExtension
from sqlalchemy.orm import mapper, relationship, column_property
class MyAttrExt(AttributeExtension):
def append(self, state, value, initiator):
print "append event !"
return value
def set(self, state, value, oldvalue, initiator):
print "set event !"
return value
mapper(SomeClass, sometable, properties={
'foo':column_property(sometable.c.foo, extension=MyAttrExt()),
'bar':relationship(Bar, extension=MyAttrExt())
})
Note that the :class:`.AttributeExtension` methods
:meth:`~.AttributeExtension.append` and
:meth:`~.AttributeExtension.set` need to return the
``value`` parameter. The returned value is used as the
effective value, and allows the extension to change what is
ultimately persisted.
AttributeExtension is assembled within the descriptors associated
with a mapped class.
"""
active_history = True
"""indicates that the set() method would like to receive the 'old' value,
even if it means firing lazy callables.
Note that ``active_history`` can also be set directly via
:func:`.column_property` and :func:`.relationship`.
"""
@classmethod
def _adapt_listener(cls, self, listener):
event.listen(self, 'append', listener.append,
active_history=listener.active_history,
raw=True, retval=True)
event.listen(self, 'remove', listener.remove,
active_history=listener.active_history,
raw=True, retval=True)
event.listen(self, 'set', listener.set,
active_history=listener.active_history,
raw=True, retval=True)
def append(self, state, value, initiator):
"""Receive a collection append event.
The returned value will be used as the actual value to be
appended.
"""
return value
def remove(self, state, value, initiator):
"""Receive a remove event.
No return value is defined.
"""
pass
def set(self, state, value, oldvalue, initiator):
"""Receive a set event.
The returned value will be used as the actual value to be
set.
"""
return value

View File

@ -0,0 +1,699 @@
# orm/descriptor_props.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Descriptor properties are more "auxiliary" properties
that exist as configurational elements, but don't participate
as actively in the load/persist ORM loop.
"""
from .interfaces import MapperProperty, PropComparator
from .util import _none_set
from . import attributes
from .. import util, sql, exc as sa_exc, event, schema
from ..sql import expression
from . import properties
from . import query
class DescriptorProperty(MapperProperty):
""":class:`.MapperProperty` which proxies access to a
user-defined descriptor."""
doc = None
def instrument_class(self, mapper):
prop = self
class _ProxyImpl(object):
accepts_scalar_loader = False
expire_missing = True
collection = False
def __init__(self, key):
self.key = key
if hasattr(prop, 'get_history'):
def get_history(self, state, dict_,
passive=attributes.PASSIVE_OFF):
return prop.get_history(state, dict_, passive)
if self.descriptor is None:
desc = getattr(mapper.class_, self.key, None)
if mapper._is_userland_descriptor(desc):
self.descriptor = desc
if self.descriptor is None:
def fset(obj, value):
setattr(obj, self.name, value)
def fdel(obj):
delattr(obj, self.name)
def fget(obj):
return getattr(obj, self.name)
self.descriptor = property(
fget=fget,
fset=fset,
fdel=fdel,
)
proxy_attr = attributes.create_proxied_attribute(
self.descriptor)(
self.parent.class_,
self.key,
self.descriptor,
lambda: self._comparator_factory(mapper),
doc=self.doc,
original_property=self
)
proxy_attr.impl = _ProxyImpl(self.key)
mapper.class_manager.instrument_attribute(self.key, proxy_attr)
@util.langhelpers.dependency_for("sqlalchemy.orm.properties")
class CompositeProperty(DescriptorProperty):
"""Defines a "composite" mapped attribute, representing a collection
of columns as one attribute.
:class:`.CompositeProperty` is constructed using the :func:`.composite`
function.
.. seealso::
:ref:`mapper_composite`
"""
def __init__(self, class_, *attrs, **kwargs):
r"""Return a composite column-based property for use with a Mapper.
See the mapping documentation section :ref:`mapper_composite` for a
full usage example.
The :class:`.MapperProperty` returned by :func:`.composite`
is the :class:`.CompositeProperty`.
:param class\_:
The "composite type" class.
:param \*cols:
List of Column objects to be mapped.
:param active_history=False:
When ``True``, indicates that the "previous" value for a
scalar attribute should be loaded when replaced, if not
already loaded. See the same flag on :func:`.column_property`.
.. versionchanged:: 0.7
This flag specifically becomes meaningful
- previously it was a placeholder.
:param group:
A group name for this property when marked as deferred.
:param deferred:
When True, the column property is "deferred", meaning that it does
not load immediately, and is instead loaded when the attribute is
first accessed on an instance. See also
:func:`~sqlalchemy.orm.deferred`.
:param comparator_factory: a class which extends
:class:`.CompositeProperty.Comparator` which provides custom SQL
clause generation for comparison operations.
:param doc:
optional string that will be applied as the doc on the
class-bound descriptor.
:param info: Optional data dictionary which will be populated into the
:attr:`.MapperProperty.info` attribute of this object.
.. versionadded:: 0.8
:param extension:
an :class:`.AttributeExtension` instance,
or list of extensions, which will be prepended to the list of
attribute listeners for the resulting descriptor placed on the
class. **Deprecated.** Please see :class:`.AttributeEvents`.
"""
super(CompositeProperty, self).__init__()
self.attrs = attrs
self.composite_class = class_
self.active_history = kwargs.get('active_history', False)
self.deferred = kwargs.get('deferred', False)
self.group = kwargs.get('group', None)
self.comparator_factory = kwargs.pop('comparator_factory',
self.__class__.Comparator)
if 'info' in kwargs:
self.info = kwargs.pop('info')
util.set_creation_order(self)
self._create_descriptor()
def instrument_class(self, mapper):
super(CompositeProperty, self).instrument_class(mapper)
self._setup_event_handlers()
def do_init(self):
"""Initialization which occurs after the :class:`.CompositeProperty`
has been associated with its parent mapper.
"""
self._setup_arguments_on_columns()
def _create_descriptor(self):
"""Create the Python descriptor that will serve as
the access point on instances of the mapped class.
"""
def fget(instance):
dict_ = attributes.instance_dict(instance)
state = attributes.instance_state(instance)
if self.key not in dict_:
# key not present. Iterate through related
# attributes, retrieve their values. This
# ensures they all load.
values = [
getattr(instance, key)
for key in self._attribute_keys
]
# current expected behavior here is that the composite is
# created on access if the object is persistent or if
# col attributes have non-None. This would be better
# if the composite were created unconditionally,
# but that would be a behavioral change.
if self.key not in dict_ and (
state.key is not None or
not _none_set.issuperset(values)
):
dict_[self.key] = self.composite_class(*values)
state.manager.dispatch.refresh(state, None, [self.key])
return dict_.get(self.key, None)
def fset(instance, value):
dict_ = attributes.instance_dict(instance)
state = attributes.instance_state(instance)
attr = state.manager[self.key]
previous = dict_.get(self.key, attributes.NO_VALUE)
for fn in attr.dispatch.set:
value = fn(state, value, previous, attr.impl)
dict_[self.key] = value
if value is None:
for key in self._attribute_keys:
setattr(instance, key, None)
else:
for key, value in zip(
self._attribute_keys,
value.__composite_values__()):
setattr(instance, key, value)
def fdel(instance):
state = attributes.instance_state(instance)
dict_ = attributes.instance_dict(instance)
previous = dict_.pop(self.key, attributes.NO_VALUE)
attr = state.manager[self.key]
attr.dispatch.remove(state, previous, attr.impl)
for key in self._attribute_keys:
setattr(instance, key, None)
self.descriptor = property(fget, fset, fdel)
@util.memoized_property
def _comparable_elements(self):
return [
getattr(self.parent.class_, prop.key)
for prop in self.props
]
@util.memoized_property
def props(self):
props = []
for attr in self.attrs:
if isinstance(attr, str):
prop = self.parent.get_property(
attr, _configure_mappers=False)
elif isinstance(attr, schema.Column):
prop = self.parent._columntoproperty[attr]
elif isinstance(attr, attributes.InstrumentedAttribute):
prop = attr.property
else:
raise sa_exc.ArgumentError(
"Composite expects Column objects or mapped "
"attributes/attribute names as arguments, got: %r"
% (attr,))
props.append(prop)
return props
@property
def columns(self):
return [a for a in self.attrs if isinstance(a, schema.Column)]
def _setup_arguments_on_columns(self):
"""Propagate configuration arguments made on this composite
to the target columns, for those that apply.
"""
for prop in self.props:
prop.active_history = self.active_history
if self.deferred:
prop.deferred = self.deferred
prop.strategy_key = (
("deferred", True),
("instrument", True))
prop.group = self.group
def _setup_event_handlers(self):
"""Establish events that populate/expire the composite attribute."""
def load_handler(state, *args):
dict_ = state.dict
if self.key in dict_:
return
# if column elements aren't loaded, skip.
# __get__() will initiate a load for those
# columns
for k in self._attribute_keys:
if k not in dict_:
return
# assert self.key not in dict_
dict_[self.key] = self.composite_class(
*[state.dict[key] for key in
self._attribute_keys]
)
def expire_handler(state, keys):
if keys is None or set(self._attribute_keys).intersection(keys):
state.dict.pop(self.key, None)
def insert_update_handler(mapper, connection, state):
"""After an insert or update, some columns may be expired due
to server side defaults, or re-populated due to client side
defaults. Pop out the composite value here so that it
recreates.
"""
state.dict.pop(self.key, None)
event.listen(self.parent, 'after_insert',
insert_update_handler, raw=True)
event.listen(self.parent, 'after_update',
insert_update_handler, raw=True)
event.listen(self.parent, 'load',
load_handler, raw=True, propagate=True)
event.listen(self.parent, 'refresh',
load_handler, raw=True, propagate=True)
event.listen(self.parent, 'expire',
expire_handler, raw=True, propagate=True)
# TODO: need a deserialize hook here
@util.memoized_property
def _attribute_keys(self):
return [
prop.key for prop in self.props
]
def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
"""Provided for userland code that uses attributes.get_history()."""
added = []
deleted = []
has_history = False
for prop in self.props:
key = prop.key
hist = state.manager[key].impl.get_history(state, dict_)
if hist.has_changes():
has_history = True
non_deleted = hist.non_deleted()
if non_deleted:
added.extend(non_deleted)
else:
added.append(None)
if hist.deleted:
deleted.extend(hist.deleted)
else:
deleted.append(None)
if has_history:
return attributes.History(
[self.composite_class(*added)],
(),
[self.composite_class(*deleted)]
)
else:
return attributes.History(
(), [self.composite_class(*added)], ()
)
def _comparator_factory(self, mapper):
return self.comparator_factory(self, mapper)
class CompositeBundle(query.Bundle):
def __init__(self, property, expr):
self.property = property
super(CompositeProperty.CompositeBundle, self).__init__(
property.key, *expr)
def create_row_processor(self, query, procs, labels):
def proc(row):
return self.property.composite_class(
*[proc(row) for proc in procs])
return proc
class Comparator(PropComparator):
"""Produce boolean, comparison, and other operators for
:class:`.CompositeProperty` attributes.
See the example in :ref:`composite_operations` for an overview
of usage , as well as the documentation for :class:`.PropComparator`.
See also:
:class:`.PropComparator`
:class:`.ColumnOperators`
:ref:`types_operators`
:attr:`.TypeEngine.comparator_factory`
"""
__hash__ = None
@property
def clauses(self):
return self.__clause_element__()
def __clause_element__(self):
return expression.ClauseList(
group=False, *self._comparable_elements)
def _query_clause_element(self):
return CompositeProperty.CompositeBundle(
self.prop, self.__clause_element__())
@util.memoized_property
def _comparable_elements(self):
if self._adapt_to_entity:
return [
getattr(
self._adapt_to_entity.entity,
prop.key
) for prop in self.prop._comparable_elements
]
else:
return self.prop._comparable_elements
def __eq__(self, other):
if other is None:
values = [None] * len(self.prop._comparable_elements)
else:
values = other.__composite_values__()
comparisons = [
a == b
for a, b in zip(self.prop._comparable_elements, values)
]
if self._adapt_to_entity:
comparisons = [self.adapter(x) for x in comparisons]
return sql.and_(*comparisons)
def __ne__(self, other):
return sql.not_(self.__eq__(other))
def __str__(self):
return str(self.parent.class_.__name__) + "." + self.key
@util.langhelpers.dependency_for("sqlalchemy.orm.properties")
class ConcreteInheritedProperty(DescriptorProperty):
"""A 'do nothing' :class:`.MapperProperty` that disables
an attribute on a concrete subclass that is only present
on the inherited mapper, not the concrete classes' mapper.
Cases where this occurs include:
* When the superclass mapper is mapped against a
"polymorphic union", which includes all attributes from
all subclasses.
* When a relationship() is configured on an inherited mapper,
but not on the subclass mapper. Concrete mappers require
that relationship() is configured explicitly on each
subclass.
"""
def _comparator_factory(self, mapper):
comparator_callable = None
for m in self.parent.iterate_to_root():
p = m._props[self.key]
if not isinstance(p, ConcreteInheritedProperty):
comparator_callable = p.comparator_factory
break
return comparator_callable
def __init__(self):
super(ConcreteInheritedProperty, self).__init__()
def warn():
raise AttributeError("Concrete %s does not implement "
"attribute %r at the instance level. Add "
"this property explicitly to %s." %
(self.parent, self.key, self.parent))
class NoninheritedConcreteProp(object):
def __set__(s, obj, value):
warn()
def __delete__(s, obj):
warn()
def __get__(s, obj, owner):
if obj is None:
return self.descriptor
warn()
self.descriptor = NoninheritedConcreteProp()
@util.langhelpers.dependency_for("sqlalchemy.orm.properties")
class SynonymProperty(DescriptorProperty):
def __init__(self, name, map_column=None,
descriptor=None, comparator_factory=None,
doc=None, info=None):
"""Denote an attribute name as a synonym to a mapped property,
in that the attribute will mirror the value and expression behavior
of another attribute.
:param name: the name of the existing mapped property. This
can refer to the string name of any :class:`.MapperProperty`
configured on the class, including column-bound attributes
and relationships.
:param descriptor: a Python :term:`descriptor` that will be used
as a getter (and potentially a setter) when this attribute is
accessed at the instance level.
:param map_column: if ``True``, the :func:`.synonym` construct will
locate the existing named :class:`.MapperProperty` based on the
attribute name of this :func:`.synonym`, and assign it to a new
attribute linked to the name of this :func:`.synonym`.
That is, given a mapping like::
class MyClass(Base):
__tablename__ = 'my_table'
id = Column(Integer, primary_key=True)
job_status = Column(String(50))
job_status = synonym("_job_status", map_column=True)
The above class ``MyClass`` will now have the ``job_status``
:class:`.Column` object mapped to the attribute named
``_job_status``, and the attribute named ``job_status`` will refer
to the synonym itself. This feature is typically used in
conjunction with the ``descriptor`` argument in order to link a
user-defined descriptor as a "wrapper" for an existing column.
:param info: Optional data dictionary which will be populated into the
:attr:`.InspectionAttr.info` attribute of this object.
.. versionadded:: 1.0.0
:param comparator_factory: A subclass of :class:`.PropComparator`
that will provide custom comparison behavior at the SQL expression
level.
.. note::
For the use case of providing an attribute which redefines both
Python-level and SQL-expression level behavior of an attribute,
please refer to the Hybrid attribute introduced at
:ref:`mapper_hybrids` for a more effective technique.
.. seealso::
:ref:`synonyms` - examples of functionality.
:ref:`mapper_hybrids` - Hybrids provide a better approach for
more complicated attribute-wrapping schemes than synonyms.
"""
super(SynonymProperty, self).__init__()
self.name = name
self.map_column = map_column
self.descriptor = descriptor
self.comparator_factory = comparator_factory
self.doc = doc or (descriptor and descriptor.__doc__) or None
if info:
self.info = info
util.set_creation_order(self)
# TODO: when initialized, check _proxied_property,
# emit a warning if its not a column-based property
@util.memoized_property
def _proxied_property(self):
return getattr(self.parent.class_, self.name).property
def _comparator_factory(self, mapper):
prop = self._proxied_property
if self.comparator_factory:
comp = self.comparator_factory(prop, mapper)
else:
comp = prop.comparator_factory(prop, mapper)
return comp
def set_parent(self, parent, init):
if self.map_column:
# implement the 'map_column' option.
if self.key not in parent.mapped_table.c:
raise sa_exc.ArgumentError(
"Can't compile synonym '%s': no column on table "
"'%s' named '%s'"
% (self.name, parent.mapped_table.description, self.key))
elif parent.mapped_table.c[self.key] in \
parent._columntoproperty and \
parent._columntoproperty[
parent.mapped_table.c[self.key]
].key == self.name:
raise sa_exc.ArgumentError(
"Can't call map_column=True for synonym %r=%r, "
"a ColumnProperty already exists keyed to the name "
"%r for column %r" %
(self.key, self.name, self.name, self.key)
)
p = properties.ColumnProperty(parent.mapped_table.c[self.key])
parent._configure_property(
self.name, p,
init=init,
setparent=True)
p._mapped_by_synonym = self.key
self.parent = parent
@util.langhelpers.dependency_for("sqlalchemy.orm.properties")
class ComparableProperty(DescriptorProperty):
"""Instruments a Python property for use in query expressions."""
def __init__(
self, comparator_factory, descriptor=None, doc=None, info=None):
"""Provides a method of applying a :class:`.PropComparator`
to any Python descriptor attribute.
.. versionchanged:: 0.7
:func:`.comparable_property` is superseded by
the :mod:`~sqlalchemy.ext.hybrid` extension. See the example
at :ref:`hybrid_custom_comparators`.
Allows any Python descriptor to behave like a SQL-enabled
attribute when used at the class level in queries, allowing
redefinition of expression operator behavior.
In the example below we redefine :meth:`.PropComparator.operate`
to wrap both sides of an expression in ``func.lower()`` to produce
case-insensitive comparison::
from sqlalchemy.orm import comparable_property
from sqlalchemy.orm.interfaces import PropComparator
from sqlalchemy.sql import func
from sqlalchemy import Integer, String, Column
from sqlalchemy.ext.declarative import declarative_base
class CaseInsensitiveComparator(PropComparator):
def __clause_element__(self):
return self.prop
def operate(self, op, other):
return op(
func.lower(self.__clause_element__()),
func.lower(other)
)
Base = declarative_base()
class SearchWord(Base):
__tablename__ = 'search_word'
id = Column(Integer, primary_key=True)
word = Column(String)
word_insensitive = comparable_property(lambda prop, mapper:
CaseInsensitiveComparator(
mapper.c.word, mapper)
)
A mapping like the above allows the ``word_insensitive`` attribute
to render an expression like::
>>> print SearchWord.word_insensitive == "Trucks"
lower(search_word.word) = lower(:lower_1)
:param comparator_factory:
A PropComparator subclass or factory that defines operator behavior
for this property.
:param descriptor:
Optional when used in a ``properties={}`` declaration. The Python
descriptor or property to layer comparison behavior on top of.
The like-named descriptor will be automatically retrieved from the
mapped class if left blank in a ``properties`` declaration.
:param info: Optional data dictionary which will be populated into the
:attr:`.InspectionAttr.info` attribute of this object.
.. versionadded:: 1.0.0
"""
super(ComparableProperty, self).__init__()
self.descriptor = descriptor
self.comparator_factory = comparator_factory
self.doc = doc or (descriptor and descriptor.__doc__) or None
if info:
self.info = info
util.set_creation_order(self)
def _comparator_factory(self, mapper):
return self.comparator_factory(self, mapper)

2187
sqlalchemy/orm/events.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,528 @@
# orm/instrumentation.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Defines SQLAlchemy's system of class instrumentation.
This module is usually not directly visible to user applications, but
defines a large part of the ORM's interactivity.
instrumentation.py deals with registration of end-user classes
for state tracking. It interacts closely with state.py
and attributes.py which establish per-instance and per-class-attribute
instrumentation, respectively.
The class instrumentation system can be customized on a per-class
or global basis using the :mod:`sqlalchemy.ext.instrumentation`
module, which provides the means to build and specify
alternate instrumentation forms.
.. versionchanged: 0.8
The instrumentation extension system was moved out of the
ORM and into the external :mod:`sqlalchemy.ext.instrumentation`
package. When that package is imported, it installs
itself within sqlalchemy.orm so that its more comprehensive
resolution mechanics take effect.
"""
from . import exc, collections, interfaces, state
from .. import util
from . import base
_memoized_key_collection = util.group_expirable_memoized_property()
class ClassManager(dict):
"""tracks state information at the class level."""
MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR
STATE_ATTR = base.DEFAULT_STATE_ATTR
_state_setter = staticmethod(util.attrsetter(STATE_ATTR))
deferred_scalar_loader = None
original_init = object.__init__
factory = None
def __init__(self, class_):
self.class_ = class_
self.info = {}
self.new_init = None
self.local_attrs = {}
self.originals = {}
self._bases = [mgr for mgr in [
manager_of_class(base)
for base in self.class_.__bases__
if isinstance(base, type)
] if mgr is not None]
for base in self._bases:
self.update(base)
self.dispatch._events._new_classmanager_instance(class_, self)
# events._InstanceEventsHold.populate(class_, self)
for basecls in class_.__mro__:
mgr = manager_of_class(basecls)
if mgr is not None:
self.dispatch._update(mgr.dispatch)
self.manage()
self._instrument_init()
if '__del__' in class_.__dict__:
util.warn("__del__() method on class %s will "
"cause unreachable cycles and memory leaks, "
"as SQLAlchemy instrumentation often creates "
"reference cycles. Please remove this method." %
class_)
def __hash__(self):
return id(self)
def __eq__(self, other):
return other is self
@property
def is_mapped(self):
return 'mapper' in self.__dict__
@_memoized_key_collection
def _all_key_set(self):
return frozenset(self)
@_memoized_key_collection
def _collection_impl_keys(self):
return frozenset([
attr.key for attr in self.values() if attr.impl.collection])
@_memoized_key_collection
def _scalar_loader_impls(self):
return frozenset([
attr.impl for attr in
self.values() if attr.impl.accepts_scalar_loader])
@util.memoized_property
def mapper(self):
# raises unless self.mapper has been assigned
raise exc.UnmappedClassError(self.class_)
def _all_sqla_attributes(self, exclude=None):
"""return an iterator of all classbound attributes that are
implement :class:`.InspectionAttr`.
This includes :class:`.QueryableAttribute` as well as extension
types such as :class:`.hybrid_property` and
:class:`.AssociationProxy`.
"""
if exclude is None:
exclude = set()
for supercls in self.class_.__mro__:
for key in set(supercls.__dict__).difference(exclude):
exclude.add(key)
val = supercls.__dict__[key]
if isinstance(val, interfaces.InspectionAttr):
yield key, val
def _attr_has_impl(self, key):
"""Return True if the given attribute is fully initialized.
i.e. has an impl.
"""
return key in self and self[key].impl is not None
def _subclass_manager(self, cls):
"""Create a new ClassManager for a subclass of this ClassManager's
class.
This is called automatically when attributes are instrumented so that
the attributes can be propagated to subclasses against their own
class-local manager, without the need for mappers etc. to have already
pre-configured managers for the full class hierarchy. Mappers
can post-configure the auto-generated ClassManager when needed.
"""
manager = manager_of_class(cls)
if manager is None:
manager = _instrumentation_factory.create_manager_for_cls(cls)
return manager
def _instrument_init(self):
# TODO: self.class_.__init__ is often the already-instrumented
# __init__ from an instrumented superclass. We still need to make
# our own wrapper, but it would
# be nice to wrap the original __init__ and not our existing wrapper
# of such, since this adds method overhead.
self.original_init = self.class_.__init__
self.new_init = _generate_init(self.class_, self)
self.install_member('__init__', self.new_init)
def _uninstrument_init(self):
if self.new_init:
self.uninstall_member('__init__')
self.new_init = None
@util.memoized_property
def _state_constructor(self):
self.dispatch.first_init(self, self.class_)
return state.InstanceState
def manage(self):
"""Mark this instance as the manager for its class."""
setattr(self.class_, self.MANAGER_ATTR, self)
def dispose(self):
"""Dissasociate this manager from its class."""
delattr(self.class_, self.MANAGER_ATTR)
@util.hybridmethod
def manager_getter(self):
return _default_manager_getter
@util.hybridmethod
def state_getter(self):
"""Return a (instance) -> InstanceState callable.
"state getter" callables should raise either KeyError or
AttributeError if no InstanceState could be found for the
instance.
"""
return _default_state_getter
@util.hybridmethod
def dict_getter(self):
return _default_dict_getter
def instrument_attribute(self, key, inst, propagated=False):
if propagated:
if key in self.local_attrs:
return # don't override local attr with inherited attr
else:
self.local_attrs[key] = inst
self.install_descriptor(key, inst)
_memoized_key_collection.expire_instance(self)
self[key] = inst
for cls in self.class_.__subclasses__():
manager = self._subclass_manager(cls)
manager.instrument_attribute(key, inst, True)
def subclass_managers(self, recursive):
for cls in self.class_.__subclasses__():
mgr = manager_of_class(cls)
if mgr is not None and mgr is not self:
yield mgr
if recursive:
for m in mgr.subclass_managers(True):
yield m
def post_configure_attribute(self, key):
_instrumentation_factory.dispatch.\
attribute_instrument(self.class_, key, self[key])
def uninstrument_attribute(self, key, propagated=False):
if key not in self:
return
if propagated:
if key in self.local_attrs:
return # don't get rid of local attr
else:
del self.local_attrs[key]
self.uninstall_descriptor(key)
_memoized_key_collection.expire_instance(self)
del self[key]
for cls in self.class_.__subclasses__():
manager = manager_of_class(cls)
if manager:
manager.uninstrument_attribute(key, True)
def unregister(self):
"""remove all instrumentation established by this ClassManager."""
self._uninstrument_init()
self.mapper = self.dispatch = None
self.info.clear()
for key in list(self):
if key in self.local_attrs:
self.uninstrument_attribute(key)
def install_descriptor(self, key, inst):
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
raise KeyError("%r: requested attribute name conflicts with "
"instrumentation attribute of the same name." %
key)
setattr(self.class_, key, inst)
def uninstall_descriptor(self, key):
delattr(self.class_, key)
def install_member(self, key, implementation):
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
raise KeyError("%r: requested attribute name conflicts with "
"instrumentation attribute of the same name." %
key)
self.originals.setdefault(key, getattr(self.class_, key, None))
setattr(self.class_, key, implementation)
def uninstall_member(self, key):
original = self.originals.pop(key, None)
if original is not None:
setattr(self.class_, key, original)
def instrument_collection_class(self, key, collection_class):
return collections.prepare_instrumentation(collection_class)
def initialize_collection(self, key, state, factory):
user_data = factory()
adapter = collections.CollectionAdapter(
self.get_impl(key), state, user_data)
return adapter, user_data
def is_instrumented(self, key, search=False):
if search:
return key in self
else:
return key in self.local_attrs
def get_impl(self, key):
return self[key].impl
@property
def attributes(self):
return iter(self.values())
# InstanceState management
def new_instance(self, state=None):
instance = self.class_.__new__(self.class_)
if state is None:
state = self._state_constructor(instance, self)
self._state_setter(instance, state)
return instance
def setup_instance(self, instance, state=None):
if state is None:
state = self._state_constructor(instance, self)
self._state_setter(instance, state)
def teardown_instance(self, instance):
delattr(instance, self.STATE_ATTR)
def _serialize(self, state, state_dict):
return _SerializeManager(state, state_dict)
def _new_state_if_none(self, instance):
"""Install a default InstanceState if none is present.
A private convenience method used by the __init__ decorator.
"""
if hasattr(instance, self.STATE_ATTR):
return False
elif self.class_ is not instance.__class__ and \
self.is_mapped:
# this will create a new ClassManager for the
# subclass, without a mapper. This is likely a
# user error situation but allow the object
# to be constructed, so that it is usable
# in a non-ORM context at least.
return self._subclass_manager(instance.__class__).\
_new_state_if_none(instance)
else:
state = self._state_constructor(instance, self)
self._state_setter(instance, state)
return state
def has_state(self, instance):
return hasattr(instance, self.STATE_ATTR)
def has_parent(self, state, key, optimistic=False):
"""TODO"""
return self.get_impl(key).hasparent(state, optimistic=optimistic)
def __bool__(self):
"""All ClassManagers are non-zero regardless of attribute state."""
return True
__nonzero__ = __bool__
def __repr__(self):
return '<%s of %r at %x>' % (
self.__class__.__name__, self.class_, id(self))
class _SerializeManager(object):
"""Provide serialization of a :class:`.ClassManager`.
The :class:`.InstanceState` uses ``__init__()`` on serialize
and ``__call__()`` on deserialize.
"""
def __init__(self, state, d):
self.class_ = state.class_
manager = state.manager
manager.dispatch.pickle(state, d)
def __call__(self, state, inst, state_dict):
state.manager = manager = manager_of_class(self.class_)
if manager is None:
raise exc.UnmappedInstanceError(
inst,
"Cannot deserialize object of type %r - "
"no mapper() has "
"been configured for this class within the current "
"Python process!" %
self.class_)
elif manager.is_mapped and not manager.mapper.configured:
manager.mapper._configure_all()
# setup _sa_instance_state ahead of time so that
# unpickle events can access the object normally.
# see [ticket:2362]
if inst is not None:
manager.setup_instance(inst, state)
manager.dispatch.unpickle(state, state_dict)
class InstrumentationFactory(object):
"""Factory for new ClassManager instances."""
def create_manager_for_cls(self, class_):
assert class_ is not None
assert manager_of_class(class_) is None
# give a more complicated subclass
# a chance to do what it wants here
manager, factory = self._locate_extended_factory(class_)
if factory is None:
factory = ClassManager
manager = factory(class_)
self._check_conflicts(class_, factory)
manager.factory = factory
self.dispatch.class_instrument(class_)
return manager
def _locate_extended_factory(self, class_):
"""Overridden by a subclass to do an extended lookup."""
return None, None
def _check_conflicts(self, class_, factory):
"""Overridden by a subclass to test for conflicting factories."""
return
def unregister(self, class_):
manager = manager_of_class(class_)
manager.unregister()
manager.dispose()
self.dispatch.class_uninstrument(class_)
if ClassManager.MANAGER_ATTR in class_.__dict__:
delattr(class_, ClassManager.MANAGER_ATTR)
# this attribute is replaced by sqlalchemy.ext.instrumentation
# when importred.
_instrumentation_factory = InstrumentationFactory()
# these attributes are replaced by sqlalchemy.ext.instrumentation
# when a non-standard InstrumentationManager class is first
# used to instrument a class.
instance_state = _default_state_getter = base.instance_state
instance_dict = _default_dict_getter = base.instance_dict
manager_of_class = _default_manager_getter = base.manager_of_class
def register_class(class_):
"""Register class instrumentation.
Returns the existing or newly created class manager.
"""
manager = manager_of_class(class_)
if manager is None:
manager = _instrumentation_factory.create_manager_for_cls(class_)
return manager
def unregister_class(class_):
"""Unregister class instrumentation."""
_instrumentation_factory.unregister(class_)
def is_instrumented(instance, key):
"""Return True if the given attribute on the given instance is
instrumented by the attributes package.
This function may be used regardless of instrumentation
applied directly to the class, i.e. no descriptors are required.
"""
return manager_of_class(instance.__class__).\
is_instrumented(key, search=True)
def _generate_init(class_, class_manager):
"""Build an __init__ decorator that triggers ClassManager events."""
# TODO: we should use the ClassManager's notion of the
# original '__init__' method, once ClassManager is fixed
# to always reference that.
original__init__ = class_.__init__
assert original__init__
# Go through some effort here and don't change the user's __init__
# calling signature, including the unlikely case that it has
# a return value.
# FIXME: need to juggle local names to avoid constructor argument
# clashes.
func_body = """\
def __init__(%(apply_pos)s):
new_state = class_manager._new_state_if_none(%(self_arg)s)
if new_state:
return new_state._initialize_instance(%(apply_kw)s)
else:
return original__init__(%(apply_kw)s)
"""
func_vars = util.format_argspec_init(original__init__, grouped=False)
func_text = func_body % func_vars
if util.py2k:
func = getattr(original__init__, 'im_func', original__init__)
func_defaults = getattr(func, 'func_defaults', None)
else:
func_defaults = getattr(original__init__, '__defaults__', None)
func_kw_defaults = getattr(original__init__, '__kwdefaults__', None)
env = locals().copy()
exec(func_text, env)
__init__ = env['__init__']
__init__.__doc__ = original__init__.__doc__
if func_defaults:
__init__.__defaults__ = func_defaults
if not util.py2k and func_kw_defaults:
__init__.__kwdefaults__ = func_kw_defaults
return __init__

703
sqlalchemy/orm/loading.py Normal file
View File

@ -0,0 +1,703 @@
# orm/loading.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""private module containing functions used to convert database
rows into object instances and associated state.
the functions here are called primarily by Query, Mapper,
as well as some of the attribute loading strategies.
"""
from __future__ import absolute_import
from .. import util
from . import attributes, exc as orm_exc
from ..sql import util as sql_util
from . import strategy_options
from .util import _none_set, state_str
from .base import _SET_DEFERRED_EXPIRED, _DEFER_FOR_STATE
from .. import exc as sa_exc
import collections
_new_runid = util.counter()
def instances(query, cursor, context):
"""Return an ORM result as an iterator."""
context.runid = _new_runid()
filtered = query._has_mapper_entities
single_entity = len(query._entities) == 1 and \
query._entities[0].supports_single_entity
if filtered:
if single_entity:
filter_fn = id
else:
def filter_fn(row):
return tuple(
id(item)
if ent.use_id_for_hash
else item
for ent, item in zip(query._entities, row)
)
try:
(process, labels) = \
list(zip(*[
query_entity.row_processor(query,
context, cursor)
for query_entity in query._entities
]))
if not single_entity:
keyed_tuple = util.lightweight_named_tuple('result', labels)
while True:
context.partials = {}
if query._yield_per:
fetch = cursor.fetchmany(query._yield_per)
if not fetch:
break
else:
fetch = cursor.fetchall()
if single_entity:
proc = process[0]
rows = [proc(row) for row in fetch]
else:
rows = [keyed_tuple([proc(row) for proc in process])
for row in fetch]
if filtered:
rows = util.unique_list(rows, filter_fn)
for row in rows:
yield row
if not query._yield_per:
break
except Exception as err:
cursor.close()
util.raise_from_cause(err)
@util.dependencies("sqlalchemy.orm.query")
def merge_result(querylib, query, iterator, load=True):
"""Merge a result into this :class:`.Query` object's Session."""
session = query.session
if load:
# flush current contents if we expect to load data
session._autoflush()
autoflush = session.autoflush
try:
session.autoflush = False
single_entity = len(query._entities) == 1
if single_entity:
if isinstance(query._entities[0], querylib._MapperEntity):
result = [session._merge(
attributes.instance_state(instance),
attributes.instance_dict(instance),
load=load, _recursive={}, _resolve_conflict_map={})
for instance in iterator]
else:
result = list(iterator)
else:
mapped_entities = [i for i, e in enumerate(query._entities)
if isinstance(e, querylib._MapperEntity)]
result = []
keys = [ent._label_name for ent in query._entities]
keyed_tuple = util.lightweight_named_tuple('result', keys)
for row in iterator:
newrow = list(row)
for i in mapped_entities:
if newrow[i] is not None:
newrow[i] = session._merge(
attributes.instance_state(newrow[i]),
attributes.instance_dict(newrow[i]),
load=load, _recursive={}, _resolve_conflict_map={})
result.append(keyed_tuple(newrow))
return iter(result)
finally:
session.autoflush = autoflush
def get_from_identity(session, key, passive):
"""Look up the given key in the given session's identity map,
check the object for expired state if found.
"""
instance = session.identity_map.get(key)
if instance is not None:
state = attributes.instance_state(instance)
# expired - ensure it still exists
if state.expired:
if not passive & attributes.SQL_OK:
# TODO: no coverage here
return attributes.PASSIVE_NO_RESULT
elif not passive & attributes.RELATED_OBJECT_OK:
# this mode is used within a flush and the instance's
# expired state will be checked soon enough, if necessary
return instance
try:
state._load_expired(state, passive)
except orm_exc.ObjectDeletedError:
session._remove_newly_deleted([state])
return None
return instance
else:
return None
def load_on_ident(query, key,
refresh_state=None, lockmode=None,
only_load_props=None):
"""Load the given identity key from the database."""
if key is not None:
ident = key[1]
else:
ident = None
if refresh_state is None:
q = query._clone()
q._get_condition()
else:
q = query._clone()
if ident is not None:
mapper = query._mapper_zero()
(_get_clause, _get_params) = mapper._get_clause
# None present in ident - turn those comparisons
# into "IS NULL"
if None in ident:
nones = set([
_get_params[col].key for col, value in
zip(mapper.primary_key, ident) if value is None
])
_get_clause = sql_util.adapt_criterion_to_null(
_get_clause, nones)
_get_clause = q._adapt_clause(_get_clause, True, False)
q._criterion = _get_clause
params = dict([
(_get_params[primary_key].key, id_val)
for id_val, primary_key in zip(ident, mapper.primary_key)
])
q._params = params
if lockmode is not None:
version_check = True
q = q.with_lockmode(lockmode)
elif query._for_update_arg is not None:
version_check = True
q._for_update_arg = query._for_update_arg
else:
version_check = False
q._get_options(
populate_existing=bool(refresh_state),
version_check=version_check,
only_load_props=only_load_props,
refresh_state=refresh_state)
q._order_by = None
try:
return q.one()
except orm_exc.NoResultFound:
return None
def _setup_entity_query(
context, mapper, query_entity,
path, adapter, column_collection,
with_polymorphic=None, only_load_props=None,
polymorphic_discriminator=None, **kw):
if with_polymorphic:
poly_properties = mapper._iterate_polymorphic_properties(
with_polymorphic)
else:
poly_properties = mapper._polymorphic_properties
quick_populators = {}
path.set(
context.attributes,
"memoized_setups",
quick_populators)
for value in poly_properties:
if only_load_props and \
value.key not in only_load_props:
continue
value.setup(
context,
query_entity,
path,
adapter,
only_load_props=only_load_props,
column_collection=column_collection,
memoized_populators=quick_populators,
**kw
)
if polymorphic_discriminator is not None and \
polymorphic_discriminator \
is not mapper.polymorphic_on:
if adapter:
pd = adapter.columns[polymorphic_discriminator]
else:
pd = polymorphic_discriminator
column_collection.append(pd)
def _instance_processor(
mapper, context, result, path, adapter,
only_load_props=None, refresh_state=None,
polymorphic_discriminator=None,
_polymorphic_from=None):
"""Produce a mapper level row processor callable
which processes rows into mapped instances."""
# note that this method, most of which exists in a closure
# called _instance(), resists being broken out, as
# attempts to do so tend to add significant function
# call overhead. _instance() is the most
# performance-critical section in the whole ORM.
pk_cols = mapper.primary_key
if adapter:
pk_cols = [adapter.columns[c] for c in pk_cols]
identity_class = mapper._identity_class
populators = collections.defaultdict(list)
props = mapper._prop_set
if only_load_props is not None:
props = props.intersection(
mapper._props[k] for k in only_load_props)
quick_populators = path.get(
context.attributes, "memoized_setups", _none_set)
for prop in props:
if prop in quick_populators:
# this is an inlined path just for column-based attributes.
col = quick_populators[prop]
if col is _DEFER_FOR_STATE:
populators["new"].append(
(prop.key, prop._deferred_column_loader))
elif col is _SET_DEFERRED_EXPIRED:
# note that in this path, we are no longer
# searching in the result to see if the column might
# be present in some unexpected way.
populators["expire"].append((prop.key, False))
else:
if adapter:
col = adapter.columns[col]
getter = result._getter(col, False)
if getter:
populators["quick"].append((prop.key, getter))
else:
# fall back to the ColumnProperty itself, which
# will iterate through all of its columns
# to see if one fits
prop.create_row_processor(
context, path, mapper, result, adapter, populators)
else:
prop.create_row_processor(
context, path, mapper, result, adapter, populators)
propagate_options = context.propagate_options
load_path = context.query._current_path + path \
if context.query._current_path.path else path
session_identity_map = context.session.identity_map
populate_existing = context.populate_existing or mapper.always_refresh
load_evt = bool(mapper.class_manager.dispatch.load)
refresh_evt = bool(mapper.class_manager.dispatch.refresh)
persistent_evt = bool(context.session.dispatch.loaded_as_persistent)
if persistent_evt:
loaded_as_persistent = context.session.dispatch.loaded_as_persistent
instance_state = attributes.instance_state
instance_dict = attributes.instance_dict
session_id = context.session.hash_key
version_check = context.version_check
runid = context.runid
if refresh_state:
refresh_identity_key = refresh_state.key
if refresh_identity_key is None:
# super-rare condition; a refresh is being called
# on a non-instance-key instance; this is meant to only
# occur within a flush()
refresh_identity_key = \
mapper._identity_key_from_state(refresh_state)
else:
refresh_identity_key = None
if mapper.allow_partial_pks:
is_not_primary_key = _none_set.issuperset
else:
is_not_primary_key = _none_set.intersection
def _instance(row):
# determine the state that we'll be populating
if refresh_identity_key:
# fixed state that we're refreshing
state = refresh_state
instance = state.obj()
dict_ = instance_dict(instance)
isnew = state.runid != runid
currentload = True
loaded_instance = False
else:
# look at the row, see if that identity is in the
# session, or we have to create a new one
identitykey = (
identity_class,
tuple([row[column] for column in pk_cols])
)
instance = session_identity_map.get(identitykey)
if instance is not None:
# existing instance
state = instance_state(instance)
dict_ = instance_dict(instance)
isnew = state.runid != runid
currentload = not isnew
loaded_instance = False
if version_check and not currentload:
_validate_version_id(mapper, state, dict_, row, adapter)
else:
# create a new instance
# check for non-NULL values in the primary key columns,
# else no entity is returned for the row
if is_not_primary_key(identitykey[1]):
return None
isnew = True
currentload = True
loaded_instance = True
instance = mapper.class_manager.new_instance()
dict_ = instance_dict(instance)
state = instance_state(instance)
state.key = identitykey
# attach instance to session.
state.session_id = session_id
session_identity_map._add_unpresent(state, identitykey)
# populate. this looks at whether this state is new
# for this load or was existing, and whether or not this
# row is the first row with this identity.
if currentload or populate_existing:
# full population routines. Objects here are either
# just created, or we are doing a populate_existing
# be conservative about setting load_path when populate_existing
# is in effect; want to maintain options from the original
# load. see test_expire->test_refresh_maintains_deferred_options
if isnew and (propagate_options or not populate_existing):
state.load_options = propagate_options
state.load_path = load_path
_populate_full(
context, row, state, dict_, isnew, load_path,
loaded_instance, populate_existing, populators)
if isnew:
if loaded_instance:
if load_evt:
state.manager.dispatch.load(state, context)
if persistent_evt:
loaded_as_persistent(context.session, state.obj())
elif refresh_evt:
state.manager.dispatch.refresh(
state, context, only_load_props)
if populate_existing or state.modified:
if refresh_state and only_load_props:
state._commit(dict_, only_load_props)
else:
state._commit_all(dict_, session_identity_map)
else:
# partial population routines, for objects that were already
# in the Session, but a row matches them; apply eager loaders
# on existing objects, etc.
unloaded = state.unloaded
isnew = state not in context.partials
if not isnew or unloaded or populators["eager"]:
# state is having a partial set of its attributes
# refreshed. Populate those attributes,
# and add to the "context.partials" collection.
to_load = _populate_partial(
context, row, state, dict_, isnew, load_path,
unloaded, populators)
if isnew:
if refresh_evt:
state.manager.dispatch.refresh(
state, context, to_load)
state._commit(dict_, to_load)
return instance
if mapper.polymorphic_map and not _polymorphic_from and not refresh_state:
# if we are doing polymorphic, dispatch to a different _instance()
# method specific to the subclass mapper
_instance = _decorate_polymorphic_switch(
_instance, context, mapper, result, path,
polymorphic_discriminator, adapter)
return _instance
def _populate_full(
context, row, state, dict_, isnew, load_path,
loaded_instance, populate_existing, populators):
if isnew:
# first time we are seeing a row with this identity.
state.runid = context.runid
for key, getter in populators["quick"]:
dict_[key] = getter(row)
if populate_existing:
for key, set_callable in populators["expire"]:
dict_.pop(key, None)
if set_callable:
state.expired_attributes.add(key)
else:
for key, set_callable in populators["expire"]:
if set_callable:
state.expired_attributes.add(key)
for key, populator in populators["new"]:
populator(state, dict_, row)
for key, populator in populators["delayed"]:
populator(state, dict_, row)
elif load_path != state.load_path:
# new load path, e.g. object is present in more than one
# column position in a series of rows
state.load_path = load_path
# if we have data, and the data isn't in the dict, OK, let's put
# it in.
for key, getter in populators["quick"]:
if key not in dict_:
dict_[key] = getter(row)
# otherwise treat like an "already seen" row
for key, populator in populators["existing"]:
populator(state, dict_, row)
# TODO: allow "existing" populator to know this is
# a new path for the state:
# populator(state, dict_, row, new_path=True)
else:
# have already seen rows with this identity in this same path.
for key, populator in populators["existing"]:
populator(state, dict_, row)
# TODO: same path
# populator(state, dict_, row, new_path=False)
def _populate_partial(
context, row, state, dict_, isnew, load_path,
unloaded, populators):
if not isnew:
to_load = context.partials[state]
for key, populator in populators["existing"]:
if key in to_load:
populator(state, dict_, row)
else:
to_load = unloaded
context.partials[state] = to_load
for key, getter in populators["quick"]:
if key in to_load:
dict_[key] = getter(row)
for key, set_callable in populators["expire"]:
if key in to_load:
dict_.pop(key, None)
if set_callable:
state.expired_attributes.add(key)
for key, populator in populators["new"]:
if key in to_load:
populator(state, dict_, row)
for key, populator in populators["delayed"]:
if key in to_load:
populator(state, dict_, row)
for key, populator in populators["eager"]:
if key not in unloaded:
populator(state, dict_, row)
return to_load
def _validate_version_id(mapper, state, dict_, row, adapter):
version_id_col = mapper.version_id_col
if version_id_col is None:
return
if adapter:
version_id_col = adapter.columns[version_id_col]
if mapper._get_state_attr_by_column(
state, dict_, mapper.version_id_col) != row[version_id_col]:
raise orm_exc.StaleDataError(
"Instance '%s' has version id '%s' which "
"does not match database-loaded version id '%s'."
% (state_str(state), mapper._get_state_attr_by_column(
state, dict_, mapper.version_id_col),
row[version_id_col]))
def _decorate_polymorphic_switch(
instance_fn, context, mapper, result, path,
polymorphic_discriminator, adapter):
if polymorphic_discriminator is not None:
polymorphic_on = polymorphic_discriminator
else:
polymorphic_on = mapper.polymorphic_on
if polymorphic_on is None:
return instance_fn
if adapter:
polymorphic_on = adapter.columns[polymorphic_on]
def configure_subclass_mapper(discriminator):
try:
sub_mapper = mapper.polymorphic_map[discriminator]
except KeyError:
raise AssertionError(
"No such polymorphic_identity %r is defined" %
discriminator)
else:
if sub_mapper is mapper:
return None
return _instance_processor(
sub_mapper, context, result,
path, adapter, _polymorphic_from=mapper)
polymorphic_instances = util.PopulateDict(
configure_subclass_mapper
)
def polymorphic_instance(row):
discriminator = row[polymorphic_on]
if discriminator is not None:
_instance = polymorphic_instances[discriminator]
if _instance:
return _instance(row)
return instance_fn(row)
return polymorphic_instance
def load_scalar_attributes(mapper, state, attribute_names):
"""initiate a column-based attribute refresh operation."""
# assert mapper is _state_mapper(state)
session = state.session
if not session:
raise orm_exc.DetachedInstanceError(
"Instance %s is not bound to a Session; "
"attribute refresh operation cannot proceed" %
(state_str(state)))
has_key = bool(state.key)
result = False
if mapper.inherits and not mapper.concrete:
# because we are using Core to produce a select() that we
# pass to the Query, we aren't calling setup() for mapped
# attributes; in 1.0 this means deferred attrs won't get loaded
# by default
statement = mapper._optimized_get_statement(state, attribute_names)
if statement is not None:
result = load_on_ident(
session.query(mapper).
options(
strategy_options.Load(mapper).undefer("*")
).from_statement(statement),
None,
only_load_props=attribute_names,
refresh_state=state
)
if result is False:
if has_key:
identity_key = state.key
else:
# this codepath is rare - only valid when inside a flush, and the
# object is becoming persistent but hasn't yet been assigned
# an identity_key.
# check here to ensure we have the attrs we need.
pk_attrs = [mapper._columntoproperty[col].key
for col in mapper.primary_key]
if state.expired_attributes.intersection(pk_attrs):
raise sa_exc.InvalidRequestError(
"Instance %s cannot be refreshed - it's not "
" persistent and does not "
"contain a full primary key." % state_str(state))
identity_key = mapper._identity_key_from_state(state)
if (_none_set.issubset(identity_key) and
not mapper.allow_partial_pks) or \
_none_set.issuperset(identity_key):
util.warn_limited(
"Instance %s to be refreshed doesn't "
"contain a full primary key - can't be refreshed "
"(and shouldn't be expired, either).",
state_str(state))
return
result = load_on_ident(
session.query(mapper),
identity_key,
refresh_state=state,
only_load_props=attribute_names)
# if instance is pending, a refresh operation
# may not complete (even if PK attributes are assigned)
if has_key and result is None:
raise orm_exc.ObjectDeletedError(state)

View File

@ -0,0 +1,271 @@
# orm/path_registry.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Path tracking utilities, representing mapper graph traversals.
"""
from .. import inspection
from .. import util
from .. import exc
from itertools import chain
from .base import class_mapper
import logging
log = logging.getLogger(__name__)
def _unreduce_path(path):
return PathRegistry.deserialize(path)
_WILDCARD_TOKEN = "*"
_DEFAULT_TOKEN = "_sa_default"
class PathRegistry(object):
"""Represent query load paths and registry functions.
Basically represents structures like:
(<User mapper>, "orders", <Order mapper>, "items", <Item mapper>)
These structures are generated by things like
query options (joinedload(), subqueryload(), etc.) and are
used to compose keys stored in the query._attributes dictionary
for various options.
They are then re-composed at query compile/result row time as
the query is formed and as rows are fetched, where they again
serve to compose keys to look up options in the context.attributes
dictionary, which is copied from query._attributes.
The path structure has a limited amount of caching, where each
"root" ultimately pulls from a fixed registry associated with
the first mapper, that also contains elements for each of its
property keys. However paths longer than two elements, which
are the exception rather than the rule, are generated on an
as-needed basis.
"""
is_token = False
is_root = False
def __eq__(self, other):
return other is not None and \
self.path == other.path
def set(self, attributes, key, value):
log.debug("set '%s' on path '%s' to '%s'", key, self, value)
attributes[(key, self.path)] = value
def setdefault(self, attributes, key, value):
log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value)
attributes.setdefault((key, self.path), value)
def get(self, attributes, key, value=None):
key = (key, self.path)
if key in attributes:
return attributes[key]
else:
return value
def __len__(self):
return len(self.path)
@property
def length(self):
return len(self.path)
def pairs(self):
path = self.path
for i in range(0, len(path), 2):
yield path[i], path[i + 1]
def contains_mapper(self, mapper):
for path_mapper in [
self.path[i] for i in range(0, len(self.path), 2)
]:
if path_mapper.is_mapper and \
path_mapper.isa(mapper):
return True
else:
return False
def contains(self, attributes, key):
return (key, self.path) in attributes
def __reduce__(self):
return _unreduce_path, (self.serialize(), )
def serialize(self):
path = self.path
return list(zip(
[m.class_ for m in [path[i] for i in range(0, len(path), 2)]],
[path[i].key for i in range(1, len(path), 2)] + [None]
))
@classmethod
def deserialize(cls, path):
if path is None:
return None
p = tuple(chain(*[(class_mapper(mcls),
class_mapper(mcls).attrs[key]
if key is not None else None)
for mcls, key in path]))
if p and p[-1] is None:
p = p[0:-1]
return cls.coerce(p)
@classmethod
def per_mapper(cls, mapper):
return EntityRegistry(
cls.root, mapper
)
@classmethod
def coerce(cls, raw):
return util.reduce(lambda prev, next: prev[next], raw, cls.root)
def token(self, token):
if token.endswith(':' + _WILDCARD_TOKEN):
return TokenRegistry(self, token)
elif token.endswith(":" + _DEFAULT_TOKEN):
return TokenRegistry(self.root, token)
else:
raise exc.ArgumentError("invalid token: %s" % token)
def __add__(self, other):
return util.reduce(
lambda prev, next: prev[next],
other.path, self)
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self.path, )
class RootRegistry(PathRegistry):
"""Root registry, defers to mappers so that
paths are maintained per-root-mapper.
"""
path = ()
has_entity = False
is_aliased_class = False
is_root = True
def __getitem__(self, entity):
return entity._path_registry
PathRegistry.root = RootRegistry()
class TokenRegistry(PathRegistry):
def __init__(self, parent, token):
self.token = token
self.parent = parent
self.path = parent.path + (token,)
has_entity = False
is_token = True
def generate_for_superclasses(self):
if not self.parent.is_aliased_class and not self.parent.is_root:
for ent in self.parent.mapper.iterate_to_root():
yield TokenRegistry(self.parent.parent[ent], self.token)
else:
yield self
def __getitem__(self, entity):
raise NotImplementedError()
class PropRegistry(PathRegistry):
def __init__(self, parent, prop):
# restate this path in terms of the
# given MapperProperty's parent.
insp = inspection.inspect(parent[-1])
if not insp.is_aliased_class or insp._use_mapper_path:
parent = parent.parent[prop.parent]
elif insp.is_aliased_class and insp.with_polymorphic_mappers:
if prop.parent is not insp.mapper and \
prop.parent in insp.with_polymorphic_mappers:
subclass_entity = parent[-1]._entity_for_mapper(prop.parent)
parent = parent.parent[subclass_entity]
self.prop = prop
self.parent = parent
self.path = parent.path + (prop,)
self._wildcard_path_loader_key = (
"loader",
self.parent.path + self.prop._wildcard_token
)
self._default_path_loader_key = self.prop._default_path_loader_key
self._loader_key = ("loader", self.path)
def __str__(self):
return " -> ".join(
str(elem) for elem in self.path
)
@util.memoized_property
def has_entity(self):
return hasattr(self.prop, "mapper")
@util.memoized_property
def entity(self):
return self.prop.mapper
@property
def mapper(self):
return self.entity
@property
def entity_path(self):
return self[self.entity]
def __getitem__(self, entity):
if isinstance(entity, (int, slice)):
return self.path[entity]
else:
return EntityRegistry(
self, entity
)
class EntityRegistry(PathRegistry, dict):
is_aliased_class = False
has_entity = True
def __init__(self, parent, entity):
self.key = entity
self.parent = parent
self.is_aliased_class = entity.is_aliased_class
self.entity = entity
self.path = parent.path + (entity,)
self.entity_path = self
@property
def mapper(self):
return inspection.inspect(self.entity).mapper
def __bool__(self):
return True
__nonzero__ = __bool__
def __getitem__(self, entity):
if isinstance(entity, (int, slice)):
return self.path[entity]
else:
return dict.__getitem__(self, entity)
def __missing__(self, key):
self[key] = item = PropRegistry(self, key)
return item

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,203 @@
# sql/annotation.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""The :class:`.Annotated` class and related routines; creates hash-equivalent
copies of SQL constructs which contain context-specific markers and
associations.
"""
from .. import util
from . import operators
class Annotated(object):
"""clones a ClauseElement and applies an 'annotations' dictionary.
Unlike regular clones, this clone also mimics __hash__() and
__cmp__() of the original element so that it takes its place
in hashed collections.
A reference to the original element is maintained, for the important
reason of keeping its hash value current. When GC'ed, the
hash value may be reused, causing conflicts.
.. note:: The rationale for Annotated producing a brand new class,
rather than placing the functionality directly within ClauseElement,
is **performance**. The __hash__() method is absent on plain
ClauseElement which leads to significantly reduced function call
overhead, as the use of sets and dictionaries against ClauseElement
objects is prevalent, but most are not "annotated".
"""
def __new__(cls, *args):
if not args:
# clone constructor
return object.__new__(cls)
else:
element, values = args
# pull appropriate subclass from registry of annotated
# classes
try:
cls = annotated_classes[element.__class__]
except KeyError:
cls = _new_annotation_type(element.__class__, cls)
return object.__new__(cls)
def __init__(self, element, values):
self.__dict__ = element.__dict__.copy()
self.__element = element
self._annotations = values
self._hash = hash(element)
def _annotate(self, values):
_values = self._annotations.copy()
_values.update(values)
return self._with_annotations(_values)
def _with_annotations(self, values):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
clone._annotations = values
return clone
def _deannotate(self, values=None, clone=True):
if values is None:
return self.__element
else:
_values = self._annotations.copy()
for v in values:
_values.pop(v, None)
return self._with_annotations(_values)
def _compiler_dispatch(self, visitor, **kw):
return self.__element.__class__._compiler_dispatch(
self, visitor, **kw)
@property
def _constructor(self):
return self.__element._constructor
def _clone(self):
clone = self.__element._clone()
if clone is self.__element:
# detect immutable, don't change anything
return self
else:
# update the clone with any changes that have occurred
# to this object's __dict__.
clone.__dict__.update(self.__dict__)
return self.__class__(clone, self._annotations)
def __hash__(self):
return self._hash
def __eq__(self, other):
if isinstance(self.__element, operators.ColumnOperators):
return self.__element.__class__.__eq__(self, other)
else:
return hash(other) == hash(self)
# hard-generate Annotated subclasses. this technique
# is used instead of on-the-fly types (i.e. type.__new__())
# so that the resulting objects are pickleable.
annotated_classes = {}
def _deep_annotate(element, annotations, exclude=None):
"""Deep copy the given ClauseElement, annotating each element
with the given annotations dictionary.
Elements within the exclude collection will be cloned but not annotated.
"""
def clone(elem):
if exclude and \
hasattr(elem, 'proxy_set') and \
elem.proxy_set.intersection(exclude):
newelem = elem._clone()
elif annotations != elem._annotations:
newelem = elem._annotate(annotations)
else:
newelem = elem
newelem._copy_internals(clone=clone)
return newelem
if element is not None:
element = clone(element)
return element
def _deep_deannotate(element, values=None):
"""Deep copy the given element, removing annotations."""
cloned = util.column_dict()
def clone(elem):
# if a values dict is given,
# the elem must be cloned each time it appears,
# as there may be different annotations in source
# elements that are remaining. if totally
# removing all annotations, can assume the same
# slate...
if values or elem not in cloned:
newelem = elem._deannotate(values=values, clone=True)
newelem._copy_internals(clone=clone)
if not values:
cloned[elem] = newelem
return newelem
else:
return cloned[elem]
if element is not None:
element = clone(element)
return element
def _shallow_annotate(element, annotations):
"""Annotate the given ClauseElement and copy its internals so that
internal objects refer to the new annotated object.
Basically used to apply a "dont traverse" annotation to a
selectable, without digging throughout the whole
structure wasting time.
"""
element = element._annotate(annotations)
element._copy_internals()
return element
def _new_annotation_type(cls, base_cls):
if issubclass(cls, Annotated):
return cls
elif cls in annotated_classes:
return annotated_classes[cls]
for super_ in cls.__mro__:
# check if an Annotated subclass more specific than
# the given base_cls is already registered, such
# as AnnotatedColumnElement.
if super_ in annotated_classes:
base_cls = annotated_classes[super_]
break
annotated_classes[cls] = anno_cls = type(
"Annotated%s" % cls.__name__,
(base_cls, cls), {})
globals()["Annotated%s" % cls.__name__] = anno_cls
return anno_cls
def _prepare_annotations(target_hierarchy, base_cls):
stack = [target_hierarchy]
while stack:
cls = stack.pop()
stack.extend(cls.__subclasses__())
_new_annotation_type(cls, base_cls)

633
sqlalchemy/sql/base.py Normal file
View File

@ -0,0 +1,633 @@
# sql/base.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Foundational utilities common to many sql modules.
"""
from .. import util, exc
import itertools
from .visitors import ClauseVisitor
import re
import collections
PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT')
NO_ARG = util.symbol('NO_ARG')
class Immutable(object):
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
def unique_params(self, *optionaldict, **kwargs):
raise NotImplementedError("Immutable objects do not support copying")
def params(self, *optionaldict, **kwargs):
raise NotImplementedError("Immutable objects do not support copying")
def _clone(self):
return self
def _from_objects(*elements):
return itertools.chain(*[element._from_objects for element in elements])
@util.decorator
def _generative(fn, *args, **kw):
"""Mark a method as generative."""
self = args[0]._generate()
fn(self, *args[1:], **kw)
return self
class _DialectArgView(collections.MutableMapping):
"""A dictionary view of dialect-level arguments in the form
<dialectname>_<argument_name>.
"""
def __init__(self, obj):
self.obj = obj
def _key(self, key):
try:
dialect, value_key = key.split("_", 1)
except ValueError:
raise KeyError(key)
else:
return dialect, value_key
def __getitem__(self, key):
dialect, value_key = self._key(key)
try:
opt = self.obj.dialect_options[dialect]
except exc.NoSuchModuleError:
raise KeyError(key)
else:
return opt[value_key]
def __setitem__(self, key, value):
try:
dialect, value_key = self._key(key)
except KeyError:
raise exc.ArgumentError(
"Keys must be of the form <dialectname>_<argname>")
else:
self.obj.dialect_options[dialect][value_key] = value
def __delitem__(self, key):
dialect, value_key = self._key(key)
del self.obj.dialect_options[dialect][value_key]
def __len__(self):
return sum(len(args._non_defaults) for args in
self.obj.dialect_options.values())
def __iter__(self):
return (
util.safe_kwarg("%s_%s" % (dialect_name, value_name))
for dialect_name in self.obj.dialect_options
for value_name in
self.obj.dialect_options[dialect_name]._non_defaults
)
class _DialectArgDict(collections.MutableMapping):
"""A dictionary view of dialect-level arguments for a specific
dialect.
Maintains a separate collection of user-specified arguments
and dialect-specified default arguments.
"""
def __init__(self):
self._non_defaults = {}
self._defaults = {}
def __len__(self):
return len(set(self._non_defaults).union(self._defaults))
def __iter__(self):
return iter(set(self._non_defaults).union(self._defaults))
def __getitem__(self, key):
if key in self._non_defaults:
return self._non_defaults[key]
else:
return self._defaults[key]
def __setitem__(self, key, value):
self._non_defaults[key] = value
def __delitem__(self, key):
del self._non_defaults[key]
class DialectKWArgs(object):
"""Establish the ability for a class to have dialect-specific arguments
with defaults and constructor validation.
The :class:`.DialectKWArgs` interacts with the
:attr:`.DefaultDialect.construct_arguments` present on a dialect.
.. seealso::
:attr:`.DefaultDialect.construct_arguments`
"""
@classmethod
def argument_for(cls, dialect_name, argument_name, default):
"""Add a new kind of dialect-specific keyword argument for this class.
E.g.::
Index.argument_for("mydialect", "length", None)
some_index = Index('a', 'b', mydialect_length=5)
The :meth:`.DialectKWArgs.argument_for` method is a per-argument
way adding extra arguments to the
:attr:`.DefaultDialect.construct_arguments` dictionary. This
dictionary provides a list of argument names accepted by various
schema-level constructs on behalf of a dialect.
New dialects should typically specify this dictionary all at once as a
data member of the dialect class. The use case for ad-hoc addition of
argument names is typically for end-user code that is also using
a custom compilation scheme which consumes the additional arguments.
:param dialect_name: name of a dialect. The dialect must be
locatable, else a :class:`.NoSuchModuleError` is raised. The
dialect must also include an existing
:attr:`.DefaultDialect.construct_arguments` collection, indicating
that it participates in the keyword-argument validation and default
system, else :class:`.ArgumentError` is raised. If the dialect does
not include this collection, then any keyword argument can be
specified on behalf of this dialect already. All dialects packaged
within SQLAlchemy include this collection, however for third party
dialects, support may vary.
:param argument_name: name of the parameter.
:param default: default value of the parameter.
.. versionadded:: 0.9.4
"""
construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
if construct_arg_dictionary is None:
raise exc.ArgumentError(
"Dialect '%s' does have keyword-argument "
"validation and defaults enabled configured" %
dialect_name)
if cls not in construct_arg_dictionary:
construct_arg_dictionary[cls] = {}
construct_arg_dictionary[cls][argument_name] = default
@util.memoized_property
def dialect_kwargs(self):
"""A collection of keyword arguments specified as dialect-specific
options to this construct.
The arguments are present here in their original ``<dialect>_<kwarg>``
format. Only arguments that were actually passed are included;
unlike the :attr:`.DialectKWArgs.dialect_options` collection, which
contains all options known by this dialect including defaults.
The collection is also writable; keys are accepted of the
form ``<dialect>_<kwarg>`` where the value will be assembled
into the list of options.
.. versionadded:: 0.9.2
.. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs`
collection is now writable.
.. seealso::
:attr:`.DialectKWArgs.dialect_options` - nested dictionary form
"""
return _DialectArgView(self)
@property
def kwargs(self):
"""A synonym for :attr:`.DialectKWArgs.dialect_kwargs`."""
return self.dialect_kwargs
@util.dependencies("sqlalchemy.dialects")
def _kw_reg_for_dialect(dialects, dialect_name):
dialect_cls = dialects.registry.load(dialect_name)
if dialect_cls.construct_arguments is None:
return None
return dict(dialect_cls.construct_arguments)
_kw_registry = util.PopulateDict(_kw_reg_for_dialect)
def _kw_reg_for_dialect_cls(self, dialect_name):
construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
d = _DialectArgDict()
if construct_arg_dictionary is None:
d._defaults.update({"*": None})
else:
for cls in reversed(self.__class__.__mro__):
if cls in construct_arg_dictionary:
d._defaults.update(construct_arg_dictionary[cls])
return d
@util.memoized_property
def dialect_options(self):
"""A collection of keyword arguments specified as dialect-specific
options to this construct.
This is a two-level nested registry, keyed to ``<dialect_name>``
and ``<argument_name>``. For example, the ``postgresql_where``
argument would be locatable as::
arg = my_object.dialect_options['postgresql']['where']
.. versionadded:: 0.9.2
.. seealso::
:attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form
"""
return util.PopulateDict(
util.portable_instancemethod(self._kw_reg_for_dialect_cls)
)
def _validate_dialect_kwargs(self, kwargs):
# validate remaining kwargs that they all specify DB prefixes
if not kwargs:
return
for k in kwargs:
m = re.match('^(.+?)_(.+)$', k)
if not m:
raise TypeError(
"Additional arguments should be "
"named <dialectname>_<argument>, got '%s'" % k)
dialect_name, arg_name = m.group(1, 2)
try:
construct_arg_dictionary = self.dialect_options[dialect_name]
except exc.NoSuchModuleError:
util.warn(
"Can't validate argument %r; can't "
"locate any SQLAlchemy dialect named %r" %
(k, dialect_name))
self.dialect_options[dialect_name] = d = _DialectArgDict()
d._defaults.update({"*": None})
d._non_defaults[arg_name] = kwargs[k]
else:
if "*" not in construct_arg_dictionary and \
arg_name not in construct_arg_dictionary:
raise exc.ArgumentError(
"Argument %r is not accepted by "
"dialect %r on behalf of %r" % (
k,
dialect_name, self.__class__
))
else:
construct_arg_dictionary[arg_name] = kwargs[k]
class Generative(object):
"""Allow a ClauseElement to generate itself via the
@_generative decorator.
"""
def _generate(self):
s = self.__class__.__new__(self.__class__)
s.__dict__ = self.__dict__.copy()
return s
class Executable(Generative):
"""Mark a ClauseElement as supporting execution.
:class:`.Executable` is a superclass for all "statement" types
of objects, including :func:`select`, :func:`delete`, :func:`update`,
:func:`insert`, :func:`text`.
"""
supports_execution = True
_execution_options = util.immutabledict()
_bind = None
@_generative
def execution_options(self, **kw):
""" Set non-SQL options for the statement which take effect during
execution.
Execution options can be set on a per-statement or
per :class:`.Connection` basis. Additionally, the
:class:`.Engine` and ORM :class:`~.orm.query.Query` objects provide
access to execution options which they in turn configure upon
connections.
The :meth:`execution_options` method is generative. A new
instance of this statement is returned that contains the options::
statement = select([table.c.x, table.c.y])
statement = statement.execution_options(autocommit=True)
Note that only a subset of possible execution options can be applied
to a statement - these include "autocommit" and "stream_results",
but not "isolation_level" or "compiled_cache".
See :meth:`.Connection.execution_options` for a full list of
possible options.
.. seealso::
:meth:`.Connection.execution_options()`
:meth:`.Query.execution_options()`
"""
if 'isolation_level' in kw:
raise exc.ArgumentError(
"'isolation_level' execution option may only be specified "
"on Connection.execution_options(), or "
"per-engine using the isolation_level "
"argument to create_engine()."
)
if 'compiled_cache' in kw:
raise exc.ArgumentError(
"'compiled_cache' execution option may only be specified "
"on Connection.execution_options(), not per statement."
)
self._execution_options = self._execution_options.union(kw)
def execute(self, *multiparams, **params):
"""Compile and execute this :class:`.Executable`."""
e = self.bind
if e is None:
label = getattr(self, 'description', self.__class__.__name__)
msg = ('This %s is not directly bound to a Connection or Engine.'
'Use the .execute() method of a Connection or Engine '
'to execute this construct.' % label)
raise exc.UnboundExecutionError(msg)
return e._execute_clauseelement(self, multiparams, params)
def scalar(self, *multiparams, **params):
"""Compile and execute this :class:`.Executable`, returning the
result's scalar representation.
"""
return self.execute(*multiparams, **params).scalar()
@property
def bind(self):
"""Returns the :class:`.Engine` or :class:`.Connection` to
which this :class:`.Executable` is bound, or None if none found.
This is a traversal which checks locally, then
checks among the "from" clauses of associated objects
until a bound engine or connection is found.
"""
if self._bind is not None:
return self._bind
for f in _from_objects(self):
if f is self:
continue
engine = f.bind
if engine is not None:
return engine
else:
return None
class SchemaEventTarget(object):
"""Base class for elements that are the targets of :class:`.DDLEvents`
events.
This includes :class:`.SchemaItem` as well as :class:`.SchemaType`.
"""
def _set_parent(self, parent):
"""Associate with this SchemaEvent's parent object."""
def _set_parent_with_dispatch(self, parent):
self.dispatch.before_parent_attach(self, parent)
self._set_parent(parent)
self.dispatch.after_parent_attach(self, parent)
class SchemaVisitor(ClauseVisitor):
"""Define the visiting for ``SchemaItem`` objects."""
__traverse_options__ = {'schema_visitor': True}
class ColumnCollection(util.OrderedProperties):
"""An ordered dictionary that stores a list of ColumnElement
instances.
Overrides the ``__eq__()`` method to produce SQL clauses between
sets of correlated columns.
"""
__slots__ = '_all_columns'
def __init__(self, *columns):
super(ColumnCollection, self).__init__()
object.__setattr__(self, '_all_columns', [])
for c in columns:
self.add(c)
def __str__(self):
return repr([str(c) for c in self])
def replace(self, column):
"""add the given column to this collection, removing unaliased
versions of this column as well as existing columns with the
same key.
e.g.::
t = Table('sometable', metadata, Column('col1', Integer))
t.columns.replace(Column('col1', Integer, key='columnone'))
will remove the original 'col1' from the collection, and add
the new column under the name 'columnname'.
Used by schema.Column to override columns during table reflection.
"""
remove_col = None
if column.name in self and column.key != column.name:
other = self[column.name]
if other.name == other.key:
remove_col = other
del self._data[other.key]
if column.key in self._data:
remove_col = self._data[column.key]
self._data[column.key] = column
if remove_col is not None:
self._all_columns[:] = [column if c is remove_col
else c for c in self._all_columns]
else:
self._all_columns.append(column)
def add(self, column):
"""Add a column to this collection.
The key attribute of the column will be used as the hash key
for this dictionary.
"""
if not column.key:
raise exc.ArgumentError(
"Can't add unnamed column to column collection")
self[column.key] = column
def __delitem__(self, key):
raise NotImplementedError()
def __setattr__(self, key, object):
raise NotImplementedError()
def __setitem__(self, key, value):
if key in self:
# this warning is primarily to catch select() statements
# which have conflicting column names in their exported
# columns collection
existing = self[key]
if not existing.shares_lineage(value):
util.warn('Column %r on table %r being replaced by '
'%r, which has the same key. Consider '
'use_labels for select() statements.' %
(key, getattr(existing, 'table', None), value))
# pop out memoized proxy_set as this
# operation may very well be occurring
# in a _make_proxy operation
util.memoized_property.reset(value, "proxy_set")
self._all_columns.append(value)
self._data[key] = value
def clear(self):
raise NotImplementedError()
def remove(self, column):
del self._data[column.key]
self._all_columns[:] = [
c for c in self._all_columns if c is not column]
def update(self, iter):
cols = list(iter)
all_col_set = set(self._all_columns)
self._all_columns.extend(
c for label, c in cols if c not in all_col_set)
self._data.update((label, c) for label, c in cols)
def extend(self, iter):
cols = list(iter)
all_col_set = set(self._all_columns)
self._all_columns.extend(c for c in cols if c not in all_col_set)
self._data.update((c.key, c) for c in cols)
__hash__ = None
@util.dependencies("sqlalchemy.sql.elements")
def __eq__(self, elements, other):
l = []
for c in getattr(other, "_all_columns", other):
for local in self._all_columns:
if c.shares_lineage(local):
l.append(c == local)
return elements.and_(*l)
def __contains__(self, other):
if not isinstance(other, util.string_types):
raise exc.ArgumentError("__contains__ requires a string argument")
return util.OrderedProperties.__contains__(self, other)
def __getstate__(self):
return {'_data': self._data,
'_all_columns': self._all_columns}
def __setstate__(self, state):
object.__setattr__(self, '_data', state['_data'])
object.__setattr__(self, '_all_columns', state['_all_columns'])
def contains_column(self, col):
return col in set(self._all_columns)
def as_immutable(self):
return ImmutableColumnCollection(self._data, self._all_columns)
class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
def __init__(self, data, all_columns):
util.ImmutableProperties.__init__(self, data)
object.__setattr__(self, '_all_columns', all_columns)
extend = remove = util.ImmutableProperties._immutable
class ColumnSet(util.ordered_column_set):
def contains_column(self, col):
return col in self
def extend(self, cols):
for col in cols:
self.add(col)
def __add__(self, other):
return list(self) + list(other)
@util.dependencies("sqlalchemy.sql.elements")
def __eq__(self, elements, other):
l = []
for c in other:
for local in self:
if c.shares_lineage(local):
l.append(c == local)
return elements.and_(*l)
def __hash__(self):
return hash(tuple(x for x in self))
def _bind_or_error(schemaitem, msg=None):
bind = schemaitem.bind
if not bind:
name = schemaitem.__class__.__name__
label = getattr(schemaitem, 'fullname',
getattr(schemaitem, 'name', None))
if label:
item = '%s object %r' % (name, label)
else:
item = '%s object' % name
if msg is None:
msg = "%s is not bound to an Engine or Connection. "\
"Execution can not proceed without a database to execute "\
"against." % item
raise exc.UnboundExecutionError(msg)
return bind

692
sqlalchemy/sql/crud.py Normal file
View File

@ -0,0 +1,692 @@
# sql/crud.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Functions used by compiler.py to determine the parameters rendered
within INSERT and UPDATE statements.
"""
from .. import util
from .. import exc
from . import dml
from . import elements
import operator
REQUIRED = util.symbol('REQUIRED', """
Placeholder for the value within a :class:`.BindParameter`
which is required to be present when the statement is passed
to :meth:`.Connection.execute`.
This symbol is typically used when a :func:`.expression.insert`
or :func:`.expression.update` statement is compiled without parameter
values present.
""")
ISINSERT = util.symbol('ISINSERT')
ISUPDATE = util.symbol('ISUPDATE')
ISDELETE = util.symbol('ISDELETE')
def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
restore_isinsert = compiler.isinsert
restore_isupdate = compiler.isupdate
restore_isdelete = compiler.isdelete
should_restore = (
restore_isinsert or restore_isupdate or restore_isdelete
) or len(compiler.stack) > 1
if local_stmt_type is ISINSERT:
compiler.isupdate = False
compiler.isinsert = True
elif local_stmt_type is ISUPDATE:
compiler.isupdate = True
compiler.isinsert = False
elif local_stmt_type is ISDELETE:
if not should_restore:
compiler.isdelete = True
else:
assert False, "ISINSERT, ISUPDATE, or ISDELETE expected"
try:
if local_stmt_type in (ISINSERT, ISUPDATE):
return _get_crud_params(compiler, stmt, **kw)
finally:
if should_restore:
compiler.isinsert = restore_isinsert
compiler.isupdate = restore_isupdate
compiler.isdelete = restore_isdelete
def _get_crud_params(compiler, stmt, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
Also generates the Compiled object's postfetch, prefetch, and
returning column collections, used for default handling and ultimately
populating the ResultProxy's prefetch_cols() and postfetch_cols()
collections.
"""
compiler.postfetch = []
compiler.insert_prefetch = []
compiler.update_prefetch = []
compiler.returning = []
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and stmt.parameters is None:
return [
(c, _create_bind_param(
compiler, c, None, required=True))
for c in stmt.table.columns
]
if stmt._has_multi_parameters:
stmt_parameters = stmt.parameters[0]
else:
stmt_parameters = stmt.parameters
# getters - these are normally just column.key,
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
_column_as_key, _getattr_col_key, _col_bind_name = \
_key_getters_for_crud_column(compiler, stmt)
# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
parameters = {}
else:
parameters = dict((_column_as_key(key), REQUIRED)
for key in compiler.column_keys
if not stmt_parameters or
key not in stmt_parameters)
# create a list of column assignment clauses as tuples
values = []
if stmt_parameters is not None:
_get_stmt_parameters_params(
compiler,
parameters, stmt_parameters, _column_as_key, values, kw)
check_columns = {}
# special logic that only occurs for multi-table UPDATE
# statements
if compiler.isupdate and stmt._extra_froms and stmt_parameters:
_get_multitable_params(
compiler, stmt, stmt_parameters, check_columns,
_col_bind_name, _getattr_col_key, values, kw)
if compiler.isinsert and stmt.select_names:
_scan_insert_from_select_cols(
compiler, stmt, parameters,
_getattr_col_key, _column_as_key,
_col_bind_name, check_columns, values, kw)
else:
_scan_cols(
compiler, stmt, parameters,
_getattr_col_key, _column_as_key,
_col_bind_name, check_columns, values, kw)
if parameters and stmt_parameters:
check = set(parameters).intersection(
_column_as_key(k) for k in stmt_parameters
).difference(check_columns)
if check:
raise exc.CompileError(
"Unconsumed column names: %s" %
(", ".join("%s" % c for c in check))
)
if stmt._has_multi_parameters:
values = _extend_values_for_multiparams(compiler, stmt, values, kw)
return values
def _create_bind_param(
compiler, col, value, process=True,
required=False, name=None, **kw):
if name is None:
name = col.key
bindparam = elements.BindParameter(
name, value, type_=col.type, required=required)
bindparam._is_crud = True
if process:
bindparam = bindparam._compiler_dispatch(compiler, **kw)
return bindparam
def _key_getters_for_crud_column(compiler, stmt):
if compiler.isupdate and stmt._extra_froms:
# when extra tables are present, refer to the columns
# in those extra tables as table-qualified, including in
# dictionaries and when rendering bind param names.
# the "main" table of the statement remains unqualified,
# allowing the most compatibility with a non-multi-table
# statement.
_et = set(stmt._extra_froms)
def _column_as_key(key):
str_key = elements._column_as_key(key)
if hasattr(key, 'table') and key.table in _et:
return (key.table.name, str_key)
else:
return str_key
def _getattr_col_key(col):
if col.table in _et:
return (col.table.name, col.key)
else:
return col.key
def _col_bind_name(col):
if col.table in _et:
return "%s_%s" % (col.table.name, col.key)
else:
return col.key
else:
_column_as_key = elements._column_as_key
_getattr_col_key = _col_bind_name = operator.attrgetter("key")
return _column_as_key, _getattr_col_key, _col_bind_name
def _scan_insert_from_select_cols(
compiler, stmt, parameters, _getattr_col_key,
_column_as_key, _col_bind_name, check_columns, values, kw):
need_pks, implicit_returning, \
implicit_return_defaults, postfetch_lastrowid = \
_get_returning_modifiers(compiler, stmt)
cols = [stmt.table.c[_column_as_key(name)]
for name in stmt.select_names]
compiler._insert_from_select = stmt.select
add_select_cols = []
if stmt.include_insert_from_select_defaults:
col_set = set(cols)
for col in stmt.table.columns:
if col not in col_set and col.default:
cols.append(col)
for c in cols:
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
parameters.pop(col_key)
values.append((c, None))
else:
_append_param_insert_select_hasdefault(
compiler, stmt, c, add_select_cols, kw)
if add_select_cols:
values.extend(add_select_cols)
compiler._insert_from_select = compiler._insert_from_select._generate()
compiler._insert_from_select._raw_columns = \
tuple(compiler._insert_from_select._raw_columns) + tuple(
expr for col, expr in add_select_cols)
def _scan_cols(
compiler, stmt, parameters, _getattr_col_key,
_column_as_key, _col_bind_name, check_columns, values, kw):
need_pks, implicit_returning, \
implicit_return_defaults, postfetch_lastrowid = \
_get_returning_modifiers(compiler, stmt)
if stmt._parameter_ordering:
parameter_ordering = [
_column_as_key(key) for key in stmt._parameter_ordering
]
ordered_keys = set(parameter_ordering)
cols = [
stmt.table.c[key] for key in parameter_ordering
] + [
c for c in stmt.table.c if c.key not in ordered_keys
]
else:
cols = stmt.table.columns
for c in cols:
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
_append_param_parameter(
compiler, stmt, c, col_key, parameters, _col_bind_name,
implicit_returning, implicit_return_defaults, values, kw)
elif compiler.isinsert:
if c.primary_key and \
need_pks and \
(
implicit_returning or
not postfetch_lastrowid or
c is not stmt.table._autoincrement_column
):
if implicit_returning:
_append_param_insert_pk_returning(
compiler, stmt, c, values, kw)
else:
_append_param_insert_pk(compiler, stmt, c, values, kw)
elif c.default is not None:
_append_param_insert_hasdefault(
compiler, stmt, c, implicit_return_defaults,
values, kw)
elif c.server_default is not None:
if implicit_return_defaults and \
c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
elif implicit_return_defaults and \
c in implicit_return_defaults:
compiler.returning.append(c)
elif c.primary_key and \
c is not stmt.table._autoincrement_column and \
not c.nullable:
_warn_pk_with_no_anticipated_value(c)
elif compiler.isupdate:
_append_param_update(
compiler, stmt, c, implicit_return_defaults, values, kw)
def _append_param_parameter(
compiler, stmt, c, col_key, parameters, _col_bind_name,
implicit_returning, implicit_return_defaults, values, kw):
value = parameters.pop(col_key)
if elements._is_literal(value):
value = _create_bind_param(
compiler, c, value, required=value is REQUIRED,
name=_col_bind_name(c)
if not stmt._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
**kw
)
else:
if isinstance(value, elements.BindParameter) and \
value.type._isnull:
value = value._clone()
value.type = c.type
if c.primary_key and implicit_returning:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
elif implicit_return_defaults and \
c in implicit_return_defaults:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
else:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
values.append((c, value))
def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
"""Create a primary key expression in the INSERT statement and
possibly a RETURNING clause for it.
If the column has a Python-side default, we will create a bound
parameter for it and "pre-execute" the Python function. If
the column has a SQL expression default, or is a sequence,
we will add it directly into the INSERT statement and add a
RETURNING element to get the new value. If the column has a
server side default or is marked as the "autoincrement" column,
we will add a RETRUNING element to get at the value.
If all the above tests fail, that indicates a primary key column with no
noted default generation capabilities that has no parameter passed;
raise an exception.
"""
if c.default is not None:
if c.default.is_sequence:
if compiler.dialect.supports_sequences and \
(not c.default.optional or
not compiler.dialect.sequences_optional):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
compiler.returning.append(c)
elif c.default.is_clause_element:
values.append(
(c, compiler.process(
c.default.arg.self_group(), **kw))
)
compiler.returning.append(c)
else:
values.append(
(c, _create_insert_prefetch_bind_param(compiler, c))
)
elif c is stmt.table._autoincrement_column or c.server_default is not None:
compiler.returning.append(c)
elif not c.nullable:
# no .default, no .server_default, not autoincrement, we have
# no indication this primary key column will have any value
_warn_pk_with_no_anticipated_value(c)
def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None):
param = _create_bind_param(compiler, c, None, process=process, name=name)
compiler.insert_prefetch.append(c)
return param
def _create_update_prefetch_bind_param(compiler, c, process=True, name=None):
param = _create_bind_param(compiler, c, None, process=process, name=name)
compiler.update_prefetch.append(c)
return param
class _multiparam_column(elements.ColumnElement):
def __init__(self, original, index):
self.key = "%s_m%d" % (original.key, index + 1)
self.original = original
self.default = original.default
self.type = original.type
def __eq__(self, other):
return isinstance(other, _multiparam_column) and \
other.key == self.key and \
other.original == self.original
def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
if not c.default:
raise exc.CompileError(
"INSERT value for column %s is explicitly rendered as a bound"
"parameter in the VALUES clause; "
"a Python-side value or SQL expression is required" % c)
elif c.default.is_clause_element:
return compiler.process(c.default.arg.self_group(), **kw)
else:
col = _multiparam_column(c, index)
if isinstance(stmt, dml.Insert):
return _create_insert_prefetch_bind_param(compiler, col)
else:
return _create_update_prefetch_bind_param(compiler, col)
def _append_param_insert_pk(compiler, stmt, c, values, kw):
"""Create a bound parameter in the INSERT statement to receive a
'prefetched' default value.
The 'prefetched' value indicates that we are to invoke a Python-side
default function or expliclt SQL expression before the INSERT statement
proceeds, so that we have a primary key value available.
if the column has no noted default generation capabilities, it has
no value passed in either; raise an exception.
"""
if (
(
# column has a Python-side default
c.default is not None and
(
# and it won't be a Sequence
not c.default.is_sequence or
compiler.dialect.supports_sequences
)
)
or
(
# column is the "autoincrement column"
c is stmt.table._autoincrement_column and
(
# and it's either a "sequence" or a
# pre-executable "autoincrement" sequence
compiler.dialect.supports_sequences or
compiler.dialect.preexecute_autoincrement_sequences
)
)
):
values.append(
(c, _create_insert_prefetch_bind_param(compiler, c))
)
elif c.default is None and c.server_default is None and not c.nullable:
# no .default, no .server_default, not autoincrement, we have
# no indication this primary key column will have any value
_warn_pk_with_no_anticipated_value(c)
def _append_param_insert_hasdefault(
compiler, stmt, c, implicit_return_defaults, values, kw):
if c.default.is_sequence:
if compiler.dialect.supports_sequences and \
(not c.default.optional or
not compiler.dialect.sequences_optional):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
if implicit_return_defaults and \
c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
elif c.default.is_clause_element:
proc = compiler.process(c.default.arg.self_group(), **kw)
values.append((c, proc))
if implicit_return_defaults and \
c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
# don't add primary key column to postfetch
compiler.postfetch.append(c)
else:
values.append(
(c, _create_insert_prefetch_bind_param(compiler, c))
)
def _append_param_insert_select_hasdefault(
compiler, stmt, c, values, kw):
if c.default.is_sequence:
if compiler.dialect.supports_sequences and \
(not c.default.optional or
not compiler.dialect.sequences_optional):
proc = c.default
values.append((c, proc.next_value()))
elif c.default.is_clause_element:
proc = c.default.arg.self_group()
values.append((c, proc))
else:
values.append(
(c, _create_insert_prefetch_bind_param(compiler, c, process=False))
)
def _append_param_update(
compiler, stmt, c, implicit_return_defaults, values, kw):
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
(c, compiler.process(
c.onupdate.arg.self_group(), **kw))
)
if implicit_return_defaults and \
c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
else:
values.append(
(c, _create_update_prefetch_bind_param(compiler, c))
)
elif c.server_onupdate is not None:
if implicit_return_defaults and \
c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
elif implicit_return_defaults and \
stmt._return_defaults is not True and \
c in implicit_return_defaults:
compiler.returning.append(c)
def _get_multitable_params(
compiler, stmt, stmt_parameters, check_columns,
_col_bind_name, _getattr_col_key, values, kw):
normalized_params = dict(
(elements._clause_element_as_expr(c), param)
for c, param in stmt_parameters.items()
)
affected_tables = set()
for t in stmt._extra_froms:
for c in t.c:
if c in normalized_params:
affected_tables.add(t)
check_columns[_getattr_col_key(c)] = c
value = normalized_params[c]
if elements._is_literal(value):
value = _create_bind_param(
compiler, c, value, required=value is REQUIRED,
name=_col_bind_name(c))
else:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
values.append((c, value))
# determine tables which are actually to be updated - process onupdate
# and server_onupdate for these
for t in affected_tables:
for c in t.c:
if c in normalized_params:
continue
elif (c.onupdate is not None and not
c.onupdate.is_sequence):
if c.onupdate.is_clause_element:
values.append(
(c, compiler.process(
c.onupdate.arg.self_group(),
**kw)
)
)
compiler.postfetch.append(c)
else:
values.append(
(c, _create_update_prefetch_bind_param(
compiler, c, name=_col_bind_name(c)))
)
elif c.server_onupdate is not None:
compiler.postfetch.append(c)
def _extend_values_for_multiparams(compiler, stmt, values, kw):
values_0 = values
values = [values]
values.extend(
[
(
c,
(_create_bind_param(
compiler, c, row[c.key],
name="%s_m%d" % (c.key, i + 1), **kw
) if elements._is_literal(row[c.key])
else compiler.process(
row[c.key].self_group(), **kw))
if c.key in row else
_process_multiparam_default_bind(compiler, stmt, c, i, kw)
)
for (c, param) in values_0
]
for i, row in enumerate(stmt.parameters[1:])
)
return values
def _get_stmt_parameters_params(
compiler, parameters, stmt_parameters, _column_as_key, values, kw):
for k, v in stmt_parameters.items():
colkey = _column_as_key(k)
if colkey is not None:
parameters.setdefault(colkey, v)
else:
# a non-Column expression on the left side;
# add it to values() in an "as-is" state,
# coercing right side to bound param
if elements._is_literal(v):
v = compiler.process(
elements.BindParameter(None, v, type_=k.type),
**kw)
else:
v = compiler.process(v.self_group(), **kw)
values.append((k, v))
def _get_returning_modifiers(compiler, stmt):
need_pks = compiler.isinsert and \
not compiler.inline and \
not stmt._returning and \
not stmt._has_multi_parameters
implicit_returning = need_pks and \
compiler.dialect.implicit_returning and \
stmt.table.implicit_returning
if compiler.isinsert:
implicit_return_defaults = (implicit_returning and
stmt._return_defaults)
elif compiler.isupdate:
implicit_return_defaults = (compiler.dialect.implicit_returning and
stmt.table.implicit_returning and
stmt._return_defaults)
else:
# this line is unused, currently we are always
# isinsert or isupdate
implicit_return_defaults = False # pragma: no cover
if implicit_return_defaults:
if stmt._return_defaults is True:
implicit_return_defaults = set(stmt.table.c)
else:
implicit_return_defaults = set(stmt._return_defaults)
postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
return need_pks, implicit_returning, \
implicit_return_defaults, postfetch_lastrowid
def _warn_pk_with_no_anticipated_value(c):
msg = (
"Column '%s.%s' is marked as a member of the "
"primary key for table '%s', "
"but has no Python-side or server-side default generator indicated, "
"nor does it indicate 'autoincrement=True' or 'nullable=True', "
"and no explicit value is passed. "
"Primary key columns typically may not store NULL."
%
(c.table.fullname, c.name, c.table.fullname))
if len(c.table.primary_key) > 1:
msg += (
" Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be "
"indicated explicitly for composite (e.g. multicolumn) primary "
"keys if AUTO_INCREMENT/SERIAL/IDENTITY "
"behavior is expected for one of the columns in the primary key. "
"CREATE TABLE statements are impacted by this change as well on "
"most backends.")
util.warn(msg)

1100
sqlalchemy/sql/ddl.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,308 @@
# sql/default_comparator.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Default implementation of SQL comparison operations.
"""
from .. import exc, util
from . import type_api
from . import operators
from .elements import BindParameter, True_, False_, BinaryExpression, \
Null, _const_expr, _clause_element_as_expr, \
ClauseList, ColumnElement, TextClause, UnaryExpression, \
collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \
Slice, Visitable, _literal_as_binds
from .selectable import SelectBase, Alias, Selectable, ScalarSelect
def _boolean_compare(expr, op, obj, negate=None, reverse=False,
_python_is_types=(util.NoneType, bool),
result_type = None,
**kwargs):
if result_type is None:
result_type = type_api.BOOLEANTYPE
if isinstance(obj, _python_is_types + (Null, True_, False_)):
# allow x ==/!= True/False to be treated as a literal.
# this comes out to "== / != true/false" or "1/0" if those
# constants aren't supported and works on all platforms
if op in (operators.eq, operators.ne) and \
isinstance(obj, (bool, True_, False_)):
return BinaryExpression(expr,
_literal_as_text(obj),
op,
type_=result_type,
negate=negate, modifiers=kwargs)
elif op in (operators.is_distinct_from, operators.isnot_distinct_from):
return BinaryExpression(expr,
_literal_as_text(obj),
op,
type_=result_type,
negate=negate, modifiers=kwargs)
else:
# all other None/True/False uses IS, IS NOT
if op in (operators.eq, operators.is_):
return BinaryExpression(expr, _const_expr(obj),
operators.is_,
negate=operators.isnot)
elif op in (operators.ne, operators.isnot):
return BinaryExpression(expr, _const_expr(obj),
operators.isnot,
negate=operators.is_)
else:
raise exc.ArgumentError(
"Only '=', '!=', 'is_()', 'isnot()', "
"'is_distinct_from()', 'isnot_distinct_from()' "
"operators can be used with None/True/False")
else:
obj = _check_literal(expr, op, obj)
if reverse:
return BinaryExpression(obj,
expr,
op,
type_=result_type,
negate=negate, modifiers=kwargs)
else:
return BinaryExpression(expr,
obj,
op,
type_=result_type,
negate=negate, modifiers=kwargs)
def _binary_operate(expr, op, obj, reverse=False, result_type=None,
**kw):
obj = _check_literal(expr, op, obj)
if reverse:
left, right = obj, expr
else:
left, right = expr, obj
if result_type is None:
op, result_type = left.comparator._adapt_expression(
op, right.comparator)
return BinaryExpression(
left, right, op, type_=result_type, modifiers=kw)
def _conjunction_operate(expr, op, other, **kw):
if op is operators.and_:
return and_(expr, other)
elif op is operators.or_:
return or_(expr, other)
else:
raise NotImplementedError()
def _scalar(expr, op, fn, **kw):
return fn(expr)
def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
seq_or_selectable = _clause_element_as_expr(seq_or_selectable)
if isinstance(seq_or_selectable, ScalarSelect):
return _boolean_compare(expr, op, seq_or_selectable,
negate=negate_op)
elif isinstance(seq_or_selectable, SelectBase):
# TODO: if we ever want to support (x, y, z) IN (select x,
# y, z from table), we would need a multi-column version of
# as_scalar() to produce a multi- column selectable that
# does not export itself as a FROM clause
return _boolean_compare(
expr, op, seq_or_selectable.as_scalar(),
negate=negate_op, **kw)
elif isinstance(seq_or_selectable, (Selectable, TextClause)):
return _boolean_compare(expr, op, seq_or_selectable,
negate=negate_op, **kw)
elif isinstance(seq_or_selectable, ClauseElement):
raise exc.InvalidRequestError(
'in_() accepts'
' either a list of expressions '
'or a selectable: %r' % seq_or_selectable)
# Handle non selectable arguments as sequences
args = []
for o in seq_or_selectable:
if not _is_literal(o):
if not isinstance(o, operators.ColumnOperators):
raise exc.InvalidRequestError(
'in_() accepts'
' either a list of expressions '
'or a selectable: %r' % o)
elif o is None:
o = Null()
else:
o = expr._bind_param(op, o)
args.append(o)
if len(args) == 0:
# Special case handling for empty IN's, behave like
# comparison against zero row selectable. We use != to
# build the contradiction as it handles NULL values
# appropriately, i.e. "not (x IN ())" should not return NULL
# values for x.
util.warn('The IN-predicate on "%s" was invoked with an '
'empty sequence. This results in a '
'contradiction, which nonetheless can be '
'expensive to evaluate. Consider alternative '
'strategies for improved performance.' % expr)
if op is operators.in_op:
return expr != expr
else:
return expr == expr
return _boolean_compare(expr, op,
ClauseList(*args).self_group(against=op),
negate=negate_op)
def _getitem_impl(expr, op, other, **kw):
if isinstance(expr.type, type_api.INDEXABLE):
other = _check_literal(expr, op, other)
return _binary_operate(expr, op, other, **kw)
else:
_unsupported_impl(expr, op, other, **kw)
def _unsupported_impl(expr, op, *arg, **kw):
raise NotImplementedError("Operator '%s' is not supported on "
"this expression" % op.__name__)
def _inv_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.__inv__`."""
if hasattr(expr, 'negation_clause'):
return expr.negation_clause
else:
return expr._negate()
def _neg_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.__neg__`."""
return UnaryExpression(expr, operator=operators.neg, type_=expr.type)
def _match_impl(expr, op, other, **kw):
"""See :meth:`.ColumnOperators.match`."""
return _boolean_compare(
expr, operators.match_op,
_check_literal(
expr, operators.match_op, other),
result_type=type_api.MATCHTYPE,
negate=operators.notmatch_op
if op is operators.match_op else operators.match_op,
**kw
)
def _distinct_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.distinct`."""
return UnaryExpression(expr, operator=operators.distinct_op,
type_=expr.type)
def _between_impl(expr, op, cleft, cright, **kw):
"""See :meth:`.ColumnOperators.between`."""
return BinaryExpression(
expr,
ClauseList(
_check_literal(expr, operators.and_, cleft),
_check_literal(expr, operators.and_, cright),
operator=operators.and_,
group=False, group_contents=False),
op,
negate=operators.notbetween_op
if op is operators.between_op
else operators.between_op,
modifiers=kw)
def _collate_impl(expr, op, other, **kw):
return collate(expr, other)
# a mapping of operators with the method they use, along with
# their negated operator for comparison operators
operator_lookup = {
"and_": (_conjunction_operate,),
"or_": (_conjunction_operate,),
"inv": (_inv_impl,),
"add": (_binary_operate,),
"mul": (_binary_operate,),
"sub": (_binary_operate,),
"div": (_binary_operate,),
"mod": (_binary_operate,),
"truediv": (_binary_operate,),
"custom_op": (_binary_operate,),
"json_path_getitem_op": (_binary_operate, ),
"json_getitem_op": (_binary_operate, ),
"concat_op": (_binary_operate,),
"lt": (_boolean_compare, operators.ge),
"le": (_boolean_compare, operators.gt),
"ne": (_boolean_compare, operators.eq),
"gt": (_boolean_compare, operators.le),
"ge": (_boolean_compare, operators.lt),
"eq": (_boolean_compare, operators.ne),
"is_distinct_from": (_boolean_compare, operators.isnot_distinct_from),
"isnot_distinct_from": (_boolean_compare, operators.is_distinct_from),
"like_op": (_boolean_compare, operators.notlike_op),
"ilike_op": (_boolean_compare, operators.notilike_op),
"notlike_op": (_boolean_compare, operators.like_op),
"notilike_op": (_boolean_compare, operators.ilike_op),
"contains_op": (_boolean_compare, operators.notcontains_op),
"startswith_op": (_boolean_compare, operators.notstartswith_op),
"endswith_op": (_boolean_compare, operators.notendswith_op),
"desc_op": (_scalar, UnaryExpression._create_desc),
"asc_op": (_scalar, UnaryExpression._create_asc),
"nullsfirst_op": (_scalar, UnaryExpression._create_nullsfirst),
"nullslast_op": (_scalar, UnaryExpression._create_nullslast),
"in_op": (_in_impl, operators.notin_op),
"notin_op": (_in_impl, operators.in_op),
"is_": (_boolean_compare, operators.is_),
"isnot": (_boolean_compare, operators.isnot),
"collate": (_collate_impl,),
"match_op": (_match_impl,),
"notmatch_op": (_match_impl,),
"distinct_op": (_distinct_impl,),
"between_op": (_between_impl, ),
"notbetween_op": (_between_impl, ),
"neg": (_neg_impl,),
"getitem": (_getitem_impl,),
"lshift": (_unsupported_impl,),
"rshift": (_unsupported_impl,),
"contains": (_unsupported_impl,),
}
def _check_literal(expr, operator, other, bindparam_type=None):
if isinstance(other, (ColumnElement, TextClause)):
if isinstance(other, BindParameter) and \
other.type._isnull:
other = other._clone()
other.type = expr.type
return other
elif hasattr(other, '__clause_element__'):
other = other.__clause_element__()
elif isinstance(other, type_api.TypeEngine.Comparator):
other = other.expr
if isinstance(other, (SelectBase, Alias)):
return other.as_scalar()
elif not isinstance(other, Visitable):
return expr._bind_param(operator, other, type_=bindparam_type)
else:
return other

851
sqlalchemy/sql/dml.py Normal file
View File

@ -0,0 +1,851 @@
# sql/dml.py
# Copyright (C) 2009-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`.
"""
from .base import Executable, _generative, _from_objects, DialectKWArgs, \
ColumnCollection
from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \
_column_as_key
from .selectable import _interpret_as_from, _interpret_as_select, \
HasPrefixes, HasCTE
from .. import util
from .. import exc
class UpdateBase(
HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement):
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.
"""
__visit_name__ = 'update_base'
_execution_options = \
Executable._execution_options.union({'autocommit': True})
_hints = util.immutabledict()
_parameter_ordering = None
_prefixes = ()
named_with_column = False
def _process_colparams(self, parameters):
def process_single(p):
if isinstance(p, (list, tuple)):
return dict(
(c.key, pval)
for c, pval in zip(self.table.c, p)
)
else:
return p
if self._preserve_parameter_order and parameters is not None:
if not isinstance(parameters, list) or \
(parameters and not isinstance(parameters[0], tuple)):
raise ValueError(
"When preserve_parameter_order is True, "
"values() only accepts a list of 2-tuples")
self._parameter_ordering = [key for key, value in parameters]
return dict(parameters), False
if (isinstance(parameters, (list, tuple)) and parameters and
isinstance(parameters[0], (list, tuple, dict))):
if not self._supports_multi_parameters:
raise exc.InvalidRequestError(
"This construct does not support "
"multiple parameter sets.")
return [process_single(p) for p in parameters], True
else:
return process_single(parameters), False
def params(self, *arg, **kw):
"""Set the parameters for the statement.
This method raises ``NotImplementedError`` on the base class,
and is overridden by :class:`.ValuesBase` to provide the
SET/VALUES clause of UPDATE and INSERT.
"""
raise NotImplementedError(
"params() is not supported for INSERT/UPDATE/DELETE statements."
" To set the values for an INSERT or UPDATE statement, use"
" stmt.values(**parameters).")
def bind(self):
"""Return a 'bind' linked to this :class:`.UpdateBase`
or a :class:`.Table` associated with it.
"""
return self._bind or self.table.bind
def _set_bind(self, bind):
self._bind = bind
bind = property(bind, _set_bind)
@_generative
def returning(self, *cols):
r"""Add a :term:`RETURNING` or equivalent clause to this statement.
e.g.::
stmt = table.update().\
where(table.c.data == 'value').\
values(status='X').\
returning(table.c.server_flag,
table.c.updated_timestamp)
for server_flag, updated_timestamp in connection.execute(stmt):
print(server_flag, updated_timestamp)
The given collection of column expressions should be derived from
the table that is
the target of the INSERT, UPDATE, or DELETE. While :class:`.Column`
objects are typical, the elements can also be expressions::
stmt = table.insert().returning(
(table.c.first_name + " " + table.c.last_name).
label('fullname'))
Upon compilation, a RETURNING clause, or database equivalent,
will be rendered within the statement. For INSERT and UPDATE,
the values are the newly inserted/updated values. For DELETE,
the values are those of the rows which were deleted.
Upon execution, the values of the columns to be returned are made
available via the result set and can be iterated using
:meth:`.ResultProxy.fetchone` and similar. For DBAPIs which do not
natively support returning values (i.e. cx_oracle), SQLAlchemy will
approximate this behavior at the result level so that a reasonable
amount of behavioral neutrality is provided.
Note that not all databases/DBAPIs
support RETURNING. For those backends with no support,
an exception is raised upon compilation and/or execution.
For those who do support it, the functionality across backends
varies greatly, including restrictions on executemany()
and other statements which return multiple rows. Please
read the documentation notes for the database in use in
order to determine the availability of RETURNING.
.. seealso::
:meth:`.ValuesBase.return_defaults` - an alternative method tailored
towards efficient fetching of server-side defaults and triggers
for single-row INSERTs or UPDATEs.
"""
self._returning = cols
@_generative
def with_hint(self, text, selectable=None, dialect_name="*"):
"""Add a table hint for a single table to this
INSERT/UPDATE/DELETE statement.
.. note::
:meth:`.UpdateBase.with_hint` currently applies only to
Microsoft SQL Server. For MySQL INSERT/UPDATE/DELETE hints, use
:meth:`.UpdateBase.prefix_with`.
The text of the hint is rendered in the appropriate
location for the database backend in use, relative
to the :class:`.Table` that is the subject of this
statement, or optionally to that of the given
:class:`.Table` passed as the ``selectable`` argument.
The ``dialect_name`` option will limit the rendering of a particular
hint to a particular backend. Such as, to add a hint
that only takes effect for SQL Server::
mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql")
.. versionadded:: 0.7.6
:param text: Text of the hint.
:param selectable: optional :class:`.Table` that specifies
an element of the FROM clause within an UPDATE or DELETE
to be the subject of the hint - applies only to certain backends.
:param dialect_name: defaults to ``*``, if specified as the name
of a particular dialect, will apply these hints only when
that dialect is in use.
"""
if selectable is None:
selectable = self.table
self._hints = self._hints.union(
{(selectable, dialect_name): text})
class ValuesBase(UpdateBase):
"""Supplies support for :meth:`.ValuesBase.values` to
INSERT and UPDATE constructs."""
__visit_name__ = 'values_base'
_supports_multi_parameters = False
_has_multi_parameters = False
_preserve_parameter_order = False
select = None
_post_values_clause = None
def __init__(self, table, values, prefixes):
self.table = _interpret_as_from(table)
self.parameters, self._has_multi_parameters = \
self._process_colparams(values)
if prefixes:
self._setup_prefixes(prefixes)
@_generative
def values(self, *args, **kwargs):
r"""specify a fixed VALUES clause for an INSERT statement, or the SET
clause for an UPDATE.
Note that the :class:`.Insert` and :class:`.Update` constructs support
per-execution time formatting of the VALUES and/or SET clauses,
based on the arguments passed to :meth:`.Connection.execute`.
However, the :meth:`.ValuesBase.values` method can be used to "fix" a
particular set of parameters into the statement.
Multiple calls to :meth:`.ValuesBase.values` will produce a new
construct, each one with the parameter list modified to include
the new parameters sent. In the typical case of a single
dictionary of parameters, the newly passed keys will replace
the same keys in the previous construct. In the case of a list-based
"multiple values" construct, each new list of values is extended
onto the existing list of values.
:param \**kwargs: key value pairs representing the string key
of a :class:`.Column` mapped to the value to be rendered into the
VALUES or SET clause::
users.insert().values(name="some name")
users.update().where(users.c.id==5).values(name="some name")
:param \*args: As an alternative to passing key/value parameters,
a dictionary, tuple, or list of dictionaries or tuples can be passed
as a single positional argument in order to form the VALUES or
SET clause of the statement. The forms that are accepted vary
based on whether this is an :class:`.Insert` or an :class:`.Update`
construct.
For either an :class:`.Insert` or :class:`.Update` construct, a
single dictionary can be passed, which works the same as that of
the kwargs form::
users.insert().values({"name": "some name"})
users.update().values({"name": "some new name"})
Also for either form but more typically for the :class:`.Insert`
construct, a tuple that contains an entry for every column in the
table is also accepted::
users.insert().values((5, "some name"))
The :class:`.Insert` construct also supports being passed a list
of dictionaries or full-table-tuples, which on the server will
render the less common SQL syntax of "multiple values" - this
syntax is supported on backends such as SQLite, PostgreSQL, MySQL,
but not necessarily others::
users.insert().values([
{"name": "some name"},
{"name": "some other name"},
{"name": "yet another name"},
])
The above form would render a multiple VALUES statement similar to::
INSERT INTO users (name) VALUES
(:name_1),
(:name_2),
(:name_3)
It is essential to note that **passing multiple values is
NOT the same as using traditional executemany() form**. The above
syntax is a **special** syntax not typically used. To emit an
INSERT statement against multiple rows, the normal method is
to pass a multiple values list to the :meth:`.Connection.execute`
method, which is supported by all database backends and is generally
more efficient for a very large number of parameters.
.. seealso::
:ref:`execute_multiple` - an introduction to
the traditional Core method of multiple parameter set
invocation for INSERTs and other statements.
.. versionchanged:: 1.0.0 an INSERT that uses a multiple-VALUES
clause, even a list of length one,
implies that the :paramref:`.Insert.inline` flag is set to
True, indicating that the statement will not attempt to fetch
the "last inserted primary key" or other defaults. The
statement deals with an arbitrary number of rows, so the
:attr:`.ResultProxy.inserted_primary_key` accessor does not
apply.
.. versionchanged:: 1.0.0 A multiple-VALUES INSERT now supports
columns with Python side default values and callables in the
same way as that of an "executemany" style of invocation; the
callable is invoked for each row. See :ref:`bug_3288`
for other details.
The :class:`.Update` construct supports a special form which is a
list of 2-tuples, which when provided must be passed in conjunction
with the
:paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`
parameter.
This form causes the UPDATE statement to render the SET clauses
using the order of parameters given to :meth:`.Update.values`, rather
than the ordering of columns given in the :class:`.Table`.
.. versionadded:: 1.0.10 - added support for parameter-ordered
UPDATE statements via the
:paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`
flag.
.. seealso::
:ref:`updates_order_parameters` - full example of the
:paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`
flag
.. seealso::
:ref:`inserts_and_updates` - SQL Expression
Language Tutorial
:func:`~.expression.insert` - produce an ``INSERT`` statement
:func:`~.expression.update` - produce an ``UPDATE`` statement
"""
if self.select is not None:
raise exc.InvalidRequestError(
"This construct already inserts from a SELECT")
if self._has_multi_parameters and kwargs:
raise exc.InvalidRequestError(
"This construct already has multiple parameter sets.")
if args:
if len(args) > 1:
raise exc.ArgumentError(
"Only a single dictionary/tuple or list of "
"dictionaries/tuples is accepted positionally.")
v = args[0]
else:
v = {}
if self.parameters is None:
self.parameters, self._has_multi_parameters = \
self._process_colparams(v)
else:
if self._has_multi_parameters:
self.parameters = list(self.parameters)
p, self._has_multi_parameters = self._process_colparams(v)
if not self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
"formats in one statement")
self.parameters.extend(p)
else:
self.parameters = self.parameters.copy()
p, self._has_multi_parameters = self._process_colparams(v)
if self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
"formats in one statement")
self.parameters.update(p)
if kwargs:
if self._has_multi_parameters:
raise exc.ArgumentError(
"Can't pass kwargs and multiple parameter sets "
"simultaneously")
else:
self.parameters.update(kwargs)
@_generative
def return_defaults(self, *cols):
"""Make use of a :term:`RETURNING` clause for the purpose
of fetching server-side expressions and defaults.
E.g.::
stmt = table.insert().values(data='newdata').return_defaults()
result = connection.execute(stmt)
server_created_at = result.returned_defaults['created_at']
When used against a backend that supports RETURNING, all column
values generated by SQL expression or server-side-default will be
added to any existing RETURNING clause, provided that
:meth:`.UpdateBase.returning` is not used simultaneously. The column
values will then be available on the result using the
:attr:`.ResultProxy.returned_defaults` accessor as a dictionary,
referring to values keyed to the :class:`.Column` object as well as
its ``.key``.
This method differs from :meth:`.UpdateBase.returning` in these ways:
1. :meth:`.ValuesBase.return_defaults` is only intended for use with
an INSERT or an UPDATE statement that matches exactly one row.
While the RETURNING construct in the general sense supports
multiple rows for a multi-row UPDATE or DELETE statement, or for
special cases of INSERT that return multiple rows (e.g. INSERT from
SELECT, multi-valued VALUES clause),
:meth:`.ValuesBase.return_defaults` is intended only for an
"ORM-style" single-row INSERT/UPDATE statement. The row returned
by the statement is also consumed implicitly when
:meth:`.ValuesBase.return_defaults` is used. By contrast,
:meth:`.UpdateBase.returning` leaves the RETURNING result-set
intact with a collection of any number of rows.
2. It is compatible with the existing logic to fetch auto-generated
primary key values, also known as "implicit returning". Backends
that support RETURNING will automatically make use of RETURNING in
order to fetch the value of newly generated primary keys; while the
:meth:`.UpdateBase.returning` method circumvents this behavior,
:meth:`.ValuesBase.return_defaults` leaves it intact.
3. It can be called against any backend. Backends that don't support
RETURNING will skip the usage of the feature, rather than raising
an exception. The return value of
:attr:`.ResultProxy.returned_defaults` will be ``None``
:meth:`.ValuesBase.return_defaults` is used by the ORM to provide
an efficient implementation for the ``eager_defaults`` feature of
:func:`.mapper`.
:param cols: optional list of column key names or :class:`.Column`
objects. If omitted, all column expressions evaluated on the server
are added to the returning list.
.. versionadded:: 0.9.0
.. seealso::
:meth:`.UpdateBase.returning`
:attr:`.ResultProxy.returned_defaults`
"""
self._return_defaults = cols or True
class Insert(ValuesBase):
"""Represent an INSERT construct.
The :class:`.Insert` object is created using the
:func:`~.expression.insert()` function.
.. seealso::
:ref:`coretutorial_insert_expressions`
"""
__visit_name__ = 'insert'
_supports_multi_parameters = True
def __init__(self,
table,
values=None,
inline=False,
bind=None,
prefixes=None,
returning=None,
return_defaults=False,
**dialect_kw):
"""Construct an :class:`.Insert` object.
Similar functionality is available via the
:meth:`~.TableClause.insert` method on
:class:`~.schema.Table`.
:param table: :class:`.TableClause` which is the subject of the
insert.
:param values: collection of values to be inserted; see
:meth:`.Insert.values` for a description of allowed formats here.
Can be omitted entirely; a :class:`.Insert` construct will also
dynamically render the VALUES clause at execution time based on
the parameters passed to :meth:`.Connection.execute`.
:param inline: if True, no attempt will be made to retrieve the
SQL-generated default values to be provided within the statement;
in particular,
this allows SQL expressions to be rendered 'inline' within the
statement without the need to pre-execute them beforehand; for
backends that support "returning", this turns off the "implicit
returning" feature for the statement.
If both `values` and compile-time bind parameters are present, the
compile-time bind parameters override the information specified
within `values` on a per-key basis.
The keys within `values` can be either
:class:`~sqlalchemy.schema.Column` objects or their string
identifiers. Each key may reference one of:
* a literal data value (i.e. string, number, etc.);
* a Column object;
* a SELECT statement.
If a ``SELECT`` statement is specified which references this
``INSERT`` statement's table, the statement will be correlated
against the ``INSERT`` statement.
.. seealso::
:ref:`coretutorial_insert_expressions` - SQL Expression Tutorial
:ref:`inserts_and_updates` - SQL Expression Tutorial
"""
ValuesBase.__init__(self, table, values, prefixes)
self._bind = bind
self.select = self.select_names = None
self.include_insert_from_select_defaults = False
self.inline = inline
self._returning = returning
self._validate_dialect_kwargs(dialect_kw)
self._return_defaults = return_defaults
def get_children(self, **kwargs):
if self.select is not None:
return self.select,
else:
return ()
@_generative
def from_select(self, names, select, include_defaults=True):
"""Return a new :class:`.Insert` construct which represents
an ``INSERT...FROM SELECT`` statement.
e.g.::
sel = select([table1.c.a, table1.c.b]).where(table1.c.c > 5)
ins = table2.insert().from_select(['a', 'b'], sel)
:param names: a sequence of string column names or :class:`.Column`
objects representing the target columns.
:param select: a :func:`.select` construct, :class:`.FromClause`
or other construct which resolves into a :class:`.FromClause`,
such as an ORM :class:`.Query` object, etc. The order of
columns returned from this FROM clause should correspond to the
order of columns sent as the ``names`` parameter; while this
is not checked before passing along to the database, the database
would normally raise an exception if these column lists don't
correspond.
:param include_defaults: if True, non-server default values and
SQL expressions as specified on :class:`.Column` objects
(as documented in :ref:`metadata_defaults_toplevel`) not
otherwise specified in the list of names will be rendered
into the INSERT and SELECT statements, so that these values are also
included in the data to be inserted.
.. note:: A Python-side default that uses a Python callable function
will only be invoked **once** for the whole statement, and **not
per row**.
.. versionadded:: 1.0.0 - :meth:`.Insert.from_select` now renders
Python-side and SQL expression column defaults into the
SELECT statement for columns otherwise not included in the
list of column names.
.. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT
implies that the :paramref:`.insert.inline` flag is set to
True, indicating that the statement will not attempt to fetch
the "last inserted primary key" or other defaults. The statement
deals with an arbitrary number of rows, so the
:attr:`.ResultProxy.inserted_primary_key` accessor does not apply.
.. versionadded:: 0.8.3
"""
if self.parameters:
raise exc.InvalidRequestError(
"This construct already inserts value expressions")
self.parameters, self._has_multi_parameters = \
self._process_colparams(
dict((_column_as_key(n), Null()) for n in names))
self.select_names = names
self.inline = True
self.include_insert_from_select_defaults = include_defaults
self.select = _interpret_as_select(select)
def _copy_internals(self, clone=_clone, **kw):
# TODO: coverage
self.parameters = self.parameters.copy()
if self.select is not None:
self.select = _clone(self.select)
class Update(ValuesBase):
"""Represent an Update construct.
The :class:`.Update` object is created using the :func:`update()`
function.
"""
__visit_name__ = 'update'
def __init__(self,
table,
whereclause=None,
values=None,
inline=False,
bind=None,
prefixes=None,
returning=None,
return_defaults=False,
preserve_parameter_order=False,
**dialect_kw):
r"""Construct an :class:`.Update` object.
E.g.::
from sqlalchemy import update
stmt = update(users).where(users.c.id==5).\
values(name='user #5')
Similar functionality is available via the
:meth:`~.TableClause.update` method on
:class:`.Table`::
stmt = users.update().\
where(users.c.id==5).\
values(name='user #5')
:param table: A :class:`.Table` object representing the database
table to be updated.
:param whereclause: Optional SQL expression describing the ``WHERE``
condition of the ``UPDATE`` statement. Modern applications
may prefer to use the generative :meth:`~Update.where()`
method to specify the ``WHERE`` clause.
The WHERE clause can refer to multiple tables.
For databases which support this, an ``UPDATE FROM`` clause will
be generated, or on MySQL, a multi-table update. The statement
will fail on databases that don't have support for multi-table
update statements. A SQL-standard method of referring to
additional tables in the WHERE clause is to use a correlated
subquery::
users.update().values(name='ed').where(
users.c.name==select([addresses.c.email_address]).\
where(addresses.c.user_id==users.c.id).\
as_scalar()
)
.. versionchanged:: 0.7.4
The WHERE clause can refer to multiple tables.
:param values:
Optional dictionary which specifies the ``SET`` conditions of the
``UPDATE``. If left as ``None``, the ``SET``
conditions are determined from those parameters passed to the
statement during the execution and/or compilation of the
statement. When compiled standalone without any parameters,
the ``SET`` clause generates for all columns.
Modern applications may prefer to use the generative
:meth:`.Update.values` method to set the values of the
UPDATE statement.
:param inline:
if True, SQL defaults present on :class:`.Column` objects via
the ``default`` keyword will be compiled 'inline' into the statement
and not pre-executed. This means that their values will not
be available in the dictionary returned from
:meth:`.ResultProxy.last_updated_params`.
:param preserve_parameter_order: if True, the update statement is
expected to receive parameters **only** via the :meth:`.Update.values`
method, and they must be passed as a Python ``list`` of 2-tuples.
The rendered UPDATE statement will emit the SET clause for each
referenced column maintaining this order.
.. versionadded:: 1.0.10
.. seealso::
:ref:`updates_order_parameters` - full example of the
:paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order` flag
If both ``values`` and compile-time bind parameters are present, the
compile-time bind parameters override the information specified
within ``values`` on a per-key basis.
The keys within ``values`` can be either :class:`.Column`
objects or their string identifiers (specifically the "key" of the
:class:`.Column`, normally but not necessarily equivalent to
its "name"). Normally, the
:class:`.Column` objects used here are expected to be
part of the target :class:`.Table` that is the table
to be updated. However when using MySQL, a multiple-table
UPDATE statement can refer to columns from any of
the tables referred to in the WHERE clause.
The values referred to in ``values`` are typically:
* a literal data value (i.e. string, number, etc.)
* a SQL expression, such as a related :class:`.Column`,
a scalar-returning :func:`.select` construct,
etc.
When combining :func:`.select` constructs within the values
clause of an :func:`.update` construct,
the subquery represented by the :func:`.select` should be
*correlated* to the parent table, that is, providing criterion
which links the table inside the subquery to the outer table
being updated::
users.update().values(
name=select([addresses.c.email_address]).\
where(addresses.c.user_id==users.c.id).\
as_scalar()
)
.. seealso::
:ref:`inserts_and_updates` - SQL Expression
Language Tutorial
"""
self._preserve_parameter_order = preserve_parameter_order
ValuesBase.__init__(self, table, values, prefixes)
self._bind = bind
self._returning = returning
if whereclause is not None:
self._whereclause = _literal_as_text(whereclause)
else:
self._whereclause = None
self.inline = inline
self._validate_dialect_kwargs(dialect_kw)
self._return_defaults = return_defaults
def get_children(self, **kwargs):
if self._whereclause is not None:
return self._whereclause,
else:
return ()
def _copy_internals(self, clone=_clone, **kw):
# TODO: coverage
self._whereclause = clone(self._whereclause, **kw)
self.parameters = self.parameters.copy()
@_generative
def where(self, whereclause):
"""return a new update() construct with the given expression added to
its WHERE clause, joined to the existing clause via AND, if any.
"""
if self._whereclause is not None:
self._whereclause = and_(self._whereclause,
_literal_as_text(whereclause))
else:
self._whereclause = _literal_as_text(whereclause)
@property
def _extra_froms(self):
# TODO: this could be made memoized
# if the memoization is reset on each generative call.
froms = []
seen = set([self.table])
if self._whereclause is not None:
for item in _from_objects(self._whereclause):
if not seen.intersection(item._cloned_set):
froms.append(item)
seen.update(item._cloned_set)
return froms
class Delete(UpdateBase):
"""Represent a DELETE construct.
The :class:`.Delete` object is created using the :func:`delete()`
function.
"""
__visit_name__ = 'delete'
def __init__(self,
table,
whereclause=None,
bind=None,
returning=None,
prefixes=None,
**dialect_kw):
"""Construct :class:`.Delete` object.
Similar functionality is available via the
:meth:`~.TableClause.delete` method on
:class:`~.schema.Table`.
:param table: The table to delete rows from.
:param whereclause: A :class:`.ClauseElement` describing the ``WHERE``
condition of the ``DELETE`` statement. Note that the
:meth:`~Delete.where()` generative method may be used instead.
.. seealso::
:ref:`deletes` - SQL Expression Tutorial
"""
self._bind = bind
self.table = _interpret_as_from(table)
self._returning = returning
if prefixes:
self._setup_prefixes(prefixes)
if whereclause is not None:
self._whereclause = _literal_as_text(whereclause)
else:
self._whereclause = None
self._validate_dialect_kwargs(dialect_kw)
def get_children(self, **kwargs):
if self._whereclause is not None:
return self._whereclause,
else:
return ()
@_generative
def where(self, whereclause):
"""Add the given WHERE clause to a newly returned delete construct."""
if self._whereclause is not None:
self._whereclause = and_(self._whereclause,
_literal_as_text(whereclause))
else:
self._whereclause = _literal_as_text(whereclause)
def _copy_internals(self, clone=_clone, **kw):
# TODO: coverage
self._whereclause = clone(self._whereclause, **kw)

4403
sqlalchemy/sql/elements.py Normal file

File diff suppressed because it is too large Load Diff

146
sqlalchemy/sql/naming.py Normal file
View File

@ -0,0 +1,146 @@
# sqlalchemy/naming.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Establish constraint and index naming conventions.
"""
from .schema import Constraint, ForeignKeyConstraint, PrimaryKeyConstraint, \
UniqueConstraint, CheckConstraint, Index, Table, Column
from .. import event, events
from .. import exc
from .elements import _truncated_label, _defer_name, _defer_none_name, conv
import re
class ConventionDict(object):
def __init__(self, const, table, convention):
self.const = const
self._is_fk = isinstance(const, ForeignKeyConstraint)
self.table = table
self.convention = convention
self._const_name = const.name
def _key_table_name(self):
return self.table.name
def _column_X(self, idx):
if self._is_fk:
fk = self.const.elements[idx]
return fk.parent
else:
return list(self.const.columns)[idx]
def _key_constraint_name(self):
if isinstance(self._const_name, (type(None), _defer_none_name)):
raise exc.InvalidRequestError(
"Naming convention including "
"%(constraint_name)s token requires that "
"constraint is explicitly named."
)
if not isinstance(self._const_name, conv):
self.const.name = None
return self._const_name
def _key_column_X_name(self, idx):
return self._column_X(idx).name
def _key_column_X_label(self, idx):
return self._column_X(idx)._label
def _key_referred_table_name(self):
fk = self.const.elements[0]
refs = fk.target_fullname.split(".")
if len(refs) == 3:
refschema, reftable, refcol = refs
else:
reftable, refcol = refs
return reftable
def _key_referred_column_X_name(self, idx):
fk = self.const.elements[idx]
refs = fk.target_fullname.split(".")
if len(refs) == 3:
refschema, reftable, refcol = refs
else:
reftable, refcol = refs
return refcol
def __getitem__(self, key):
if key in self.convention:
return self.convention[key](self.const, self.table)
elif hasattr(self, '_key_%s' % key):
return getattr(self, '_key_%s' % key)()
else:
col_template = re.match(r".*_?column_(\d+)_.+", key)
if col_template:
idx = col_template.group(1)
attr = "_key_" + key.replace(idx, "X")
idx = int(idx)
if hasattr(self, attr):
return getattr(self, attr)(idx)
raise KeyError(key)
_prefix_dict = {
Index: "ix",
PrimaryKeyConstraint: "pk",
CheckConstraint: "ck",
UniqueConstraint: "uq",
ForeignKeyConstraint: "fk"
}
def _get_convention(dict_, key):
for super_ in key.__mro__:
if super_ in _prefix_dict and _prefix_dict[super_] in dict_:
return dict_[_prefix_dict[super_]]
elif super_ in dict_:
return dict_[super_]
else:
return None
def _constraint_name_for_table(const, table):
metadata = table.metadata
convention = _get_convention(metadata.naming_convention, type(const))
if isinstance(const.name, conv):
return const.name
elif convention is not None and \
not isinstance(const.name, conv) and \
(
const.name is None or
"constraint_name" in convention or
isinstance(const.name, _defer_name)):
return conv(
convention % ConventionDict(const, table,
metadata.naming_convention)
)
elif isinstance(convention, _defer_none_name):
return None
@event.listens_for(Constraint, "after_parent_attach")
@event.listens_for(Index, "after_parent_attach")
def _constraint_name(const, table):
if isinstance(table, Column):
# for column-attached constraint, set another event
# to link the column attached to the table as this constraint
# associated with the table.
event.listen(table, "after_parent_attach",
lambda col, table: _constraint_name(const, table)
)
elif isinstance(table, Table):
if isinstance(const.name, (conv, _defer_name)):
return
newname = _constraint_name_for_table(const, table)
if newname is not None:
const.name = newname

4027
sqlalchemy/sql/schema.py Normal file

File diff suppressed because it is too large Load Diff

3716
sqlalchemy/sql/selectable.py Normal file

File diff suppressed because it is too large Load Diff

2619
sqlalchemy/sql/sqltypes.py Normal file

File diff suppressed because it is too large Load Diff

1307
sqlalchemy/sql/type_api.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,36 @@
# testing/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .warnings import assert_warnings
from . import config
from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
fails_on, fails_on_everything_except, skip, only_on, exclude, \
against as _against, _server_version, only_if, fails
def against(*queries):
return _against(config._current, *queries)
from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \
assert_raises_message, AssertsCompiledSQL, ComparesTables, \
AssertsExecutionResults, expect_deprecated, expect_warnings, \
in_, not_in_, eq_ignore_whitespace, eq_regex, is_true, is_false
from .util import run_as_contextmanager, rowset, fail, \
provide_metadata, adict, force_drop_names, \
teardown_events
crashes = skip
from .config import db
from .config import requirements as requires
from . import mock

View File

@ -0,0 +1,520 @@
# testing/assertions.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from __future__ import absolute_import
from . import util as testutil
from sqlalchemy import pool, orm, util
from sqlalchemy.engine import default, url
from sqlalchemy.util import decorator, compat
from sqlalchemy import types as sqltypes, schema, exc as sa_exc
import warnings
import re
from .exclusions import db_spec
from . import assertsql
from . import config
from .util import fail
import contextlib
from . import mock
def expect_warnings(*messages, **kw):
"""Context manager which expects one or more warnings.
With no arguments, squelches all SAWarnings emitted via
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
pass string expressions that will match selected warnings via regex;
all non-matching warnings are sent through.
The expect version **asserts** that the warnings were in fact seen.
Note that the test suite sets SAWarning warnings to raise exceptions.
"""
return _expect_warnings(sa_exc.SAWarning, messages, **kw)
@contextlib.contextmanager
def expect_warnings_on(db, *messages, **kw):
"""Context manager which expects one or more warnings on specific
dialects.
The expect version **asserts** that the warnings were in fact seen.
"""
spec = db_spec(db)
if isinstance(db, util.string_types) and not spec(config._current):
yield
else:
with expect_warnings(*messages, **kw):
yield
def emits_warning(*messages):
"""Decorator form of expect_warnings().
Note that emits_warning does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with expect_warnings(assert_=False, *messages):
return fn(*args, **kw)
return decorate
def expect_deprecated(*messages, **kw):
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
def emits_warning_on(db, *messages):
"""Mark a test as emitting a warning on a specific dialect.
With no arguments, squelches all SAWarning failures. Or pass one or more
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
Note that emits_warning_on does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with expect_warnings_on(db, assert_=False, *messages):
return fn(*args, **kw)
return decorate
def uses_deprecated(*messages):
"""Mark a test as immune from fatal deprecation warnings.
With no arguments, squelches all SADeprecationWarning failures.
Or pass one or more strings; these will be matched to the root
of the warning description by warnings.filterwarnings().
As a special case, you may pass a function name prefixed with //
and it will be re-written as needed to match the standard warning
verbiage emitted by the sqlalchemy.util.deprecated decorator.
Note that uses_deprecated does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with expect_deprecated(*messages, assert_=False):
return fn(*args, **kw)
return decorate
@contextlib.contextmanager
def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
py2konly=False):
if regex:
filters = [re.compile(msg, re.I | re.S) for msg in messages]
else:
filters = messages
seen = set(filters)
real_warn = warnings.warn
def our_warn(msg, exception, *arg, **kw):
if not issubclass(exception, exc_cls):
return real_warn(msg, exception, *arg, **kw)
if not filters:
return
for filter_ in filters:
if (regex and filter_.match(msg)) or \
(not regex and filter_ == msg):
seen.discard(filter_)
break
else:
real_warn(msg, exception, *arg, **kw)
with mock.patch("warnings.warn", our_warn):
yield
if assert_ and (not py2konly or not compat.py3k):
assert not seen, "Warnings were not seen: %s" % \
", ".join("%r" % (s.pattern if regex else s) for s in seen)
def global_cleanup_assertions():
"""Check things that have to be finalized at the end of a test suite.
Hardcoded at the moment, a modular system can be built here
to support things like PG prepared transactions, tables all
dropped, etc.
"""
_assert_no_stray_pool_connections()
_STRAY_CONNECTION_FAILURES = 0
def _assert_no_stray_pool_connections():
global _STRAY_CONNECTION_FAILURES
# lazy gc on cPython means "do nothing." pool connections
# shouldn't be in cycles, should go away.
testutil.lazy_gc()
# however, once in awhile, on an EC2 machine usually,
# there's a ref in there. usually just one.
if pool._refs:
# OK, let's be somewhat forgiving.
_STRAY_CONNECTION_FAILURES += 1
print("Encountered a stray connection in test cleanup: %s"
% str(pool._refs))
# then do a real GC sweep. We shouldn't even be here
# so a single sweep should really be doing it, otherwise
# there's probably a real unreachable cycle somewhere.
testutil.gc_collect()
# if we've already had two of these occurrences, or
# after a hard gc sweep we still have pool._refs?!
# now we have to raise.
if pool._refs:
err = str(pool._refs)
# but clean out the pool refs collection directly,
# reset the counter,
# so the error doesn't at least keep happening.
pool._refs.clear()
_STRAY_CONNECTION_FAILURES = 0
warnings.warn(
"Stray connection refused to leave "
"after gc.collect(): %s" % err)
elif _STRAY_CONNECTION_FAILURES > 10:
assert False, "Encountered more than 10 stray connections"
_STRAY_CONNECTION_FAILURES = 0
def eq_regex(a, b, msg=None):
assert re.match(b, a), msg or "%r !~ %r" % (a, b)
def eq_(a, b, msg=None):
"""Assert a == b, with repr messaging on failure."""
assert a == b, msg or "%r != %r" % (a, b)
def ne_(a, b, msg=None):
"""Assert a != b, with repr messaging on failure."""
assert a != b, msg or "%r == %r" % (a, b)
def le_(a, b, msg=None):
"""Assert a <= b, with repr messaging on failure."""
assert a <= b, msg or "%r != %r" % (a, b)
def is_true(a, msg=None):
is_(a, True, msg=msg)
def is_false(a, msg=None):
is_(a, False, msg=msg)
def is_(a, b, msg=None):
"""Assert a is b, with repr messaging on failure."""
assert a is b, msg or "%r is not %r" % (a, b)
def is_not_(a, b, msg=None):
"""Assert a is not b, with repr messaging on failure."""
assert a is not b, msg or "%r is %r" % (a, b)
def in_(a, b, msg=None):
"""Assert a in b, with repr messaging on failure."""
assert a in b, msg or "%r not in %r" % (a, b)
def not_in_(a, b, msg=None):
"""Assert a in not b, with repr messaging on failure."""
assert a not in b, msg or "%r is in %r" % (a, b)
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
a, fragment)
def eq_ignore_whitespace(a, b, msg=None):
a = re.sub(r'^\s+?|\n', "", a)
a = re.sub(r' {2,}', " ", a)
b = re.sub(r'^\s+?|\n', "", b)
b = re.sub(r' {2,}', " ", b)
assert a == b, msg or "%r != %r" % (a, b)
def assert_raises(except_cls, callable_, *args, **kw):
try:
callable_(*args, **kw)
success = False
except except_cls:
success = True
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
try:
callable_(*args, **kwargs)
assert False, "Callable did not raise an exception"
except except_cls as e:
assert re.search(
msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
print(util.text_type(e).encode('utf-8'))
class AssertsCompiledSQL(object):
def assert_compile(self, clause, result, params=None,
checkparams=None, dialect=None,
checkpositional=None,
check_prefetch=None,
use_default_dialect=False,
allow_dialect_select=False,
literal_binds=False,
schema_translate_map=None):
if use_default_dialect:
dialect = default.DefaultDialect()
elif allow_dialect_select:
dialect = None
else:
if dialect is None:
dialect = getattr(self, '__dialect__', None)
if dialect is None:
dialect = config.db.dialect
elif dialect == 'default':
dialect = default.DefaultDialect()
elif dialect == 'default_enhanced':
dialect = default.StrCompileDialect()
elif isinstance(dialect, util.string_types):
dialect = url.URL(dialect).get_dialect()()
kw = {}
compile_kwargs = {}
if schema_translate_map:
kw['schema_translate_map'] = schema_translate_map
if params is not None:
kw['column_keys'] = list(params)
if literal_binds:
compile_kwargs['literal_binds'] = True
if isinstance(clause, orm.Query):
context = clause._compile_context()
context.statement.use_labels = True
clause = context.statement
if compile_kwargs:
kw['compile_kwargs'] = compile_kwargs
c = clause.compile(dialect=dialect, **kw)
param_str = repr(getattr(c, 'params', {}))
if util.py3k:
param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
print(
("\nSQL String:\n" +
util.text_type(c) +
param_str).encode('utf-8'))
else:
print(
"\nSQL String:\n" +
util.text_type(c).encode('utf-8') +
param_str)
cc = re.sub(r'[\n\t]', '', util.text_type(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
if checkparams is not None:
eq_(c.construct_params(params), checkparams)
if checkpositional is not None:
p = c.construct_params(params)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
if check_prefetch is not None:
eq_(c.prefetch, check_prefetch)
class ComparesTables(object):
def assert_tables_equal(self, table, reflected_table, strict_types=False):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
eq_(c.name, reflected_c.name)
assert reflected_c is reflected_table.c[c.name]
eq_(c.primary_key, reflected_c.primary_key)
eq_(c.nullable, reflected_c.nullable)
if strict_types:
msg = "Type '%s' doesn't correspond to type '%s'"
assert isinstance(reflected_c.type, type(c.type)), \
msg % (reflected_c.type, c.type)
else:
self.assert_types_base(reflected_c, c)
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
eq_(
set([f.column.name for f in c.foreign_keys]),
set([f.column.name for f in reflected_c.foreign_keys])
)
if c.server_default:
assert isinstance(reflected_c.server_default,
schema.FetchedValue)
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
assert c1.type._compare_type_affinity(c2.type),\
"On column %r, type '%s' doesn't correspond to type '%s'" % \
(c1.name, c1.type, c2.type)
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
result = list(result)
print(repr(result))
self.assert_list(result, class_, objects)
def assert_list(self, result, class_, list):
self.assert_(len(result) == len(list),
"result list is not the same size as test list, " +
"for class " + class_.__name__)
for i in range(0, len(list)):
self.assert_row(class_, result[i], list[i])
def assert_row(self, class_, rowobj, desc):
self.assert_(rowobj.__class__ is class_,
"item class is not " + repr(class_))
for key, value in desc.items():
if isinstance(value, tuple):
if isinstance(value[1], list):
self.assert_list(getattr(rowobj, key), value[0], value[1])
else:
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
self.assert_(getattr(rowobj, key) == value,
"attribute %s value %s does not match %s" % (
key, getattr(rowobj, key), value))
def assert_unordered_result(self, result, cls, *expected):
"""As assert_result, but the order of objects is not considered.
The algorithm is very expensive but not a big deal for the small
numbers of rows that the test suite manipulates.
"""
class immutabledict(dict):
def __hash__(self):
return id(self)
found = util.IdentitySet(result)
expected = set([immutabledict(e) for e in expected])
for wrong in util.itertools_filterfalse(lambda o:
isinstance(o, cls), found):
fail('Unexpected type "%s", expected "%s"' % (
type(wrong).__name__, cls.__name__))
if len(found) != len(expected):
fail('Unexpected object count "%s", expected "%s"' % (
len(found), len(expected)))
NOVALUE = object()
def _compare_item(obj, spec):
for key, value in spec.items():
if isinstance(value, tuple):
try:
self.assert_unordered_result(
getattr(obj, key), value[0], *value[1])
except AssertionError:
return False
else:
if getattr(obj, key, NOVALUE) != value:
return False
return True
for expected_item in expected:
for found_item in found:
if _compare_item(found_item, expected_item):
found.remove(found_item)
break
else:
fail(
"Expected %s instance with attributes %s not found." % (
cls.__name__, repr(expected_item)))
return True
def sql_execution_asserter(self, db=None):
if db is None:
from . import db as db
return assertsql.assert_engine(db)
def assert_sql_execution(self, db, callable_, *rules):
with self.sql_execution_asserter(db) as asserter:
callable_()
asserter.assert_(*rules)
def assert_sql(self, db, callable_, rules):
newrules = []
for rule in rules:
if isinstance(rule, dict):
newrule = assertsql.AllOf(*[
assertsql.CompiledSQL(k, v) for k, v in rule.items()
])
else:
newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
self.assert_sql_execution(
db, callable_, assertsql.CountStatements(count))
@contextlib.contextmanager
def assert_execution(self, *rules):
assertsql.asserter.add_rules(rules)
try:
yield
assertsql.asserter.statement_complete()
finally:
assertsql.asserter.clear_rules()
def assert_statement_count(self, count):
return self.assert_execution(assertsql.CountStatements(count))

View File

@ -0,0 +1,377 @@
# testing/assertsql.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from ..engine.default import DefaultDialect
from .. import util
import re
import collections
import contextlib
from .. import event
from sqlalchemy.schema import _DDLCompiles
from sqlalchemy.engine.util import _distill_params
from sqlalchemy.engine import url
class AssertRule(object):
is_consumed = False
errormessage = None
consume_statement = True
def process_statement(self, execute_observed):
pass
def no_more_statements(self):
assert False, 'All statements are complete, but pending '\
'assertion rules remain'
class SQLMatchRule(AssertRule):
pass
class CursorSQL(SQLMatchRule):
consume_statement = False
def __init__(self, statement, params=None):
self.statement = statement
self.params = params
def process_statement(self, execute_observed):
stmt = execute_observed.statements[0]
if self.statement != stmt.statement or (
self.params is not None and self.params != stmt.parameters):
self.errormessage = \
"Testing for exact SQL %s parameters %s received %s %s" % (
self.statement, self.params,
stmt.statement, stmt.parameters
)
else:
execute_observed.statements.pop(0)
self.is_consumed = True
if not execute_observed.statements:
self.consume_statement = True
class CompiledSQL(SQLMatchRule):
def __init__(self, statement, params=None, dialect='default'):
self.statement = statement
self.params = params
self.dialect = dialect
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r'[\n\t]', '', self.statement)
return received_statement == stmt
def _compile_dialect(self, execute_observed):
if self.dialect == 'default':
return DefaultDialect()
else:
# ugh
if self.dialect == 'postgresql':
params = {'implicit_returning': True}
else:
params = {}
return url.URL(self.dialect).get_dialect()(**params)
def _received_statement(self, execute_observed):
"""reconstruct the statement and params in terms
of a target dialect, which for CompiledSQL is just DefaultDialect."""
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
if isinstance(context.compiled.statement, _DDLCompiles):
compiled = \
context.compiled.statement.compile(
dialect=compare_dialect,
schema_translate_map=context.
execution_options.get('schema_translate_map'))
else:
compiled = (
context.compiled.statement.compile(
dialect=compare_dialect,
column_keys=context.compiled.column_keys,
inline=context.compiled.inline,
schema_translate_map=context.
execution_options.get('schema_translate_map'))
)
_received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
parameters = execute_observed.parameters
if not parameters:
_received_parameters = [compiled.construct_params()]
else:
_received_parameters = [
compiled.construct_params(m) for m in parameters]
return _received_statement, _received_parameters
def process_statement(self, execute_observed):
context = execute_observed.context
_received_statement, _received_parameters = \
self._received_statement(execute_observed)
params = self._all_params(context)
equivalent = self._compare_sql(execute_observed, _received_statement)
if equivalent:
if params is not None:
all_params = list(params)
all_received = list(_received_parameters)
while all_params and all_received:
param = dict(all_params.pop(0))
for idx, received in enumerate(list(all_received)):
# do a positive compare only
for param_key in param:
# a key in param did not match current
# 'received'
if param_key not in received or \
received[param_key] != param[param_key]:
break
else:
# all keys in param matched 'received';
# onto next param
del all_received[idx]
break
else:
# param did not match any entry
# in all_received
equivalent = False
break
if all_params or all_received:
equivalent = False
if equivalent:
self.is_consumed = True
self.errormessage = None
else:
self.errormessage = self._failure_message(params) % {
'received_statement': _received_statement,
'received_parameters': _received_parameters
}
def _all_params(self, context):
if self.params:
if util.callable(self.params):
params = self.params(context)
else:
params = self.params
if not isinstance(params, list):
params = [params]
return params
else:
return None
def _failure_message(self, expected_params):
return (
'Testing for compiled statement %r partial params %r, '
'received %%(received_statement)r with params '
'%%(received_parameters)r' % (
self.statement.replace('%', '%%'), expected_params
)
)
class RegexSQL(CompiledSQL):
def __init__(self, regex, params=None):
SQLMatchRule.__init__(self)
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
self.dialect = 'default'
def _failure_message(self, expected_params):
return (
'Testing for compiled statement ~%r partial params %r, '
'received %%(received_statement)r with params '
'%%(received_parameters)r' % (
self.orig_regex, expected_params
)
)
def _compare_sql(self, execute_observed, received_statement):
return bool(self.regex.match(received_statement))
class DialectSQL(CompiledSQL):
def _compile_dialect(self, execute_observed):
return execute_observed.context.dialect
def _compare_no_space(self, real_stmt, received_stmt):
stmt = re.sub(r'[\n\t]', '', real_stmt)
return received_stmt == stmt
def _received_statement(self, execute_observed):
received_stmt, received_params = super(DialectSQL, self).\
_received_statement(execute_observed)
# TODO: why do we need this part?
for real_stmt in execute_observed.statements:
if self._compare_no_space(real_stmt.statement, received_stmt):
break
else:
raise AssertionError(
"Can't locate compiled statement %r in list of "
"statements actually invoked" % received_stmt)
return received_stmt, execute_observed.context.compiled_parameters
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r'[\n\t]', '', self.statement)
# convert our comparison statement to have the
# paramstyle of the received
paramstyle = execute_observed.context.dialect.paramstyle
if paramstyle == 'pyformat':
stmt = re.sub(
r':([\w_]+)', r"%(\1)s", stmt)
else:
# positional params
repl = None
if paramstyle == 'qmark':
repl = "?"
elif paramstyle == 'format':
repl = r"%s"
elif paramstyle == 'numeric':
repl = None
stmt = re.sub(r':([\w_]+)', repl, stmt)
return received_statement == stmt
class CountStatements(AssertRule):
def __init__(self, count):
self.count = count
self._statement_count = 0
def process_statement(self, execute_observed):
self._statement_count += 1
def no_more_statements(self):
if self.count != self._statement_count:
assert False, 'desired statement count %d does not match %d' \
% (self.count, self._statement_count)
class AllOf(AssertRule):
def __init__(self, *rules):
self.rules = set(rules)
def process_statement(self, execute_observed):
for rule in list(self.rules):
rule.errormessage = None
rule.process_statement(execute_observed)
if rule.is_consumed:
self.rules.discard(rule)
if not self.rules:
self.is_consumed = True
break
elif not rule.errormessage:
# rule is not done yet
self.errormessage = None
break
else:
self.errormessage = list(self.rules)[0].errormessage
class Or(AllOf):
def process_statement(self, execute_observed):
for rule in self.rules:
rule.process_statement(execute_observed)
if rule.is_consumed:
self.is_consumed = True
break
else:
self.errormessage = list(self.rules)[0].errormessage
class SQLExecuteObserved(object):
def __init__(self, context, clauseelement, multiparams, params):
self.context = context
self.clauseelement = clauseelement
self.parameters = _distill_params(multiparams, params)
self.statements = []
class SQLCursorExecuteObserved(
collections.namedtuple(
"SQLCursorExecuteObserved",
["statement", "parameters", "context", "executemany"])
):
pass
class SQLAsserter(object):
def __init__(self):
self.accumulated = []
def _close(self):
self._final = self.accumulated
del self.accumulated
def assert_(self, *rules):
rules = list(rules)
observed = list(self._final)
while observed and rules:
rule = rules[0]
rule.process_statement(observed[0])
if rule.is_consumed:
rules.pop(0)
elif rule.errormessage:
assert False, rule.errormessage
if rule.consume_statement:
observed.pop(0)
if not observed and rules:
rules[0].no_more_statements()
elif not rules and observed:
assert False, "Additional SQL statements remain"
@contextlib.contextmanager
def assert_engine(engine):
asserter = SQLAsserter()
orig = []
@event.listens_for(engine, "before_execute")
def connection_execute(conn, clauseelement, multiparams, params):
# grab the original statement + params before any cursor
# execution
orig[:] = clauseelement, multiparams, params
@event.listens_for(engine, "after_cursor_execute")
def cursor_execute(conn, cursor, statement, parameters,
context, executemany):
if not context:
return
# then grab real cursor statements and associate them all
# around a single context
if asserter.accumulated and \
asserter.accumulated[-1].context is context:
obs = asserter.accumulated[-1]
else:
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
asserter.accumulated.append(obs)
obs.statements.append(
SQLCursorExecuteObserved(
statement, parameters, context, executemany)
)
try:
yield asserter
finally:
event.remove(engine, "after_cursor_execute", cursor_execute)
event.remove(engine, "before_execute", connection_execute)
asserter._close()

View File

@ -0,0 +1,97 @@
# testing/config.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import collections
requirements = None
db = None
db_url = None
db_opts = None
file_config = None
test_schema = None
test_schema_2 = None
_current = None
try:
from unittest import SkipTest as _skip_test_exception
except ImportError:
_skip_test_exception = None
class Config(object):
def __init__(self, db, db_opts, options, file_config):
self.db = db
self.db_opts = db_opts
self.options = options
self.file_config = file_config
self.test_schema = "test_schema"
self.test_schema_2 = "test_schema_2"
_stack = collections.deque()
_configs = {}
@classmethod
def register(cls, db, db_opts, options, file_config):
"""add a config as one of the global configs.
If there are no configs set up yet, this config also
gets set as the "_current".
"""
cfg = Config(db, db_opts, options, file_config)
cls._configs[cfg.db.name] = cfg
cls._configs[(cfg.db.name, cfg.db.dialect)] = cfg
cls._configs[cfg.db] = cfg
return cfg
@classmethod
def set_as_current(cls, config, namespace):
global db, _current, db_url, test_schema, test_schema_2, db_opts
_current = config
db_url = config.db.url
db_opts = config.db_opts
test_schema = config.test_schema
test_schema_2 = config.test_schema_2
namespace.db = db = config.db
@classmethod
def push_engine(cls, db, namespace):
assert _current, "Can't push without a default Config set up"
cls.push(
Config(
db, _current.db_opts, _current.options, _current.file_config),
namespace
)
@classmethod
def push(cls, config, namespace):
cls._stack.append(_current)
cls.set_as_current(config, namespace)
@classmethod
def reset(cls, namespace):
if cls._stack:
cls.set_as_current(cls._stack[0], namespace)
cls._stack.clear()
@classmethod
def all_configs(cls):
for cfg in set(cls._configs.values()):
yield cfg
@classmethod
def all_dbs(cls):
for cfg in cls.all_configs():
yield cfg.db
def skip_test(self, msg):
skip_test(msg)
def skip_test(msg):
raise _skip_test_exception(msg)

View File

@ -0,0 +1,349 @@
# testing/engines.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from __future__ import absolute_import
import weakref
from . import config
from .util import decorator
from .. import event, pool
import re
import warnings
class ConnectionKiller(object):
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
self.testing_engines = weakref.WeakKeyDictionary()
self.conns = set()
def add_engine(self, engine):
self.testing_engines[engine] = True
def connect(self, dbapi_conn, con_record):
self.conns.add((dbapi_conn, con_record))
def checkout(self, dbapi_con, con_record, con_proxy):
self.proxy_refs[con_proxy] = True
def invalidate(self, dbapi_con, con_record, exception):
self.conns.discard((dbapi_con, con_record))
def _safe(self, fn):
try:
fn()
except Exception as e:
warnings.warn(
"testing_reaper couldn't "
"rollback/close connection: %s" % e)
def rollback_all(self):
for rec in list(self.proxy_refs):
if rec is not None and rec.is_valid:
self._safe(rec.rollback)
def close_all(self):
for rec in list(self.proxy_refs):
if rec is not None and rec.is_valid:
self._safe(rec._close)
def _after_test_ctx(self):
# this can cause a deadlock with pg8000 - pg8000 acquires
# prepared statement lock inside of rollback() - if async gc
# is collecting in finalize_fairy, deadlock.
# not sure if this should be if pypy/jython only.
# note that firebird/fdb definitely needs this though
for conn, rec in list(self.conns):
self._safe(conn.rollback)
def _stop_test_ctx(self):
if config.options.low_connections:
self._stop_test_ctx_minimal()
else:
self._stop_test_ctx_aggressive()
def _stop_test_ctx_minimal(self):
self.close_all()
self.conns = set()
for rec in list(self.testing_engines):
if rec is not config.db:
rec.dispose()
def _stop_test_ctx_aggressive(self):
self.close_all()
for conn, rec in list(self.conns):
self._safe(conn.close)
rec.connection = None
self.conns = set()
for rec in list(self.testing_engines):
rec.dispose()
def assert_all_closed(self):
for rec in self.proxy_refs:
if rec.is_valid:
assert False
testing_reaper = ConnectionKiller()
def drop_all_tables(metadata, bind):
testing_reaper.close_all()
if hasattr(bind, 'close'):
bind.close()
if not config.db.dialect.supports_alter:
from . import assertions
with assertions.expect_warnings(
"Can't sort tables", assert_=False):
metadata.drop_all(bind)
else:
metadata.drop_all(bind)
@decorator
def assert_conns_closed(fn, *args, **kw):
try:
fn(*args, **kw)
finally:
testing_reaper.assert_all_closed()
@decorator
def rollback_open_connections(fn, *args, **kw):
"""Decorator that rolls back all open connections after fn execution."""
try:
fn(*args, **kw)
finally:
testing_reaper.rollback_all()
@decorator
def close_first(fn, *args, **kw):
"""Decorator that closes all connections before fn execution."""
testing_reaper.close_all()
fn(*args, **kw)
@decorator
def close_open_connections(fn, *args, **kw):
"""Decorator that closes all connections after fn execution."""
try:
fn(*args, **kw)
finally:
testing_reaper.close_all()
def all_dialects(exclude=None):
import sqlalchemy.databases as d
for name in d.__all__:
# TEMPORARY
if exclude and name in exclude:
continue
mod = getattr(d, name, None)
if not mod:
mod = getattr(__import__(
'sqlalchemy.databases.%s' % name).databases, name)
yield mod.dialect()
class ReconnectFixture(object):
def __init__(self, dbapi):
self.dbapi = dbapi
self.connections = []
def __getattr__(self, key):
return getattr(self.dbapi, key)
def connect(self, *args, **kwargs):
conn = self.dbapi.connect(*args, **kwargs)
self.connections.append(conn)
return conn
def _safe(self, fn):
try:
fn()
except Exception as e:
warnings.warn(
"ReconnectFixture couldn't "
"close connection: %s" % e)
def shutdown(self):
# TODO: this doesn't cover all cases
# as nicely as we'd like, namely MySQLdb.
# would need to implement R. Brewer's
# proxy server idea to get better
# coverage.
for c in list(self.connections):
self._safe(c.close)
self.connections = []
def reconnecting_engine(url=None, options=None):
url = url or config.db.url
dbapi = config.db.dialect.dbapi
if not options:
options = {}
options['module'] = ReconnectFixture(dbapi)
engine = testing_engine(url, options)
_dispose = engine.dispose
def dispose():
engine.dialect.dbapi.shutdown()
_dispose()
engine.test_shutdown = engine.dialect.dbapi.shutdown
engine.dispose = dispose
return engine
def testing_engine(url=None, options=None):
"""Produce an engine configured by --options with optional overrides."""
from sqlalchemy import create_engine
from sqlalchemy.engine.url import make_url
if not options:
use_reaper = True
else:
use_reaper = options.pop('use_reaper', True)
url = url or config.db.url
url = make_url(url)
if options is None:
if config.db is None or url.drivername == config.db.url.drivername:
options = config.db_opts
else:
options = {}
elif config.db is not None and url.drivername == config.db.url.drivername:
default_opt = config.db_opts.copy()
default_opt.update(options)
engine = create_engine(url, **options)
engine._has_events = True # enable event blocks, helps with profiling
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
engine.pool._max_overflow = 0
if use_reaper:
event.listen(engine.pool, 'connect', testing_reaper.connect)
event.listen(engine.pool, 'checkout', testing_reaper.checkout)
event.listen(engine.pool, 'invalidate', testing_reaper.invalidate)
testing_reaper.add_engine(engine)
return engine
def mock_engine(dialect_name=None):
"""Provides a mocking engine based on the current testing.db.
This is normally used to test DDL generation flow as emitted
by an Engine.
It should not be used in other cases, as assert_compile() and
assert_sql_execution() are much better choices with fewer
moving parts.
"""
from sqlalchemy import create_engine
if not dialect_name:
dialect_name = config.db.name
buffer = []
def executor(sql, *a, **kw):
buffer.append(sql)
def assert_sql(stmts):
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
assert recv == stmts, recv
def print_sql():
d = engine.dialect
return "\n".join(
str(s.compile(dialect=d))
for s in engine.mock
)
engine = create_engine(dialect_name + '://',
strategy='mock', executor=executor)
assert not hasattr(engine, 'mock')
engine.mock = buffer
engine.assert_sql = assert_sql
engine.print_sql = print_sql
return engine
class DBAPIProxyCursor(object):
"""Proxy a DBAPI cursor.
Tests can provide subclasses of this to intercept
DBAPI-level cursor operations.
"""
def __init__(self, engine, conn, *args, **kwargs):
self.engine = engine
self.connection = conn
self.cursor = conn.cursor(*args, **kwargs)
def execute(self, stmt, parameters=None, **kw):
if parameters:
return self.cursor.execute(stmt, parameters, **kw)
else:
return self.cursor.execute(stmt, **kw)
def executemany(self, stmt, params, **kw):
return self.cursor.executemany(stmt, params, **kw)
def __getattr__(self, key):
return getattr(self.cursor, key)
class DBAPIProxyConnection(object):
"""Proxy a DBAPI connection.
Tests can provide subclasses of this to intercept
DBAPI-level connection operations.
"""
def __init__(self, engine, cursor_cls):
self.conn = self._sqla_unwrap = engine.pool._creator()
self.engine = engine
self.cursor_cls = cursor_cls
def cursor(self, *args, **kwargs):
return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
def close(self):
self.conn.close()
def __getattr__(self, key):
return getattr(self.conn, key)
def proxying_engine(conn_cls=DBAPIProxyConnection,
cursor_cls=DBAPIProxyCursor):
"""Produce an engine that provides proxy hooks for
common methods.
"""
def mock_conn():
return conn_cls(config.db, cursor_cls)
return testing_engine(options={'creator': mock_conn})

View File

@ -0,0 +1,101 @@
# testing/entities.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
_repr_stack = set()
class BasicEntity(object):
def __init__(self, **kw):
for key, value in kw.items():
setattr(self, key, value)
def __repr__(self):
if id(self) in _repr_stack:
return object.__repr__(self)
_repr_stack.add(id(self))
try:
return "%s(%s)" % (
(self.__class__.__name__),
', '.join(["%s=%r" % (key, getattr(self, key))
for key in sorted(self.__dict__.keys())
if not key.startswith('_')]))
finally:
_repr_stack.remove(id(self))
_recursion_stack = set()
class ComparableEntity(BasicEntity):
def __hash__(self):
return hash(self.__class__)
def __ne__(self, other):
return not self.__eq__(other)
def __eq__(self, other):
"""'Deep, sparse compare.
Deeply compare two entities, following the non-None attributes of the
non-persisted object, if possible.
"""
if other is self:
return True
elif not self.__class__ == other.__class__:
return False
if id(self) in _recursion_stack:
return True
_recursion_stack.add(id(self))
try:
# pick the entity that's not SA persisted as the source
try:
self_key = sa.orm.attributes.instance_state(self).key
except sa.orm.exc.NO_STATE:
self_key = None
if other is None:
a = self
b = other
elif self_key is not None:
a = other
b = self
else:
a = self
b = other
for attr in list(a.__dict__):
if attr.startswith('_'):
continue
value = getattr(a, attr)
try:
# handle lazy loader errors
battr = getattr(b, attr)
except (AttributeError, sa_exc.UnboundExecutionError):
return False
if hasattr(value, '__iter__'):
if hasattr(value, '__getitem__') and not hasattr(
value, 'keys'):
if list(value) != list(battr):
return False
else:
if set(value) != set(battr):
return False
else:
if value is not None and value != battr:
return False
return True
finally:
_recursion_stack.remove(id(self))

Some files were not shown because too many files have changed in this diff Show More