Updated SqlAlchemy + the new files
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
/*
|
||||
processors.c
|
||||
Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
|
||||
Copyright (C) 2010-2017 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com
|
||||
|
||||
This module is part of SQLAlchemy and is released under
|
||||
the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
@@ -9,26 +10,30 @@ the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
#include <Python.h>
|
||||
#include <datetime.h>
|
||||
|
||||
#define MODULE_NAME "cprocessors"
|
||||
#define MODULE_DOC "Module containing C versions of data processing functions."
|
||||
|
||||
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
|
||||
typedef int Py_ssize_t;
|
||||
#define PY_SSIZE_T_MAX INT_MAX
|
||||
#define PY_SSIZE_T_MIN INT_MIN
|
||||
#endif
|
||||
|
||||
static PyObject *
|
||||
int_to_boolean(PyObject *self, PyObject *arg)
|
||||
{
|
||||
long l = 0;
|
||||
int l = 0;
|
||||
PyObject *res;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
l = PyInt_AsLong(arg);
|
||||
l = PyObject_IsTrue(arg);
|
||||
if (l == 0) {
|
||||
res = Py_False;
|
||||
} else if (l == 1) {
|
||||
res = Py_True;
|
||||
} else if ((l == -1) && PyErr_Occurred()) {
|
||||
/* -1 can be either the actual value, or an error flag. */
|
||||
return NULL;
|
||||
} else {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"int_to_boolean only accepts None, 0 or 1");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@@ -57,15 +62,51 @@ to_float(PyObject *self, PyObject *arg)
|
||||
static PyObject *
|
||||
str_to_datetime(PyObject *self, PyObject *arg)
|
||||
{
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject *bytes;
|
||||
PyObject *err_bytes;
|
||||
#endif
|
||||
const char *str;
|
||||
int numparsed;
|
||||
unsigned int year, month, day, hour, minute, second, microsecond = 0;
|
||||
PyObject *err_repr;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
bytes = PyUnicode_AsASCIIString(arg);
|
||||
if (bytes == NULL)
|
||||
str = NULL;
|
||||
else
|
||||
str = PyBytes_AS_STRING(bytes);
|
||||
#else
|
||||
str = PyString_AsString(arg);
|
||||
if (str == NULL)
|
||||
#endif
|
||||
if (str == NULL) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse datetime string '%.200s' "
|
||||
"- value is not a string.",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse datetime string '%.200s' "
|
||||
"- value is not a string.",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* microseconds are optional */
|
||||
/*
|
||||
@@ -73,9 +114,31 @@ str_to_datetime(PyObject *self, PyObject *arg)
|
||||
not accept "2000-01-01 00:00:00.". I don't know which is better, but they
|
||||
should be coherent.
|
||||
*/
|
||||
if (sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day,
|
||||
&hour, &minute, &second, µsecond) < 6) {
|
||||
PyErr_SetString(PyExc_ValueError, "Couldn't parse datetime string.");
|
||||
numparsed = sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day,
|
||||
&hour, &minute, &second, µsecond);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_DECREF(bytes);
|
||||
#endif
|
||||
if (numparsed < 6) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse datetime string: %.200s",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse datetime string: %.200s",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
return PyDateTime_FromDateAndTime(year, month, day,
|
||||
@@ -85,15 +148,50 @@ str_to_datetime(PyObject *self, PyObject *arg)
|
||||
static PyObject *
|
||||
str_to_time(PyObject *self, PyObject *arg)
|
||||
{
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject *bytes;
|
||||
PyObject *err_bytes;
|
||||
#endif
|
||||
const char *str;
|
||||
int numparsed;
|
||||
unsigned int hour, minute, second, microsecond = 0;
|
||||
PyObject *err_repr;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
bytes = PyUnicode_AsASCIIString(arg);
|
||||
if (bytes == NULL)
|
||||
str = NULL;
|
||||
else
|
||||
str = PyBytes_AS_STRING(bytes);
|
||||
#else
|
||||
str = PyString_AsString(arg);
|
||||
if (str == NULL)
|
||||
#endif
|
||||
if (str == NULL) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse time string '%.200s' - value is not a string.",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse time string '%.200s' - value is not a string.",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* microseconds are optional */
|
||||
/*
|
||||
@@ -101,9 +199,31 @@ str_to_time(PyObject *self, PyObject *arg)
|
||||
not accept "00:00:00.". I don't know which is better, but they should be
|
||||
coherent.
|
||||
*/
|
||||
if (sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second,
|
||||
µsecond) < 3) {
|
||||
PyErr_SetString(PyExc_ValueError, "Couldn't parse time string.");
|
||||
numparsed = sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second,
|
||||
µsecond);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_DECREF(bytes);
|
||||
#endif
|
||||
if (numparsed < 3) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse time string: %.200s",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse time string: %.200s",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
return PyTime_FromTime(hour, minute, second, microsecond);
|
||||
@@ -112,18 +232,74 @@ str_to_time(PyObject *self, PyObject *arg)
|
||||
static PyObject *
|
||||
str_to_date(PyObject *self, PyObject *arg)
|
||||
{
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject *bytes;
|
||||
PyObject *err_bytes;
|
||||
#endif
|
||||
const char *str;
|
||||
int numparsed;
|
||||
unsigned int year, month, day;
|
||||
PyObject *err_repr;
|
||||
|
||||
if (arg == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
bytes = PyUnicode_AsASCIIString(arg);
|
||||
if (bytes == NULL)
|
||||
str = NULL;
|
||||
else
|
||||
str = PyBytes_AS_STRING(bytes);
|
||||
#else
|
||||
str = PyString_AsString(arg);
|
||||
if (str == NULL)
|
||||
#endif
|
||||
if (str == NULL) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse date string '%.200s' - value is not a string.",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse date string '%.200s' - value is not a string.",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (sscanf(str, "%4u-%2u-%2u", &year, &month, &day) != 3) {
|
||||
PyErr_SetString(PyExc_ValueError, "Couldn't parse date string.");
|
||||
numparsed = sscanf(str, "%4u-%2u-%2u", &year, &month, &day);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_DECREF(bytes);
|
||||
#endif
|
||||
if (numparsed != 3) {
|
||||
err_repr = PyObject_Repr(arg);
|
||||
if (err_repr == NULL)
|
||||
return NULL;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(err_repr);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse date string: %.200s",
|
||||
PyBytes_AS_STRING(err_bytes));
|
||||
Py_DECREF(err_bytes);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Couldn't parse date string: %.200s",
|
||||
PyString_AsString(err_repr));
|
||||
#endif
|
||||
Py_DECREF(err_repr);
|
||||
return NULL;
|
||||
}
|
||||
return PyDate_FromDate(year, month, day);
|
||||
@@ -159,17 +335,35 @@ UnicodeResultProcessor_init(UnicodeResultProcessor *self, PyObject *args,
|
||||
PyObject *encoding, *errors = NULL;
|
||||
static char *kwlist[] = {"encoding", "errors", NULL};
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "U|U:__init__", kwlist,
|
||||
&encoding, &errors))
|
||||
return -1;
|
||||
#else
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:__init__", kwlist,
|
||||
&encoding, &errors))
|
||||
return -1;
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
encoding = PyUnicode_AsASCIIString(encoding);
|
||||
#else
|
||||
Py_INCREF(encoding);
|
||||
#endif
|
||||
self->encoding = encoding;
|
||||
|
||||
if (errors) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
errors = PyUnicode_AsASCIIString(errors);
|
||||
#else
|
||||
Py_INCREF(errors);
|
||||
#endif
|
||||
} else {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
errors = PyBytes_FromString("strict");
|
||||
#else
|
||||
errors = PyString_FromString("strict");
|
||||
#endif
|
||||
if (errors == NULL)
|
||||
return -1;
|
||||
}
|
||||
@@ -188,28 +382,88 @@ UnicodeResultProcessor_process(UnicodeResultProcessor *self, PyObject *value)
|
||||
if (value == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (PyBytes_AsStringAndSize(value, &str, &len))
|
||||
return NULL;
|
||||
|
||||
encoding = PyBytes_AS_STRING(self->encoding);
|
||||
errors = PyBytes_AS_STRING(self->errors);
|
||||
#else
|
||||
if (PyString_AsStringAndSize(value, &str, &len))
|
||||
return NULL;
|
||||
|
||||
encoding = PyString_AS_STRING(self->encoding);
|
||||
errors = PyString_AS_STRING(self->errors);
|
||||
#endif
|
||||
|
||||
return PyUnicode_Decode(str, len, encoding, errors);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
UnicodeResultProcessor_conditional_process(UnicodeResultProcessor *self, PyObject *value)
|
||||
{
|
||||
const char *encoding, *errors;
|
||||
char *str;
|
||||
Py_ssize_t len;
|
||||
|
||||
if (value == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (PyUnicode_Check(value) == 1) {
|
||||
Py_INCREF(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
if (PyBytes_AsStringAndSize(value, &str, &len))
|
||||
return NULL;
|
||||
|
||||
encoding = PyBytes_AS_STRING(self->encoding);
|
||||
errors = PyBytes_AS_STRING(self->errors);
|
||||
#else
|
||||
|
||||
if (PyUnicode_Check(value) == 1) {
|
||||
Py_INCREF(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
if (PyString_AsStringAndSize(value, &str, &len))
|
||||
return NULL;
|
||||
|
||||
|
||||
encoding = PyString_AS_STRING(self->encoding);
|
||||
errors = PyString_AS_STRING(self->errors);
|
||||
#endif
|
||||
|
||||
return PyUnicode_Decode(str, len, encoding, errors);
|
||||
}
|
||||
|
||||
static void
|
||||
UnicodeResultProcessor_dealloc(UnicodeResultProcessor *self)
|
||||
{
|
||||
Py_XDECREF(self->encoding);
|
||||
Py_XDECREF(self->errors);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
#else
|
||||
self->ob_type->tp_free((PyObject*)self);
|
||||
#endif
|
||||
}
|
||||
|
||||
static PyMethodDef UnicodeResultProcessor_methods[] = {
|
||||
{"process", (PyCFunction)UnicodeResultProcessor_process, METH_O,
|
||||
"The value processor itself."},
|
||||
{"conditional_process", (PyCFunction)UnicodeResultProcessor_conditional_process, METH_O,
|
||||
"Conditional version of the value processor."},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
static PyTypeObject UnicodeResultProcessorType = {
|
||||
PyObject_HEAD_INIT(NULL)
|
||||
0, /* ob_size */
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
"sqlalchemy.cprocessors.UnicodeResultProcessor", /* tp_name */
|
||||
sizeof(UnicodeResultProcessor), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
0, /* tp_dealloc */
|
||||
(destructor)UnicodeResultProcessor_dealloc, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
@@ -255,7 +509,11 @@ DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args,
|
||||
{
|
||||
PyObject *type, *format;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (!PyArg_ParseTuple(args, "OU", &type, &format))
|
||||
#else
|
||||
if (!PyArg_ParseTuple(args, "OS", &type, &format))
|
||||
#endif
|
||||
return -1;
|
||||
|
||||
Py_INCREF(type);
|
||||
@@ -275,22 +533,40 @@ DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value)
|
||||
if (value == Py_None)
|
||||
Py_RETURN_NONE;
|
||||
|
||||
if (PyFloat_CheckExact(value)) {
|
||||
/* Decimal does not accept float values directly */
|
||||
args = PyTuple_Pack(1, value);
|
||||
if (args == NULL)
|
||||
return NULL;
|
||||
/* Decimal does not accept float values directly */
|
||||
/* SQLite can also give us an integer here (see [ticket:2432]) */
|
||||
/* XXX: starting with Python 3.1, we could use Decimal.from_float(f),
|
||||
but the result wouldn't be the same */
|
||||
|
||||
str = PyString_Format(self->format, args);
|
||||
if (str == NULL)
|
||||
return NULL;
|
||||
args = PyTuple_Pack(1, value);
|
||||
if (args == NULL)
|
||||
return NULL;
|
||||
|
||||
result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
|
||||
Py_DECREF(str);
|
||||
return result;
|
||||
} else {
|
||||
return PyObject_CallFunctionObjArgs(self->type, value, NULL);
|
||||
}
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
str = PyUnicode_Format(self->format, args);
|
||||
#else
|
||||
str = PyString_Format(self->format, args);
|
||||
#endif
|
||||
|
||||
Py_DECREF(args);
|
||||
if (str == NULL)
|
||||
return NULL;
|
||||
|
||||
result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
|
||||
Py_DECREF(str);
|
||||
return result;
|
||||
}
|
||||
|
||||
static void
|
||||
DecimalResultProcessor_dealloc(DecimalResultProcessor *self)
|
||||
{
|
||||
Py_XDECREF(self->type);
|
||||
Py_XDECREF(self->format);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
#else
|
||||
self->ob_type->tp_free((PyObject*)self);
|
||||
#endif
|
||||
}
|
||||
|
||||
static PyMethodDef DecimalResultProcessor_methods[] = {
|
||||
@@ -300,12 +576,11 @@ static PyMethodDef DecimalResultProcessor_methods[] = {
|
||||
};
|
||||
|
||||
static PyTypeObject DecimalResultProcessorType = {
|
||||
PyObject_HEAD_INIT(NULL)
|
||||
0, /* ob_size */
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
"sqlalchemy.DecimalResultProcessor", /* tp_name */
|
||||
sizeof(DecimalResultProcessor), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
0, /* tp_dealloc */
|
||||
(destructor)DecimalResultProcessor_dealloc, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
@@ -341,11 +616,6 @@ static PyTypeObject DecimalResultProcessorType = {
|
||||
0, /* tp_new */
|
||||
};
|
||||
|
||||
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
|
||||
#define PyMODINIT_FUNC void
|
||||
#endif
|
||||
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"int_to_boolean", int_to_boolean, METH_O,
|
||||
"Convert an integer to a boolean."},
|
||||
@@ -362,23 +632,53 @@ static PyMethodDef module_methods[] = {
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
|
||||
#define PyMODINIT_FUNC void
|
||||
#endif
|
||||
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
|
||||
static struct PyModuleDef module_def = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
MODULE_NAME,
|
||||
MODULE_DOC,
|
||||
-1,
|
||||
module_methods
|
||||
};
|
||||
|
||||
#define INITERROR return NULL
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit_cprocessors(void)
|
||||
|
||||
#else
|
||||
|
||||
#define INITERROR return
|
||||
|
||||
PyMODINIT_FUNC
|
||||
initcprocessors(void)
|
||||
|
||||
#endif
|
||||
|
||||
{
|
||||
PyObject *m;
|
||||
|
||||
UnicodeResultProcessorType.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&UnicodeResultProcessorType) < 0)
|
||||
return;
|
||||
INITERROR;
|
||||
|
||||
DecimalResultProcessorType.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&DecimalResultProcessorType) < 0)
|
||||
return;
|
||||
INITERROR;
|
||||
|
||||
m = Py_InitModule3("cprocessors", module_methods,
|
||||
"Module containing C versions of data processing functions.");
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
m = PyModule_Create(&module_def);
|
||||
#else
|
||||
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
|
||||
#endif
|
||||
if (m == NULL)
|
||||
return;
|
||||
INITERROR;
|
||||
|
||||
PyDateTime_IMPORT;
|
||||
|
||||
@@ -389,5 +689,8 @@ initcprocessors(void)
|
||||
Py_INCREF(&DecimalResultProcessorType);
|
||||
PyModule_AddObject(m, "DecimalResultProcessor",
|
||||
(PyObject *)&DecimalResultProcessorType);
|
||||
}
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return m;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/*
|
||||
resultproxy.c
|
||||
Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
|
||||
Copyright (C) 2010-2017 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com
|
||||
|
||||
This module is part of SQLAlchemy and is released under
|
||||
the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
@@ -8,6 +9,18 @@ the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#define MODULE_NAME "cresultproxy"
|
||||
#define MODULE_DOC "Module containing C versions of core ResultProxy classes."
|
||||
|
||||
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
|
||||
typedef int Py_ssize_t;
|
||||
#define PY_SSIZE_T_MAX INT_MAX
|
||||
#define PY_SSIZE_T_MIN INT_MIN
|
||||
typedef Py_ssize_t (*lenfunc)(PyObject *);
|
||||
#define PyInt_FromSsize_t(x) PyInt_FromLong(x)
|
||||
typedef intargfunc ssizeargfunc;
|
||||
#endif
|
||||
|
||||
|
||||
/***********
|
||||
* Structs *
|
||||
@@ -69,8 +82,8 @@ BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds)
|
||||
Py_INCREF(parent);
|
||||
self->parent = parent;
|
||||
|
||||
if (!PyTuple_CheckExact(row)) {
|
||||
PyErr_SetString(PyExc_TypeError, "row must be a tuple");
|
||||
if (!PySequence_Check(row)) {
|
||||
PyErr_SetString(PyExc_TypeError, "row must be a sequence");
|
||||
return -1;
|
||||
}
|
||||
Py_INCREF(row);
|
||||
@@ -100,11 +113,11 @@ BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds)
|
||||
static PyObject *
|
||||
BaseRowProxy_reduce(PyObject *self)
|
||||
{
|
||||
PyObject *method, *state;
|
||||
PyObject *module, *reconstructor, *cls;
|
||||
PyObject *method, *state;
|
||||
PyObject *module, *reconstructor, *cls;
|
||||
|
||||
method = PyObject_GetAttrString(self, "__getstate__");
|
||||
if (method == NULL)
|
||||
method = PyObject_GetAttrString(self, "__getstate__");
|
||||
if (method == NULL)
|
||||
return NULL;
|
||||
|
||||
state = PyObject_CallObject(method, NULL);
|
||||
@@ -112,7 +125,7 @@ BaseRowProxy_reduce(PyObject *self)
|
||||
if (state == NULL)
|
||||
return NULL;
|
||||
|
||||
module = PyImport_ImportModule("sqlalchemy.engine.base");
|
||||
module = PyImport_ImportModule("sqlalchemy.engine.result");
|
||||
if (module == NULL)
|
||||
return NULL;
|
||||
|
||||
@@ -140,7 +153,11 @@ BaseRowProxy_dealloc(BaseRowProxy *self)
|
||||
Py_XDECREF(self->row);
|
||||
Py_XDECREF(self->processors);
|
||||
Py_XDECREF(self->keymap);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_TYPE(self)->tp_free((PyObject *)self);
|
||||
#else
|
||||
self->ob_type->tp_free((PyObject *)self);
|
||||
#endif
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
@@ -148,13 +165,15 @@ BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
|
||||
{
|
||||
Py_ssize_t num_values, num_processors;
|
||||
PyObject **valueptr, **funcptr, **resultptr;
|
||||
PyObject *func, *result, *processed_value;
|
||||
PyObject *func, *result, *processed_value, *values_fastseq;
|
||||
|
||||
num_values = Py_SIZE(values);
|
||||
num_processors = Py_SIZE(processors);
|
||||
num_values = PySequence_Length(values);
|
||||
num_processors = PyList_Size(processors);
|
||||
if (num_values != num_processors) {
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
"number of values in row differ from number of column processors");
|
||||
PyErr_Format(PyExc_RuntimeError,
|
||||
"number of values in row (%d) differ from number of column "
|
||||
"processors (%d)",
|
||||
(int)num_values, (int)num_processors);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@@ -166,9 +185,11 @@ BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
|
||||
if (result == NULL)
|
||||
return NULL;
|
||||
|
||||
/* we don't need to use PySequence_Fast as long as values, processors and
|
||||
* result are simple tuple or lists. */
|
||||
valueptr = PySequence_Fast_ITEMS(values);
|
||||
values_fastseq = PySequence_Fast(values, "row must be a sequence");
|
||||
if (values_fastseq == NULL)
|
||||
return NULL;
|
||||
|
||||
valueptr = PySequence_Fast_ITEMS(values_fastseq);
|
||||
funcptr = PySequence_Fast_ITEMS(processors);
|
||||
resultptr = PySequence_Fast_ITEMS(result);
|
||||
while (--num_values >= 0) {
|
||||
@@ -177,6 +198,7 @@ BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
|
||||
processed_value = PyObject_CallFunctionObjArgs(func, *valueptr,
|
||||
NULL);
|
||||
if (processed_value == NULL) {
|
||||
Py_DECREF(values_fastseq);
|
||||
Py_DECREF(result);
|
||||
return NULL;
|
||||
}
|
||||
@@ -189,6 +211,7 @@ BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
|
||||
funcptr++;
|
||||
resultptr++;
|
||||
}
|
||||
Py_DECREF(values_fastseq);
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -199,19 +222,12 @@ BaseRowProxy_values(BaseRowProxy *self)
|
||||
self->processors, 0);
|
||||
}
|
||||
|
||||
static PyTupleObject *
|
||||
BaseRowProxy_tuplevalues(BaseRowProxy *self)
|
||||
{
|
||||
return (PyTupleObject *)BaseRowProxy_processvalues(self->row,
|
||||
self->processors, 1);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_iter(BaseRowProxy *self)
|
||||
{
|
||||
PyObject *values, *result;
|
||||
|
||||
values = (PyObject *)BaseRowProxy_tuplevalues(self);
|
||||
values = BaseRowProxy_processvalues(self->row, self->processors, 1);
|
||||
if (values == NULL)
|
||||
return NULL;
|
||||
|
||||
@@ -226,26 +242,39 @@ BaseRowProxy_iter(BaseRowProxy *self)
|
||||
static Py_ssize_t
|
||||
BaseRowProxy_length(BaseRowProxy *self)
|
||||
{
|
||||
return Py_SIZE(self->row);
|
||||
return PySequence_Length(self->row);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
|
||||
{
|
||||
PyObject *processors, *values;
|
||||
PyObject *processor, *value;
|
||||
PyObject *record, *result, *indexobject;
|
||||
PyObject *exc_module, *exception;
|
||||
PyObject *processor, *value, *processed_value;
|
||||
PyObject *row, *record, *result, *indexobject;
|
||||
PyObject *exc_module, *exception, *cstr_obj;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject *bytes;
|
||||
#endif
|
||||
char *cstr_key;
|
||||
long index;
|
||||
int key_fallback = 0;
|
||||
int tuple_check = 0;
|
||||
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (PyInt_CheckExact(key)) {
|
||||
index = PyInt_AS_LONG(key);
|
||||
} else if (PyLong_CheckExact(key)) {
|
||||
if (index < 0)
|
||||
index += BaseRowProxy_length(self);
|
||||
} else
|
||||
#endif
|
||||
|
||||
if (PyLong_CheckExact(key)) {
|
||||
index = PyLong_AsLong(key);
|
||||
if ((index == -1) && PyErr_Occurred())
|
||||
/* -1 can be either the actual value, or an error flag. */
|
||||
return NULL;
|
||||
if (index < 0)
|
||||
index += BaseRowProxy_length(self);
|
||||
} else if (PySlice_Check(key)) {
|
||||
values = PyObject_GetItem(self->row, key);
|
||||
if (values == NULL)
|
||||
@@ -268,12 +297,17 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
|
||||
"O", key);
|
||||
if (record == NULL)
|
||||
return NULL;
|
||||
key_fallback = 1;
|
||||
}
|
||||
|
||||
indexobject = PyTuple_GetItem(record, 1);
|
||||
indexobject = PyTuple_GetItem(record, 2);
|
||||
if (indexobject == NULL)
|
||||
return NULL;
|
||||
|
||||
if (key_fallback) {
|
||||
Py_DECREF(record);
|
||||
}
|
||||
|
||||
if (indexobject == Py_None) {
|
||||
exc_module = PyImport_ImportModule("sqlalchemy.exc");
|
||||
if (exc_module == NULL)
|
||||
@@ -285,17 +319,47 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
|
||||
if (exception == NULL)
|
||||
return NULL;
|
||||
|
||||
cstr_key = PyString_AsString(key);
|
||||
if (cstr_key == NULL)
|
||||
cstr_obj = PyTuple_GetItem(record, 1);
|
||||
if (cstr_obj == NULL)
|
||||
return NULL;
|
||||
|
||||
cstr_obj = PyObject_Str(cstr_obj);
|
||||
if (cstr_obj == NULL)
|
||||
return NULL;
|
||||
|
||||
/*
|
||||
FIXME: raise encoding error exception (in both versions below)
|
||||
if the key contains non-ascii chars, instead of an
|
||||
InvalidRequestError without any message like in the
|
||||
python version.
|
||||
*/
|
||||
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
bytes = PyUnicode_AsASCIIString(cstr_obj);
|
||||
if (bytes == NULL)
|
||||
return NULL;
|
||||
cstr_key = PyBytes_AS_STRING(bytes);
|
||||
#else
|
||||
cstr_key = PyString_AsString(cstr_obj);
|
||||
#endif
|
||||
if (cstr_key == NULL) {
|
||||
Py_DECREF(cstr_obj);
|
||||
return NULL;
|
||||
}
|
||||
Py_DECREF(cstr_obj);
|
||||
|
||||
PyErr_Format(exception,
|
||||
"Ambiguous column name '%s' in result set! "
|
||||
"try 'use_labels' option on select statement.", cstr_key);
|
||||
"Ambiguous column name '%.200s' in "
|
||||
"result set column descriptions", cstr_key);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
index = PyLong_AsLong(indexobject);
|
||||
#else
|
||||
index = PyInt_AsLong(indexobject);
|
||||
#endif
|
||||
if ((index == -1) && PyErr_Occurred())
|
||||
/* -1 can be either the actual value, or an error flag. */
|
||||
return NULL;
|
||||
@@ -304,22 +368,53 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
|
||||
if (processor == NULL)
|
||||
return NULL;
|
||||
|
||||
value = PyTuple_GetItem(self->row, index);
|
||||
row = self->row;
|
||||
if (PyTuple_CheckExact(row)) {
|
||||
value = PyTuple_GetItem(row, index);
|
||||
tuple_check = 1;
|
||||
}
|
||||
else {
|
||||
value = PySequence_GetItem(row, index);
|
||||
tuple_check = 0;
|
||||
}
|
||||
|
||||
if (value == NULL)
|
||||
return NULL;
|
||||
|
||||
if (processor != Py_None) {
|
||||
return PyObject_CallFunctionObjArgs(processor, value, NULL);
|
||||
processed_value = PyObject_CallFunctionObjArgs(processor, value, NULL);
|
||||
if (!tuple_check) {
|
||||
Py_DECREF(value);
|
||||
}
|
||||
return processed_value;
|
||||
} else {
|
||||
Py_INCREF(value);
|
||||
if (tuple_check) {
|
||||
Py_INCREF(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getitem(PyObject *self, Py_ssize_t i)
|
||||
{
|
||||
PyObject *index;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
index = PyLong_FromSsize_t(i);
|
||||
#else
|
||||
index = PyInt_FromSsize_t(i);
|
||||
#endif
|
||||
return BaseRowProxy_subscript((BaseRowProxy*)self, index);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
|
||||
{
|
||||
PyObject *tmp;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject *err_bytes;
|
||||
#endif
|
||||
|
||||
if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) {
|
||||
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
|
||||
@@ -329,7 +424,28 @@ BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
|
||||
else
|
||||
return tmp;
|
||||
|
||||
return BaseRowProxy_subscript(self, name);
|
||||
tmp = BaseRowProxy_subscript(self, name);
|
||||
if (tmp == NULL && PyErr_ExceptionMatches(PyExc_KeyError)) {
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
err_bytes = PyUnicode_AsASCIIString(name);
|
||||
if (err_bytes == NULL)
|
||||
return NULL;
|
||||
PyErr_Format(
|
||||
PyExc_AttributeError,
|
||||
"Could not locate column in row for column '%.200s'",
|
||||
PyBytes_AS_STRING(err_bytes)
|
||||
);
|
||||
#else
|
||||
PyErr_Format(
|
||||
PyExc_AttributeError,
|
||||
"Could not locate column in row for column '%.200s'",
|
||||
PyString_AsString(name)
|
||||
);
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
/***********************
|
||||
@@ -354,7 +470,7 @@ BaseRowProxy_setparent(BaseRowProxy *self, PyObject *value, void *closure)
|
||||
return -1;
|
||||
}
|
||||
|
||||
module = PyImport_ImportModule("sqlalchemy.engine.base");
|
||||
module = PyImport_ImportModule("sqlalchemy.engine.result");
|
||||
if (module == NULL)
|
||||
return -1;
|
||||
|
||||
@@ -393,9 +509,9 @@ BaseRowProxy_setrow(BaseRowProxy *self, PyObject *value, void *closure)
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!PyTuple_CheckExact(value)) {
|
||||
if (!PySequence_Check(value)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"The 'row' attribute value must be a tuple");
|
||||
"The 'row' attribute value must be a sequence");
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -487,8 +603,8 @@ static PyGetSetDef BaseRowProxy_getseters[] = {
|
||||
static PyMethodDef BaseRowProxy_methods[] = {
|
||||
{"values", (PyCFunction)BaseRowProxy_values, METH_NOARGS,
|
||||
"Return the values represented by this BaseRowProxy as a list."},
|
||||
{"__reduce__", (PyCFunction)BaseRowProxy_reduce, METH_NOARGS,
|
||||
"Pickle support method."},
|
||||
{"__reduce__", (PyCFunction)BaseRowProxy_reduce, METH_NOARGS,
|
||||
"Pickle support method."},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
@@ -496,7 +612,7 @@ static PySequenceMethods BaseRowProxy_as_sequence = {
|
||||
(lenfunc)BaseRowProxy_length, /* sq_length */
|
||||
0, /* sq_concat */
|
||||
0, /* sq_repeat */
|
||||
0, /* sq_item */
|
||||
(ssizeargfunc)BaseRowProxy_getitem, /* sq_item */
|
||||
0, /* sq_slice */
|
||||
0, /* sq_ass_item */
|
||||
0, /* sq_ass_slice */
|
||||
@@ -512,8 +628,7 @@ static PyMappingMethods BaseRowProxy_as_mapping = {
|
||||
};
|
||||
|
||||
static PyTypeObject BaseRowProxyType = {
|
||||
PyObject_HEAD_INIT(NULL)
|
||||
0, /* ob_size */
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
"sqlalchemy.cresultproxy.BaseRowProxy", /* tp_name */
|
||||
sizeof(BaseRowProxy), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
@@ -553,34 +668,60 @@ static PyTypeObject BaseRowProxyType = {
|
||||
0 /* tp_new */
|
||||
};
|
||||
|
||||
|
||||
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
|
||||
#define PyMODINIT_FUNC void
|
||||
#endif
|
||||
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"safe_rowproxy_reconstructor", safe_rowproxy_reconstructor, METH_VARARGS,
|
||||
"reconstruct a RowProxy instance from its pickled form."},
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
|
||||
#define PyMODINIT_FUNC void
|
||||
#endif
|
||||
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
|
||||
static struct PyModuleDef module_def = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
MODULE_NAME,
|
||||
MODULE_DOC,
|
||||
-1,
|
||||
module_methods
|
||||
};
|
||||
|
||||
#define INITERROR return NULL
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit_cresultproxy(void)
|
||||
|
||||
#else
|
||||
|
||||
#define INITERROR return
|
||||
|
||||
PyMODINIT_FUNC
|
||||
initcresultproxy(void)
|
||||
|
||||
#endif
|
||||
|
||||
{
|
||||
PyObject *m;
|
||||
|
||||
BaseRowProxyType.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&BaseRowProxyType) < 0)
|
||||
return;
|
||||
INITERROR;
|
||||
|
||||
m = Py_InitModule3("cresultproxy", module_methods,
|
||||
"Module containing C versions of core ResultProxy classes.");
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
m = PyModule_Create(&module_def);
|
||||
#else
|
||||
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
|
||||
#endif
|
||||
if (m == NULL)
|
||||
return;
|
||||
INITERROR;
|
||||
|
||||
Py_INCREF(&BaseRowProxyType);
|
||||
PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType);
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return m;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
225
sqlalchemy/cextension/utils.c
Normal file
225
sqlalchemy/cextension/utils.c
Normal file
@@ -0,0 +1,225 @@
|
||||
/*
|
||||
utils.c
|
||||
Copyright (C) 2012-2017 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
|
||||
This module is part of SQLAlchemy and is released under
|
||||
the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
*/
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#define MODULE_NAME "cutils"
|
||||
#define MODULE_DOC "Module containing C versions of utility functions."
|
||||
|
||||
/*
|
||||
Given arguments from the calling form *multiparams, **params,
|
||||
return a list of bind parameter structures, usually a list of
|
||||
dictionaries.
|
||||
|
||||
In the case of 'raw' execution which accepts positional parameters,
|
||||
it may be a list of tuples or lists.
|
||||
|
||||
*/
|
||||
static PyObject *
|
||||
distill_params(PyObject *self, PyObject *args)
|
||||
{
|
||||
PyObject *multiparams, *params;
|
||||
PyObject *enclosing_list, *double_enclosing_list;
|
||||
PyObject *zero_element, *zero_element_item;
|
||||
Py_ssize_t multiparam_size, zero_element_length;
|
||||
|
||||
if (!PyArg_UnpackTuple(args, "_distill_params", 2, 2, &multiparams, ¶ms)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (multiparams != Py_None) {
|
||||
multiparam_size = PyTuple_Size(multiparams);
|
||||
if (multiparam_size < 0) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
else {
|
||||
multiparam_size = 0;
|
||||
}
|
||||
|
||||
if (multiparam_size == 0) {
|
||||
if (params != Py_None && PyDict_Size(params) != 0) {
|
||||
enclosing_list = PyList_New(1);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(params);
|
||||
if (PyList_SetItem(enclosing_list, 0, params) == -1) {
|
||||
Py_DECREF(params);
|
||||
Py_DECREF(enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
else {
|
||||
enclosing_list = PyList_New(0);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
return enclosing_list;
|
||||
}
|
||||
else if (multiparam_size == 1) {
|
||||
zero_element = PyTuple_GetItem(multiparams, 0);
|
||||
if (PyTuple_Check(zero_element) || PyList_Check(zero_element)) {
|
||||
zero_element_length = PySequence_Length(zero_element);
|
||||
|
||||
if (zero_element_length != 0) {
|
||||
zero_element_item = PySequence_GetItem(zero_element, 0);
|
||||
if (zero_element_item == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
else {
|
||||
zero_element_item = NULL;
|
||||
}
|
||||
|
||||
if (zero_element_length == 0 ||
|
||||
(
|
||||
PyObject_HasAttrString(zero_element_item, "__iter__") &&
|
||||
!PyObject_HasAttrString(zero_element_item, "strip")
|
||||
)
|
||||
) {
|
||||
/*
|
||||
* execute(stmt, [{}, {}, {}, ...])
|
||||
* execute(stmt, [(), (), (), ...])
|
||||
*/
|
||||
Py_XDECREF(zero_element_item);
|
||||
Py_INCREF(zero_element);
|
||||
return zero_element;
|
||||
}
|
||||
else {
|
||||
/*
|
||||
* execute(stmt, ("value", "value"))
|
||||
*/
|
||||
Py_XDECREF(zero_element_item);
|
||||
enclosing_list = PyList_New(1);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(zero_element);
|
||||
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
|
||||
Py_DECREF(zero_element);
|
||||
Py_DECREF(enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
return enclosing_list;
|
||||
}
|
||||
}
|
||||
else if (PyObject_HasAttrString(zero_element, "keys")) {
|
||||
/*
|
||||
* execute(stmt, {"key":"value"})
|
||||
*/
|
||||
enclosing_list = PyList_New(1);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(zero_element);
|
||||
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
|
||||
Py_DECREF(zero_element);
|
||||
Py_DECREF(enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
return enclosing_list;
|
||||
} else {
|
||||
enclosing_list = PyList_New(1);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
double_enclosing_list = PyList_New(1);
|
||||
if (double_enclosing_list == NULL) {
|
||||
Py_DECREF(enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(zero_element);
|
||||
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
|
||||
Py_DECREF(zero_element);
|
||||
Py_DECREF(enclosing_list);
|
||||
Py_DECREF(double_enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
if (PyList_SetItem(double_enclosing_list, 0, enclosing_list) == -1) {
|
||||
Py_DECREF(zero_element);
|
||||
Py_DECREF(enclosing_list);
|
||||
Py_DECREF(double_enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
return double_enclosing_list;
|
||||
}
|
||||
}
|
||||
else {
|
||||
zero_element = PyTuple_GetItem(multiparams, 0);
|
||||
if (PyObject_HasAttrString(zero_element, "__iter__") &&
|
||||
!PyObject_HasAttrString(zero_element, "strip")
|
||||
) {
|
||||
Py_INCREF(multiparams);
|
||||
return multiparams;
|
||||
}
|
||||
else {
|
||||
enclosing_list = PyList_New(1);
|
||||
if (enclosing_list == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(multiparams);
|
||||
if (PyList_SetItem(enclosing_list, 0, multiparams) == -1) {
|
||||
Py_DECREF(multiparams);
|
||||
Py_DECREF(enclosing_list);
|
||||
return NULL;
|
||||
}
|
||||
return enclosing_list;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"_distill_params", distill_params, METH_VARARGS,
|
||||
"Distill an execute() parameter structure."},
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
|
||||
#define PyMODINIT_FUNC void
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
|
||||
static struct PyModuleDef module_def = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
MODULE_NAME,
|
||||
MODULE_DOC,
|
||||
-1,
|
||||
module_methods
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyMODINIT_FUNC
|
||||
PyInit_cutils(void)
|
||||
#else
|
||||
PyMODINIT_FUNC
|
||||
initcutils(void)
|
||||
#endif
|
||||
{
|
||||
PyObject *m;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
m = PyModule_Create(&module_def);
|
||||
#else
|
||||
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (m == NULL)
|
||||
return NULL;
|
||||
return m;
|
||||
#else
|
||||
if (m == NULL)
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -1,31 +1,30 @@
|
||||
# __init__.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
# databases/__init__.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from sqlalchemy.dialects.sqlite import base as sqlite
|
||||
from sqlalchemy.dialects.postgresql import base as postgresql
|
||||
"""Include imports from the sqlalchemy.dialects package for backwards
|
||||
compatibility with pre 0.6 versions.
|
||||
|
||||
"""
|
||||
from ..dialects.sqlite import base as sqlite
|
||||
from ..dialects.postgresql import base as postgresql
|
||||
postgres = postgresql
|
||||
from sqlalchemy.dialects.mysql import base as mysql
|
||||
from sqlalchemy.dialects.oracle import base as oracle
|
||||
from sqlalchemy.dialects.firebird import base as firebird
|
||||
from sqlalchemy.dialects.maxdb import base as maxdb
|
||||
from sqlalchemy.dialects.informix import base as informix
|
||||
from sqlalchemy.dialects.mssql import base as mssql
|
||||
from sqlalchemy.dialects.access import base as access
|
||||
from sqlalchemy.dialects.sybase import base as sybase
|
||||
from ..dialects.mysql import base as mysql
|
||||
from ..dialects.oracle import base as oracle
|
||||
from ..dialects.firebird import base as firebird
|
||||
from ..dialects.mssql import base as mssql
|
||||
from ..dialects.sybase import base as sybase
|
||||
|
||||
|
||||
__all__ = (
|
||||
'access',
|
||||
'firebird',
|
||||
'informix',
|
||||
'maxdb',
|
||||
'mssql',
|
||||
'mysql',
|
||||
'postgresql',
|
||||
'sqlite',
|
||||
'oracle',
|
||||
'sybase',
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
from sqlalchemy.dialects.firebird import base, kinterbasdb
|
||||
# firebird/__init__.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
base.dialect = kinterbasdb.dialect
|
||||
from sqlalchemy.dialects.firebird import base, kinterbasdb, fdb
|
||||
|
||||
base.dialect = fdb.dialect
|
||||
|
||||
from sqlalchemy.dialects.firebird.base import \
|
||||
SMALLINT, BIGINT, FLOAT, FLOAT, DATE, TIME, \
|
||||
TEXT, NUMERIC, FLOAT, TIMESTAMP, VARCHAR, CHAR, BLOB,\
|
||||
dialect
|
||||
|
||||
|
||||
__all__ = (
|
||||
'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME',
|
||||
'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME',
|
||||
'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB',
|
||||
'dialect'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# firebird.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
# firebird/base.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
Support for the Firebird database.
|
||||
r"""
|
||||
|
||||
Connectivity is usually supplied via the kinterbasdb_ DBAPI module.
|
||||
.. dialect:: firebird
|
||||
:name: Firebird
|
||||
|
||||
Dialects
|
||||
~~~~~~~~
|
||||
Firebird Dialects
|
||||
-----------------
|
||||
|
||||
Firebird offers two distinct dialects_ (not to be confused with a
|
||||
SQLAlchemy ``Dialect``):
|
||||
@@ -27,7 +28,7 @@ support for dialect 1 is not well tested and probably has
|
||||
incompatibilities.
|
||||
|
||||
Locking Behavior
|
||||
~~~~~~~~~~~~~~~~
|
||||
----------------
|
||||
|
||||
Firebird locks tables aggressively. For this reason, a DROP TABLE may
|
||||
hang until other transactions are released. SQLAlchemy does its best
|
||||
@@ -47,20 +48,20 @@ The above use case can be alleviated by calling ``first()`` on the
|
||||
all remaining cursor/connection resources.
|
||||
|
||||
RETURNING support
|
||||
~~~~~~~~~~~~~~~~~
|
||||
-----------------
|
||||
|
||||
Firebird 2.0 supports returning a result set from inserts, and 2.1
|
||||
extends that to deletes and updates. This is generically exposed by
|
||||
the SQLAlchemy ``returning()`` method, such as::
|
||||
|
||||
# INSERT..RETURNING
|
||||
result = table.insert().returning(table.c.col1, table.c.col2).\\
|
||||
result = table.insert().returning(table.c.col1, table.c.col2).\
|
||||
values(name='foo')
|
||||
print result.fetchall()
|
||||
|
||||
# UPDATE..RETURNING
|
||||
raises = empl.update().returning(empl.c.id, empl.c.salary).\\
|
||||
where(empl.c.sales>100).\\
|
||||
raises = empl.update().returning(empl.c.id, empl.c.salary).\
|
||||
where(empl.c.sales>100).\
|
||||
values(dict(salary=empl.c.salary * 1.1))
|
||||
print raises.fetchall()
|
||||
|
||||
@@ -69,18 +70,17 @@ the SQLAlchemy ``returning()`` method, such as::
|
||||
|
||||
"""
|
||||
|
||||
import datetime, re
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import schema as sa_schema
|
||||
from sqlalchemy import exc, types as sqltypes, sql, util
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.engine import base, default, reflection
|
||||
from sqlalchemy.sql import compiler
|
||||
from sqlalchemy.sql.elements import quoted_name
|
||||
|
||||
|
||||
from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE,
|
||||
FLOAT, INTEGER, NUMERIC, SMALLINT,
|
||||
TEXT, TIME, TIMESTAMP, VARCHAR)
|
||||
from sqlalchemy.types import (BIGINT, BLOB, DATE, FLOAT, INTEGER, NUMERIC,
|
||||
SMALLINT, TEXT, TIME, TIMESTAMP, Integer)
|
||||
|
||||
|
||||
RESERVED_WORDS = set([
|
||||
@@ -120,65 +120,144 @@ RESERVED_WORDS = set([
|
||||
"union", "unique", "update", "upper", "user", "using", "value",
|
||||
"values", "varchar", "variable", "varying", "view", "wait", "when",
|
||||
"where", "while", "with", "work", "write", "year",
|
||||
])
|
||||
])
|
||||
|
||||
|
||||
class _StringType(sqltypes.String):
|
||||
"""Base for Firebird string types."""
|
||||
|
||||
def __init__(self, charset=None, **kw):
|
||||
self.charset = charset
|
||||
super(_StringType, self).__init__(**kw)
|
||||
|
||||
|
||||
class VARCHAR(_StringType, sqltypes.VARCHAR):
|
||||
"""Firebird VARCHAR type"""
|
||||
__visit_name__ = 'VARCHAR'
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
super(VARCHAR, self).__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class CHAR(_StringType, sqltypes.CHAR):
|
||||
"""Firebird CHAR type"""
|
||||
__visit_name__ = 'CHAR'
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
super(CHAR, self).__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class _FBDateTime(sqltypes.DateTime):
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if type(value) == datetime.date:
|
||||
return datetime.datetime(value.year, value.month, value.day)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
colspecs = {
|
||||
sqltypes.DateTime: _FBDateTime
|
||||
}
|
||||
|
||||
ischema_names = {
|
||||
'SHORT': SMALLINT,
|
||||
'LONG': BIGINT,
|
||||
'QUAD': FLOAT,
|
||||
'FLOAT': FLOAT,
|
||||
'DATE': DATE,
|
||||
'TIME': TIME,
|
||||
'TEXT': TEXT,
|
||||
'INT64': NUMERIC,
|
||||
'DOUBLE': FLOAT,
|
||||
'TIMESTAMP': TIMESTAMP,
|
||||
'SHORT': SMALLINT,
|
||||
'LONG': INTEGER,
|
||||
'QUAD': FLOAT,
|
||||
'FLOAT': FLOAT,
|
||||
'DATE': DATE,
|
||||
'TIME': TIME,
|
||||
'TEXT': TEXT,
|
||||
'INT64': BIGINT,
|
||||
'DOUBLE': FLOAT,
|
||||
'TIMESTAMP': TIMESTAMP,
|
||||
'VARYING': VARCHAR,
|
||||
'CSTRING': CHAR,
|
||||
'BLOB': BLOB,
|
||||
}
|
||||
'BLOB': BLOB,
|
||||
}
|
||||
|
||||
|
||||
# TODO: date conversion types (should be implemented as _FBDateTime, _FBDate, etc.
|
||||
# as bind/result functionality is required)
|
||||
# TODO: date conversion types (should be implemented as _FBDateTime,
|
||||
# _FBDate, etc. as bind/result functionality is required)
|
||||
|
||||
class FBTypeCompiler(compiler.GenericTypeCompiler):
|
||||
def visit_boolean(self, type_):
|
||||
return self.visit_SMALLINT(type_)
|
||||
def visit_boolean(self, type_, **kw):
|
||||
return self.visit_SMALLINT(type_, **kw)
|
||||
|
||||
def visit_datetime(self, type_):
|
||||
return self.visit_TIMESTAMP(type_)
|
||||
def visit_datetime(self, type_, **kw):
|
||||
return self.visit_TIMESTAMP(type_, **kw)
|
||||
|
||||
def visit_TEXT(self, type_):
|
||||
def visit_TEXT(self, type_, **kw):
|
||||
return "BLOB SUB_TYPE 1"
|
||||
|
||||
def visit_BLOB(self, type_):
|
||||
def visit_BLOB(self, type_, **kw):
|
||||
return "BLOB SUB_TYPE 0"
|
||||
|
||||
def _extend_string(self, type_, basic):
|
||||
charset = getattr(type_, 'charset', None)
|
||||
if charset is None:
|
||||
return basic
|
||||
else:
|
||||
return '%s CHARACTER SET %s' % (basic, charset)
|
||||
|
||||
def visit_CHAR(self, type_, **kw):
|
||||
basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw)
|
||||
return self._extend_string(type_, basic)
|
||||
|
||||
def visit_VARCHAR(self, type_, **kw):
|
||||
if not type_.length:
|
||||
raise exc.CompileError(
|
||||
"VARCHAR requires a length on dialect %s" %
|
||||
self.dialect.name)
|
||||
basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw)
|
||||
return self._extend_string(type_, basic)
|
||||
|
||||
|
||||
class FBCompiler(sql.compiler.SQLCompiler):
|
||||
"""Firebird specific idiosincrasies"""
|
||||
"""Firebird specific idiosyncrasies"""
|
||||
|
||||
def visit_mod(self, binary, **kw):
|
||||
# Firebird lacks a builtin modulo operator, but there is
|
||||
# an equivalent function in the ib_udf library.
|
||||
return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
|
||||
ansi_bind_rules = True
|
||||
|
||||
# def visit_contains_op_binary(self, binary, operator, **kw):
|
||||
# cant use CONTAINING b.c. it's case insensitive.
|
||||
|
||||
# def visit_notcontains_op_binary(self, binary, operator, **kw):
|
||||
# cant use NOT CONTAINING b.c. it's case insensitive.
|
||||
|
||||
def visit_now_func(self, fn, **kw):
|
||||
return "CURRENT_TIMESTAMP"
|
||||
|
||||
def visit_startswith_op_binary(self, binary, operator, **kw):
|
||||
return '%s STARTING WITH %s' % (
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.right._compiler_dispatch(self, **kw))
|
||||
|
||||
def visit_notstartswith_op_binary(self, binary, operator, **kw):
|
||||
return '%s NOT STARTING WITH %s' % (
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.right._compiler_dispatch(self, **kw))
|
||||
|
||||
def visit_mod_binary(self, binary, operator, **kw):
|
||||
return "mod(%s, %s)" % (
|
||||
self.process(binary.left, **kw),
|
||||
self.process(binary.right, **kw))
|
||||
|
||||
def visit_alias(self, alias, asfrom=False, **kwargs):
|
||||
if self.dialect._version_two:
|
||||
return super(FBCompiler, self).visit_alias(alias, asfrom=asfrom, **kwargs)
|
||||
return super(FBCompiler, self).\
|
||||
visit_alias(alias, asfrom=asfrom, **kwargs)
|
||||
else:
|
||||
# Override to not use the AS keyword which FB 1.5 does not like
|
||||
if asfrom:
|
||||
alias_name = isinstance(alias.name, expression._generated_label) and \
|
||||
self._truncated_identifier("alias", alias.name) or alias.name
|
||||
alias_name = isinstance(alias.name,
|
||||
expression._truncated_label) and \
|
||||
self._truncated_identifier("alias",
|
||||
alias.name) or alias.name
|
||||
|
||||
return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + \
|
||||
self.preparer.format_alias(alias, alias_name)
|
||||
return self.process(
|
||||
alias.original, asfrom=asfrom, **kwargs) + \
|
||||
" " + \
|
||||
self.preparer.format_alias(alias, alias_name)
|
||||
else:
|
||||
return self.process(alias.original, **kwargs)
|
||||
|
||||
@@ -200,8 +279,12 @@ class FBCompiler(sql.compiler.SQLCompiler):
|
||||
visit_char_length_func = visit_length_func
|
||||
|
||||
def function_argspec(self, func, **kw):
|
||||
# TODO: this probably will need to be
|
||||
# narrowed to a fixed list, some no-arg functions
|
||||
# may require parens - see similar example in the oracle
|
||||
# dialect
|
||||
if func.clauses is not None and len(func.clauses):
|
||||
return self.process(func.clause_expr)
|
||||
return self.process(func.clause_expr, **kw)
|
||||
else:
|
||||
return ""
|
||||
|
||||
@@ -211,41 +294,37 @@ class FBCompiler(sql.compiler.SQLCompiler):
|
||||
def visit_sequence(self, seq):
|
||||
return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
|
||||
|
||||
def get_select_precolumns(self, select):
|
||||
def get_select_precolumns(self, select, **kw):
|
||||
"""Called when building a ``SELECT`` statement, position is just
|
||||
before column list Firebird puts the limit and offset right
|
||||
after the ``SELECT``...
|
||||
"""
|
||||
|
||||
result = ""
|
||||
if select._limit:
|
||||
result += "FIRST %d " % select._limit
|
||||
if select._offset:
|
||||
result +="SKIP %d " % select._offset
|
||||
if select._limit_clause is not None:
|
||||
result += "FIRST %s " % self.process(select._limit_clause, **kw)
|
||||
if select._offset_clause is not None:
|
||||
result += "SKIP %s " % self.process(select._offset_clause, **kw)
|
||||
if select._distinct:
|
||||
result += "DISTINCT "
|
||||
return result
|
||||
|
||||
def limit_clause(self, select):
|
||||
def limit_clause(self, select, **kw):
|
||||
"""Already taken care of in the `get_select_precolumns` method."""
|
||||
|
||||
return ""
|
||||
|
||||
def returning_clause(self, stmt, returning_cols):
|
||||
|
||||
columns = [
|
||||
self.process(
|
||||
self.label_select_column(None, c, asfrom=False),
|
||||
within_columns_clause=True,
|
||||
result_map=self.result_map
|
||||
)
|
||||
for c in expression._select_iterables(returning_cols)
|
||||
]
|
||||
self._label_select_column(None, c, True, False, {})
|
||||
for c in expression._select_iterables(returning_cols)
|
||||
]
|
||||
|
||||
return 'RETURNING ' + ', '.join(columns)
|
||||
|
||||
|
||||
class FBDDLCompiler(sql.compiler.DDLCompiler):
|
||||
"""Firebird syntactic idiosincrasies"""
|
||||
"""Firebird syntactic idiosyncrasies"""
|
||||
|
||||
def visit_create_sequence(self, create):
|
||||
"""Generate a ``CREATE GENERATOR`` statement for the sequence."""
|
||||
@@ -253,39 +332,50 @@ class FBDDLCompiler(sql.compiler.DDLCompiler):
|
||||
# no syntax for these
|
||||
# http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html
|
||||
if create.element.start is not None:
|
||||
raise NotImplemented("Firebird SEQUENCE doesn't support START WITH")
|
||||
raise NotImplemented(
|
||||
"Firebird SEQUENCE doesn't support START WITH")
|
||||
if create.element.increment is not None:
|
||||
raise NotImplemented("Firebird SEQUENCE doesn't support INCREMENT BY")
|
||||
raise NotImplemented(
|
||||
"Firebird SEQUENCE doesn't support INCREMENT BY")
|
||||
|
||||
if self.dialect._version_two:
|
||||
return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
|
||||
return "CREATE SEQUENCE %s" % \
|
||||
self.preparer.format_sequence(create.element)
|
||||
else:
|
||||
return "CREATE GENERATOR %s" % self.preparer.format_sequence(create.element)
|
||||
return "CREATE GENERATOR %s" % \
|
||||
self.preparer.format_sequence(create.element)
|
||||
|
||||
def visit_drop_sequence(self, drop):
|
||||
"""Generate a ``DROP GENERATOR`` statement for the sequence."""
|
||||
|
||||
if self.dialect._version_two:
|
||||
return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
|
||||
return "DROP SEQUENCE %s" % \
|
||||
self.preparer.format_sequence(drop.element)
|
||||
else:
|
||||
return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element)
|
||||
return "DROP GENERATOR %s" % \
|
||||
self.preparer.format_sequence(drop.element)
|
||||
|
||||
|
||||
class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
|
||||
"""Install Firebird specific reserved words."""
|
||||
|
||||
reserved_words = RESERVED_WORDS
|
||||
illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union(
|
||||
['_'])
|
||||
|
||||
def __init__(self, dialect):
|
||||
super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
|
||||
|
||||
|
||||
class FBExecutionContext(default.DefaultExecutionContext):
|
||||
def fire_sequence(self, seq):
|
||||
def fire_sequence(self, seq, type_):
|
||||
"""Get the next value from the sequence using ``gen_id()``."""
|
||||
|
||||
return self._execute_scalar("SELECT gen_id(%s, 1) FROM rdb$database" % \
|
||||
self.dialect.identifier_preparer.format_sequence(seq))
|
||||
return self._execute_scalar(
|
||||
"SELECT gen_id(%s, 1) FROM rdb$database" %
|
||||
self.dialect.identifier_preparer.format_sequence(seq),
|
||||
type_
|
||||
)
|
||||
|
||||
|
||||
class FBDialect(default.DefaultDialect):
|
||||
@@ -305,7 +395,6 @@ class FBDialect(default.DefaultDialect):
|
||||
requires_name_normalize = True
|
||||
supports_empty_insert = False
|
||||
|
||||
|
||||
statement_compiler = FBCompiler
|
||||
ddl_compiler = FBDDLCompiler
|
||||
preparer = FBIdentifierPreparer
|
||||
@@ -315,6 +404,8 @@ class FBDialect(default.DefaultDialect):
|
||||
colspecs = colspecs
|
||||
ischema_names = ischema_names
|
||||
|
||||
construct_arguments = []
|
||||
|
||||
# defaults to dialect ver. 3,
|
||||
# will be autodetected off upon
|
||||
# first connect
|
||||
@@ -322,7 +413,13 @@ class FBDialect(default.DefaultDialect):
|
||||
|
||||
def initialize(self, connection):
|
||||
super(FBDialect, self).initialize(connection)
|
||||
self._version_two = self.server_version_info > (2, )
|
||||
self._version_two = ('firebird' in self.server_version_info and
|
||||
self.server_version_info >= (2, )
|
||||
) or \
|
||||
('interbase' in self.server_version_info and
|
||||
self.server_version_info >= (6, )
|
||||
)
|
||||
|
||||
if not self._version_two:
|
||||
# TODO: whatever other pre < 2.0 stuff goes here
|
||||
self.ischema_names = ischema_names.copy()
|
||||
@@ -330,8 +427,9 @@ class FBDialect(default.DefaultDialect):
|
||||
self.colspecs = {
|
||||
sqltypes.DateTime: sqltypes.DATE
|
||||
}
|
||||
else:
|
||||
self.implicit_returning = True
|
||||
|
||||
self.implicit_returning = self._version_two and \
|
||||
self.__dict__.get('implicit_returning', True)
|
||||
|
||||
def normalize_name(self, name):
|
||||
# Remove trailing spaces: FB uses a CHAR() type,
|
||||
@@ -340,8 +438,10 @@ class FBDialect(default.DefaultDialect):
|
||||
if name is None:
|
||||
return None
|
||||
elif name.upper() == name and \
|
||||
not self.identifier_preparer._requires_quotes(name.lower()):
|
||||
not self.identifier_preparer._requires_quotes(name.lower()):
|
||||
return name.lower()
|
||||
elif name.lower() == name:
|
||||
return quoted_name(name, quote=True)
|
||||
else:
|
||||
return name
|
||||
|
||||
@@ -349,16 +449,17 @@ class FBDialect(default.DefaultDialect):
|
||||
if name is None:
|
||||
return None
|
||||
elif name.lower() == name and \
|
||||
not self.identifier_preparer._requires_quotes(name.lower()):
|
||||
not self.identifier_preparer._requires_quotes(name.lower()):
|
||||
return name.upper()
|
||||
else:
|
||||
return name
|
||||
|
||||
def has_table(self, connection, table_name, schema=None):
|
||||
"""Return ``True`` if the given table exists, ignoring the `schema`."""
|
||||
"""Return ``True`` if the given table exists, ignoring
|
||||
the `schema`."""
|
||||
|
||||
tblqry = """
|
||||
SELECT 1 FROM rdb$database
|
||||
SELECT 1 AS has_table FROM rdb$database
|
||||
WHERE EXISTS (SELECT rdb$relation_name
|
||||
FROM rdb$relations
|
||||
WHERE rdb$relation_name=?)
|
||||
@@ -370,7 +471,7 @@ class FBDialect(default.DefaultDialect):
|
||||
"""Return ``True`` if the given sequence (generator) exists."""
|
||||
|
||||
genqry = """
|
||||
SELECT 1 FROM rdb$database
|
||||
SELECT 1 AS has_sequence FROM rdb$database
|
||||
WHERE EXISTS (SELECT rdb$generator_name
|
||||
FROM rdb$generators
|
||||
WHERE rdb$generator_name=?)
|
||||
@@ -380,18 +481,34 @@ class FBDialect(default.DefaultDialect):
|
||||
|
||||
@reflection.cache
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
# there are two queries commonly mentioned for this.
|
||||
# this one, using view_blr, is at the Firebird FAQ among other places:
|
||||
# http://www.firebirdfaq.org/faq174/
|
||||
s = """
|
||||
SELECT DISTINCT rdb$relation_name
|
||||
FROM rdb$relation_fields
|
||||
WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
|
||||
select rdb$relation_name
|
||||
from rdb$relations
|
||||
where rdb$view_blr is null
|
||||
and (rdb$system_flag is null or rdb$system_flag = 0);
|
||||
"""
|
||||
|
||||
# the other query is this one. It's not clear if there's really
|
||||
# any difference between these two. This link:
|
||||
# http://www.alberton.info/firebird_sql_meta_info.html#.Ur3vXfZGni8
|
||||
# states them as interchangeable. Some discussion at [ticket:2898]
|
||||
# SELECT DISTINCT rdb$relation_name
|
||||
# FROM rdb$relation_fields
|
||||
# WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
|
||||
|
||||
return [self.normalize_name(row[0]) for row in connection.execute(s)]
|
||||
|
||||
@reflection.cache
|
||||
def get_view_names(self, connection, schema=None, **kw):
|
||||
# see http://www.firebirdfaq.org/faq174/
|
||||
s = """
|
||||
SELECT distinct rdb$view_name
|
||||
FROM rdb$view_relations
|
||||
select rdb$relation_name
|
||||
from rdb$relations
|
||||
where rdb$view_blr is not null
|
||||
and (rdb$system_flag is null or rdb$system_flag = 0);
|
||||
"""
|
||||
return [self.normalize_name(row[0]) for row in connection.execute(s)]
|
||||
|
||||
@@ -410,7 +527,7 @@ class FBDialect(default.DefaultDialect):
|
||||
return None
|
||||
|
||||
@reflection.cache
|
||||
def get_primary_keys(self, connection, table_name, schema=None, **kw):
|
||||
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
|
||||
# Query to extract the PK/FK constrained fields of the given table
|
||||
keyqry = """
|
||||
SELECT se.rdb$field_name AS fname
|
||||
@@ -422,10 +539,12 @@ class FBDialect(default.DefaultDialect):
|
||||
# get primary key fields
|
||||
c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
|
||||
pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()]
|
||||
return pkfields
|
||||
return {'constrained_columns': pkfields, 'name': None}
|
||||
|
||||
@reflection.cache
|
||||
def get_column_sequence(self, connection, table_name, column_name, schema=None, **kw):
|
||||
def get_column_sequence(self, connection,
|
||||
table_name, column_name,
|
||||
schema=None, **kw):
|
||||
tablename = self.denormalize_name(table_name)
|
||||
colname = self.denormalize_name(column_name)
|
||||
# Heuristic-query to determine the generator associated to a PK field
|
||||
@@ -436,14 +555,15 @@ class FBDialect(default.DefaultDialect):
|
||||
ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
|
||||
AND trigdep.rdb$depended_on_type=14
|
||||
AND trigdep.rdb$dependent_type=2
|
||||
JOIN rdb$triggers trig ON trig.rdb$trigger_name=tabdep.rdb$dependent_name
|
||||
JOIN rdb$triggers trig ON
|
||||
trig.rdb$trigger_name=tabdep.rdb$dependent_name
|
||||
WHERE tabdep.rdb$depended_on_name=?
|
||||
AND tabdep.rdb$depended_on_type=0
|
||||
AND trig.rdb$trigger_type=1
|
||||
AND tabdep.rdb$field_name=?
|
||||
AND (SELECT count(*)
|
||||
FROM rdb$dependencies trigdep2
|
||||
WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
|
||||
FROM rdb$dependencies trigdep2
|
||||
WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
|
||||
"""
|
||||
genr = connection.execute(genqry, [tablename, colname]).first()
|
||||
if genr is not None:
|
||||
@@ -453,24 +573,29 @@ class FBDialect(default.DefaultDialect):
|
||||
def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
# Query to extract the details of all the fields of the given table
|
||||
tblqry = """
|
||||
SELECT DISTINCT r.rdb$field_name AS fname,
|
||||
SELECT r.rdb$field_name AS fname,
|
||||
r.rdb$null_flag AS null_flag,
|
||||
t.rdb$type_name AS ftype,
|
||||
f.rdb$field_sub_type AS stype,
|
||||
f.rdb$field_length/COALESCE(cs.rdb$bytes_per_character,1) AS flen,
|
||||
f.rdb$field_length/
|
||||
COALESCE(cs.rdb$bytes_per_character,1) AS flen,
|
||||
f.rdb$field_precision AS fprec,
|
||||
f.rdb$field_scale AS fscale,
|
||||
COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault
|
||||
COALESCE(r.rdb$default_source,
|
||||
f.rdb$default_source) AS fdefault
|
||||
FROM rdb$relation_fields r
|
||||
JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
|
||||
JOIN rdb$types t
|
||||
ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE'
|
||||
LEFT JOIN rdb$character_sets cs ON f.rdb$character_set_id=cs.rdb$character_set_id
|
||||
ON t.rdb$type=f.rdb$field_type AND
|
||||
t.rdb$field_name='RDB$FIELD_TYPE'
|
||||
LEFT JOIN rdb$character_sets cs ON
|
||||
f.rdb$character_set_id=cs.rdb$character_set_id
|
||||
WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
|
||||
ORDER BY r.rdb$field_position
|
||||
"""
|
||||
# get the PK, used to determine the eventual associated sequence
|
||||
pkey_cols = self.get_primary_keys(connection, table_name)
|
||||
pk_constraint = self.get_pk_constraint(connection, table_name)
|
||||
pkey_cols = pk_constraint['constrained_columns']
|
||||
|
||||
tablename = self.denormalize_name(table_name)
|
||||
# get all of the fields for this table
|
||||
@@ -490,8 +615,10 @@ class FBDialect(default.DefaultDialect):
|
||||
util.warn("Did not recognize type '%s' of column '%s'" %
|
||||
(colspec, name))
|
||||
coltype = sqltypes.NULLTYPE
|
||||
elif colspec == 'INT64':
|
||||
coltype = coltype(precision=row['fprec'], scale=row['fscale'] * -1)
|
||||
elif issubclass(coltype, Integer) and row['fprec'] != 0:
|
||||
coltype = NUMERIC(
|
||||
precision=row['fprec'],
|
||||
scale=row['fscale'] * -1)
|
||||
elif colspec in ('VARYING', 'CSTRING'):
|
||||
coltype = coltype(row['flen'])
|
||||
elif colspec == 'TEXT':
|
||||
@@ -502,25 +629,29 @@ class FBDialect(default.DefaultDialect):
|
||||
else:
|
||||
coltype = BLOB()
|
||||
else:
|
||||
coltype = coltype(row)
|
||||
coltype = coltype()
|
||||
|
||||
# does it have a default value?
|
||||
defvalue = None
|
||||
if row['fdefault'] is not None:
|
||||
# the value comes down as "DEFAULT 'value'": there may be
|
||||
# more than one whitespace around the "DEFAULT" keyword
|
||||
# and it may also be lower case
|
||||
# (see also http://tracker.firebirdsql.org/browse/CORE-356)
|
||||
defexpr = row['fdefault'].lstrip()
|
||||
assert defexpr[:8].rstrip()=='DEFAULT', "Unrecognized default value: %s" % defexpr
|
||||
assert defexpr[:8].rstrip().upper() == \
|
||||
'DEFAULT', "Unrecognized default value: %s" % \
|
||||
defexpr
|
||||
defvalue = defexpr[8:].strip()
|
||||
if defvalue == 'NULL':
|
||||
# Redundant
|
||||
defvalue = None
|
||||
col_d = {
|
||||
'name' : name,
|
||||
'type' : coltype,
|
||||
'nullable' : not bool(row['null_flag']),
|
||||
'default' : defvalue
|
||||
'name': name,
|
||||
'type': coltype,
|
||||
'nullable': not bool(row['null_flag']),
|
||||
'default': defvalue,
|
||||
'autoincrement': 'auto',
|
||||
}
|
||||
|
||||
if orig_colname.lower() == orig_colname:
|
||||
@@ -528,7 +659,7 @@ class FBDialect(default.DefaultDialect):
|
||||
|
||||
# if the PK is a single field, try to see if its linked to
|
||||
# a sequence thru a trigger
|
||||
if len(pkey_cols)==1 and name==pkey_cols[0]:
|
||||
if len(pkey_cols) == 1 and name == pkey_cols[0]:
|
||||
seq_d = self.get_column_sequence(connection, tablename, name)
|
||||
if seq_d is not None:
|
||||
col_d['sequence'] = seq_d
|
||||
@@ -547,7 +678,8 @@ class FBDialect(default.DefaultDialect):
|
||||
FROM rdb$relation_constraints rc
|
||||
JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
|
||||
JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
|
||||
JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name
|
||||
JOIN rdb$index_segments cse ON
|
||||
cse.rdb$index_name=ix1.rdb$index_name
|
||||
JOIN rdb$index_segments se
|
||||
ON se.rdb$index_name=ix2.rdb$index_name
|
||||
AND se.rdb$field_position=cse.rdb$field_position
|
||||
@@ -557,12 +689,12 @@ class FBDialect(default.DefaultDialect):
|
||||
tablename = self.denormalize_name(table_name)
|
||||
|
||||
c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
|
||||
fks = util.defaultdict(lambda:{
|
||||
'name' : None,
|
||||
'constrained_columns' : [],
|
||||
'referred_schema' : None,
|
||||
'referred_table' : None,
|
||||
'referred_columns' : []
|
||||
fks = util.defaultdict(lambda: {
|
||||
'name': None,
|
||||
'constrained_columns': [],
|
||||
'referred_schema': None,
|
||||
'referred_table': None,
|
||||
'referred_columns': []
|
||||
})
|
||||
|
||||
for row in c:
|
||||
@@ -571,10 +703,11 @@ class FBDialect(default.DefaultDialect):
|
||||
if not fk['name']:
|
||||
fk['name'] = cname
|
||||
fk['referred_table'] = self.normalize_name(row['targetrname'])
|
||||
fk['constrained_columns'].append(self.normalize_name(row['fname']))
|
||||
fk['constrained_columns'].append(
|
||||
self.normalize_name(row['fname']))
|
||||
fk['referred_columns'].append(
|
||||
self.normalize_name(row['targetfname']))
|
||||
return fks.values()
|
||||
self.normalize_name(row['targetfname']))
|
||||
return list(fks.values())
|
||||
|
||||
@reflection.cache
|
||||
def get_indexes(self, connection, table_name, schema=None, **kw):
|
||||
@@ -586,10 +719,11 @@ class FBDialect(default.DefaultDialect):
|
||||
JOIN rdb$index_segments ic
|
||||
ON ix.rdb$index_name=ic.rdb$index_name
|
||||
LEFT OUTER JOIN rdb$relation_constraints
|
||||
ON rdb$relation_constraints.rdb$index_name = ic.rdb$index_name
|
||||
ON rdb$relation_constraints.rdb$index_name =
|
||||
ic.rdb$index_name
|
||||
WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
|
||||
AND rdb$relation_constraints.rdb$constraint_type IS NULL
|
||||
ORDER BY index_name, field_name
|
||||
ORDER BY index_name, ic.rdb$field_position
|
||||
"""
|
||||
c = connection.execute(qry, [self.denormalize_name(table_name)])
|
||||
|
||||
@@ -601,19 +735,7 @@ class FBDialect(default.DefaultDialect):
|
||||
indexrec['column_names'] = []
|
||||
indexrec['unique'] = bool(row['unique_flag'])
|
||||
|
||||
indexrec['column_names'].append(self.normalize_name(row['field_name']))
|
||||
indexrec['column_names'].append(
|
||||
self.normalize_name(row['field_name']))
|
||||
|
||||
return indexes.values()
|
||||
|
||||
def do_execute(self, cursor, statement, parameters, **kwargs):
|
||||
# kinterbase does not accept a None, but wants an empty list
|
||||
# when there are no arguments.
|
||||
cursor.execute(statement, parameters or [])
|
||||
|
||||
def do_rollback(self, connection):
|
||||
# Use the retaining feature, that keeps the transaction going
|
||||
connection.rollback(True)
|
||||
|
||||
def do_commit(self, connection):
|
||||
# Use the retaining feature, that keeps the transaction going
|
||||
connection.commit(True)
|
||||
return list(indexes.values())
|
||||
|
||||
118
sqlalchemy/dialects/firebird/fdb.py
Normal file
118
sqlalchemy/dialects/firebird/fdb.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# firebird/fdb.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: firebird+fdb
|
||||
:name: fdb
|
||||
:dbapi: pyodbc
|
||||
:connectstring: firebird+fdb://user:password@host:port/path/to/db\
|
||||
[?key=value&key=value...]
|
||||
:url: http://pypi.python.org/pypi/fdb/
|
||||
|
||||
fdb is a kinterbasdb compatible DBAPI for Firebird.
|
||||
|
||||
.. versionadded:: 0.8 - Support for the fdb Firebird driver.
|
||||
|
||||
.. versionchanged:: 0.9 - The fdb dialect is now the default dialect
|
||||
under the ``firebird://`` URL space, as ``fdb`` is now the official
|
||||
Python driver for Firebird.
|
||||
|
||||
Arguments
|
||||
----------
|
||||
|
||||
The ``fdb`` dialect is based on the
|
||||
:mod:`sqlalchemy.dialects.firebird.kinterbasdb` dialect, however does not
|
||||
accept every argument that Kinterbasdb does.
|
||||
|
||||
* ``enable_rowcount`` - True by default, setting this to False disables
|
||||
the usage of "cursor.rowcount" with the
|
||||
Kinterbasdb dialect, which SQLAlchemy ordinarily calls upon automatically
|
||||
after any UPDATE or DELETE statement. When disabled, SQLAlchemy's
|
||||
ResultProxy will return -1 for result.rowcount. The rationale here is
|
||||
that Kinterbasdb requires a second round trip to the database when
|
||||
.rowcount is called - since SQLA's resultproxy automatically closes
|
||||
the cursor after a non-result-returning statement, rowcount must be
|
||||
called, if at all, before the result object is returned. Additionally,
|
||||
cursor.rowcount may not return correct results with older versions
|
||||
of Firebird, and setting this flag to False will also cause the
|
||||
SQLAlchemy ORM to ignore its usage. The behavior can also be controlled on a
|
||||
per-execution basis using the ``enable_rowcount`` option with
|
||||
:meth:`.Connection.execution_options`::
|
||||
|
||||
conn = engine.connect().execution_options(enable_rowcount=True)
|
||||
r = conn.execute(stmt)
|
||||
print r.rowcount
|
||||
|
||||
* ``retaining`` - False by default. Setting this to True will pass the
|
||||
``retaining=True`` keyword argument to the ``.commit()`` and ``.rollback()``
|
||||
methods of the DBAPI connection, which can improve performance in some
|
||||
situations, but apparently with significant caveats.
|
||||
Please read the fdb and/or kinterbasdb DBAPI documentation in order to
|
||||
understand the implications of this flag.
|
||||
|
||||
.. versionadded:: 0.8.2 - ``retaining`` keyword argument specifying
|
||||
transaction retaining behavior - in 0.8 it defaults to ``True``
|
||||
for backwards compatibility.
|
||||
|
||||
.. versionchanged:: 0.9.0 - the ``retaining`` flag defaults to ``False``.
|
||||
In 0.8 it defaulted to ``True``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
http://pythonhosted.org/fdb/usage-guide.html#retaining-transactions
|
||||
- information on the "retaining" flag.
|
||||
|
||||
"""
|
||||
|
||||
from .kinterbasdb import FBDialect_kinterbasdb
|
||||
from ... import util
|
||||
|
||||
|
||||
class FBDialect_fdb(FBDialect_kinterbasdb):
|
||||
|
||||
def __init__(self, enable_rowcount=True,
|
||||
retaining=False, **kwargs):
|
||||
super(FBDialect_fdb, self).__init__(
|
||||
enable_rowcount=enable_rowcount,
|
||||
retaining=retaining, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('fdb')
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
if opts.get('port'):
|
||||
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
|
||||
del opts['port']
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, 'type_conv', int)
|
||||
|
||||
return ([], opts)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
"""Get the version of the Firebird server used by a connection.
|
||||
|
||||
Returns a tuple of (`major`, `minor`, `build`), three integers
|
||||
representing the version of the attached server.
|
||||
"""
|
||||
|
||||
# This is the simpler approach (the other uses the services api),
|
||||
# that for backward compatibility reasons returns a string like
|
||||
# LI-V6.3.3.12981 Firebird 2.0
|
||||
# where the first version is a fake one resembling the old
|
||||
# Interbase signature.
|
||||
|
||||
isc_info_firebird_version = 103
|
||||
fbconn = connection.connection
|
||||
|
||||
version = fbconn.db_info(isc_info_firebird_version)
|
||||
|
||||
return self._parse_version_info(version)
|
||||
|
||||
dialect = FBDialect_fdb
|
||||
@@ -1,69 +1,119 @@
|
||||
# kinterbasdb.py
|
||||
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
# firebird/kinterbasdb.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
The most common way to connect to a Firebird engine is implemented by
|
||||
kinterbasdb__, currently maintained__ directly by the Firebird people.
|
||||
.. dialect:: firebird+kinterbasdb
|
||||
:name: kinterbasdb
|
||||
:dbapi: kinterbasdb
|
||||
:connectstring: firebird+kinterbasdb://user:password@host:port/path/to/db\
|
||||
[?key=value&key=value...]
|
||||
:url: http://firebirdsql.org/index.php?op=devel&sub=python
|
||||
|
||||
The connection URL is of the form
|
||||
``firebird[+kinterbasdb]://user:password@host:port/path/to/db[?key=value&key=value...]``.
|
||||
Arguments
|
||||
----------
|
||||
|
||||
Kinterbasedb backend specific keyword arguments are:
|
||||
The Kinterbasdb backend accepts the ``enable_rowcount`` and ``retaining``
|
||||
arguments accepted by the :mod:`sqlalchemy.dialects.firebird.fdb` dialect.
|
||||
In addition, it also accepts the following:
|
||||
|
||||
type_conv
|
||||
select the kind of mapping done on the types: by default SQLAlchemy
|
||||
uses 200 with Unicode, datetime and decimal support (see details__).
|
||||
* ``type_conv`` - select the kind of mapping done on the types: by default
|
||||
SQLAlchemy uses 200 with Unicode, datetime and decimal support. See
|
||||
the linked documents below for further information.
|
||||
|
||||
concurrency_level
|
||||
set the backend policy with regards to threading issues: by default
|
||||
SQLAlchemy uses policy 1 (see details__).
|
||||
* ``concurrency_level`` - set the backend policy with regards to threading
|
||||
issues: by default SQLAlchemy uses policy 1. See the linked documents
|
||||
below for further information.
|
||||
|
||||
.. seealso::
|
||||
|
||||
http://sourceforge.net/projects/kinterbasdb
|
||||
|
||||
http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
|
||||
|
||||
http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
|
||||
|
||||
__ http://sourceforge.net/projects/kinterbasdb
|
||||
__ http://firebirdsql.org/index.php?op=devel&sub=python
|
||||
__ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
|
||||
__ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler
|
||||
from sqlalchemy import util, types as sqltypes
|
||||
from .base import FBDialect, FBExecutionContext
|
||||
from ... import util, types as sqltypes
|
||||
from re import match
|
||||
import decimal
|
||||
|
||||
class _FBNumeric_kinterbasdb(sqltypes.Numeric):
|
||||
|
||||
class _kinterbasdb_numeric(object):
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
if isinstance(value, decimal.Decimal):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
|
||||
class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric):
|
||||
pass
|
||||
|
||||
|
||||
class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float):
|
||||
pass
|
||||
|
||||
|
||||
class FBExecutionContext_kinterbasdb(FBExecutionContext):
|
||||
@property
|
||||
def rowcount(self):
|
||||
if self.execution_options.get('enable_rowcount',
|
||||
self.dialect.enable_rowcount):
|
||||
return self.cursor.rowcount
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
class FBDialect_kinterbasdb(FBDialect):
|
||||
driver = 'kinterbasdb'
|
||||
supports_sane_rowcount = False
|
||||
supports_sane_multi_rowcount = False
|
||||
|
||||
execution_ctx_cls = FBExecutionContext_kinterbasdb
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
|
||||
colspecs = util.update_copy(
|
||||
FBDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric:_FBNumeric_kinterbasdb
|
||||
sqltypes.Numeric: _FBNumeric_kinterbasdb,
|
||||
sqltypes.Float: _FBFloat_kinterbasdb,
|
||||
}
|
||||
|
||||
)
|
||||
|
||||
def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
|
||||
super(FBDialect_kinterbasdb, self).__init__(**kwargs)
|
||||
|
||||
)
|
||||
|
||||
def __init__(self, type_conv=200, concurrency_level=1,
|
||||
enable_rowcount=True,
|
||||
retaining=False, **kwargs):
|
||||
super(FBDialect_kinterbasdb, self).__init__(**kwargs)
|
||||
self.enable_rowcount = enable_rowcount
|
||||
self.type_conv = type_conv
|
||||
self.concurrency_level = concurrency_level
|
||||
self.retaining = retaining
|
||||
if enable_rowcount:
|
||||
self.supports_sane_rowcount = True
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
k = __import__('kinterbasdb')
|
||||
return k
|
||||
return __import__('kinterbasdb')
|
||||
|
||||
def do_execute(self, cursor, statement, parameters, context=None):
|
||||
# kinterbase does not accept a None, but wants an empty list
|
||||
# when there are no arguments.
|
||||
cursor.execute(statement, parameters or [])
|
||||
|
||||
def do_rollback(self, dbapi_connection):
|
||||
dbapi_connection.rollback(self.retaining)
|
||||
|
||||
def do_commit(self, dbapi_connection):
|
||||
dbapi_connection.commit(self.retaining)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
@@ -72,17 +122,22 @@ class FBDialect_kinterbasdb(FBDialect):
|
||||
del opts['port']
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, 'type_conv', int)
|
||||
|
||||
type_conv = opts.pop('type_conv', self.type_conv)
|
||||
concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
|
||||
concurrency_level = opts.pop('concurrency_level',
|
||||
self.concurrency_level)
|
||||
|
||||
if self.dbapi is not None:
|
||||
initialized = getattr(self.dbapi, 'initialized', None)
|
||||
if initialized is None:
|
||||
# CVS rev 1.96 changed the name of the attribute:
|
||||
# http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
|
||||
# http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/
|
||||
# Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
|
||||
initialized = getattr(self.dbapi, '_initialized', False)
|
||||
if not initialized:
|
||||
self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
|
||||
self.dbapi.init(type_conv=type_conv,
|
||||
concurrency_level=concurrency_level)
|
||||
return ([], opts)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
@@ -96,24 +151,33 @@ class FBDialect_kinterbasdb(FBDialect):
|
||||
# that for backward compatibility reasons returns a string like
|
||||
# LI-V6.3.3.12981 Firebird 2.0
|
||||
# where the first version is a fake one resembling the old
|
||||
# Interbase signature. This is more than enough for our purposes,
|
||||
# as this is mainly (only?) used by the testsuite.
|
||||
|
||||
from re import match
|
||||
# Interbase signature.
|
||||
|
||||
fbconn = connection.connection
|
||||
version = fbconn.server_version
|
||||
m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
|
||||
if not m:
|
||||
raise AssertionError("Could not determine version from string '%s'" % version)
|
||||
return tuple([int(x) for x in m.group(5, 6, 4)])
|
||||
|
||||
def is_disconnect(self, e):
|
||||
if isinstance(e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)):
|
||||
return self._parse_version_info(version)
|
||||
|
||||
def _parse_version_info(self, version):
|
||||
m = match(
|
||||
r'\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?', version)
|
||||
if not m:
|
||||
raise AssertionError(
|
||||
"Could not determine version from string '%s'" % version)
|
||||
|
||||
if m.group(5) != None:
|
||||
return tuple([int(x) for x in m.group(6, 7, 4)] + ['firebird'])
|
||||
else:
|
||||
return tuple([int(x) for x in m.group(1, 2, 3)] + ['interbase'])
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, (self.dbapi.OperationalError,
|
||||
self.dbapi.ProgrammingError)):
|
||||
msg = str(e)
|
||||
return ('Unable to complete network request to host' in msg or
|
||||
'Invalid connection state' in msg or
|
||||
'Invalid cursor state' in msg)
|
||||
'Invalid cursor state' in msg or
|
||||
'connection shutdown' in msg)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,4 +1,12 @@
|
||||
from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, pymssql, zxjdbc, mxodbc
|
||||
# mssql/__init__.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, \
|
||||
pymssql, zxjdbc, mxodbc
|
||||
|
||||
base.dialect = pyodbc.dialect
|
||||
|
||||
@@ -11,9 +19,9 @@ from sqlalchemy.dialects.mssql.base import \
|
||||
|
||||
|
||||
__all__ = (
|
||||
'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR',
|
||||
'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR',
|
||||
'NCHAR', 'TEXT', 'NTEXT', 'DECIMAL', 'NUMERIC', 'FLOAT', 'DATETIME',
|
||||
'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME',
|
||||
'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME',
|
||||
'BINARY', 'VARBINARY', 'BIT', 'REAL', 'IMAGE', 'TIMESTAMP',
|
||||
'MONEY', 'SMALLMONEY', 'UNIQUEIDENTIFIER', 'SQL_VARIANT', 'dialect'
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,15 +1,34 @@
|
||||
"""
|
||||
The adodbapi dialect is not implemented for 0.6 at this time.
|
||||
# mssql/adodbapi.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: mssql+adodbapi
|
||||
:name: adodbapi
|
||||
:dbapi: adodbapi
|
||||
:connectstring: mssql+adodbapi://<username>:<password>@<dsnname>
|
||||
:url: http://adodbapi.sourceforge.net/
|
||||
|
||||
.. note::
|
||||
|
||||
The adodbapi dialect is not implemented SQLAlchemy versions 0.6 and
|
||||
above at this time.
|
||||
|
||||
"""
|
||||
import datetime
|
||||
from sqlalchemy import types as sqltypes, util
|
||||
from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
|
||||
import sys
|
||||
|
||||
|
||||
class MSDateTime_adodbapi(MSDateTime):
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
# adodbapi will return datetimes with empty time values as datetime.date() objects.
|
||||
# adodbapi will return datetimes with empty time
|
||||
# values as datetime.date() objects.
|
||||
# Promote them back to full datetime.datetime()
|
||||
if type(value) is datetime.date:
|
||||
return datetime.datetime(value.year, value.month, value.day)
|
||||
@@ -23,7 +42,7 @@ class MSDialect_adodbapi(MSDialect):
|
||||
supports_unicode = sys.maxunicode == 65535
|
||||
supports_unicode_statements = True
|
||||
driver = 'adodbapi'
|
||||
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
import adodbapi as module
|
||||
@@ -32,28 +51,37 @@ class MSDialect_adodbapi(MSDialect):
|
||||
colspecs = util.update_copy(
|
||||
MSDialect.colspecs,
|
||||
{
|
||||
sqltypes.DateTime:MSDateTime_adodbapi
|
||||
sqltypes.DateTime: MSDateTime_adodbapi
|
||||
}
|
||||
)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
keys = url.query
|
||||
def check_quote(token):
|
||||
if ";" in str(token):
|
||||
token = "'%s'" % token
|
||||
return token
|
||||
|
||||
keys = dict(
|
||||
(k, check_quote(v)) for k, v in url.query.items()
|
||||
)
|
||||
|
||||
connectors = ["Provider=SQLOLEDB"]
|
||||
if 'port' in keys:
|
||||
connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
|
||||
connectors.append("Data Source=%s, %s" %
|
||||
(keys.get("host"), keys.get("port")))
|
||||
else:
|
||||
connectors.append ("Data Source=%s" % keys.get("host"))
|
||||
connectors.append ("Initial Catalog=%s" % keys.get("database"))
|
||||
connectors.append("Data Source=%s" % keys.get("host"))
|
||||
connectors.append("Initial Catalog=%s" % keys.get("database"))
|
||||
user = keys.get("user")
|
||||
if user:
|
||||
connectors.append("User Id=%s" % user)
|
||||
connectors.append("Password=%s" % keys.get("password", ""))
|
||||
else:
|
||||
connectors.append("Integrated Security=SSPI")
|
||||
return [[";".join (connectors)], {}]
|
||||
return [[";".join(connectors)], {}]
|
||||
|
||||
def is_disconnect(self, e):
|
||||
return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
return isinstance(e, self.dbapi.adodbapi.DatabaseError) and \
|
||||
"'connection failure'" in str(e)
|
||||
|
||||
dialect = MSDialect_adodbapi
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,83 +1,136 @@
|
||||
from sqlalchemy import Table, MetaData, Column, ForeignKey
|
||||
from sqlalchemy.types import String, Unicode, Integer, TypeDecorator
|
||||
# mssql/information_schema.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
# TODO: should be using the sys. catalog with SQL Server, not information
|
||||
# schema
|
||||
|
||||
from ... import Table, MetaData, Column
|
||||
from ...types import String, Unicode, UnicodeText, Integer, TypeDecorator
|
||||
from ... import cast
|
||||
from ... import util
|
||||
from ...sql import expression
|
||||
from ...ext.compiler import compiles
|
||||
|
||||
ischema = MetaData()
|
||||
|
||||
|
||||
class CoerceUnicode(TypeDecorator):
|
||||
impl = Unicode
|
||||
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if isinstance(value, str):
|
||||
if util.py2k and isinstance(value, util.binary_type):
|
||||
value = value.decode(dialect.encoding)
|
||||
return value
|
||||
|
||||
|
||||
def bind_expression(self, bindvalue):
|
||||
return _cast_on_2005(bindvalue)
|
||||
|
||||
|
||||
class _cast_on_2005(expression.ColumnElement):
|
||||
def __init__(self, bindvalue):
|
||||
self.bindvalue = bindvalue
|
||||
|
||||
|
||||
@compiles(_cast_on_2005)
|
||||
def _compile(element, compiler, **kw):
|
||||
from . import base
|
||||
if compiler.dialect.server_version_info < base.MS_2005_VERSION:
|
||||
return compiler.process(element.bindvalue, **kw)
|
||||
else:
|
||||
return compiler.process(cast(element.bindvalue, Unicode), **kw)
|
||||
|
||||
schemata = Table("SCHEMATA", ischema,
|
||||
Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
|
||||
Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
|
||||
Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
|
||||
Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
|
||||
Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
tables = Table("TABLES", ischema,
|
||||
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column(
|
||||
"TABLE_TYPE", String(convert_unicode=True),
|
||||
key="table_type"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
columns = Table("COLUMNS", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("IS_NULLABLE", Integer, key="is_nullable"),
|
||||
Column("DATA_TYPE", String, key="data_type"),
|
||||
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
|
||||
Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"),
|
||||
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
|
||||
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
|
||||
Column("COLUMN_DEFAULT", Integer, key="column_default"),
|
||||
Column("COLLATION_NAME", String, key="collation_name"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("IS_NULLABLE", Integer, key="is_nullable"),
|
||||
Column("DATA_TYPE", String, key="data_type"),
|
||||
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
|
||||
Column("CHARACTER_MAXIMUM_LENGTH", Integer,
|
||||
key="character_maximum_length"),
|
||||
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
|
||||
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
|
||||
Column("COLUMN_DEFAULT", Integer, key="column_default"),
|
||||
Column("COLLATION_NAME", String, key="collation_name"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
constraints = Table("TABLE_CONSTRAINTS", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
Column("CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode,
|
||||
key="constraint_name"),
|
||||
Column("CONSTRAINT_TYPE", String(
|
||||
convert_unicode=True), key="constraint_type"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
Column("TABLE_SCHEMA", CoerceUnicode,
|
||||
key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode,
|
||||
key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode,
|
||||
key="column_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode,
|
||||
key="constraint_name"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
key_constraints = Table("KEY_COLUMN_USAGE", ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
Column("TABLE_SCHEMA", CoerceUnicode,
|
||||
key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode,
|
||||
key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode,
|
||||
key="column_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode,
|
||||
key="constraint_name"),
|
||||
Column("ORDINAL_POSITION", Integer,
|
||||
key="ordinal_position"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
|
||||
Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
|
||||
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, key="unique_constraint_catalog"), # TODO: is CATLOG misspelled ?
|
||||
Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, key="unique_constraint_schema"),
|
||||
Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"),
|
||||
Column("MATCH_OPTION", String, key="match_option"),
|
||||
Column("UPDATE_RULE", String, key="update_rule"),
|
||||
Column("DELETE_RULE", String, key="delete_rule"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
Column("CONSTRAINT_CATALOG", CoerceUnicode,
|
||||
key="constraint_catalog"),
|
||||
Column("CONSTRAINT_SCHEMA", CoerceUnicode,
|
||||
key="constraint_schema"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode,
|
||||
key="constraint_name"),
|
||||
# TODO: is CATLOG misspelled ?
|
||||
Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode,
|
||||
key="unique_constraint_catalog"),
|
||||
|
||||
Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode,
|
||||
key="unique_constraint_schema"),
|
||||
Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode,
|
||||
key="unique_constraint_name"),
|
||||
Column("MATCH_OPTION", String, key="match_option"),
|
||||
Column("UPDATE_RULE", String, key="update_rule"),
|
||||
Column("DELETE_RULE", String, key="delete_rule"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
views = Table("VIEWS", ischema,
|
||||
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
|
||||
Column("CHECK_OPTION", String, key="check_option"),
|
||||
Column("IS_UPDATABLE", String, key="is_updatable"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
|
||||
Column("CHECK_OPTION", String, key="check_option"),
|
||||
Column("IS_UPDATABLE", String, key="is_updatable"),
|
||||
schema="INFORMATION_SCHEMA")
|
||||
|
||||
@@ -1,55 +1,105 @@
|
||||
# mssql/mxodbc.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
Support for MS-SQL via mxODBC.
|
||||
.. dialect:: mssql+mxodbc
|
||||
:name: mxODBC
|
||||
:dbapi: mxodbc
|
||||
:connectstring: mssql+mxodbc://<username>:<password>@<dsnname>
|
||||
:url: http://www.egenix.com/
|
||||
|
||||
mxODBC is available at:
|
||||
|
||||
http://www.egenix.com/
|
||||
|
||||
This was tested with mxODBC 3.1.2 and the SQL Server Native
|
||||
Client connected to MSSQL 2005 and 2008 Express Editions.
|
||||
|
||||
Connecting
|
||||
~~~~~~~~~~
|
||||
|
||||
Connection is via DSN::
|
||||
|
||||
mssql+mxodbc://<username>:<password>@<dsnname>
|
||||
|
||||
Execution Modes
|
||||
~~~~~~~~~~~~~~~
|
||||
---------------
|
||||
|
||||
mxODBC features two styles of statement execution, using the ``cursor.execute()``
|
||||
and ``cursor.executedirect()`` methods (the second being an extension to the
|
||||
DBAPI specification). The former makes use of the native
|
||||
parameter binding services of the ODBC driver, while the latter uses string escaping.
|
||||
The primary advantage to native parameter binding is that the same statement, when
|
||||
executed many times, is only prepared once. Whereas the primary advantage to the
|
||||
latter is that the rules for bind parameter placement are relaxed. MS-SQL has very
|
||||
strict rules for native binds, including that they cannot be placed within the argument
|
||||
lists of function calls, anywhere outside the FROM, or even within subqueries within the
|
||||
FROM clause - making the usage of bind parameters within SELECT statements impossible for
|
||||
all but the most simplistic statements. For this reason, the mxODBC dialect uses the
|
||||
"native" mode by default only for INSERT, UPDATE, and DELETE statements, and uses the
|
||||
escaped string mode for all other statements. This behavior can be controlled completely
|
||||
via :meth:`~sqlalchemy.sql.expression.Executable.execution_options`
|
||||
using the ``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a value of
|
||||
``True`` will unconditionally use native bind parameters and a value of ``False`` will
|
||||
uncondtionally use string-escaped parameters.
|
||||
mxODBC features two styles of statement execution, using the
|
||||
``cursor.execute()`` and ``cursor.executedirect()`` methods (the second being
|
||||
an extension to the DBAPI specification). The former makes use of a particular
|
||||
API call specific to the SQL Server Native Client ODBC driver known
|
||||
SQLDescribeParam, while the latter does not.
|
||||
|
||||
mxODBC apparently only makes repeated use of a single prepared statement
|
||||
when SQLDescribeParam is used. The advantage to prepared statement reuse is
|
||||
one of performance. The disadvantage is that SQLDescribeParam has a limited
|
||||
set of scenarios in which bind parameters are understood, including that they
|
||||
cannot be placed within the argument lists of function calls, anywhere outside
|
||||
the FROM, or even within subqueries within the FROM clause - making the usage
|
||||
of bind parameters within SELECT statements impossible for all but the most
|
||||
simplistic statements.
|
||||
|
||||
For this reason, the mxODBC dialect uses the "native" mode by default only for
|
||||
INSERT, UPDATE, and DELETE statements, and uses the escaped string mode for
|
||||
all other statements.
|
||||
|
||||
This behavior can be controlled via
|
||||
:meth:`~sqlalchemy.sql.expression.Executable.execution_options` using the
|
||||
``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a
|
||||
value of ``True`` will unconditionally use native bind parameters and a value
|
||||
of ``False`` will unconditionally use string-escaped parameters.
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.connectors.mxodbc import MxODBCConnector
|
||||
from sqlalchemy.dialects.mssql.pyodbc import MSExecutionContext_pyodbc
|
||||
from sqlalchemy.dialects.mssql.base import (MSExecutionContext, MSDialect,
|
||||
MSSQLCompiler, MSSQLStrictCompiler,
|
||||
_MSDateTime, _MSDate, TIME)
|
||||
from ... import types as sqltypes
|
||||
from ...connectors.mxodbc import MxODBCConnector
|
||||
from .pyodbc import MSExecutionContext_pyodbc, _MSNumeric_pyodbc
|
||||
from .base import (MSDialect,
|
||||
MSSQLStrictCompiler,
|
||||
VARBINARY,
|
||||
_MSDateTime, _MSDate, _MSTime)
|
||||
|
||||
|
||||
class _MSNumeric_mxodbc(_MSNumeric_pyodbc):
|
||||
"""Include pyodbc's numeric processor.
|
||||
"""
|
||||
|
||||
|
||||
class _MSDate_mxodbc(_MSDate):
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return "%s-%s-%s" % (value.year, value.month, value.day)
|
||||
else:
|
||||
return None
|
||||
return process
|
||||
|
||||
|
||||
class _MSTime_mxodbc(_MSTime):
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return "%s:%s:%s" % (value.hour, value.minute, value.second)
|
||||
else:
|
||||
return None
|
||||
return process
|
||||
|
||||
|
||||
class _VARBINARY_mxodbc(VARBINARY):
|
||||
|
||||
"""
|
||||
mxODBC Support for VARBINARY column types.
|
||||
|
||||
This handles the special case for null VARBINARY values,
|
||||
which maps None values to the mx.ODBC.Manager.BinaryNull symbol.
|
||||
"""
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
if dialect.dbapi is None:
|
||||
return None
|
||||
|
||||
DBAPIBinary = dialect.dbapi.Binary
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return DBAPIBinary(value)
|
||||
else:
|
||||
# should pull from mx.ODBC.Manager.BinaryNull
|
||||
return dialect.dbapi.BinaryNull
|
||||
return process
|
||||
|
||||
|
||||
class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
|
||||
"""
|
||||
@@ -57,27 +107,33 @@ class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
|
||||
SELECT SCOPE_IDENTITY in cases where OUTPUT clause
|
||||
does not work (tables with insert triggers).
|
||||
"""
|
||||
#todo - investigate whether the pyodbc execution context
|
||||
# todo - investigate whether the pyodbc execution context
|
||||
# is really only being used in cases where OUTPUT
|
||||
# won't work.
|
||||
|
||||
|
||||
class MSDialect_mxodbc(MxODBCConnector, MSDialect):
|
||||
|
||||
# TODO: may want to use this only if FreeTDS is not in use,
|
||||
# since FreeTDS doesn't seem to use native binds.
|
||||
statement_compiler = MSSQLStrictCompiler
|
||||
|
||||
# this is only needed if "native ODBC" mode is used,
|
||||
# which is now disabled by default.
|
||||
# statement_compiler = MSSQLStrictCompiler
|
||||
|
||||
execution_ctx_cls = MSExecutionContext_mxodbc
|
||||
|
||||
# flag used by _MSNumeric_mxodbc
|
||||
_need_decimal_fix = True
|
||||
|
||||
colspecs = {
|
||||
#sqltypes.Numeric : _MSNumeric,
|
||||
sqltypes.DateTime : _MSDateTime,
|
||||
sqltypes.Date : _MSDate,
|
||||
sqltypes.Time : TIME,
|
||||
sqltypes.Numeric: _MSNumeric_mxodbc,
|
||||
sqltypes.DateTime: _MSDateTime,
|
||||
sqltypes.Date: _MSDate_mxodbc,
|
||||
sqltypes.Time: _MSTime_mxodbc,
|
||||
VARBINARY: _VARBINARY_mxodbc,
|
||||
sqltypes.LargeBinary: _VARBINARY_mxodbc,
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, description_encoding='latin-1', **params):
|
||||
def __init__(self, description_encoding=None, **params):
|
||||
super(MSDialect_mxodbc, self).__init__(**params)
|
||||
self.description_encoding = description_encoding
|
||||
|
||||
dialect = MSDialect_mxodbc
|
||||
|
||||
|
||||
@@ -1,41 +1,27 @@
|
||||
"""
|
||||
Support for the pymssql dialect.
|
||||
|
||||
This dialect supports pymssql 1.0 and greater.
|
||||
|
||||
pymssql is available at:
|
||||
|
||||
http://pymssql.sourceforge.net/
|
||||
|
||||
Connecting
|
||||
^^^^^^^^^^
|
||||
|
||||
Sample connect string::
|
||||
|
||||
mssql+pymssql://<username>:<password>@<freetds_name>
|
||||
|
||||
Adding "?charset=utf8" or similar will cause pymssql to return
|
||||
strings as Python unicode objects. This can potentially improve
|
||||
performance in some scenarios as decoding of strings is
|
||||
handled natively.
|
||||
|
||||
Limitations
|
||||
^^^^^^^^^^^
|
||||
|
||||
pymssql inherits a lot of limitations from FreeTDS, including:
|
||||
|
||||
* no support for multibyte schema identifiers
|
||||
* poor support for large decimals
|
||||
* poor support for binary fields
|
||||
* poor support for VARCHAR/CHAR fields over 255 characters
|
||||
|
||||
Please consult the pymssql documentation for further information.
|
||||
# mssql/pymssql.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
from sqlalchemy.dialects.mssql.base import MSDialect
|
||||
from sqlalchemy import types as sqltypes, util, processors
|
||||
.. dialect:: mssql+pymssql
|
||||
:name: pymssql
|
||||
:dbapi: pymssql
|
||||
:connectstring: mssql+pymssql://<username>:<password>@<freetds_name>/?\
|
||||
charset=utf8
|
||||
:url: http://pymssql.org/
|
||||
|
||||
pymssql is a Python module that provides a Python DBAPI interface around
|
||||
`FreeTDS <http://www.freetds.org/>`_. Compatible builds are available for
|
||||
Linux, MacOSX and Windows platforms.
|
||||
|
||||
"""
|
||||
from .base import MSDialect
|
||||
from ... import types as sqltypes, util, processors
|
||||
import re
|
||||
import decimal
|
||||
|
||||
|
||||
class _MSNumeric_pymssql(sqltypes.Numeric):
|
||||
def result_processor(self, dialect, type_):
|
||||
@@ -44,29 +30,31 @@ class _MSNumeric_pymssql(sqltypes.Numeric):
|
||||
else:
|
||||
return sqltypes.Numeric.result_processor(self, dialect, type_)
|
||||
|
||||
|
||||
class MSDialect_pymssql(MSDialect):
|
||||
supports_sane_rowcount = False
|
||||
max_identifier_length = 30
|
||||
driver = 'pymssql'
|
||||
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MSDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric:_MSNumeric_pymssql,
|
||||
sqltypes.Float:sqltypes.Float,
|
||||
sqltypes.Numeric: _MSNumeric_pymssql,
|
||||
sqltypes.Float: sqltypes.Float,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
module = __import__('pymssql')
|
||||
# pymmsql doesn't have a Binary method. we use string
|
||||
# TODO: monkeypatching here is less than ideal
|
||||
module.Binary = str
|
||||
|
||||
# pymmsql < 2.1.1 doesn't have a Binary method. we use string
|
||||
client_ver = tuple(int(x) for x in module.__version__.split("."))
|
||||
if client_ver < (2, 1, 1):
|
||||
# TODO: monkeypatching here is less than ideal
|
||||
module.Binary = lambda x: x if hasattr(x, 'decode') else str(x)
|
||||
|
||||
if client_ver < (1, ):
|
||||
util.warn("The pymssql dialect expects at least "
|
||||
"the 1.0 series of the pymssql DBAPI.")
|
||||
"the 1.0 series of the pymssql DBAPI.")
|
||||
return module
|
||||
|
||||
def __init__(self, **params):
|
||||
@@ -75,7 +63,8 @@ class MSDialect_pymssql(MSDialect):
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
vers = connection.scalar("select @@version")
|
||||
m = re.match(r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers)
|
||||
m = re.match(
|
||||
r"Microsoft .*? - (\d+).(\d+).(\d+).(\d+)", vers)
|
||||
if m:
|
||||
return tuple(int(x) for x in m.group(1, 2, 3, 4))
|
||||
else:
|
||||
@@ -84,18 +73,25 @@ class MSDialect_pymssql(MSDialect):
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
opts.update(url.query)
|
||||
opts.pop('port', None)
|
||||
port = opts.pop('port', None)
|
||||
if port and 'host' in opts:
|
||||
opts['host'] = "%s:%s" % (opts['host'], port)
|
||||
return [[], opts]
|
||||
|
||||
def is_disconnect(self, e):
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
for msg in (
|
||||
"Adaptive Server connection timed out",
|
||||
"Net-Lib error during Connection reset by peer",
|
||||
"message 20003", # connection timeout
|
||||
"Error 10054",
|
||||
"Not connected to any MS SQL server",
|
||||
"Connection is closed"
|
||||
"Connection is closed",
|
||||
"message 20006", # Write to the server failed
|
||||
"message 20017", # Unexpected EOF from the server
|
||||
):
|
||||
if msg in str(e):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
dialect = MSDialect_pymssql
|
||||
dialect = MSDialect_pymssql
|
||||
|
||||
@@ -1,99 +1,135 @@
|
||||
"""
|
||||
Support for MS-SQL via pyodbc.
|
||||
# mssql/pyodbc.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
pyodbc is available at:
|
||||
r"""
|
||||
.. dialect:: mssql+pyodbc
|
||||
:name: PyODBC
|
||||
:dbapi: pyodbc
|
||||
:connectstring: mssql+pyodbc://<username>:<password>@<dsnname>
|
||||
:url: http://pypi.python.org/pypi/pyodbc/
|
||||
|
||||
http://pypi.python.org/pypi/pyodbc/
|
||||
Connecting to PyODBC
|
||||
--------------------
|
||||
|
||||
Connecting
|
||||
^^^^^^^^^^
|
||||
The URL here is to be translated to PyODBC connection strings, as
|
||||
detailed in `ConnectionStrings <https://code.google.com/p/pyodbc/wiki/ConnectionStrings>`_.
|
||||
|
||||
Examples of pyodbc connection string URLs:
|
||||
DSN Connections
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
* ``mssql+pyodbc://mydsn`` - connects using the specified DSN named ``mydsn``.
|
||||
The connection string that is created will appear like::
|
||||
A DSN-based connection is **preferred** overall when using ODBC. A
|
||||
basic DSN-based connection looks like::
|
||||
|
||||
dsn=mydsn;Trusted_Connection=Yes
|
||||
engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn")
|
||||
|
||||
* ``mssql+pyodbc://user:pass@mydsn`` - connects using the DSN named
|
||||
``mydsn`` passing in the ``UID`` and ``PWD`` information. The
|
||||
connection string that is created will appear like::
|
||||
Which above, will pass the following connection string to PyODBC::
|
||||
|
||||
dsn=mydsn;UID=user;PWD=pass
|
||||
|
||||
* ``mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english`` - connects
|
||||
using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
|
||||
information, plus the additional connection configuration option
|
||||
``LANGUAGE``. The connection string that is created will appear
|
||||
like::
|
||||
If the username and password are omitted, the DSN form will also add
|
||||
the ``Trusted_Connection=yes`` directive to the ODBC string.
|
||||
|
||||
dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
|
||||
Hostname Connections
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
* ``mssql+pyodbc://user:pass@host/db`` - connects using a connection string
|
||||
dynamically created that would appear like::
|
||||
Hostname-based connections are **not preferred**, however are supported.
|
||||
The ODBC driver name must be explicitly specified::
|
||||
|
||||
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
|
||||
engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=SQL+Server+Native+Client+10.0")
|
||||
|
||||
* ``mssql+pyodbc://user:pass@host:123/db`` - connects using a connection
|
||||
string that is dynamically created, which also includes the port
|
||||
information using the comma syntax. If your connection string
|
||||
requires the port information to be passed as a ``port`` keyword
|
||||
see the next example. This will create the following connection
|
||||
string::
|
||||
.. versionchanged:: 1.0.0 Hostname-based PyODBC connections now require the
|
||||
SQL Server driver name specified explicitly. SQLAlchemy cannot
|
||||
choose an optimal default here as it varies based on platform
|
||||
and installed drivers.
|
||||
|
||||
DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
|
||||
Other keywords interpreted by the Pyodbc dialect to be passed to
|
||||
``pyodbc.connect()`` in both the DSN and hostname cases include:
|
||||
``odbc_autotranslate``, ``ansi``, ``unicode_results``, ``autocommit``.
|
||||
|
||||
* ``mssql+pyodbc://user:pass@host/db?port=123`` - connects using a connection
|
||||
string that is dynamically created that includes the port
|
||||
information as a separate ``port`` keyword. This will create the
|
||||
following connection string::
|
||||
Pass through exact Pyodbc string
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
|
||||
A PyODBC connection string can also be sent exactly as specified in
|
||||
`ConnectionStrings <https://code.google.com/p/pyodbc/wiki/ConnectionStrings>`_
|
||||
into the driver using the parameter ``odbc_connect``. The delimeters must be URL escaped, however,
|
||||
as illustrated below using ``urllib.quote_plus``::
|
||||
|
||||
If you require a connection string that is outside the options
|
||||
presented above, use the ``odbc_connect`` keyword to pass in a
|
||||
urlencoded connection string. What gets passed in will be urldecoded
|
||||
and passed directly.
|
||||
import urllib
|
||||
params = urllib.quote_plus("DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password")
|
||||
|
||||
For example::
|
||||
engine = create_engine("mssql+pyodbc:///?odbc_connect=%s" % params)
|
||||
|
||||
mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
|
||||
|
||||
would create the following connection string::
|
||||
Unicode Binds
|
||||
-------------
|
||||
|
||||
dsn=mydsn;Database=db
|
||||
The current state of PyODBC on a unix backend with FreeTDS and/or
|
||||
EasySoft is poor regarding unicode; different OS platforms and versions of
|
||||
UnixODBC versus IODBC versus FreeTDS/EasySoft versus PyODBC itself
|
||||
dramatically alter how strings are received. The PyODBC dialect attempts to
|
||||
use all the information it knows to determine whether or not a Python unicode
|
||||
literal can be passed directly to the PyODBC driver or not; while SQLAlchemy
|
||||
can encode these to bytestrings first, some users have reported that PyODBC
|
||||
mis-handles bytestrings for certain encodings and requires a Python unicode
|
||||
object, while the author has observed widespread cases where a Python unicode
|
||||
is completely misinterpreted by PyODBC, particularly when dealing with
|
||||
the information schema tables used in table reflection, and the value
|
||||
must first be encoded to a bytestring.
|
||||
|
||||
Encoding your connection string can be easily accomplished through
|
||||
the python shell. For example::
|
||||
It is for this reason that whether or not unicode literals for bound
|
||||
parameters be sent to PyODBC can be controlled using the
|
||||
``supports_unicode_binds`` parameter to ``create_engine()``. When
|
||||
left at its default of ``None``, the PyODBC dialect will use its
|
||||
best guess as to whether or not the driver deals with unicode literals
|
||||
well. When ``False``, unicode literals will be encoded first, and when
|
||||
``True`` unicode literals will be passed straight through. This is an interim
|
||||
flag that hopefully should not be needed when the unicode situation stabilizes
|
||||
for unix + PyODBC.
|
||||
|
||||
>>> import urllib
|
||||
>>> urllib.quote_plus('dsn=mydsn;Database=db')
|
||||
'dsn%3Dmydsn%3BDatabase%3Ddb'
|
||||
.. versionadded:: 0.7.7
|
||||
``supports_unicode_binds`` parameter to ``create_engine()``\ .
|
||||
|
||||
Rowcount Support
|
||||
----------------
|
||||
|
||||
Pyodbc only has partial support for rowcount. See the notes at
|
||||
:ref:`mssql_rowcount_versioning` for important notes when using ORM
|
||||
versioning.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect
|
||||
from sqlalchemy.connectors.pyodbc import PyODBCConnector
|
||||
from sqlalchemy import types as sqltypes, util
|
||||
from .base import MSExecutionContext, MSDialect, VARBINARY
|
||||
from ...connectors.pyodbc import PyODBCConnector
|
||||
from ... import types as sqltypes, util, exc
|
||||
import decimal
|
||||
import re
|
||||
|
||||
|
||||
class _ms_numeric_pyodbc(object):
|
||||
|
||||
class _MSNumeric_pyodbc(sqltypes.Numeric):
|
||||
"""Turns Decimals with adjusted() < 0 or > 7 into strings.
|
||||
|
||||
This is the only method that is proven to work with Pyodbc+MSSQL
|
||||
without crashing (floats can be used but seem to cause sporadic
|
||||
crashes).
|
||||
|
||||
|
||||
The routines here are needed for older pyodbc versions
|
||||
as well as current mxODBC versions.
|
||||
|
||||
"""
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
super_process = super(_MSNumeric_pyodbc, self).bind_processor(dialect)
|
||||
|
||||
super_process = super(_ms_numeric_pyodbc, self).\
|
||||
bind_processor(dialect)
|
||||
|
||||
if not dialect._need_decimal_fix:
|
||||
return super_process
|
||||
|
||||
def process(value):
|
||||
if self.asdecimal and \
|
||||
isinstance(value, decimal.Decimal):
|
||||
|
||||
|
||||
adjusted = value.adjusted()
|
||||
if adjusted < 0:
|
||||
return self._small_dec_to_string(value)
|
||||
@@ -105,72 +141,106 @@ class _MSNumeric_pyodbc(sqltypes.Numeric):
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
# these routines needed for older versions of pyodbc.
|
||||
# as of 2.1.8 this logic is integrated.
|
||||
|
||||
def _small_dec_to_string(self, value):
|
||||
return "%s0.%s%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
'0' * (abs(value.adjusted()) - 1),
|
||||
"".join([str(nint) for nint in value._int]))
|
||||
(value < 0 and '-' or ''),
|
||||
'0' * (abs(value.adjusted()) - 1),
|
||||
"".join([str(nint) for nint in value.as_tuple()[1]]))
|
||||
|
||||
def _large_dec_to_string(self, value):
|
||||
_int = value.as_tuple()[1]
|
||||
if 'E' in str(value):
|
||||
result = "%s%s%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
"".join([str(s) for s in value._int]),
|
||||
"0" * (value.adjusted() - (len(value._int)-1)))
|
||||
(value < 0 and '-' or ''),
|
||||
"".join([str(s) for s in _int]),
|
||||
"0" * (value.adjusted() - (len(_int) - 1)))
|
||||
else:
|
||||
if (len(value._int) - 1) > value.adjusted():
|
||||
if (len(_int) - 1) > value.adjusted():
|
||||
result = "%s%s.%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
"".join([str(s) for s in value._int][0:value.adjusted() + 1]),
|
||||
"".join([str(s) for s in value._int][value.adjusted() + 1:]))
|
||||
(value < 0 and '-' or ''),
|
||||
"".join(
|
||||
[str(s) for s in _int][0:value.adjusted() + 1]),
|
||||
"".join(
|
||||
[str(s) for s in _int][value.adjusted() + 1:]))
|
||||
else:
|
||||
result = "%s%s" % (
|
||||
(value < 0 and '-' or ''),
|
||||
"".join([str(s) for s in value._int][0:value.adjusted() + 1]))
|
||||
(value < 0 and '-' or ''),
|
||||
"".join(
|
||||
[str(s) for s in _int][0:value.adjusted() + 1]))
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric):
|
||||
pass
|
||||
|
||||
|
||||
class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float):
|
||||
pass
|
||||
|
||||
|
||||
class _VARBINARY_pyodbc(VARBINARY):
|
||||
def bind_processor(self, dialect):
|
||||
if dialect.dbapi is None:
|
||||
return None
|
||||
|
||||
DBAPIBinary = dialect.dbapi.Binary
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return DBAPIBinary(value)
|
||||
else:
|
||||
# pyodbc-specific
|
||||
return dialect.dbapi.BinaryNull
|
||||
return process
|
||||
|
||||
|
||||
class MSExecutionContext_pyodbc(MSExecutionContext):
|
||||
_embedded_scope_identity = False
|
||||
|
||||
|
||||
def pre_exec(self):
|
||||
"""where appropriate, issue "select scope_identity()" in the same statement.
|
||||
|
||||
"""where appropriate, issue "select scope_identity()" in the same
|
||||
statement.
|
||||
|
||||
Background on why "scope_identity()" is preferable to "@@identity":
|
||||
http://msdn.microsoft.com/en-us/library/ms190315.aspx
|
||||
|
||||
|
||||
Background on why we attempt to embed "scope_identity()" into the same
|
||||
statement as the INSERT:
|
||||
http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
super(MSExecutionContext_pyodbc, self).pre_exec()
|
||||
|
||||
# don't embed the scope_identity select into an "INSERT .. DEFAULT VALUES"
|
||||
# don't embed the scope_identity select into an
|
||||
# "INSERT .. DEFAULT VALUES"
|
||||
if self._select_lastrowid and \
|
||||
self.dialect.use_scope_identity and \
|
||||
len(self.parameters[0]):
|
||||
self._embedded_scope_identity = True
|
||||
|
||||
|
||||
self.statement += "; select scope_identity()"
|
||||
|
||||
def post_exec(self):
|
||||
if self._embedded_scope_identity:
|
||||
# Fetch the last inserted id from the manipulated statement
|
||||
# We may have to skip over a number of result sets with no data (due to triggers, etc.)
|
||||
# We may have to skip over a number of result sets with
|
||||
# no data (due to triggers, etc.)
|
||||
while True:
|
||||
try:
|
||||
# fetchall() ensures the cursor is consumed
|
||||
# fetchall() ensures the cursor is consumed
|
||||
# without closing it (FreeTDS particularly)
|
||||
row = self.cursor.fetchall()[0]
|
||||
row = self.cursor.fetchall()[0]
|
||||
break
|
||||
except self.dialect.dbapi.Error, e:
|
||||
except self.dialect.dbapi.Error as e:
|
||||
# no way around this - nextset() consumes the previous set
|
||||
# so we need to just keep flipping
|
||||
self.cursor.nextset()
|
||||
|
||||
|
||||
self._lastrowid = int(row[0])
|
||||
else:
|
||||
super(MSExecutionContext_pyodbc, self).post_exec()
|
||||
@@ -180,18 +250,43 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
|
||||
|
||||
execution_ctx_cls = MSExecutionContext_pyodbc
|
||||
|
||||
pyodbc_driver_name = 'SQL Server'
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MSDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric:_MSNumeric_pyodbc
|
||||
sqltypes.Numeric: _MSNumeric_pyodbc,
|
||||
sqltypes.Float: _MSFloat_pyodbc,
|
||||
VARBINARY: _VARBINARY_pyodbc,
|
||||
sqltypes.LargeBinary: _VARBINARY_pyodbc,
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, description_encoding='latin-1', **params):
|
||||
|
||||
def __init__(self, description_encoding=None, **params):
|
||||
if 'description_encoding' in params:
|
||||
self.description_encoding = params.pop('description_encoding')
|
||||
super(MSDialect_pyodbc, self).__init__(**params)
|
||||
self.description_encoding = description_encoding
|
||||
self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset')
|
||||
|
||||
self.use_scope_identity = self.use_scope_identity and \
|
||||
self.dbapi and \
|
||||
hasattr(self.dbapi.Cursor, 'nextset')
|
||||
self._need_decimal_fix = self.dbapi and \
|
||||
self._dbapi_version() < (2, 1, 8)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
try:
|
||||
raw = connection.scalar("SELECT SERVERPROPERTY('ProductVersion')")
|
||||
except exc.DBAPIError:
|
||||
# SQL Server docs indicate this function isn't present prior to
|
||||
# 2008; additionally, unknown combinations of pyodbc aren't
|
||||
# able to run this query.
|
||||
return super(MSDialect_pyodbc, self).\
|
||||
_get_server_version_info(connection)
|
||||
else:
|
||||
version = []
|
||||
r = re.compile(r'[.\-]')
|
||||
for n in r.split(raw):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
|
||||
dialect = MSDialect_pyodbc
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
"""Support for the Microsoft SQL Server database via the zxjdbc JDBC
|
||||
connector.
|
||||
|
||||
JDBC Driver
|
||||
-----------
|
||||
|
||||
Requires the jTDS driver, available from: http://jtds.sourceforge.net/
|
||||
|
||||
Connecting
|
||||
----------
|
||||
|
||||
URLs are of the standard form of
|
||||
``mssql+zxjdbc://user:pass@host:port/dbname[?key=value&key=value...]``.
|
||||
|
||||
Additional arguments which may be specified either as query string
|
||||
arguments on the URL, or as keyword arguments to
|
||||
:func:`~sqlalchemy.create_engine()` will be passed as Connection
|
||||
properties to the underlying JDBC driver.
|
||||
# mssql/zxjdbc.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
|
||||
from sqlalchemy.dialects.mssql.base import MSDialect, MSExecutionContext
|
||||
from sqlalchemy.engine import base
|
||||
.. dialect:: mssql+zxjdbc
|
||||
:name: zxJDBC for Jython
|
||||
:dbapi: zxjdbc
|
||||
:connectstring: mssql+zxjdbc://user:pass@host:port/dbname\
|
||||
[?key=value&key=value...]
|
||||
:driverurl: http://jtds.sourceforge.net/
|
||||
|
||||
.. note:: Jython is not supported by current versions of SQLAlchemy. The
|
||||
zxjdbc dialect should be considered as experimental.
|
||||
|
||||
"""
|
||||
from ...connectors.zxJDBC import ZxJDBCConnector
|
||||
from .base import MSDialect, MSExecutionContext
|
||||
from ... import engine
|
||||
|
||||
|
||||
class MSExecutionContext_zxjdbc(MSExecutionContext):
|
||||
|
||||
@@ -40,15 +40,17 @@ class MSExecutionContext_zxjdbc(MSExecutionContext):
|
||||
try:
|
||||
row = self.cursor.fetchall()[0]
|
||||
break
|
||||
except self.dialect.dbapi.Error, e:
|
||||
except self.dialect.dbapi.Error:
|
||||
self.cursor.nextset()
|
||||
self._lastrowid = int(row[0])
|
||||
|
||||
if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning:
|
||||
self._result_proxy = base.FullyBufferedResultProxy(self)
|
||||
if (self.isinsert or self.isupdate or self.isdelete) and \
|
||||
self.compiled.returning:
|
||||
self._result_proxy = engine.FullyBufferedResultProxy(self)
|
||||
|
||||
if self._enable_identity_insert:
|
||||
table = self.dialect.identifier_preparer.format_table(self.compiled.statement.table)
|
||||
table = self.dialect.identifier_preparer.format_table(
|
||||
self.compiled.statement.table)
|
||||
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table)
|
||||
|
||||
|
||||
@@ -59,6 +61,9 @@ class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect):
|
||||
execution_ctx_cls = MSExecutionContext_zxjdbc
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
return tuple(int(x) for x in connection.connection.dbversion.split('.'))
|
||||
return tuple(
|
||||
int(x)
|
||||
for x in connection.connection.dbversion.split('.')
|
||||
)
|
||||
|
||||
dialect = MSDialect_zxjdbc
|
||||
|
||||
@@ -1,17 +1,31 @@
|
||||
from sqlalchemy.dialects.mysql import base, mysqldb, oursql, pyodbc, zxjdbc, mysqlconnector
|
||||
# mysql/__init__.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from . import base, mysqldb, oursql, \
|
||||
pyodbc, zxjdbc, mysqlconnector, pymysql,\
|
||||
gaerdbms, cymysql
|
||||
|
||||
# default dialect
|
||||
base.dialect = mysqldb.dialect
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import \
|
||||
BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, DOUBLE, ENUM, DECIMAL,\
|
||||
FLOAT, INTEGER, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, MEDIUMTEXT, NCHAR, \
|
||||
NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT,\
|
||||
from .base import \
|
||||
BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, \
|
||||
DECIMAL, DOUBLE, ENUM, DECIMAL,\
|
||||
FLOAT, INTEGER, INTEGER, JSON, LONGBLOB, LONGTEXT, MEDIUMBLOB, \
|
||||
MEDIUMINT, MEDIUMTEXT, NCHAR, \
|
||||
NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, \
|
||||
TINYBLOB, TINYINT, TINYTEXT,\
|
||||
VARBINARY, VARCHAR, YEAR, dialect
|
||||
|
||||
|
||||
__all__ = (
|
||||
'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE',
|
||||
'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT',
|
||||
'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP',
|
||||
'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect'
|
||||
'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME',
|
||||
'DECIMAL', 'DOUBLE', 'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER',
|
||||
'JSON', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT', 'MEDIUMTEXT',
|
||||
'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME',
|
||||
'TIMESTAMP', 'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR',
|
||||
'YEAR', 'dialect'
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
87
sqlalchemy/dialects/mysql/cymysql.py
Normal file
87
sqlalchemy/dialects/mysql/cymysql.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# mysql/cymysql.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: mysql+cymysql
|
||||
:name: CyMySQL
|
||||
:dbapi: cymysql
|
||||
:connectstring: mysql+cymysql://<username>:<password>@<host>/<dbname>\
|
||||
[?<options>]
|
||||
:url: https://github.com/nakagami/CyMySQL
|
||||
|
||||
"""
|
||||
import re
|
||||
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from .base import (BIT, MySQLDialect)
|
||||
from ... import util
|
||||
|
||||
|
||||
class _cymysqlBIT(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Convert a MySQL's 64 bit, variable length binary string to a long.
|
||||
"""
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
v = 0
|
||||
for i in util.iterbytes(value):
|
||||
v = v << 8 | i
|
||||
return v
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class MySQLDialect_cymysql(MySQLDialect_mysqldb):
|
||||
driver = 'cymysql'
|
||||
|
||||
description_encoding = None
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = False
|
||||
supports_unicode_statements = True
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MySQLDialect.colspecs,
|
||||
{
|
||||
BIT: _cymysqlBIT,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('cymysql')
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile(r'[.\-]')
|
||||
for n in r.split(dbapi_con.server_version):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
version.append(n)
|
||||
return tuple(version)
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
return connection.connection.charset
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.errno
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.OperationalError):
|
||||
return self._extract_error_code(e) in \
|
||||
(2006, 2013, 2014, 2045, 2055)
|
||||
elif isinstance(e, self.dbapi.InterfaceError):
|
||||
# if underlying connection is closed,
|
||||
# this is the error you get
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
dialect = MySQLDialect_cymysql
|
||||
311
sqlalchemy/dialects/mysql/enumerated.py
Normal file
311
sqlalchemy/dialects/mysql/enumerated.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# mysql/enumerated.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import re
|
||||
|
||||
from .types import _StringType
|
||||
from ... import exc, sql, util
|
||||
from ... import types as sqltypes
|
||||
|
||||
|
||||
class _EnumeratedValues(_StringType):
|
||||
def _init_values(self, values, kw):
|
||||
self.quoting = kw.pop('quoting', 'auto')
|
||||
|
||||
if self.quoting == 'auto' and len(values):
|
||||
# What quoting character are we using?
|
||||
q = None
|
||||
for e in values:
|
||||
if len(e) == 0:
|
||||
self.quoting = 'unquoted'
|
||||
break
|
||||
elif q is None:
|
||||
q = e[0]
|
||||
|
||||
if len(e) == 1 or e[0] != q or e[-1] != q:
|
||||
self.quoting = 'unquoted'
|
||||
break
|
||||
else:
|
||||
self.quoting = 'quoted'
|
||||
|
||||
if self.quoting == 'quoted':
|
||||
util.warn_deprecated(
|
||||
'Manually quoting %s value literals is deprecated. Supply '
|
||||
'unquoted values and use the quoting= option in cases of '
|
||||
'ambiguity.' % self.__class__.__name__)
|
||||
|
||||
values = self._strip_values(values)
|
||||
|
||||
self._enumerated_values = values
|
||||
length = max([len(v) for v in values] + [0])
|
||||
return values, length
|
||||
|
||||
@classmethod
|
||||
def _strip_values(cls, values):
|
||||
strip_values = []
|
||||
for a in values:
|
||||
if a[0:1] == '"' or a[0:1] == "'":
|
||||
# strip enclosing quotes and unquote interior
|
||||
a = a[1:-1].replace(a[0] * 2, a[0])
|
||||
strip_values.append(a)
|
||||
return strip_values
|
||||
|
||||
|
||||
class ENUM(sqltypes.Enum, _EnumeratedValues):
|
||||
"""MySQL ENUM type."""
|
||||
|
||||
__visit_name__ = 'ENUM'
|
||||
|
||||
def __init__(self, *enums, **kw):
|
||||
"""Construct an ENUM.
|
||||
|
||||
E.g.::
|
||||
|
||||
Column('myenum', ENUM("foo", "bar", "baz"))
|
||||
|
||||
:param enums: The range of valid values for this ENUM. Values will be
|
||||
quoted when generating the schema according to the quoting flag (see
|
||||
below). This object may also be a PEP-435-compliant enumerated
|
||||
type.
|
||||
|
||||
.. versionadded: 1.1 added support for PEP-435-compliant enumerated
|
||||
types.
|
||||
|
||||
:param strict: This flag has no effect.
|
||||
|
||||
.. versionchanged:: The MySQL ENUM type as well as the base Enum
|
||||
type now validates all Python data values.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
:param quoting: Defaults to 'auto': automatically determine enum value
|
||||
quoting. If all enum values are surrounded by the same quoting
|
||||
character, then use 'quoted' mode. Otherwise, use 'unquoted' mode.
|
||||
|
||||
'quoted': values in enums are already quoted, they will be used
|
||||
directly when generating the schema - this usage is deprecated.
|
||||
|
||||
'unquoted': values in enums are not quoted, they will be escaped and
|
||||
surrounded by single quotes when generating the schema.
|
||||
|
||||
Previous versions of this type always required manually quoted
|
||||
values to be supplied; future versions will always quote the string
|
||||
literals for you. This is a transitional option.
|
||||
|
||||
"""
|
||||
|
||||
kw.pop('strict', None)
|
||||
validate_strings = kw.pop("validate_strings", False)
|
||||
sqltypes.Enum.__init__(
|
||||
self, validate_strings=validate_strings, *enums)
|
||||
kw.pop('metadata', None)
|
||||
kw.pop('schema', None)
|
||||
kw.pop('name', None)
|
||||
kw.pop('quote', None)
|
||||
kw.pop('native_enum', None)
|
||||
kw.pop('inherit_schema', None)
|
||||
kw.pop('_create_events', None)
|
||||
_StringType.__init__(self, length=self.length, **kw)
|
||||
|
||||
def _setup_for_values(self, values, objects, kw):
|
||||
values, length = self._init_values(values, kw)
|
||||
return sqltypes.Enum._setup_for_values(self, values, objects, kw)
|
||||
|
||||
def _object_value_for_elem(self, elem):
|
||||
# mysql sends back a blank string for any value that
|
||||
# was persisted that was not in the enums; that is, it does no
|
||||
# validation on the incoming data, it "truncates" it to be
|
||||
# the blank string. Return it straight.
|
||||
if elem == "":
|
||||
return elem
|
||||
else:
|
||||
return super(ENUM, self)._object_value_for_elem(elem)
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[ENUM, _StringType, sqltypes.Enum])
|
||||
|
||||
def adapt(self, cls, **kw):
|
||||
return sqltypes.Enum.adapt(self, cls, **kw)
|
||||
|
||||
|
||||
class SET(_EnumeratedValues):
|
||||
"""MySQL SET type."""
|
||||
|
||||
__visit_name__ = 'SET'
|
||||
|
||||
def __init__(self, *values, **kw):
|
||||
"""Construct a SET.
|
||||
|
||||
E.g.::
|
||||
|
||||
Column('myset', SET("foo", "bar", "baz"))
|
||||
|
||||
|
||||
The list of potential values is required in the case that this
|
||||
set will be used to generate DDL for a table, or if the
|
||||
:paramref:`.SET.retrieve_as_bitwise` flag is set to True.
|
||||
|
||||
:param values: The range of valid values for this SET.
|
||||
|
||||
:param convert_unicode: Same flag as that of
|
||||
:paramref:`.String.convert_unicode`.
|
||||
|
||||
:param collation: same as that of :paramref:`.String.collation`
|
||||
|
||||
:param charset: same as that of :paramref:`.VARCHAR.charset`.
|
||||
|
||||
:param ascii: same as that of :paramref:`.VARCHAR.ascii`.
|
||||
|
||||
:param unicode: same as that of :paramref:`.VARCHAR.unicode`.
|
||||
|
||||
:param binary: same as that of :paramref:`.VARCHAR.binary`.
|
||||
|
||||
:param quoting: Defaults to 'auto': automatically determine set value
|
||||
quoting. If all values are surrounded by the same quoting
|
||||
character, then use 'quoted' mode. Otherwise, use 'unquoted' mode.
|
||||
|
||||
'quoted': values in enums are already quoted, they will be used
|
||||
directly when generating the schema - this usage is deprecated.
|
||||
|
||||
'unquoted': values in enums are not quoted, they will be escaped and
|
||||
surrounded by single quotes when generating the schema.
|
||||
|
||||
Previous versions of this type always required manually quoted
|
||||
values to be supplied; future versions will always quote the string
|
||||
literals for you. This is a transitional option.
|
||||
|
||||
.. versionadded:: 0.9.0
|
||||
|
||||
:param retrieve_as_bitwise: if True, the data for the set type will be
|
||||
persisted and selected using an integer value, where a set is coerced
|
||||
into a bitwise mask for persistence. MySQL allows this mode which
|
||||
has the advantage of being able to store values unambiguously,
|
||||
such as the blank string ``''``. The datatype will appear
|
||||
as the expression ``col + 0`` in a SELECT statement, so that the
|
||||
value is coerced into an integer value in result sets.
|
||||
This flag is required if one wishes
|
||||
to persist a set that can store the blank string ``''`` as a value.
|
||||
|
||||
.. warning::
|
||||
|
||||
When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is
|
||||
essential that the list of set values is expressed in the
|
||||
**exact same order** as exists on the MySQL database.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
|
||||
|
||||
"""
|
||||
self.retrieve_as_bitwise = kw.pop('retrieve_as_bitwise', False)
|
||||
values, length = self._init_values(values, kw)
|
||||
self.values = tuple(values)
|
||||
if not self.retrieve_as_bitwise and '' in values:
|
||||
raise exc.ArgumentError(
|
||||
"Can't use the blank value '' in a SET without "
|
||||
"setting retrieve_as_bitwise=True")
|
||||
if self.retrieve_as_bitwise:
|
||||
self._bitmap = dict(
|
||||
(value, 2 ** idx)
|
||||
for idx, value in enumerate(self.values)
|
||||
)
|
||||
self._bitmap.update(
|
||||
(2 ** idx, value)
|
||||
for idx, value in enumerate(self.values)
|
||||
)
|
||||
kw.setdefault('length', length)
|
||||
super(SET, self).__init__(**kw)
|
||||
|
||||
def column_expression(self, colexpr):
|
||||
if self.retrieve_as_bitwise:
|
||||
return sql.type_coerce(
|
||||
sql.type_coerce(colexpr, sqltypes.Integer) + 0,
|
||||
self
|
||||
)
|
||||
else:
|
||||
return colexpr
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.retrieve_as_bitwise:
|
||||
def process(value):
|
||||
if value is not None:
|
||||
value = int(value)
|
||||
|
||||
return set(
|
||||
util.map_bits(self._bitmap.__getitem__, value)
|
||||
)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
super_convert = super(SET, self).result_processor(dialect, coltype)
|
||||
|
||||
def process(value):
|
||||
if isinstance(value, util.string_types):
|
||||
# MySQLdb returns a string, let's parse
|
||||
if super_convert:
|
||||
value = super_convert(value)
|
||||
return set(re.findall(r'[^,]+', value))
|
||||
else:
|
||||
# mysql-connector-python does a naive
|
||||
# split(",") which throws in an empty string
|
||||
if value is not None:
|
||||
value.discard('')
|
||||
return value
|
||||
return process
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
super_convert = super(SET, self).bind_processor(dialect)
|
||||
if self.retrieve_as_bitwise:
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, util.int_types + util.string_types):
|
||||
if super_convert:
|
||||
return super_convert(value)
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
int_value = 0
|
||||
for v in value:
|
||||
int_value |= self._bitmap[v]
|
||||
return int_value
|
||||
else:
|
||||
|
||||
def process(value):
|
||||
# accept strings and int (actually bitflag) values directly
|
||||
if value is not None and not isinstance(
|
||||
value, util.int_types + util.string_types):
|
||||
value = ",".join(value)
|
||||
|
||||
if super_convert:
|
||||
return super_convert(value)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
def adapt(self, impltype, **kw):
|
||||
kw['retrieve_as_bitwise'] = self.retrieve_as_bitwise
|
||||
return util.constructor_copy(
|
||||
self, impltype,
|
||||
*self.values,
|
||||
**kw
|
||||
)
|
||||
102
sqlalchemy/dialects/mysql/gaerdbms.py
Normal file
102
sqlalchemy/dialects/mysql/gaerdbms.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# mysql/gaerdbms.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""
|
||||
.. dialect:: mysql+gaerdbms
|
||||
:name: Google Cloud SQL
|
||||
:dbapi: rdbms
|
||||
:connectstring: mysql+gaerdbms:///<dbname>?instance=<instancename>
|
||||
:url: https://developers.google.com/appengine/docs/python/cloud-sql/\
|
||||
developers-guide
|
||||
|
||||
This dialect is based primarily on the :mod:`.mysql.mysqldb` dialect with
|
||||
minimal changes.
|
||||
|
||||
.. versionadded:: 0.7.8
|
||||
|
||||
.. deprecated:: 1.0 This dialect is **no longer necessary** for
|
||||
Google Cloud SQL; the MySQLdb dialect can be used directly.
|
||||
Cloud SQL now recommends creating connections via the
|
||||
mysql dialect using the URL format
|
||||
|
||||
``mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>``
|
||||
|
||||
|
||||
Pooling
|
||||
-------
|
||||
|
||||
Google App Engine connections appear to be randomly recycled,
|
||||
so the dialect does not pool connections. The :class:`.NullPool`
|
||||
implementation is installed within the :class:`.Engine` by
|
||||
default.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from ...pool import NullPool
|
||||
import re
|
||||
from sqlalchemy.util import warn_deprecated
|
||||
|
||||
|
||||
def _is_dev_environment():
|
||||
return os.environ.get('SERVER_SOFTWARE', '').startswith('Development/')
|
||||
|
||||
|
||||
class MySQLDialect_gaerdbms(MySQLDialect_mysqldb):
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
|
||||
warn_deprecated(
|
||||
"Google Cloud SQL now recommends creating connections via the "
|
||||
"MySQLdb dialect directly, using the URL format "
|
||||
"mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/"
|
||||
"<projectid>:<instancename>"
|
||||
)
|
||||
|
||||
# from django:
|
||||
# http://code.google.com/p/googleappengine/source/
|
||||
# browse/trunk/python/google/storage/speckle/
|
||||
# python/django/backend/base.py#118
|
||||
# see also [ticket:2649]
|
||||
# see also http://stackoverflow.com/q/14224679/34549
|
||||
from google.appengine.api import apiproxy_stub_map
|
||||
|
||||
if _is_dev_environment():
|
||||
from google.appengine.api import rdbms_mysqldb
|
||||
return rdbms_mysqldb
|
||||
elif apiproxy_stub_map.apiproxy.GetStub('rdbms'):
|
||||
from google.storage.speckle.python.api import rdbms_apiproxy
|
||||
return rdbms_apiproxy
|
||||
else:
|
||||
from google.storage.speckle.python.api import rdbms_googleapi
|
||||
return rdbms_googleapi
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
# Cloud SQL connections die at any moment
|
||||
return NullPool
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args()
|
||||
if not _is_dev_environment():
|
||||
# 'dsn' and 'instance' are because we are skipping
|
||||
# the traditional google.api.rdbms wrapper
|
||||
opts['dsn'] = ''
|
||||
opts['instance'] = url.query['instance']
|
||||
return [], opts
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
match = re.compile(r"^(\d+)L?:|^\((\d+)L?,").match(str(exception))
|
||||
# The rdbms api will wrap then re-raise some types of errors
|
||||
# making this regex return no matches.
|
||||
code = match.group(1) or match.group(2) if match else None
|
||||
if code:
|
||||
return int(code)
|
||||
|
||||
dialect = MySQLDialect_gaerdbms
|
||||
79
sqlalchemy/dialects/mysql/json.py
Normal file
79
sqlalchemy/dialects/mysql/json.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# mysql/json.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import json
|
||||
|
||||
from ...sql import elements
|
||||
from ... import types as sqltypes
|
||||
from ... import util
|
||||
|
||||
|
||||
class JSON(sqltypes.JSON):
|
||||
"""MySQL JSON type.
|
||||
|
||||
MySQL supports JSON as of version 5.7. Note that MariaDB does **not**
|
||||
support JSON at the time of this writing.
|
||||
|
||||
The :class:`.mysql.JSON` type supports persistence of JSON values
|
||||
as well as the core index operations provided by :class:`.types.JSON`
|
||||
datatype, by adapting the operations to render the ``JSON_EXTRACT``
|
||||
function at the database level.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _FormatTypeMixin(object):
|
||||
def _format_value(self, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
super_proc = self.string_bind_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
super_proc = self.string_literal_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
|
||||
|
||||
def _format_value(self, value):
|
||||
if isinstance(value, int):
|
||||
value = "$[%s]" % value
|
||||
else:
|
||||
value = '$."%s"' % value
|
||||
return value
|
||||
|
||||
|
||||
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
|
||||
def _format_value(self, value):
|
||||
return "$%s" % (
|
||||
"".join([
|
||||
"[%s]" % elem if isinstance(elem, int)
|
||||
else '."%s"' % elem for elem in value
|
||||
])
|
||||
)
|
||||
@@ -1,28 +1,34 @@
|
||||
"""Support for the MySQL database via the MySQL Connector/Python adapter.
|
||||
# mysql/mysqlconnector.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
MySQL Connector/Python is available at:
|
||||
"""
|
||||
.. dialect:: mysql+mysqlconnector
|
||||
:name: MySQL Connector/Python
|
||||
:dbapi: myconnpy
|
||||
:connectstring: mysql+mysqlconnector://<user>:<password>@\
|
||||
<host>[:<port>]/<dbname>
|
||||
:url: http://dev.mysql.com/downloads/connector/python/
|
||||
|
||||
https://launchpad.net/myconnpy
|
||||
|
||||
Connecting
|
||||
-----------
|
||||
Unicode
|
||||
-------
|
||||
|
||||
Connect string format::
|
||||
|
||||
mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
Please see :ref:`mysql_unicode` for current recommendations on unicode
|
||||
handling.
|
||||
|
||||
"""
|
||||
|
||||
from .base import (MySQLDialect, MySQLExecutionContext,
|
||||
MySQLCompiler, MySQLIdentifierPreparer,
|
||||
BIT)
|
||||
|
||||
from ... import util
|
||||
import re
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import (MySQLDialect,
|
||||
MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer,
|
||||
BIT)
|
||||
|
||||
from sqlalchemy.engine import base as engine_base, default
|
||||
from sqlalchemy.sql import operators as sql_operators
|
||||
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
|
||||
from sqlalchemy import processors
|
||||
|
||||
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
|
||||
|
||||
@@ -31,17 +37,36 @@ class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
|
||||
|
||||
|
||||
class MySQLCompiler_mysqlconnector(MySQLCompiler):
|
||||
def visit_mod(self, binary, **kw):
|
||||
return self.process(binary.left) + " %% " + self.process(binary.right)
|
||||
def visit_mod_binary(self, binary, operator, **kw):
|
||||
if self.dialect._mysqlconnector_double_percents:
|
||||
return self.process(binary.left, **kw) + " %% " + \
|
||||
self.process(binary.right, **kw)
|
||||
else:
|
||||
return self.process(binary.left, **kw) + " % " + \
|
||||
self.process(binary.right, **kw)
|
||||
|
||||
def post_process_text(self, text):
|
||||
return text.replace('%', '%%')
|
||||
if self.dialect._mysqlconnector_double_percents:
|
||||
return text.replace('%', '%%')
|
||||
else:
|
||||
return text
|
||||
|
||||
def escape_literal_column(self, text):
|
||||
if self.dialect._mysqlconnector_double_percents:
|
||||
return text.replace('%', '%%')
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
|
||||
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value.replace("%", "%%")
|
||||
if self.dialect._mysqlconnector_double_percents:
|
||||
return value.replace("%", "%%")
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class _myconnpyBIT(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
@@ -49,10 +74,12 @@ class _myconnpyBIT(BIT):
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
driver = 'mysqlconnector'
|
||||
supports_unicode_statements = True
|
||||
|
||||
supports_unicode_binds = True
|
||||
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
|
||||
@@ -71,6 +98,10 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
}
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def supports_unicode_statements(self):
|
||||
return util.py3k or self._mysqlconnector_version_info > (2, 0)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
from mysql import connector
|
||||
@@ -78,48 +109,75 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, 'allow_local_infile', bool)
|
||||
util.coerce_kw_type(opts, 'autocommit', bool)
|
||||
util.coerce_kw_type(opts, 'buffered', bool)
|
||||
util.coerce_kw_type(opts, 'compress', bool)
|
||||
util.coerce_kw_type(opts, 'connection_timeout', int)
|
||||
util.coerce_kw_type(opts, 'connect_timeout', int)
|
||||
util.coerce_kw_type(opts, 'consume_results', bool)
|
||||
util.coerce_kw_type(opts, 'force_ipv6', bool)
|
||||
util.coerce_kw_type(opts, 'get_warnings', bool)
|
||||
util.coerce_kw_type(opts, 'pool_reset_session', bool)
|
||||
util.coerce_kw_type(opts, 'pool_size', int)
|
||||
util.coerce_kw_type(opts, 'raise_on_warnings', bool)
|
||||
opts['buffered'] = True
|
||||
opts['raise_on_warnings'] = True
|
||||
util.coerce_kw_type(opts, 'raw', bool)
|
||||
util.coerce_kw_type(opts, 'ssl_verify_cert', bool)
|
||||
util.coerce_kw_type(opts, 'use_pure', bool)
|
||||
util.coerce_kw_type(opts, 'use_unicode', bool)
|
||||
|
||||
# unfortunately, MySQL/connector python refuses to release a
|
||||
# cursor without reading fully, so non-buffered isn't an option
|
||||
opts.setdefault('buffered', True)
|
||||
|
||||
# FOUND_ROWS must be set in ClientFlag to enable
|
||||
# supports_sane_rowcount.
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
from mysql.connector.constants import ClientFlag
|
||||
client_flags = opts.get('client_flags', ClientFlag.get_default())
|
||||
client_flags = opts.get(
|
||||
'client_flags', ClientFlag.get_default())
|
||||
client_flags |= ClientFlag.FOUND_ROWS
|
||||
opts['client_flags'] = client_flags
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return [[], opts]
|
||||
|
||||
@util.memoized_property
|
||||
def _mysqlconnector_version_info(self):
|
||||
if self.dbapi and hasattr(self.dbapi, '__version__'):
|
||||
m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?',
|
||||
self.dbapi.__version__)
|
||||
if m:
|
||||
return tuple(
|
||||
int(x)
|
||||
for x in m.group(1, 2, 3)
|
||||
if x is not None)
|
||||
|
||||
@util.memoized_property
|
||||
def _mysqlconnector_double_percents(self):
|
||||
return not util.py3k and self._mysqlconnector_version_info < (2, 0)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
|
||||
from mysql.connector.constants import ClientFlag
|
||||
dbapi_con.set_client_flag(ClientFlag.FOUND_ROWS)
|
||||
|
||||
version = dbapi_con.get_server_version()
|
||||
return tuple(version)
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
return connection.connection.get_characterset_info()
|
||||
return connection.connection.charset
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
try:
|
||||
return exception.orig.errno
|
||||
except AttributeError:
|
||||
return None
|
||||
return exception.errno
|
||||
|
||||
def is_disconnect(self, e):
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
|
||||
exceptions = (self.dbapi.OperationalError,self.dbapi.InterfaceError)
|
||||
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
|
||||
if isinstance(e, exceptions):
|
||||
return e.errno in errnos
|
||||
return e.errno in errnos or \
|
||||
"MySQL Connection not available." in str(e)
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -129,4 +187,17 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
def _compat_fetchone(self, rp, charset=None):
|
||||
return rp.fetchone()
|
||||
|
||||
_isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED',
|
||||
'READ COMMITTED', 'REPEATABLE READ',
|
||||
'AUTOCOMMIT'])
|
||||
|
||||
def _set_isolation_level(self, connection, level):
|
||||
if level == 'AUTOCOMMIT':
|
||||
connection.autocommit = True
|
||||
else:
|
||||
connection.autocommit = False
|
||||
super(MySQLDialect_mysqlconnector, self)._set_isolation_level(
|
||||
connection, level)
|
||||
|
||||
|
||||
dialect = MySQLDialect_mysqlconnector
|
||||
|
||||
@@ -1,85 +1,87 @@
|
||||
"""Support for the MySQL database via the MySQL-python adapter.
|
||||
|
||||
MySQL-Python is available at:
|
||||
|
||||
http://sourceforge.net/projects/mysql-python
|
||||
|
||||
At least version 1.2.1 or 1.2.2 should be used.
|
||||
|
||||
Connecting
|
||||
-----------
|
||||
|
||||
Connect string format::
|
||||
|
||||
mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
|
||||
Character Sets
|
||||
--------------
|
||||
|
||||
Many MySQL server installations default to a ``latin1`` encoding for client
|
||||
connections. All data sent through the connection will be converted into
|
||||
``latin1``, even if you have ``utf8`` or another character set on your tables
|
||||
and columns. With versions 4.1 and higher, you can change the connection
|
||||
character set either through server configuration or by including the
|
||||
``charset`` parameter in the URL used for ``create_engine``. The ``charset``
|
||||
option is passed through to MySQL-Python and has the side-effect of also
|
||||
enabling ``use_unicode`` in the driver by default. For regular encoded
|
||||
strings, also pass ``use_unicode=0`` in the connection arguments::
|
||||
|
||||
# set client encoding to utf8; all strings come back as unicode
|
||||
create_engine('mysql+mysqldb:///mydb?charset=utf8')
|
||||
|
||||
# set client encoding to utf8; all strings come back as utf8 str
|
||||
create_engine('mysql+mysqldb:///mydb?charset=utf8&use_unicode=0')
|
||||
|
||||
Known Issues
|
||||
-------------
|
||||
|
||||
MySQL-python at least as of version 1.2.2 has a serious memory leak related
|
||||
to unicode conversion, a feature which is disabled via ``use_unicode=0``.
|
||||
The recommended connection form with SQLAlchemy is::
|
||||
|
||||
engine = create_engine('mysql://scott:tiger@localhost/test?charset=utf8&use_unicode=0', pool_recycle=3600)
|
||||
|
||||
# mysql/mysqldb.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: mysql+mysqldb
|
||||
:name: MySQL-Python
|
||||
:dbapi: mysqldb
|
||||
:connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
:url: http://sourceforge.net/projects/mysql-python
|
||||
|
||||
.. _mysqldb_unicode:
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
Please see :ref:`mysql_unicode` for current recommendations on unicode
|
||||
handling.
|
||||
|
||||
Py3K Support
|
||||
------------
|
||||
|
||||
Currently, MySQLdb only runs on Python 2 and development has been stopped.
|
||||
`mysqlclient`_ is fork of MySQLdb and provides Python 3 support as well
|
||||
as some bugfixes.
|
||||
|
||||
.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python
|
||||
|
||||
Using MySQLdb with Google Cloud SQL
|
||||
-----------------------------------
|
||||
|
||||
Google Cloud SQL now recommends use of the MySQLdb dialect. Connect
|
||||
using a URL like the following::
|
||||
|
||||
mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
|
||||
|
||||
Server Side Cursors
|
||||
-------------------
|
||||
|
||||
The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
|
||||
|
||||
"""
|
||||
|
||||
from .base import (MySQLDialect, MySQLExecutionContext,
|
||||
MySQLCompiler, MySQLIdentifierPreparer)
|
||||
from .base import TEXT
|
||||
from ... import sql
|
||||
from ... import util
|
||||
import re
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import (MySQLDialect, MySQLExecutionContext,
|
||||
MySQLCompiler, MySQLIdentifierPreparer)
|
||||
from sqlalchemy.engine import base as engine_base, default
|
||||
from sqlalchemy.sql import operators as sql_operators
|
||||
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
|
||||
from sqlalchemy import processors
|
||||
|
||||
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
|
||||
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
if hasattr(self, '_rowcount'):
|
||||
return self._rowcount
|
||||
else:
|
||||
return self.cursor.rowcount
|
||||
|
||||
|
||||
|
||||
|
||||
class MySQLCompiler_mysqldb(MySQLCompiler):
|
||||
def visit_mod(self, binary, **kw):
|
||||
return self.process(binary.left) + " %% " + self.process(binary.right)
|
||||
|
||||
def visit_mod_binary(self, binary, operator, **kw):
|
||||
return self.process(binary.left, **kw) + " %% " + \
|
||||
self.process(binary.right, **kw)
|
||||
|
||||
def post_process_text(self, text):
|
||||
return text.replace('%', '%%')
|
||||
|
||||
|
||||
class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer):
|
||||
|
||||
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value.replace("%", "%%")
|
||||
|
||||
|
||||
class MySQLDialect_mysqldb(MySQLDialect):
|
||||
driver = 'mysqldb'
|
||||
supports_unicode_statements = False
|
||||
supports_unicode_statements = True
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
|
||||
@@ -89,13 +91,20 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
execution_ctx_cls = MySQLExecutionContext_mysqldb
|
||||
statement_compiler = MySQLCompiler_mysqldb
|
||||
preparer = MySQLIdentifierPreparer_mysqldb
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MySQLDialect.colspecs,
|
||||
{
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, server_side_cursors=False, **kwargs):
|
||||
super(MySQLDialect_mysqldb, self).__init__(**kwargs)
|
||||
self.server_side_cursors = server_side_cursors
|
||||
|
||||
@util.langhelpers.memoized_property
|
||||
def supports_server_side_cursors(self):
|
||||
try:
|
||||
cursors = __import__('MySQLdb.cursors').cursors
|
||||
self._sscursor = cursors.SSCursor
|
||||
return True
|
||||
except (ImportError, AttributeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('MySQLdb')
|
||||
@@ -105,6 +114,30 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
if context is not None:
|
||||
context._rowcount = rowcount
|
||||
|
||||
def _check_unicode_returns(self, connection):
|
||||
# work around issue fixed in
|
||||
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
|
||||
# specific issue w/ the utf8_bin collation and unicode returns
|
||||
|
||||
has_utf8_bin = self.server_version_info > (5, ) and \
|
||||
connection.scalar(
|
||||
"show collation where %s = 'utf8' and %s = 'utf8_bin'"
|
||||
% (
|
||||
self.identifier_preparer.quote("Charset"),
|
||||
self.identifier_preparer.quote("Collation")
|
||||
))
|
||||
if has_utf8_bin:
|
||||
additional_tests = [
|
||||
sql.collate(sql.cast(
|
||||
sql.literal_column(
|
||||
"'test collated returns'"),
|
||||
TEXT(charset='utf8')), "utf8_bin")
|
||||
]
|
||||
else:
|
||||
additional_tests = []
|
||||
return super(MySQLDialect_mysqldb, self)._check_unicode_returns(
|
||||
connection, additional_tests)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(database='db', username='user',
|
||||
password='passwd')
|
||||
@@ -112,11 +145,12 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
|
||||
util.coerce_kw_type(opts, 'compress', bool)
|
||||
util.coerce_kw_type(opts, 'connect_timeout', int)
|
||||
util.coerce_kw_type(opts, 'read_timeout', int)
|
||||
util.coerce_kw_type(opts, 'client_flag', int)
|
||||
util.coerce_kw_type(opts, 'local_infile', int)
|
||||
# Note: using either of the below will cause all strings to be returned
|
||||
# as Unicode, both in raw SQL operations and with column types like
|
||||
# String and MSString.
|
||||
# Note: using either of the below will cause all strings to be
|
||||
# returned as Unicode, both in raw SQL operations and with column
|
||||
# types like String and MSString.
|
||||
util.coerce_kw_type(opts, 'use_unicode', bool)
|
||||
util.coerce_kw_type(opts, 'charset', str)
|
||||
|
||||
@@ -124,7 +158,8 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
# query string.
|
||||
|
||||
ssl = {}
|
||||
for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']:
|
||||
keys = ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']
|
||||
for key in keys:
|
||||
if key in opts:
|
||||
ssl[key[4:]] = opts[key]
|
||||
util.coerce_kw_type(ssl, key[4:], str)
|
||||
@@ -137,17 +172,19 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
client_flag = opts.get('client_flag', 0)
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
from MySQLdb.constants import CLIENT as CLIENT_FLAGS
|
||||
CLIENT_FLAGS = __import__(
|
||||
self.dbapi.__name__ + '.constants.CLIENT'
|
||||
).constants.CLIENT
|
||||
client_flag |= CLIENT_FLAGS.FOUND_ROWS
|
||||
except:
|
||||
pass
|
||||
except (AttributeError, ImportError):
|
||||
self.supports_sane_rowcount = False
|
||||
opts['client_flag'] = client_flag
|
||||
return [[], opts]
|
||||
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
r = re.compile(r'[.\-]')
|
||||
for n in r.split(dbapi_con.get_server_info()):
|
||||
try:
|
||||
version.append(int(n))
|
||||
@@ -156,47 +193,36 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
return tuple(version)
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
try:
|
||||
return exception.orig.args[0]
|
||||
except AttributeError:
|
||||
return None
|
||||
return exception.args[0]
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
# Note: MySQL-python 1.2.1c7 seems to ignore changes made
|
||||
# on a connection via set_character_set()
|
||||
if self.server_version_info < (4, 1, 0):
|
||||
try:
|
||||
return connection.connection.character_set_name()
|
||||
except AttributeError:
|
||||
# < 1.2.1 final MySQL-python drivers have no charset support.
|
||||
# a query is needed.
|
||||
pass
|
||||
|
||||
# Prefer 'character_set_results' for the current connection over the
|
||||
# value in the driver. SET NAMES or individual variable SETs will
|
||||
# change the charset without updating the driver's view of the world.
|
||||
#
|
||||
# If it's decided that issuing that sort of SQL leaves you SOL, then
|
||||
# this can prefer the driver value.
|
||||
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
|
||||
opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
|
||||
|
||||
if 'character_set_results' in opts:
|
||||
return opts['character_set_results']
|
||||
try:
|
||||
return connection.connection.character_set_name()
|
||||
# note: the SQL here would be
|
||||
# "SHOW VARIABLES LIKE 'character_set%%'"
|
||||
cset_name = connection.connection.character_set_name
|
||||
except AttributeError:
|
||||
# Still no charset on < 1.2.1 final...
|
||||
if 'character_set' in opts:
|
||||
return opts['character_set']
|
||||
else:
|
||||
util.warn(
|
||||
"Could not detect the connection character set with this "
|
||||
"combination of MySQL server and MySQL-python. "
|
||||
"MySQL-python >= 1.2.2 is recommended. Assuming latin1.")
|
||||
return 'latin1'
|
||||
util.warn(
|
||||
"No 'character_set_name' can be detected with "
|
||||
"this MySQL-Python version; "
|
||||
"please upgrade to a recent version of MySQL-Python. "
|
||||
"Assuming latin1.")
|
||||
return 'latin1'
|
||||
else:
|
||||
return cset_name()
|
||||
|
||||
_isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED',
|
||||
'READ COMMITTED', 'REPEATABLE READ',
|
||||
'AUTOCOMMIT'])
|
||||
|
||||
def _set_isolation_level(self, connection, level):
|
||||
if level == 'AUTOCOMMIT':
|
||||
connection.autocommit(True)
|
||||
else:
|
||||
connection.autocommit(False)
|
||||
super(MySQLDialect_mysqldb, self)._set_isolation_level(connection,
|
||||
level)
|
||||
|
||||
|
||||
dialect = MySQLDialect_mysqldb
|
||||
|
||||
@@ -1,46 +1,31 @@
|
||||
"""Support for the MySQL database via the oursql adapter.
|
||||
# mysql/oursql.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
OurSQL is available at:
|
||||
"""
|
||||
|
||||
http://packages.python.org/oursql/
|
||||
|
||||
Connecting
|
||||
-----------
|
||||
.. dialect:: mysql+oursql
|
||||
:name: OurSQL
|
||||
:dbapi: oursql
|
||||
:connectstring: mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
:url: http://packages.python.org/oursql/
|
||||
|
||||
Connect string format::
|
||||
Unicode
|
||||
-------
|
||||
|
||||
mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
Please see :ref:`mysql_unicode` for current recommendations on unicode
|
||||
handling.
|
||||
|
||||
Character Sets
|
||||
--------------
|
||||
|
||||
oursql defaults to using ``utf8`` as the connection charset, but other
|
||||
encodings may be used instead. Like the MySQL-Python driver, unicode support
|
||||
can be completely disabled::
|
||||
|
||||
# oursql sets the connection charset to utf8 automatically; all strings come
|
||||
# back as utf8 str
|
||||
create_engine('mysql+oursql:///mydb?use_unicode=0')
|
||||
|
||||
To not automatically use ``utf8`` and instead use whatever the connection
|
||||
defaults to, there is a separate parameter::
|
||||
|
||||
# use the default connection charset; all strings come back as unicode
|
||||
create_engine('mysql+oursql:///mydb?default_charset=1')
|
||||
|
||||
# use latin1 as the connection charset; all strings come back as unicode
|
||||
create_engine('mysql+oursql:///mydb?charset=latin1')
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import (BIT, MySQLDialect, MySQLExecutionContext,
|
||||
MySQLCompiler, MySQLIdentifierPreparer)
|
||||
from sqlalchemy.engine import base as engine_base, default
|
||||
from sqlalchemy.sql import operators as sql_operators
|
||||
from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
|
||||
from sqlalchemy import processors
|
||||
|
||||
from .base import (BIT, MySQLDialect, MySQLExecutionContext)
|
||||
from ... import types as sqltypes, util
|
||||
|
||||
|
||||
class _oursqlBIT(BIT):
|
||||
@@ -55,18 +40,17 @@ class MySQLExecutionContext_oursql(MySQLExecutionContext):
|
||||
@property
|
||||
def plain_query(self):
|
||||
return self.execution_options.get('_oursql_plain_query', False)
|
||||
|
||||
|
||||
|
||||
class MySQLDialect_oursql(MySQLDialect):
|
||||
driver = 'oursql'
|
||||
# Py3K
|
||||
# description_encoding = None
|
||||
# Py2K
|
||||
supports_unicode_binds = True
|
||||
supports_unicode_statements = True
|
||||
# end Py2K
|
||||
|
||||
|
||||
if util.py2k:
|
||||
supports_unicode_binds = True
|
||||
supports_unicode_statements = True
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
execution_ctx_cls = MySQLExecutionContext_oursql
|
||||
@@ -84,7 +68,8 @@ class MySQLDialect_oursql(MySQLDialect):
|
||||
return __import__('oursql')
|
||||
|
||||
def do_execute(self, cursor, statement, parameters, context=None):
|
||||
"""Provide an implementation of *cursor.execute(statement, parameters)*."""
|
||||
"""Provide an implementation of
|
||||
*cursor.execute(statement, parameters)*."""
|
||||
|
||||
if context and context.plain_query:
|
||||
cursor.execute(statement, plain_query=True)
|
||||
@@ -95,103 +80,109 @@ class MySQLDialect_oursql(MySQLDialect):
|
||||
connection.cursor().execute('BEGIN', plain_query=True)
|
||||
|
||||
def _xa_query(self, connection, query, xid):
|
||||
# Py2K
|
||||
arg = connection.connection._escape_string(xid)
|
||||
# end Py2K
|
||||
# Py3K
|
||||
# charset = self._connection_charset
|
||||
# arg = connection.connection._escape_string(xid.encode(charset)).decode(charset)
|
||||
connection.execution_options(_oursql_plain_query=True).execute(query % arg)
|
||||
if util.py2k:
|
||||
arg = connection.connection._escape_string(xid)
|
||||
else:
|
||||
charset = self._connection_charset
|
||||
arg = connection.connection._escape_string(
|
||||
xid.encode(charset)).decode(charset)
|
||||
arg = "'%s'" % arg
|
||||
connection.execution_options(
|
||||
_oursql_plain_query=True).execute(query % arg)
|
||||
|
||||
# Because mysql is bad, these methods have to be
|
||||
# Because mysql is bad, these methods have to be
|
||||
# reimplemented to use _PlainQuery. Basically, some queries
|
||||
# refuse to return any data if they're run through
|
||||
# refuse to return any data if they're run through
|
||||
# the parameterized query API, or refuse to be parameterized
|
||||
# in the first place.
|
||||
def do_begin_twophase(self, connection, xid):
|
||||
self._xa_query(connection, 'XA BEGIN "%s"', xid)
|
||||
self._xa_query(connection, 'XA BEGIN %s', xid)
|
||||
|
||||
def do_prepare_twophase(self, connection, xid):
|
||||
self._xa_query(connection, 'XA END "%s"', xid)
|
||||
self._xa_query(connection, 'XA PREPARE "%s"', xid)
|
||||
self._xa_query(connection, 'XA END %s', xid)
|
||||
self._xa_query(connection, 'XA PREPARE %s', xid)
|
||||
|
||||
def do_rollback_twophase(self, connection, xid, is_prepared=True,
|
||||
recover=False):
|
||||
if not is_prepared:
|
||||
self._xa_query(connection, 'XA END "%s"', xid)
|
||||
self._xa_query(connection, 'XA ROLLBACK "%s"', xid)
|
||||
self._xa_query(connection, 'XA END %s', xid)
|
||||
self._xa_query(connection, 'XA ROLLBACK %s', xid)
|
||||
|
||||
def do_commit_twophase(self, connection, xid, is_prepared=True,
|
||||
recover=False):
|
||||
if not is_prepared:
|
||||
self.do_prepare_twophase(connection, xid)
|
||||
self._xa_query(connection, 'XA COMMIT "%s"', xid)
|
||||
|
||||
self._xa_query(connection, 'XA COMMIT %s', xid)
|
||||
|
||||
# Q: why didn't we need all these "plain_query" overrides earlier ?
|
||||
# am i on a newer/older version of OurSQL ?
|
||||
def has_table(self, connection, table_name, schema=None):
|
||||
return MySQLDialect.has_table(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
table_name, schema)
|
||||
|
||||
def get_table_options(self, connection, table_name, schema=None, **kw):
|
||||
return MySQLDialect.get_table_options(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
table_name,
|
||||
schema = schema,
|
||||
**kw
|
||||
return MySQLDialect.has_table(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
table_name,
|
||||
schema
|
||||
)
|
||||
|
||||
def get_table_options(self, connection, table_name, schema=None, **kw):
|
||||
return MySQLDialect.get_table_options(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
table_name,
|
||||
schema=schema,
|
||||
**kw
|
||||
)
|
||||
|
||||
def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
return MySQLDialect.get_columns(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
table_name,
|
||||
schema=schema,
|
||||
**kw
|
||||
return MySQLDialect.get_columns(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
table_name,
|
||||
schema=schema,
|
||||
**kw
|
||||
)
|
||||
|
||||
|
||||
def get_view_names(self, connection, schema=None, **kw):
|
||||
return MySQLDialect.get_view_names(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
schema=schema,
|
||||
**kw
|
||||
return MySQLDialect.get_view_names(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
schema=schema,
|
||||
**kw
|
||||
)
|
||||
|
||||
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
return MySQLDialect.get_table_names(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
schema
|
||||
return MySQLDialect.get_table_names(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
schema
|
||||
)
|
||||
|
||||
|
||||
def get_schema_names(self, connection, **kw):
|
||||
return MySQLDialect.get_schema_names(self,
|
||||
connection.connect().\
|
||||
execution_options(_oursql_plain_query=True),
|
||||
**kw
|
||||
return MySQLDialect.get_schema_names(
|
||||
self,
|
||||
connection.connect().execution_options(_oursql_plain_query=True),
|
||||
**kw
|
||||
)
|
||||
|
||||
|
||||
def initialize(self, connection):
|
||||
return MySQLDialect.initialize(
|
||||
self,
|
||||
connection.execution_options(_oursql_plain_query=True)
|
||||
)
|
||||
|
||||
self,
|
||||
connection.execution_options(_oursql_plain_query=True)
|
||||
)
|
||||
|
||||
def _show_create_table(self, connection, table, charset=None,
|
||||
full_name=None):
|
||||
return MySQLDialect._show_create_table(self,
|
||||
connection.contextual_connect(close_with_result=True).
|
||||
execution_options(_oursql_plain_query=True),
|
||||
table, charset, full_name)
|
||||
return MySQLDialect._show_create_table(
|
||||
self,
|
||||
connection.contextual_connect(close_with_result=True).
|
||||
execution_options(_oursql_plain_query=True),
|
||||
table, charset, full_name
|
||||
)
|
||||
|
||||
def is_disconnect(self, e):
|
||||
if isinstance(e, self.dbapi.ProgrammingError):
|
||||
return e.errno is None and 'cursor' not in e.args[1] and e.args[1].endswith('closed')
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.ProgrammingError):
|
||||
return e.errno is None and 'cursor' not in e.args[1] \
|
||||
and e.args[1].endswith('closed')
|
||||
else:
|
||||
return e.errno in (2006, 2013, 2014, 2045, 2055)
|
||||
|
||||
@@ -203,6 +194,7 @@ class MySQLDialect_oursql(MySQLDialect):
|
||||
util.coerce_kw_type(opts, 'port', int)
|
||||
util.coerce_kw_type(opts, 'compress', bool)
|
||||
util.coerce_kw_type(opts, 'autoping', bool)
|
||||
util.coerce_kw_type(opts, 'raise_on_warnings', bool)
|
||||
|
||||
util.coerce_kw_type(opts, 'default_charset', bool)
|
||||
if opts.pop('default_charset', False):
|
||||
@@ -216,12 +208,22 @@ class MySQLDialect_oursql(MySQLDialect):
|
||||
# supports_sane_rowcount.
|
||||
opts.setdefault('found_rows', True)
|
||||
|
||||
ssl = {}
|
||||
for key in ['ssl_ca', 'ssl_key', 'ssl_cert',
|
||||
'ssl_capath', 'ssl_cipher']:
|
||||
if key in opts:
|
||||
ssl[key[4:]] = opts[key]
|
||||
util.coerce_kw_type(ssl, key[4:], str)
|
||||
del opts[key]
|
||||
if ssl:
|
||||
opts['ssl'] = ssl
|
||||
|
||||
return [[], opts]
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
r = re.compile(r'[.\-]')
|
||||
for n in r.split(dbapi_con.server_info):
|
||||
try:
|
||||
version.append(int(n))
|
||||
@@ -230,14 +232,11 @@ class MySQLDialect_oursql(MySQLDialect):
|
||||
return tuple(version)
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
try:
|
||||
return exception.orig.errno
|
||||
except AttributeError:
|
||||
return None
|
||||
return exception.errno
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
|
||||
return connection.connection.charset
|
||||
|
||||
def _compat_fetchall(self, rp, charset=None):
|
||||
|
||||
70
sqlalchemy/dialects/mysql/pymysql.py
Normal file
70
sqlalchemy/dialects/mysql/pymysql.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# mysql/pymysql.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: mysql+pymysql
|
||||
:name: PyMySQL
|
||||
:dbapi: pymysql
|
||||
:connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>\
|
||||
[?<options>]
|
||||
:url: http://www.pymysql.org/
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
Please see :ref:`mysql_unicode` for current recommendations on unicode
|
||||
handling.
|
||||
|
||||
MySQL-Python Compatibility
|
||||
--------------------------
|
||||
|
||||
The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver,
|
||||
and targets 100% compatibility. Most behavioral notes for MySQL-python apply
|
||||
to the pymysql driver as well.
|
||||
|
||||
"""
|
||||
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from ...util import langhelpers, py3k
|
||||
|
||||
|
||||
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
driver = 'pymysql'
|
||||
|
||||
description_encoding = None
|
||||
|
||||
# generally, these two values should be both True
|
||||
# or both False. PyMySQL unicode tests pass all the way back
|
||||
# to 0.4 either way. See [ticket:3337]
|
||||
supports_unicode_statements = True
|
||||
supports_unicode_binds = True
|
||||
|
||||
def __init__(self, server_side_cursors=False, **kwargs):
|
||||
super(MySQLDialect_pymysql, self).__init__(**kwargs)
|
||||
self.server_side_cursors = server_side_cursors
|
||||
|
||||
@langhelpers.memoized_property
|
||||
def supports_server_side_cursors(self):
|
||||
try:
|
||||
cursors = __import__('pymysql.cursors').cursors
|
||||
self._sscursor = cursors.SSCursor
|
||||
return True
|
||||
except (ImportError, AttributeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('pymysql')
|
||||
|
||||
if py3k:
|
||||
def _extract_error_code(self, exception):
|
||||
if isinstance(exception.args[0], Exception):
|
||||
exception = exception.args[0]
|
||||
return exception.args[0]
|
||||
|
||||
dialect = MySQLDialect_pymysql
|
||||
@@ -1,32 +1,33 @@
|
||||
"""Support for the MySQL database via the pyodbc adapter.
|
||||
|
||||
pyodbc is available at:
|
||||
|
||||
http://pypi.python.org/pypi/pyodbc/
|
||||
|
||||
Connecting
|
||||
----------
|
||||
|
||||
Connect string::
|
||||
|
||||
mysql+pyodbc://<username>:<password>@<dsnname>
|
||||
|
||||
Limitations
|
||||
-----------
|
||||
|
||||
The mysql-pyodbc dialect is subject to unresolved character encoding issues
|
||||
which exist within the current ODBC drivers available.
|
||||
(see http://code.google.com/p/pyodbc/issues/detail?id=25). Consider usage
|
||||
of OurSQL, MySQLdb, or MySQL-connector/Python.
|
||||
# mysql/pyodbc.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
|
||||
from sqlalchemy.connectors.pyodbc import PyODBCConnector
|
||||
from sqlalchemy.engine import base as engine_base
|
||||
from sqlalchemy import util
|
||||
|
||||
.. dialect:: mysql+pyodbc
|
||||
:name: PyODBC
|
||||
:dbapi: pyodbc
|
||||
:connectstring: mysql+pyodbc://<username>:<password>@<dsnname>
|
||||
:url: http://pypi.python.org/pypi/pyodbc/
|
||||
|
||||
.. note:: The PyODBC for MySQL dialect is not well supported, and
|
||||
is subject to unresolved character encoding issues
|
||||
which exist within the current ODBC drivers available.
|
||||
(see http://code.google.com/p/pyodbc/issues/detail?id=25).
|
||||
Other dialects for MySQL are recommended.
|
||||
|
||||
"""
|
||||
|
||||
from .base import MySQLDialect, MySQLExecutionContext
|
||||
from ...connectors.pyodbc import PyODBCConnector
|
||||
from ... import util
|
||||
import re
|
||||
|
||||
|
||||
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
|
||||
|
||||
def get_lastrowid(self):
|
||||
@@ -36,12 +37,13 @@ class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
|
||||
cursor.close()
|
||||
return lastrowid
|
||||
|
||||
|
||||
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
|
||||
supports_unicode_statements = False
|
||||
execution_ctx_cls = MySQLExecutionContext_pyodbc
|
||||
|
||||
pyodbc_driver_name = "MySQL"
|
||||
|
||||
|
||||
def __init__(self, **kw):
|
||||
# deal with http://code.google.com/p/pyodbc/issues/detail?id=25
|
||||
kw.setdefault('convert_unicode', True)
|
||||
@@ -62,11 +64,12 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
|
||||
if opts.get(key, None):
|
||||
return opts[key]
|
||||
|
||||
util.warn("Could not detect the connection character set. Assuming latin1.")
|
||||
util.warn("Could not detect the connection character set. "
|
||||
"Assuming latin1.")
|
||||
return 'latin1'
|
||||
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
m = re.compile(r"\((\d+)\)").search(str(exception.orig.args))
|
||||
m = re.compile(r"\((\d+)\)").search(str(exception.args))
|
||||
c = m.group(1)
|
||||
if c:
|
||||
return int(c)
|
||||
|
||||
450
sqlalchemy/dialects/mysql/reflection.py
Normal file
450
sqlalchemy/dialects/mysql/reflection.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# mysql/reflection.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import re
|
||||
from ... import log, util
|
||||
from ... import types as sqltypes
|
||||
from .enumerated import _EnumeratedValues, SET
|
||||
from .types import DATETIME, TIME, TIMESTAMP
|
||||
|
||||
|
||||
class ReflectedState(object):
|
||||
"""Stores raw information about a SHOW CREATE TABLE statement."""
|
||||
|
||||
def __init__(self):
|
||||
self.columns = []
|
||||
self.table_options = {}
|
||||
self.table_name = None
|
||||
self.keys = []
|
||||
self.constraints = []
|
||||
|
||||
|
||||
@log.class_logger
|
||||
class MySQLTableDefinitionParser(object):
|
||||
"""Parses the results of a SHOW CREATE TABLE statement."""
|
||||
|
||||
def __init__(self, dialect, preparer):
|
||||
self.dialect = dialect
|
||||
self.preparer = preparer
|
||||
self._prep_regexes()
|
||||
|
||||
def parse(self, show_create, charset):
|
||||
state = ReflectedState()
|
||||
state.charset = charset
|
||||
for line in re.split(r'\r?\n', show_create):
|
||||
if line.startswith(' ' + self.preparer.initial_quote):
|
||||
self._parse_column(line, state)
|
||||
# a regular table options line
|
||||
elif line.startswith(') '):
|
||||
self._parse_table_options(line, state)
|
||||
# an ANSI-mode table options line
|
||||
elif line == ')':
|
||||
pass
|
||||
elif line.startswith('CREATE '):
|
||||
self._parse_table_name(line, state)
|
||||
# Not present in real reflection, but may be if
|
||||
# loading from a file.
|
||||
elif not line:
|
||||
pass
|
||||
else:
|
||||
type_, spec = self._parse_constraints(line)
|
||||
if type_ is None:
|
||||
util.warn("Unknown schema content: %r" % line)
|
||||
elif type_ == 'key':
|
||||
state.keys.append(spec)
|
||||
elif type_ == 'constraint':
|
||||
state.constraints.append(spec)
|
||||
else:
|
||||
pass
|
||||
return state
|
||||
|
||||
def _parse_constraints(self, line):
|
||||
"""Parse a KEY or CONSTRAINT line.
|
||||
|
||||
:param line: A line of SHOW CREATE TABLE output
|
||||
"""
|
||||
|
||||
# KEY
|
||||
m = self._re_key.match(line)
|
||||
if m:
|
||||
spec = m.groupdict()
|
||||
# convert columns into name, length pairs
|
||||
spec['columns'] = self._parse_keyexprs(spec['columns'])
|
||||
return 'key', spec
|
||||
|
||||
# CONSTRAINT
|
||||
m = self._re_constraint.match(line)
|
||||
if m:
|
||||
spec = m.groupdict()
|
||||
spec['table'] = \
|
||||
self.preparer.unformat_identifiers(spec['table'])
|
||||
spec['local'] = [c[0]
|
||||
for c in self._parse_keyexprs(spec['local'])]
|
||||
spec['foreign'] = [c[0]
|
||||
for c in self._parse_keyexprs(spec['foreign'])]
|
||||
return 'constraint', spec
|
||||
|
||||
# PARTITION and SUBPARTITION
|
||||
m = self._re_partition.match(line)
|
||||
if m:
|
||||
# Punt!
|
||||
return 'partition', line
|
||||
|
||||
# No match.
|
||||
return (None, line)
|
||||
|
||||
def _parse_table_name(self, line, state):
|
||||
"""Extract the table name.
|
||||
|
||||
:param line: The first line of SHOW CREATE TABLE
|
||||
"""
|
||||
|
||||
regex, cleanup = self._pr_name
|
||||
m = regex.match(line)
|
||||
if m:
|
||||
state.table_name = cleanup(m.group('name'))
|
||||
|
||||
def _parse_table_options(self, line, state):
|
||||
"""Build a dictionary of all reflected table-level options.
|
||||
|
||||
:param line: The final line of SHOW CREATE TABLE output.
|
||||
"""
|
||||
|
||||
options = {}
|
||||
|
||||
if not line or line == ')':
|
||||
pass
|
||||
|
||||
else:
|
||||
rest_of_line = line[:]
|
||||
for regex, cleanup in self._pr_options:
|
||||
m = regex.search(rest_of_line)
|
||||
if not m:
|
||||
continue
|
||||
directive, value = m.group('directive'), m.group('val')
|
||||
if cleanup:
|
||||
value = cleanup(value)
|
||||
options[directive.lower()] = value
|
||||
rest_of_line = regex.sub('', rest_of_line)
|
||||
|
||||
for nope in ('auto_increment', 'data directory', 'index directory'):
|
||||
options.pop(nope, None)
|
||||
|
||||
for opt, val in options.items():
|
||||
state.table_options['%s_%s' % (self.dialect.name, opt)] = val
|
||||
|
||||
def _parse_column(self, line, state):
|
||||
"""Extract column details.
|
||||
|
||||
Falls back to a 'minimal support' variant if full parse fails.
|
||||
|
||||
:param line: Any column-bearing line from SHOW CREATE TABLE
|
||||
"""
|
||||
|
||||
spec = None
|
||||
m = self._re_column.match(line)
|
||||
if m:
|
||||
spec = m.groupdict()
|
||||
spec['full'] = True
|
||||
else:
|
||||
m = self._re_column_loose.match(line)
|
||||
if m:
|
||||
spec = m.groupdict()
|
||||
spec['full'] = False
|
||||
if not spec:
|
||||
util.warn("Unknown column definition %r" % line)
|
||||
return
|
||||
if not spec['full']:
|
||||
util.warn("Incomplete reflection of column definition %r" % line)
|
||||
|
||||
name, type_, args = spec['name'], spec['coltype'], spec['arg']
|
||||
|
||||
try:
|
||||
col_type = self.dialect.ischema_names[type_]
|
||||
except KeyError:
|
||||
util.warn("Did not recognize type '%s' of column '%s'" %
|
||||
(type_, name))
|
||||
col_type = sqltypes.NullType
|
||||
|
||||
# Column type positional arguments eg. varchar(32)
|
||||
if args is None or args == '':
|
||||
type_args = []
|
||||
elif args[0] == "'" and args[-1] == "'":
|
||||
type_args = self._re_csv_str.findall(args)
|
||||
else:
|
||||
type_args = [int(v) for v in self._re_csv_int.findall(args)]
|
||||
|
||||
# Column type keyword options
|
||||
type_kw = {}
|
||||
|
||||
if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)):
|
||||
if type_args:
|
||||
type_kw['fsp'] = type_args.pop(0)
|
||||
|
||||
for kw in ('unsigned', 'zerofill'):
|
||||
if spec.get(kw, False):
|
||||
type_kw[kw] = True
|
||||
for kw in ('charset', 'collate'):
|
||||
if spec.get(kw, False):
|
||||
type_kw[kw] = spec[kw]
|
||||
if issubclass(col_type, _EnumeratedValues):
|
||||
type_args = _EnumeratedValues._strip_values(type_args)
|
||||
|
||||
if issubclass(col_type, SET) and '' in type_args:
|
||||
type_kw['retrieve_as_bitwise'] = True
|
||||
|
||||
type_instance = col_type(*type_args, **type_kw)
|
||||
|
||||
col_kw = {}
|
||||
|
||||
# NOT NULL
|
||||
col_kw['nullable'] = True
|
||||
# this can be "NULL" in the case of TIMESTAMP
|
||||
if spec.get('notnull', False) == 'NOT NULL':
|
||||
col_kw['nullable'] = False
|
||||
|
||||
# AUTO_INCREMENT
|
||||
if spec.get('autoincr', False):
|
||||
col_kw['autoincrement'] = True
|
||||
elif issubclass(col_type, sqltypes.Integer):
|
||||
col_kw['autoincrement'] = False
|
||||
|
||||
# DEFAULT
|
||||
default = spec.get('default', None)
|
||||
|
||||
if default == 'NULL':
|
||||
# eliminates the need to deal with this later.
|
||||
default = None
|
||||
|
||||
col_d = dict(name=name, type=type_instance, default=default)
|
||||
col_d.update(col_kw)
|
||||
state.columns.append(col_d)
|
||||
|
||||
def _describe_to_create(self, table_name, columns):
|
||||
"""Re-format DESCRIBE output as a SHOW CREATE TABLE string.
|
||||
|
||||
DESCRIBE is a much simpler reflection and is sufficient for
|
||||
reflecting views for runtime use. This method formats DDL
|
||||
for columns only- keys are omitted.
|
||||
|
||||
:param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
|
||||
SHOW FULL COLUMNS FROM rows must be rearranged for use with
|
||||
this function.
|
||||
"""
|
||||
|
||||
buffer = []
|
||||
for row in columns:
|
||||
(name, col_type, nullable, default, extra) = \
|
||||
[row[i] for i in (0, 1, 2, 4, 5)]
|
||||
|
||||
line = [' ']
|
||||
line.append(self.preparer.quote_identifier(name))
|
||||
line.append(col_type)
|
||||
if not nullable:
|
||||
line.append('NOT NULL')
|
||||
if default:
|
||||
if 'auto_increment' in default:
|
||||
pass
|
||||
elif (col_type.startswith('timestamp') and
|
||||
default.startswith('C')):
|
||||
line.append('DEFAULT')
|
||||
line.append(default)
|
||||
elif default == 'NULL':
|
||||
line.append('DEFAULT')
|
||||
line.append(default)
|
||||
else:
|
||||
line.append('DEFAULT')
|
||||
line.append("'%s'" % default.replace("'", "''"))
|
||||
if extra:
|
||||
line.append(extra)
|
||||
|
||||
buffer.append(' '.join(line))
|
||||
|
||||
return ''.join([('CREATE TABLE %s (\n' %
|
||||
self.preparer.quote_identifier(table_name)),
|
||||
',\n'.join(buffer),
|
||||
'\n) '])
|
||||
|
||||
def _parse_keyexprs(self, identifiers):
|
||||
"""Unpack '"col"(2),"col" ASC'-ish strings into components."""
|
||||
|
||||
return self._re_keyexprs.findall(identifiers)
|
||||
|
||||
def _prep_regexes(self):
|
||||
"""Pre-compile regular expressions."""
|
||||
|
||||
self._re_columns = []
|
||||
self._pr_options = []
|
||||
|
||||
_final = self.preparer.final_quote
|
||||
|
||||
quotes = dict(zip(('iq', 'fq', 'esc_fq'),
|
||||
[re.escape(s) for s in
|
||||
(self.preparer.initial_quote,
|
||||
_final,
|
||||
self.preparer._escape_identifier(_final))]))
|
||||
|
||||
self._pr_name = _pr_compile(
|
||||
r'^CREATE (?:\w+ +)?TABLE +'
|
||||
r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($' % quotes,
|
||||
self.preparer._unescape_identifier)
|
||||
|
||||
# `col`,`col2`(32),`col3`(15) DESC
|
||||
#
|
||||
# Note: ASC and DESC aren't reflected, so we'll punt...
|
||||
self._re_keyexprs = _re_compile(
|
||||
r'(?:'
|
||||
r'(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)'
|
||||
r'(?:\((\d+)\))?(?=\,|$))+' % quotes)
|
||||
|
||||
# 'foo' or 'foo','bar' or 'fo,o','ba''a''r'
|
||||
self._re_csv_str = _re_compile(r'\x27(?:\x27\x27|[^\x27])*\x27')
|
||||
|
||||
# 123 or 123,456
|
||||
self._re_csv_int = _re_compile(r'\d+')
|
||||
|
||||
# `colname` <type> [type opts]
|
||||
# (NOT NULL | NULL)
|
||||
# DEFAULT ('value' | CURRENT_TIMESTAMP...)
|
||||
# COMMENT 'comment'
|
||||
# COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT)
|
||||
# STORAGE (DISK|MEMORY)
|
||||
self._re_column = _re_compile(
|
||||
r' '
|
||||
r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +'
|
||||
r'(?P<coltype>\w+)'
|
||||
r'(?:\((?P<arg>(?:\d+|\d+,\d+|'
|
||||
r'(?:\x27(?:\x27\x27|[^\x27])*\x27,?)+))\))?'
|
||||
r'(?: +(?P<unsigned>UNSIGNED))?'
|
||||
r'(?: +(?P<zerofill>ZEROFILL))?'
|
||||
r'(?: +CHARACTER SET +(?P<charset>[\w_]+))?'
|
||||
r'(?: +COLLATE +(?P<collate>[\w_]+))?'
|
||||
r'(?: +(?P<notnull>(?:NOT )?NULL))?'
|
||||
r'(?: +DEFAULT +(?P<default>'
|
||||
r'(?:NULL|\x27(?:\x27\x27|[^\x27])*\x27|\w+'
|
||||
r'(?: +ON UPDATE \w+)?)'
|
||||
r'))?'
|
||||
r'(?: +(?P<autoincr>AUTO_INCREMENT))?'
|
||||
r'(?: +COMMENT +(P<comment>(?:\x27\x27|[^\x27])+))?'
|
||||
r'(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?'
|
||||
r'(?: +STORAGE +(?P<storage>\w+))?'
|
||||
r'(?: +(?P<extra>.*))?'
|
||||
r',?$'
|
||||
% quotes
|
||||
)
|
||||
|
||||
# Fallback, try to parse as little as possible
|
||||
self._re_column_loose = _re_compile(
|
||||
r' '
|
||||
r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +'
|
||||
r'(?P<coltype>\w+)'
|
||||
r'(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?'
|
||||
r'.*?(?P<notnull>(?:NOT )NULL)?'
|
||||
% quotes
|
||||
)
|
||||
|
||||
# (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))?
|
||||
# (`col` (ASC|DESC)?, `col` (ASC|DESC)?)
|
||||
# KEY_BLOCK_SIZE size | WITH PARSER name
|
||||
self._re_key = _re_compile(
|
||||
r' '
|
||||
r'(?:(?P<type>\S+) )?KEY'
|
||||
r'(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?'
|
||||
r'(?: +USING +(?P<using_pre>\S+))?'
|
||||
r' +\((?P<columns>.+?)\)'
|
||||
r'(?: +USING +(?P<using_post>\S+))?'
|
||||
r'(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?'
|
||||
r'(?: +WITH PARSER +(?P<parser>\S+))?'
|
||||
r'(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?'
|
||||
r',?$'
|
||||
% quotes
|
||||
)
|
||||
|
||||
# CONSTRAINT `name` FOREIGN KEY (`local_col`)
|
||||
# REFERENCES `remote` (`remote_col`)
|
||||
# MATCH FULL | MATCH PARTIAL | MATCH SIMPLE
|
||||
# ON DELETE CASCADE ON UPDATE RESTRICT
|
||||
#
|
||||
# unique constraints come back as KEYs
|
||||
kw = quotes.copy()
|
||||
kw['on'] = 'RESTRICT|CASCADE|SET NULL|NOACTION'
|
||||
self._re_constraint = _re_compile(
|
||||
r' '
|
||||
r'CONSTRAINT +'
|
||||
r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +'
|
||||
r'FOREIGN KEY +'
|
||||
r'\((?P<local>[^\)]+?)\) REFERENCES +'
|
||||
r'(?P<table>%(iq)s[^%(fq)s]+%(fq)s'
|
||||
r'(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +'
|
||||
r'\((?P<foreign>[^\)]+?)\)'
|
||||
r'(?: +(?P<match>MATCH \w+))?'
|
||||
r'(?: +ON DELETE (?P<ondelete>%(on)s))?'
|
||||
r'(?: +ON UPDATE (?P<onupdate>%(on)s))?'
|
||||
% kw
|
||||
)
|
||||
|
||||
# PARTITION
|
||||
#
|
||||
# punt!
|
||||
self._re_partition = _re_compile(r'(?:.*)(?:SUB)?PARTITION(?:.*)')
|
||||
|
||||
# Table-level options (COLLATE, ENGINE, etc.)
|
||||
# Do the string options first, since they have quoted
|
||||
# strings we need to get rid of.
|
||||
for option in _options_of_type_string:
|
||||
self._add_option_string(option)
|
||||
|
||||
for option in ('ENGINE', 'TYPE', 'AUTO_INCREMENT',
|
||||
'AVG_ROW_LENGTH', 'CHARACTER SET',
|
||||
'DEFAULT CHARSET', 'CHECKSUM',
|
||||
'COLLATE', 'DELAY_KEY_WRITE', 'INSERT_METHOD',
|
||||
'MAX_ROWS', 'MIN_ROWS', 'PACK_KEYS', 'ROW_FORMAT',
|
||||
'KEY_BLOCK_SIZE'):
|
||||
self._add_option_word(option)
|
||||
|
||||
self._add_option_regex('UNION', r'\([^\)]+\)')
|
||||
self._add_option_regex('TABLESPACE', r'.*? STORAGE DISK')
|
||||
self._add_option_regex(
|
||||
'RAID_TYPE',
|
||||
r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+')
|
||||
|
||||
_optional_equals = r'(?:\s*(?:=\s*)|\s+)'
|
||||
|
||||
def _add_option_string(self, directive):
|
||||
regex = (r'(?P<directive>%s)%s'
|
||||
r"'(?P<val>(?:[^']|'')*?)'(?!')" %
|
||||
(re.escape(directive), self._optional_equals))
|
||||
self._pr_options.append(_pr_compile(
|
||||
regex, lambda v: v.replace("\\\\", "\\").replace("''", "'")
|
||||
))
|
||||
|
||||
def _add_option_word(self, directive):
|
||||
regex = (r'(?P<directive>%s)%s'
|
||||
r'(?P<val>\w+)' %
|
||||
(re.escape(directive), self._optional_equals))
|
||||
self._pr_options.append(_pr_compile(regex))
|
||||
|
||||
def _add_option_regex(self, directive, regex):
|
||||
regex = (r'(?P<directive>%s)%s'
|
||||
r'(?P<val>%s)' %
|
||||
(re.escape(directive), self._optional_equals, regex))
|
||||
self._pr_options.append(_pr_compile(regex))
|
||||
|
||||
_options_of_type_string = ('COMMENT', 'DATA DIRECTORY', 'INDEX DIRECTORY',
|
||||
'PASSWORD', 'CONNECTION')
|
||||
|
||||
|
||||
def _pr_compile(regex, cleanup=None):
|
||||
"""Prepare a 2-tuple of compiled regex and callable."""
|
||||
|
||||
return (_re_compile(regex), cleanup)
|
||||
|
||||
|
||||
def _re_compile(regex):
|
||||
"""Compile a string to regex, I and UNICODE."""
|
||||
|
||||
return re.compile(regex, re.I | re.UNICODE)
|
||||
766
sqlalchemy/dialects/mysql/types.py
Normal file
766
sqlalchemy/dialects/mysql/types.py
Normal file
@@ -0,0 +1,766 @@
|
||||
# mysql/types.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import datetime
|
||||
from ... import exc, util
|
||||
from ... import types as sqltypes
|
||||
|
||||
|
||||
class _NumericType(object):
|
||||
"""Base for MySQL numeric types.
|
||||
|
||||
This is the base both for NUMERIC as well as INTEGER, hence
|
||||
it's a mixin.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, unsigned=False, zerofill=False, **kw):
|
||||
self.unsigned = unsigned
|
||||
self.zerofill = zerofill
|
||||
super(_NumericType, self).__init__(**kw)
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(self,
|
||||
to_inspect=[_NumericType, sqltypes.Numeric])
|
||||
|
||||
|
||||
class _FloatType(_NumericType, sqltypes.Float):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
if isinstance(self, (REAL, DOUBLE)) and \
|
||||
(
|
||||
(precision is None and scale is not None) or
|
||||
(precision is not None and scale is None)
|
||||
):
|
||||
raise exc.ArgumentError(
|
||||
"You must specify both precision and scale or omit "
|
||||
"both altogether.")
|
||||
super(_FloatType, self).__init__(
|
||||
precision=precision, asdecimal=asdecimal, **kw)
|
||||
self.scale = scale
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(self, to_inspect=[_FloatType,
|
||||
_NumericType,
|
||||
sqltypes.Float])
|
||||
|
||||
|
||||
class _IntegerType(_NumericType, sqltypes.Integer):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
self.display_width = display_width
|
||||
super(_IntegerType, self).__init__(**kw)
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(self, to_inspect=[_IntegerType,
|
||||
_NumericType,
|
||||
sqltypes.Integer])
|
||||
|
||||
|
||||
class _StringType(sqltypes.String):
|
||||
"""Base for MySQL string types."""
|
||||
|
||||
def __init__(self, charset=None, collation=None,
|
||||
ascii=False, binary=False, unicode=False,
|
||||
national=False, **kw):
|
||||
self.charset = charset
|
||||
|
||||
# allow collate= or collation=
|
||||
kw.setdefault('collation', kw.pop('collate', collation))
|
||||
|
||||
self.ascii = ascii
|
||||
self.unicode = unicode
|
||||
self.binary = binary
|
||||
self.national = national
|
||||
super(_StringType, self).__init__(**kw)
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(self,
|
||||
to_inspect=[_StringType, sqltypes.String])
|
||||
|
||||
|
||||
class _MatchType(sqltypes.Float, sqltypes.MatchType):
|
||||
def __init__(self, **kw):
|
||||
# TODO: float arguments?
|
||||
sqltypes.Float.__init__(self)
|
||||
sqltypes.MatchType.__init__(self)
|
||||
|
||||
|
||||
|
||||
class NUMERIC(_NumericType, sqltypes.NUMERIC):
|
||||
"""MySQL NUMERIC type."""
|
||||
|
||||
__visit_name__ = 'NUMERIC'
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a NUMERIC.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super(NUMERIC, self).__init__(precision=precision,
|
||||
scale=scale, asdecimal=asdecimal, **kw)
|
||||
|
||||
|
||||
class DECIMAL(_NumericType, sqltypes.DECIMAL):
|
||||
"""MySQL DECIMAL type."""
|
||||
|
||||
__visit_name__ = 'DECIMAL'
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a DECIMAL.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super(DECIMAL, self).__init__(precision=precision, scale=scale,
|
||||
asdecimal=asdecimal, **kw)
|
||||
|
||||
|
||||
class DOUBLE(_FloatType):
|
||||
"""MySQL DOUBLE type."""
|
||||
|
||||
__visit_name__ = 'DOUBLE'
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a DOUBLE.
|
||||
|
||||
.. note::
|
||||
|
||||
The :class:`.DOUBLE` type by default converts from float
|
||||
to Decimal, using a truncation that defaults to 10 digits.
|
||||
Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
|
||||
to change this scale, or ``asdecimal=False`` to return values
|
||||
directly as Python floating points.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super(DOUBLE, self).__init__(precision=precision, scale=scale,
|
||||
asdecimal=asdecimal, **kw)
|
||||
|
||||
|
||||
class REAL(_FloatType, sqltypes.REAL):
|
||||
"""MySQL REAL type."""
|
||||
|
||||
__visit_name__ = 'REAL'
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a REAL.
|
||||
|
||||
.. note::
|
||||
|
||||
The :class:`.REAL` type by default converts from float
|
||||
to Decimal, using a truncation that defaults to 10 digits.
|
||||
Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
|
||||
to change this scale, or ``asdecimal=False`` to return values
|
||||
directly as Python floating points.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super(REAL, self).__init__(precision=precision, scale=scale,
|
||||
asdecimal=asdecimal, **kw)
|
||||
|
||||
|
||||
class FLOAT(_FloatType, sqltypes.FLOAT):
|
||||
"""MySQL FLOAT type."""
|
||||
|
||||
__visit_name__ = 'FLOAT'
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
|
||||
"""Construct a FLOAT.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super(FLOAT, self).__init__(precision=precision, scale=scale,
|
||||
asdecimal=asdecimal, **kw)
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
|
||||
class INTEGER(_IntegerType, sqltypes.INTEGER):
|
||||
"""MySQL INTEGER type."""
|
||||
|
||||
__visit_name__ = 'INTEGER'
|
||||
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct an INTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super(INTEGER, self).__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class BIGINT(_IntegerType, sqltypes.BIGINT):
|
||||
"""MySQL BIGINTEGER type."""
|
||||
|
||||
__visit_name__ = 'BIGINT'
|
||||
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a BIGINTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super(BIGINT, self).__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class MEDIUMINT(_IntegerType):
|
||||
"""MySQL MEDIUMINTEGER type."""
|
||||
|
||||
__visit_name__ = 'MEDIUMINT'
|
||||
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a MEDIUMINTEGER
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super(MEDIUMINT, self).__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class TINYINT(_IntegerType):
|
||||
"""MySQL TINYINT type."""
|
||||
|
||||
__visit_name__ = 'TINYINT'
|
||||
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a TINYINT.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super(TINYINT, self).__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class SMALLINT(_IntegerType, sqltypes.SMALLINT):
|
||||
"""MySQL SMALLINTEGER type."""
|
||||
|
||||
__visit_name__ = 'SMALLINT'
|
||||
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a SMALLINTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super(SMALLINT, self).__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class BIT(sqltypes.TypeEngine):
|
||||
"""MySQL BIT type.
|
||||
|
||||
This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
|
||||
for MyISAM, MEMORY, InnoDB and BDB. For older versions, use a
|
||||
MSTinyInteger() type.
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'BIT'
|
||||
|
||||
def __init__(self, length=None):
|
||||
"""Construct a BIT.
|
||||
|
||||
:param length: Optional, number of bits.
|
||||
|
||||
"""
|
||||
self.length = length
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Convert a MySQL's 64 bit, variable length binary string to a long.
|
||||
|
||||
TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector
|
||||
already do this, so this logic should be moved to those dialects.
|
||||
|
||||
"""
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
v = 0
|
||||
for i in value:
|
||||
if not isinstance(i, int):
|
||||
i = ord(i) # convert byte to int on Python 2
|
||||
v = v << 8 | i
|
||||
return v
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class TIME(sqltypes.TIME):
|
||||
"""MySQL TIME type. """
|
||||
|
||||
__visit_name__ = 'TIME'
|
||||
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL TIME type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
:param fsp: fractional seconds precision value.
|
||||
MySQL 5.6 supports storage of fractional seconds;
|
||||
this parameter will be used when emitting DDL
|
||||
for the TIME type.
|
||||
|
||||
.. note::
|
||||
|
||||
DBAPI driver support for fractional seconds may
|
||||
be limited; current support includes
|
||||
MySQL Connector/Python.
|
||||
|
||||
.. versionadded:: 0.8 The MySQL-specific TIME
|
||||
type as well as fractional seconds support.
|
||||
|
||||
"""
|
||||
super(TIME, self).__init__(timezone=timezone)
|
||||
self.fsp = fsp
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
time = datetime.time
|
||||
|
||||
def process(value):
|
||||
# convert from a timedelta value
|
||||
if value is not None:
|
||||
microseconds = value.microseconds
|
||||
seconds = value.seconds
|
||||
minutes = seconds // 60
|
||||
return time(minutes // 60,
|
||||
minutes % 60,
|
||||
seconds - minutes * 60,
|
||||
microsecond=microseconds)
|
||||
else:
|
||||
return None
|
||||
return process
|
||||
|
||||
|
||||
class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
"""MySQL TIMESTAMP type.
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'TIMESTAMP'
|
||||
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL TIMESTAMP type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
:param fsp: fractional seconds precision value.
|
||||
MySQL 5.6.4 supports storage of fractional seconds;
|
||||
this parameter will be used when emitting DDL
|
||||
for the TIMESTAMP type.
|
||||
|
||||
.. note::
|
||||
|
||||
DBAPI driver support for fractional seconds may
|
||||
be limited; current support includes
|
||||
MySQL Connector/Python.
|
||||
|
||||
.. versionadded:: 0.8.5 Added MySQL-specific :class:`.mysql.TIMESTAMP`
|
||||
with fractional seconds support.
|
||||
|
||||
"""
|
||||
super(TIMESTAMP, self).__init__(timezone=timezone)
|
||||
self.fsp = fsp
|
||||
|
||||
|
||||
class DATETIME(sqltypes.DATETIME):
|
||||
"""MySQL DATETIME type.
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'DATETIME'
|
||||
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL DATETIME type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
:param fsp: fractional seconds precision value.
|
||||
MySQL 5.6.4 supports storage of fractional seconds;
|
||||
this parameter will be used when emitting DDL
|
||||
for the DATETIME type.
|
||||
|
||||
.. note::
|
||||
|
||||
DBAPI driver support for fractional seconds may
|
||||
be limited; current support includes
|
||||
MySQL Connector/Python.
|
||||
|
||||
.. versionadded:: 0.8.5 Added MySQL-specific :class:`.mysql.DATETIME`
|
||||
with fractional seconds support.
|
||||
|
||||
"""
|
||||
super(DATETIME, self).__init__(timezone=timezone)
|
||||
self.fsp = fsp
|
||||
|
||||
|
||||
class YEAR(sqltypes.TypeEngine):
|
||||
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
|
||||
|
||||
__visit_name__ = 'YEAR'
|
||||
|
||||
def __init__(self, display_width=None):
|
||||
self.display_width = display_width
|
||||
|
||||
|
||||
class TEXT(_StringType, sqltypes.TEXT):
|
||||
"""MySQL TEXT type, for text up to 2^16 characters."""
|
||||
|
||||
__visit_name__ = 'TEXT'
|
||||
|
||||
def __init__(self, length=None, **kw):
|
||||
"""Construct a TEXT.
|
||||
|
||||
:param length: Optional, if provided the server may optimize storage
|
||||
by substituting the smallest TEXT type sufficient to store
|
||||
``length`` characters.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param national: Optional. If true, use the server's configured
|
||||
national character set.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
super(TEXT, self).__init__(length=length, **kw)
|
||||
|
||||
|
||||
class TINYTEXT(_StringType):
|
||||
"""MySQL TINYTEXT type, for text up to 2^8 characters."""
|
||||
|
||||
__visit_name__ = 'TINYTEXT'
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a TINYTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param national: Optional. If true, use the server's configured
|
||||
national character set.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
super(TINYTEXT, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class MEDIUMTEXT(_StringType):
|
||||
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
|
||||
|
||||
__visit_name__ = 'MEDIUMTEXT'
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a MEDIUMTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param national: Optional. If true, use the server's configured
|
||||
national character set.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
super(MEDIUMTEXT, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class LONGTEXT(_StringType):
|
||||
"""MySQL LONGTEXT type, for text up to 2^32 characters."""
|
||||
|
||||
__visit_name__ = 'LONGTEXT'
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a LONGTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param national: Optional. If true, use the server's configured
|
||||
national character set.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
super(LONGTEXT, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class VARCHAR(_StringType, sqltypes.VARCHAR):
|
||||
"""MySQL VARCHAR type, for variable-length character data."""
|
||||
|
||||
__visit_name__ = 'VARCHAR'
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct a VARCHAR.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param national: Optional. If true, use the server's configured
|
||||
national character set.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
super(VARCHAR, self).__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class CHAR(_StringType, sqltypes.CHAR):
|
||||
"""MySQL CHAR type, for fixed-length character data."""
|
||||
|
||||
__visit_name__ = 'CHAR'
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct a CHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
|
||||
:param binary: Optional, use the default binary collation for the
|
||||
national character set. This does not affect the type of data
|
||||
stored, use a BINARY type for binary data.
|
||||
|
||||
:param collation: Optional, request a particular collation. Must be
|
||||
compatible with the national character set.
|
||||
|
||||
"""
|
||||
super(CHAR, self).__init__(length=length, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _adapt_string_for_cast(self, type_):
|
||||
# copy the given string type into a CHAR
|
||||
# for the purposes of rendering a CAST expression
|
||||
type_ = sqltypes.to_instance(type_)
|
||||
if isinstance(type_, sqltypes.CHAR):
|
||||
return type_
|
||||
elif isinstance(type_, _StringType):
|
||||
return CHAR(
|
||||
length=type_.length,
|
||||
charset=type_.charset,
|
||||
collation=type_.collation,
|
||||
ascii=type_.ascii,
|
||||
binary=type_.binary,
|
||||
unicode=type_.unicode,
|
||||
national=False # not supported in CAST
|
||||
)
|
||||
else:
|
||||
return CHAR(length=type_.length)
|
||||
|
||||
|
||||
class NVARCHAR(_StringType, sqltypes.NVARCHAR):
|
||||
"""MySQL NVARCHAR type.
|
||||
|
||||
For variable-length character data in the server's configured national
|
||||
character set.
|
||||
"""
|
||||
|
||||
__visit_name__ = 'NVARCHAR'
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct an NVARCHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
|
||||
:param binary: Optional, use the default binary collation for the
|
||||
national character set. This does not affect the type of data
|
||||
stored, use a BINARY type for binary data.
|
||||
|
||||
:param collation: Optional, request a particular collation. Must be
|
||||
compatible with the national character set.
|
||||
|
||||
"""
|
||||
kwargs['national'] = True
|
||||
super(NVARCHAR, self).__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class NCHAR(_StringType, sqltypes.NCHAR):
|
||||
"""MySQL NCHAR type.
|
||||
|
||||
For fixed-length character data in the server's configured national
|
||||
character set.
|
||||
"""
|
||||
|
||||
__visit_name__ = 'NCHAR'
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct an NCHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
|
||||
:param binary: Optional, use the default binary collation for the
|
||||
national character set. This does not affect the type of data
|
||||
stored, use a BINARY type for binary data.
|
||||
|
||||
:param collation: Optional, request a particular collation. Must be
|
||||
compatible with the national character set.
|
||||
|
||||
"""
|
||||
kwargs['national'] = True
|
||||
super(NCHAR, self).__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class TINYBLOB(sqltypes._Binary):
|
||||
"""MySQL TINYBLOB type, for binary data up to 2^8 bytes."""
|
||||
|
||||
__visit_name__ = 'TINYBLOB'
|
||||
|
||||
|
||||
class MEDIUMBLOB(sqltypes._Binary):
|
||||
"""MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes."""
|
||||
|
||||
__visit_name__ = 'MEDIUMBLOB'
|
||||
|
||||
|
||||
class LONGBLOB(sqltypes._Binary):
|
||||
"""MySQL LONGBLOB type, for binary data up to 2^32 bytes."""
|
||||
|
||||
__visit_name__ = 'LONGBLOB'
|
||||
@@ -1,17 +1,21 @@
|
||||
"""Support for the MySQL database via Jython's zxjdbc JDBC connector.
|
||||
# mysql/zxjdbc.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
JDBC Driver
|
||||
-----------
|
||||
"""
|
||||
|
||||
The official MySQL JDBC driver is at
|
||||
http://dev.mysql.com/downloads/connector/j/.
|
||||
.. dialect:: mysql+zxjdbc
|
||||
:name: zxjdbc for Jython
|
||||
:dbapi: zxjdbc
|
||||
:connectstring: mysql+zxjdbc://<user>:<password>@<hostname>[:<port>]/\
|
||||
<database>
|
||||
:driverurl: http://dev.mysql.com/downloads/connector/j/
|
||||
|
||||
Connecting
|
||||
----------
|
||||
|
||||
Connect string format:
|
||||
|
||||
mysql+zxjdbc://<user>:<password>@<hostname>[:<port>]/<database>
|
||||
.. note:: Jython is not supported by current versions of SQLAlchemy. The
|
||||
zxjdbc dialect should be considered as experimental.
|
||||
|
||||
Character Sets
|
||||
--------------
|
||||
@@ -20,14 +24,15 @@ SQLAlchemy zxjdbc dialects pass unicode straight through to the
|
||||
zxjdbc/JDBC layer. To allow multiple character sets to be sent from the
|
||||
MySQL Connector/J JDBC driver, by default SQLAlchemy sets its
|
||||
``characterEncoding`` connection property to ``UTF-8``. It may be
|
||||
overriden via a ``create_engine`` URL parameter.
|
||||
overridden via a ``create_engine`` URL parameter.
|
||||
|
||||
"""
|
||||
import re
|
||||
|
||||
from sqlalchemy import types as sqltypes, util
|
||||
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
|
||||
from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext
|
||||
from ... import types as sqltypes, util
|
||||
from ...connectors.zxJDBC import ZxJDBCConnector
|
||||
from .base import BIT, MySQLDialect, MySQLExecutionContext
|
||||
|
||||
|
||||
class _ZxJDBCBit(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
@@ -37,7 +42,7 @@ class _ZxJDBCBit(BIT):
|
||||
return value
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
v = 0L
|
||||
v = 0
|
||||
for i in value:
|
||||
v = v << 8 | (i & 0xff)
|
||||
value = v
|
||||
@@ -82,7 +87,8 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
|
||||
if opts.get(key, None):
|
||||
return opts[key]
|
||||
|
||||
util.warn("Could not detect the connection character set. Assuming latin1.")
|
||||
util.warn("Could not detect the connection character set. "
|
||||
"Assuming latin1.")
|
||||
return 'latin1'
|
||||
|
||||
def _driver_kwargs(self):
|
||||
@@ -92,15 +98,15 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
|
||||
def _extract_error_code(self, exception):
|
||||
# e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist
|
||||
# [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' ()
|
||||
m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.orig.args))
|
||||
m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.args))
|
||||
c = m.group(1)
|
||||
if c:
|
||||
return int(c)
|
||||
|
||||
def _get_server_version_info(self,connection):
|
||||
def _get_server_version_info(self, connection):
|
||||
dbapi_con = connection.connection
|
||||
version = []
|
||||
r = re.compile('[.\-]')
|
||||
r = re.compile(r'[.\-]')
|
||||
for n in r.split(dbapi_con.dbversion):
|
||||
try:
|
||||
version.append(int(n))
|
||||
|
||||
@@ -1,17 +1,24 @@
|
||||
# oracle/__init__.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc
|
||||
|
||||
base.dialect = cx_oracle.dialect
|
||||
|
||||
from sqlalchemy.dialects.oracle.base import \
|
||||
VARCHAR, NVARCHAR, CHAR, DATE, DATETIME, NUMBER,\
|
||||
VARCHAR, NVARCHAR, CHAR, DATE, NUMBER,\
|
||||
BLOB, BFILE, CLOB, NCLOB, TIMESTAMP, RAW,\
|
||||
FLOAT, DOUBLE_PRECISION, LONG, dialect, INTERVAL,\
|
||||
VARCHAR2, NVARCHAR2
|
||||
VARCHAR2, NVARCHAR2, ROWID, dialect
|
||||
|
||||
|
||||
__all__ = (
|
||||
'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'DATETIME', 'NUMBER',
|
||||
'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW',
|
||||
'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL',
|
||||
'VARCHAR2', 'NVARCHAR2'
|
||||
'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'NUMBER',
|
||||
'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW',
|
||||
'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL',
|
||||
'VARCHAR2', 'NVARCHAR2', 'ROWID'
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,19 @@
|
||||
"""Support for the Oracle database via the zxjdbc JDBC connector.
|
||||
# oracle/zxjdbc.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
JDBC Driver
|
||||
-----------
|
||||
"""
|
||||
.. dialect:: oracle+zxjdbc
|
||||
:name: zxJDBC for Jython
|
||||
:dbapi: zxjdbc
|
||||
:connectstring: oracle+zxjdbc://user:pass@host/dbname
|
||||
:driverurl: http://www.oracle.com/technetwork/database/features/jdbc/index-091264.html
|
||||
|
||||
The official Oracle JDBC driver is at
|
||||
http://www.oracle.com/technology/software/tech/java/sqlj_jdbc/index.html.
|
||||
.. note:: Jython is not supported by current versions of SQLAlchemy. The
|
||||
zxjdbc dialect should be considered as experimental.
|
||||
|
||||
"""
|
||||
import decimal
|
||||
@@ -12,12 +21,16 @@ import re
|
||||
|
||||
from sqlalchemy import sql, types as sqltypes, util
|
||||
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
|
||||
from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, OracleExecutionContext
|
||||
from sqlalchemy.engine import base, default
|
||||
from sqlalchemy.dialects.oracle.base import (OracleCompiler,
|
||||
OracleDialect,
|
||||
OracleExecutionContext)
|
||||
from sqlalchemy.engine import result as _result
|
||||
from sqlalchemy.sql import expression
|
||||
import collections
|
||||
|
||||
SQLException = zxJDBC = None
|
||||
|
||||
|
||||
class _ZxJDBCDate(sqltypes.Date):
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
@@ -32,7 +45,7 @@ class _ZxJDBCDate(sqltypes.Date):
|
||||
class _ZxJDBCNumeric(sqltypes.Numeric):
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
#XXX: does the dialect return Decimal or not???
|
||||
# XXX: does the dialect return Decimal or not???
|
||||
# if it does (in all cases), we could use a None processor as well as
|
||||
# the to_float generic processor
|
||||
if self.asdecimal:
|
||||
@@ -53,10 +66,11 @@ class _ZxJDBCNumeric(sqltypes.Numeric):
|
||||
class OracleCompiler_zxjdbc(OracleCompiler):
|
||||
|
||||
def returning_clause(self, stmt, returning_cols):
|
||||
self.returning_cols = list(expression._select_iterables(returning_cols))
|
||||
self.returning_cols = list(
|
||||
expression._select_iterables(returning_cols))
|
||||
|
||||
# within_columns_clause=False so that labels (foo AS bar) don't render
|
||||
columns = [self.process(c, within_columns_clause=False, result_map=self.result_map)
|
||||
columns = [self.process(c, within_columns_clause=False)
|
||||
for c in self.returning_cols]
|
||||
|
||||
if not hasattr(self, 'returning_parameters'):
|
||||
@@ -64,14 +78,17 @@ class OracleCompiler_zxjdbc(OracleCompiler):
|
||||
|
||||
binds = []
|
||||
for i, col in enumerate(self.returning_cols):
|
||||
dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
|
||||
dbtype = col.type.dialect_impl(
|
||||
self.dialect).get_dbapi_type(self.dialect.dbapi)
|
||||
self.returning_parameters.append((i + 1, dbtype))
|
||||
|
||||
bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype))
|
||||
bindparam = sql.bindparam(
|
||||
"ret_%d" % i, value=ReturningParam(dbtype))
|
||||
self.binds[bindparam.key] = bindparam
|
||||
binds.append(self.bindparam_string(self._truncate_bindparam(bindparam)))
|
||||
binds.append(
|
||||
self.bindparam_string(self._truncate_bindparam(bindparam)))
|
||||
|
||||
return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
|
||||
return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
|
||||
|
||||
|
||||
class OracleExecutionContext_zxjdbc(OracleExecutionContext):
|
||||
@@ -88,15 +105,19 @@ class OracleExecutionContext_zxjdbc(OracleExecutionContext):
|
||||
try:
|
||||
try:
|
||||
rrs = self.statement.__statement__.getReturnResultSet()
|
||||
rrs.next()
|
||||
except SQLException, sqle:
|
||||
msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode())
|
||||
next(rrs)
|
||||
except SQLException as sqle:
|
||||
msg = '%s [SQLCode: %d]' % (
|
||||
sqle.getMessage(), sqle.getErrorCode())
|
||||
if sqle.getSQLState() is not None:
|
||||
msg += ' [SQLState: %s]' % sqle.getSQLState()
|
||||
raise zxJDBC.Error(msg)
|
||||
else:
|
||||
row = tuple(self.cursor.datahandler.getPyObject(rrs, index, dbtype)
|
||||
for index, dbtype in self.compiled.returning_parameters)
|
||||
row = tuple(
|
||||
self.cursor.datahandler.getPyObject(
|
||||
rrs, index, dbtype)
|
||||
for index, dbtype in
|
||||
self.compiled.returning_parameters)
|
||||
return ReturningResultProxy(self, row)
|
||||
finally:
|
||||
if rrs is not None:
|
||||
@@ -106,15 +127,15 @@ class OracleExecutionContext_zxjdbc(OracleExecutionContext):
|
||||
pass
|
||||
self.statement.close()
|
||||
|
||||
return base.ResultProxy(self)
|
||||
return _result.ResultProxy(self)
|
||||
|
||||
def create_cursor(self):
|
||||
cursor = self._connection.connection.cursor()
|
||||
cursor = self._dbapi_connection.cursor()
|
||||
cursor.datahandler = self.dialect.DataHandler(cursor.datahandler)
|
||||
return cursor
|
||||
|
||||
|
||||
class ReturningResultProxy(base.FullyBufferedResultProxy):
|
||||
class ReturningResultProxy(_result.FullyBufferedResultProxy):
|
||||
|
||||
"""ResultProxy backed by the RETURNING ResultSet results."""
|
||||
|
||||
@@ -132,7 +153,7 @@ class ReturningResultProxy(base.FullyBufferedResultProxy):
|
||||
return ret
|
||||
|
||||
def _buffer_rows(self):
|
||||
return [self._returning_row]
|
||||
return collections.deque([self._returning_row])
|
||||
|
||||
|
||||
class ReturningParam(object):
|
||||
@@ -157,8 +178,8 @@ class ReturningParam(object):
|
||||
|
||||
def __repr__(self):
|
||||
kls = self.__class__
|
||||
return '<%s.%s object at 0x%x type=%s>' % (kls.__module__, kls.__name__, id(self),
|
||||
self.type)
|
||||
return '<%s.%s object at 0x%x type=%s>' % (
|
||||
kls.__module__, kls.__name__, id(self), self.type)
|
||||
|
||||
|
||||
class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
|
||||
@@ -171,7 +192,7 @@ class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
|
||||
colspecs = util.update_copy(
|
||||
OracleDialect.colspecs,
|
||||
{
|
||||
sqltypes.Date : _ZxJDBCDate,
|
||||
sqltypes.Date: _ZxJDBCDate,
|
||||
sqltypes.Numeric: _ZxJDBCNumeric
|
||||
}
|
||||
)
|
||||
@@ -182,28 +203,33 @@ class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
|
||||
from java.sql import SQLException
|
||||
from com.ziclix.python.sql import zxJDBC
|
||||
from com.ziclix.python.sql.handler import OracleDataHandler
|
||||
class OracleReturningDataHandler(OracleDataHandler):
|
||||
|
||||
class OracleReturningDataHandler(OracleDataHandler):
|
||||
"""zxJDBC DataHandler that specially handles ReturningParam."""
|
||||
|
||||
def setJDBCObject(self, statement, index, object, dbtype=None):
|
||||
if type(object) is ReturningParam:
|
||||
statement.registerReturnParameter(index, object.type)
|
||||
elif dbtype is None:
|
||||
OracleDataHandler.setJDBCObject(self, statement, index, object)
|
||||
OracleDataHandler.setJDBCObject(
|
||||
self, statement, index, object)
|
||||
else:
|
||||
OracleDataHandler.setJDBCObject(self, statement, index, object, dbtype)
|
||||
OracleDataHandler.setJDBCObject(
|
||||
self, statement, index, object, dbtype)
|
||||
self.DataHandler = OracleReturningDataHandler
|
||||
|
||||
def initialize(self, connection):
|
||||
super(OracleDialect_zxjdbc, self).initialize(connection)
|
||||
self.implicit_returning = connection.connection.driverversion >= '10.2'
|
||||
self.implicit_returning = \
|
||||
connection.connection.driverversion >= '10.2'
|
||||
|
||||
def _create_jdbc_url(self, url):
|
||||
return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database)
|
||||
return 'jdbc:oracle:thin:@%s:%s:%s' % (
|
||||
url.host, url.port or 1521, url.database)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
|
||||
version = re.search(
|
||||
r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
|
||||
return tuple(int(x) for x in version.split('.'))
|
||||
|
||||
dialect = OracleDialect_zxjdbc
|
||||
|
||||
314
sqlalchemy/dialects/postgresql/array.py
Normal file
314
sqlalchemy/dialects/postgresql/array.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# postgresql/array.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .base import ischema_names
|
||||
from ...sql import expression, operators
|
||||
from ...sql.base import SchemaEventTarget
|
||||
from ... import types as sqltypes
|
||||
|
||||
try:
|
||||
from uuid import UUID as _python_UUID
|
||||
except ImportError:
|
||||
_python_UUID = None
|
||||
|
||||
|
||||
def Any(other, arrexpr, operator=operators.eq):
|
||||
"""A synonym for the :meth:`.ARRAY.Comparator.any` method.
|
||||
|
||||
This method is legacy and is here for backwards-compatibility.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`.expression.any_`
|
||||
|
||||
"""
|
||||
|
||||
return arrexpr.any(other, operator)
|
||||
|
||||
|
||||
def All(other, arrexpr, operator=operators.eq):
|
||||
"""A synonym for the :meth:`.ARRAY.Comparator.all` method.
|
||||
|
||||
This method is legacy and is here for backwards-compatibility.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`.expression.all_`
|
||||
|
||||
"""
|
||||
|
||||
return arrexpr.all(other, operator)
|
||||
|
||||
|
||||
class array(expression.Tuple):
|
||||
|
||||
"""A PostgreSQL ARRAY literal.
|
||||
|
||||
This is used to produce ARRAY literals in SQL expressions, e.g.::
|
||||
|
||||
from sqlalchemy.dialects.postgresql import array
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import select, func
|
||||
|
||||
stmt = select([
|
||||
array([1,2]) + array([3,4,5])
|
||||
])
|
||||
|
||||
print stmt.compile(dialect=postgresql.dialect())
|
||||
|
||||
Produces the SQL::
|
||||
|
||||
SELECT ARRAY[%(param_1)s, %(param_2)s] ||
|
||||
ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
|
||||
|
||||
An instance of :class:`.array` will always have the datatype
|
||||
:class:`.ARRAY`. The "inner" type of the array is inferred from
|
||||
the values present, unless the ``type_`` keyword argument is passed::
|
||||
|
||||
array(['foo', 'bar'], type_=CHAR)
|
||||
|
||||
.. versionadded:: 0.8 Added the :class:`~.postgresql.array` literal type.
|
||||
|
||||
See also:
|
||||
|
||||
:class:`.postgresql.ARRAY`
|
||||
|
||||
"""
|
||||
__visit_name__ = 'array'
|
||||
|
||||
def __init__(self, clauses, **kw):
|
||||
super(array, self).__init__(*clauses, **kw)
|
||||
self.type = ARRAY(self.type)
|
||||
|
||||
def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
|
||||
if _assume_scalar or operator is operators.getitem:
|
||||
# if getitem->slice were called, Indexable produces
|
||||
# a Slice object from that
|
||||
assert isinstance(obj, int)
|
||||
return expression.BindParameter(
|
||||
None, obj, _compared_to_operator=operator,
|
||||
type_=type_,
|
||||
_compared_to_type=self.type, unique=True)
|
||||
|
||||
else:
|
||||
return array([
|
||||
self._bind_param(operator, o, _assume_scalar=True, type_=type_)
|
||||
for o in obj])
|
||||
|
||||
def self_group(self, against=None):
|
||||
if (against in (
|
||||
operators.any_op, operators.all_op, operators.getitem)):
|
||||
return expression.Grouping(self)
|
||||
else:
|
||||
return self
|
||||
|
||||
|
||||
CONTAINS = operators.custom_op("@>", precedence=5)
|
||||
|
||||
CONTAINED_BY = operators.custom_op("<@", precedence=5)
|
||||
|
||||
OVERLAP = operators.custom_op("&&", precedence=5)
|
||||
|
||||
|
||||
class ARRAY(SchemaEventTarget, sqltypes.ARRAY):
|
||||
|
||||
"""PostgreSQL ARRAY type.
|
||||
|
||||
.. versionchanged:: 1.1 The :class:`.postgresql.ARRAY` type is now
|
||||
a subclass of the core :class:`.types.ARRAY` type.
|
||||
|
||||
The :class:`.postgresql.ARRAY` type is constructed in the same way
|
||||
as the core :class:`.types.ARRAY` type; a member type is required, and a
|
||||
number of dimensions is recommended if the type is to be used for more
|
||||
than one dimension::
|
||||
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
mytable = Table("mytable", metadata,
|
||||
Column("data", postgresql.ARRAY(Integer, dimensions=2))
|
||||
)
|
||||
|
||||
The :class:`.postgresql.ARRAY` type provides all operations defined on the
|
||||
core :class:`.types.ARRAY` type, including support for "dimensions", indexed
|
||||
access, and simple matching such as :meth:`.types.ARRAY.Comparator.any`
|
||||
and :meth:`.types.ARRAY.Comparator.all`. :class:`.postgresql.ARRAY` class also
|
||||
provides PostgreSQL-specific methods for containment operations, including
|
||||
:meth:`.postgresql.ARRAY.Comparator.contains`
|
||||
:meth:`.postgresql.ARRAY.Comparator.contained_by`,
|
||||
and :meth:`.postgresql.ARRAY.Comparator.overlap`, e.g.::
|
||||
|
||||
mytable.c.data.contains([1, 2])
|
||||
|
||||
The :class:`.postgresql.ARRAY` type may not be supported on all
|
||||
PostgreSQL DBAPIs; it is currently known to work on psycopg2 only.
|
||||
|
||||
Additionally, the :class:`.postgresql.ARRAY` type does not work directly in
|
||||
conjunction with the :class:`.ENUM` type. For a workaround, see the
|
||||
special type at :ref:`postgresql_array_of_enum`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.types.ARRAY` - base array type
|
||||
|
||||
:class:`.postgresql.array` - produces a literal array value.
|
||||
|
||||
"""
|
||||
|
||||
class Comparator(sqltypes.ARRAY.Comparator):
|
||||
|
||||
"""Define comparison operations for :class:`.ARRAY`.
|
||||
|
||||
Note that these operations are in addition to those provided
|
||||
by the base :class:`.types.ARRAY.Comparator` class, including
|
||||
:meth:`.types.ARRAY.Comparator.any` and
|
||||
:meth:`.types.ARRAY.Comparator.all`.
|
||||
|
||||
"""
|
||||
|
||||
def contains(self, other, **kwargs):
|
||||
"""Boolean expression. Test if elements are a superset of the
|
||||
elements of the argument array expression.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other):
|
||||
"""Boolean expression. Test if elements are a proper subset of the
|
||||
elements of the argument array expression.
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def overlap(self, other):
|
||||
"""Boolean expression. Test if array has elements in common with
|
||||
an argument array expression.
|
||||
"""
|
||||
return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
|
||||
|
||||
comparator_factory = Comparator
|
||||
|
||||
def __init__(self, item_type, as_tuple=False, dimensions=None,
|
||||
zero_indexes=False):
|
||||
"""Construct an ARRAY.
|
||||
|
||||
E.g.::
|
||||
|
||||
Column('myarray', ARRAY(Integer))
|
||||
|
||||
Arguments are:
|
||||
|
||||
:param item_type: The data type of items of this array. Note that
|
||||
dimensionality is irrelevant here, so multi-dimensional arrays like
|
||||
``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as
|
||||
``ARRAY(ARRAY(Integer))`` or such.
|
||||
|
||||
:param as_tuple=False: Specify whether return results
|
||||
should be converted to tuples from lists. DBAPIs such
|
||||
as psycopg2 return lists by default. When tuples are
|
||||
returned, the results are hashable.
|
||||
|
||||
:param dimensions: if non-None, the ARRAY will assume a fixed
|
||||
number of dimensions. This will cause the DDL emitted for this
|
||||
ARRAY to include the exact number of bracket clauses ``[]``,
|
||||
and will also optimize the performance of the type overall.
|
||||
Note that PG arrays are always implicitly "non-dimensioned",
|
||||
meaning they can store any number of dimensions no matter how
|
||||
they were declared.
|
||||
|
||||
:param zero_indexes=False: when True, index values will be converted
|
||||
between Python zero-based and PostgreSQL one-based indexes, e.g.
|
||||
a value of one will be added to all index values before passing
|
||||
to the database.
|
||||
|
||||
.. versionadded:: 0.9.5
|
||||
|
||||
|
||||
"""
|
||||
if isinstance(item_type, ARRAY):
|
||||
raise ValueError("Do not nest ARRAY types; ARRAY(basetype) "
|
||||
"handles multi-dimensional arrays of basetype")
|
||||
if isinstance(item_type, type):
|
||||
item_type = item_type()
|
||||
self.item_type = item_type
|
||||
self.as_tuple = as_tuple
|
||||
self.dimensions = dimensions
|
||||
self.zero_indexes = zero_indexes
|
||||
|
||||
@property
|
||||
def hashable(self):
|
||||
return self.as_tuple
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return list
|
||||
|
||||
def compare_values(self, x, y):
|
||||
return x == y
|
||||
|
||||
def _set_parent(self, column):
|
||||
"""Support SchemaEventTarget"""
|
||||
|
||||
if isinstance(self.item_type, SchemaEventTarget):
|
||||
self.item_type._set_parent(column)
|
||||
|
||||
def _set_parent_with_dispatch(self, parent):
|
||||
"""Support SchemaEventTarget"""
|
||||
|
||||
if isinstance(self.item_type, SchemaEventTarget):
|
||||
self.item_type._set_parent_with_dispatch(parent)
|
||||
|
||||
def _proc_array(self, arr, itemproc, dim, collection):
|
||||
if dim is None:
|
||||
arr = list(arr)
|
||||
if dim == 1 or dim is None and (
|
||||
# this has to be (list, tuple), or at least
|
||||
# not hasattr('__iter__'), since Py3K strings
|
||||
# etc. have __iter__
|
||||
not arr or not isinstance(arr[0], (list, tuple))):
|
||||
if itemproc:
|
||||
return collection(itemproc(x) for x in arr)
|
||||
else:
|
||||
return collection(arr)
|
||||
else:
|
||||
return collection(
|
||||
self._proc_array(
|
||||
x, itemproc,
|
||||
dim - 1 if dim is not None else None,
|
||||
collection)
|
||||
for x in arr
|
||||
)
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
item_proc = self.item_type.dialect_impl(dialect).\
|
||||
bind_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
return self._proc_array(
|
||||
value,
|
||||
item_proc,
|
||||
self.dimensions,
|
||||
list)
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
item_proc = self.item_type.dialect_impl(dialect).\
|
||||
result_processor(dialect, coltype)
|
||||
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
return self._proc_array(
|
||||
value,
|
||||
item_proc,
|
||||
self.dimensions,
|
||||
tuple if self.as_tuple else list)
|
||||
return process
|
||||
|
||||
ischema_names['_array'] = ARRAY
|
||||
213
sqlalchemy/dialects/postgresql/dml.py
Normal file
213
sqlalchemy/dialects/postgresql/dml.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# postgresql/on_conflict.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from ...sql.elements import ClauseElement, _literal_as_binds
|
||||
from ...sql.dml import Insert as StandardInsert
|
||||
from ...sql.expression import alias
|
||||
from ...sql import schema
|
||||
from ...util.langhelpers import public_factory
|
||||
from ...sql.base import _generative
|
||||
from ... import util
|
||||
from . import ext
|
||||
|
||||
__all__ = ('Insert', 'insert')
|
||||
|
||||
|
||||
class Insert(StandardInsert):
|
||||
"""PostgreSQL-specific implementation of INSERT.
|
||||
|
||||
Adds methods for PG-specific syntaxes such as ON CONFLICT.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
"""
|
||||
|
||||
@util.memoized_property
|
||||
def excluded(self):
|
||||
"""Provide the ``excluded`` namespace for an ON CONFLICT statement
|
||||
|
||||
PG's ON CONFLICT clause allows reference to the row that would
|
||||
be inserted, known as ``excluded``. This attribute provides
|
||||
all columns in this row to be referenaceable.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`postgresql_insert_on_conflict` - example of how
|
||||
to use :attr:`.Insert.excluded`
|
||||
|
||||
"""
|
||||
return alias(self.table, name='excluded').columns
|
||||
|
||||
@_generative
|
||||
def on_conflict_do_update(
|
||||
self,
|
||||
constraint=None, index_elements=None,
|
||||
index_where=None, set_=None, where=None):
|
||||
"""
|
||||
Specifies a DO UPDATE SET action for ON CONFLICT clause.
|
||||
|
||||
Either the ``constraint`` or ``index_elements`` argument is
|
||||
required, but only one of these can be specified.
|
||||
|
||||
:param constraint:
|
||||
The name of a unique or exclusion constraint on the table,
|
||||
or the constraint object itself if it has a .name attribute.
|
||||
|
||||
:param index_elements:
|
||||
A sequence consisting of string column names, :class:`.Column`
|
||||
objects, or other column expression objects that will be used
|
||||
to infer a target index.
|
||||
|
||||
:param index_where:
|
||||
Additional WHERE criterion that can be used to infer a
|
||||
conditional target index.
|
||||
|
||||
:param set_:
|
||||
Required argument. A dictionary or other mapping object
|
||||
with column names as keys and expressions or literals as values,
|
||||
specifying the ``SET`` actions to take.
|
||||
If the target :class:`.Column` specifies a ".key" attribute distinct
|
||||
from the column name, that key should be used.
|
||||
|
||||
.. warning:: This dictionary does **not** take into account
|
||||
Python-specified default UPDATE values or generation functions,
|
||||
e.g. those specified using :paramref:`.Column.onupdate`.
|
||||
These values will not be exercised for an ON CONFLICT style of
|
||||
UPDATE, unless they are manually specified in the
|
||||
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
|
||||
|
||||
:param where:
|
||||
Optional argument. If present, can be a literal SQL
|
||||
string or an acceptable expression for a ``WHERE`` clause
|
||||
that restricts the rows affected by ``DO UPDATE SET``. Rows
|
||||
not meeting the ``WHERE`` condition will not be updated
|
||||
(effectively a ``DO NOTHING`` for those rows).
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`postgresql_insert_on_conflict`
|
||||
|
||||
"""
|
||||
self._post_values_clause = OnConflictDoUpdate(
|
||||
constraint, index_elements, index_where, set_, where)
|
||||
return self
|
||||
|
||||
@_generative
|
||||
def on_conflict_do_nothing(
|
||||
self,
|
||||
constraint=None, index_elements=None, index_where=None):
|
||||
"""
|
||||
Specifies a DO NOTHING action for ON CONFLICT clause.
|
||||
|
||||
The ``constraint`` and ``index_elements`` arguments
|
||||
are optional, but only one of these can be specified.
|
||||
|
||||
:param constraint:
|
||||
The name of a unique or exclusion constraint on the table,
|
||||
or the constraint object itself if it has a .name attribute.
|
||||
|
||||
:param index_elements:
|
||||
A sequence consisting of string column names, :class:`.Column`
|
||||
objects, or other column expression objects that will be used
|
||||
to infer a target index.
|
||||
|
||||
:param index_where:
|
||||
Additional WHERE criterion that can be used to infer a
|
||||
conditional target index.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`postgresql_insert_on_conflict`
|
||||
|
||||
"""
|
||||
self._post_values_clause = OnConflictDoNothing(
|
||||
constraint, index_elements, index_where)
|
||||
return self
|
||||
|
||||
insert = public_factory(Insert, '.dialects.postgresql.insert')
|
||||
|
||||
|
||||
class OnConflictClause(ClauseElement):
|
||||
def __init__(
|
||||
self,
|
||||
constraint=None,
|
||||
index_elements=None,
|
||||
index_where=None):
|
||||
|
||||
if constraint is not None:
|
||||
if not isinstance(constraint, util.string_types) and \
|
||||
isinstance(constraint, (
|
||||
schema.Index, schema.Constraint,
|
||||
ext.ExcludeConstraint)):
|
||||
constraint = getattr(constraint, 'name') or constraint
|
||||
|
||||
if constraint is not None:
|
||||
if index_elements is not None:
|
||||
raise ValueError(
|
||||
"'constraint' and 'index_elements' are mutually exclusive")
|
||||
|
||||
if isinstance(constraint, util.string_types):
|
||||
self.constraint_target = constraint
|
||||
self.inferred_target_elements = None
|
||||
self.inferred_target_whereclause = None
|
||||
elif isinstance(constraint, schema.Index):
|
||||
index_elements = constraint.expressions
|
||||
index_where = \
|
||||
constraint.dialect_options['postgresql'].get("where")
|
||||
elif isinstance(constraint, ext.ExcludeConstraint):
|
||||
index_elements = constraint.columns
|
||||
index_where = constraint.where
|
||||
else:
|
||||
index_elements = constraint.columns
|
||||
index_where = \
|
||||
constraint.dialect_options['postgresql'].get("where")
|
||||
|
||||
if index_elements is not None:
|
||||
self.constraint_target = None
|
||||
self.inferred_target_elements = index_elements
|
||||
self.inferred_target_whereclause = index_where
|
||||
elif constraint is None:
|
||||
self.constraint_target = self.inferred_target_elements = \
|
||||
self.inferred_target_whereclause = None
|
||||
|
||||
|
||||
class OnConflictDoNothing(OnConflictClause):
|
||||
__visit_name__ = 'on_conflict_do_nothing'
|
||||
|
||||
|
||||
class OnConflictDoUpdate(OnConflictClause):
|
||||
__visit_name__ = 'on_conflict_do_update'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constraint=None,
|
||||
index_elements=None,
|
||||
index_where=None,
|
||||
set_=None,
|
||||
where=None):
|
||||
super(OnConflictDoUpdate, self).__init__(
|
||||
constraint=constraint,
|
||||
index_elements=index_elements,
|
||||
index_where=index_where)
|
||||
|
||||
if self.inferred_target_elements is None and \
|
||||
self.constraint_target is None:
|
||||
raise ValueError(
|
||||
"Either constraint or index_elements, "
|
||||
"but not both, must be specified unless DO NOTHING")
|
||||
|
||||
if (not isinstance(set_, dict) or not set_):
|
||||
raise ValueError("set parameter must be a non-empty dictionary")
|
||||
self.update_values_to_set = [
|
||||
(key, value)
|
||||
for key, value in set_.items()
|
||||
]
|
||||
self.update_whereclause = where
|
||||
218
sqlalchemy/dialects/postgresql/ext.py
Normal file
218
sqlalchemy/dialects/postgresql/ext.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# postgresql/ext.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from ...sql import expression
|
||||
from ...sql import elements
|
||||
from ...sql import functions
|
||||
from ...sql.schema import ColumnCollectionConstraint
|
||||
from .array import ARRAY
|
||||
|
||||
|
||||
class aggregate_order_by(expression.ColumnElement):
|
||||
"""Represent a PostgreSQL aggregate order by expression.
|
||||
|
||||
E.g.::
|
||||
|
||||
from sqlalchemy.dialects.postgresql import aggregate_order_by
|
||||
expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc()))
|
||||
stmt = select([expr])
|
||||
|
||||
would represent the expression::
|
||||
|
||||
SELECT array_agg(a ORDER BY b DESC) FROM table;
|
||||
|
||||
Similarly::
|
||||
|
||||
expr = func.string_agg(
|
||||
table.c.a,
|
||||
aggregate_order_by(literal_column("','"), table.c.a)
|
||||
)
|
||||
stmt = select([expr])
|
||||
|
||||
Would represent::
|
||||
|
||||
SELECT string_agg(a, ',' ORDER BY a) FROM table;
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.array_agg`
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'aggregate_order_by'
|
||||
|
||||
def __init__(self, target, order_by):
|
||||
self.target = elements._literal_as_binds(target)
|
||||
self.order_by = elements._literal_as_binds(order_by)
|
||||
|
||||
def self_group(self, against=None):
|
||||
return self
|
||||
|
||||
def get_children(self, **kwargs):
|
||||
return self.target, self.order_by
|
||||
|
||||
def _copy_internals(self, clone=elements._clone, **kw):
|
||||
self.target = clone(self.target, **kw)
|
||||
self.order_by = clone(self.order_by, **kw)
|
||||
|
||||
@property
|
||||
def _from_objects(self):
|
||||
return self.target._from_objects + self.order_by._from_objects
|
||||
|
||||
|
||||
class ExcludeConstraint(ColumnCollectionConstraint):
|
||||
"""A table-level EXCLUDE constraint.
|
||||
|
||||
Defines an EXCLUDE constraint as described in the `postgres
|
||||
documentation`__.
|
||||
|
||||
__ http://www.postgresql.org/docs/9.0/\
|
||||
static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
|
||||
"""
|
||||
|
||||
__visit_name__ = 'exclude_constraint'
|
||||
|
||||
where = None
|
||||
|
||||
def __init__(self, *elements, **kw):
|
||||
r"""
|
||||
Create an :class:`.ExcludeConstraint` object.
|
||||
|
||||
E.g.::
|
||||
|
||||
const = ExcludeConstraint(
|
||||
(Column('period'), '&&'),
|
||||
(Column('group'), '='),
|
||||
where=(Column('group') != 'some group')
|
||||
)
|
||||
|
||||
The constraint is normally embedded into the :class:`.Table` construct
|
||||
directly, or added later using :meth:`.append_constraint`::
|
||||
|
||||
some_table = Table(
|
||||
'some_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('period', TSRANGE()),
|
||||
Column('group', String)
|
||||
)
|
||||
|
||||
some_table.append_constraint(
|
||||
ExcludeConstraint(
|
||||
(some_table.c.period, '&&'),
|
||||
(some_table.c.group, '='),
|
||||
where=some_table.c.group != 'some group',
|
||||
name='some_table_excl_const'
|
||||
)
|
||||
)
|
||||
|
||||
:param \*elements:
|
||||
A sequence of two tuples of the form ``(column, operator)`` where
|
||||
"column" is a SQL expression element or a raw SQL string, most
|
||||
typically a :class:`.Column` object,
|
||||
and "operator" is a string containing the operator to use.
|
||||
|
||||
.. note::
|
||||
|
||||
A plain string passed for the value of "column" is interpreted
|
||||
as an arbitrary SQL expression; when passing a plain string,
|
||||
any necessary quoting and escaping syntaxes must be applied
|
||||
manually. In order to specify a column name when a
|
||||
:class:`.Column` object is not available, while ensuring that
|
||||
any necessary quoting rules take effect, an ad-hoc
|
||||
:class:`.Column` or :func:`.sql.expression.column` object may
|
||||
be used.
|
||||
|
||||
:param name:
|
||||
Optional, the in-database name of this constraint.
|
||||
|
||||
:param deferrable:
|
||||
Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
|
||||
issuing DDL for this constraint.
|
||||
|
||||
:param initially:
|
||||
Optional string. If set, emit INITIALLY <value> when issuing DDL
|
||||
for this constraint.
|
||||
|
||||
:param using:
|
||||
Optional string. If set, emit USING <index_method> when issuing DDL
|
||||
for this constraint. Defaults to 'gist'.
|
||||
|
||||
:param where:
|
||||
Optional SQL expression construct or literal SQL string.
|
||||
If set, emit WHERE <predicate> when issuing DDL
|
||||
for this constraint.
|
||||
|
||||
.. note::
|
||||
|
||||
A plain string passed here is interpreted as an arbitrary SQL
|
||||
expression; when passing a plain string, any necessary quoting
|
||||
and escaping syntaxes must be applied manually.
|
||||
|
||||
"""
|
||||
columns = []
|
||||
render_exprs = []
|
||||
self.operators = {}
|
||||
|
||||
expressions, operators = zip(*elements)
|
||||
|
||||
for (expr, column, strname, add_element), operator in zip(
|
||||
self._extract_col_expression_collection(expressions),
|
||||
operators
|
||||
):
|
||||
if add_element is not None:
|
||||
columns.append(add_element)
|
||||
|
||||
name = column.name if column is not None else strname
|
||||
|
||||
if name is not None:
|
||||
# backwards compat
|
||||
self.operators[name] = operator
|
||||
|
||||
expr = expression._literal_as_text(expr)
|
||||
|
||||
render_exprs.append(
|
||||
(expr, name, operator)
|
||||
)
|
||||
|
||||
self._render_exprs = render_exprs
|
||||
ColumnCollectionConstraint.__init__(
|
||||
self,
|
||||
*columns,
|
||||
name=kw.get('name'),
|
||||
deferrable=kw.get('deferrable'),
|
||||
initially=kw.get('initially')
|
||||
)
|
||||
self.using = kw.get('using', 'gist')
|
||||
where = kw.get('where')
|
||||
if where is not None:
|
||||
self.where = expression._literal_as_text(where)
|
||||
|
||||
def copy(self, **kw):
|
||||
elements = [(col, self.operators[col])
|
||||
for col in self.columns.keys()]
|
||||
c = self.__class__(*elements,
|
||||
name=self.name,
|
||||
deferrable=self.deferrable,
|
||||
initially=self.initially,
|
||||
where=self.where,
|
||||
using=self.using)
|
||||
c.dispatch._update(self.dispatch)
|
||||
return c
|
||||
|
||||
|
||||
def array_agg(*arg, **kw):
|
||||
"""PostgreSQL-specific form of :class:`.array_agg`, ensures
|
||||
return type is :class:`.postgresql.ARRAY` and not
|
||||
the plain :class:`.types.ARRAY`.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
"""
|
||||
kw['type_'] = ARRAY(functions._type_from_args(arg))
|
||||
return functions.func.array_agg(*arg, **kw)
|
||||
420
sqlalchemy/dialects/postgresql/hstore.py
Normal file
420
sqlalchemy/dialects/postgresql/hstore.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# postgresql/hstore.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import re
|
||||
|
||||
from .base import ischema_names
|
||||
from .array import ARRAY
|
||||
from ... import types as sqltypes
|
||||
from ...sql import functions as sqlfunc
|
||||
from ...sql import operators
|
||||
from ... import util
|
||||
|
||||
__all__ = ('HSTORE', 'hstore')
|
||||
|
||||
idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
|
||||
|
||||
GETITEM = operators.custom_op(
|
||||
"->", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
HAS_KEY = operators.custom_op(
|
||||
"?", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
HAS_ALL = operators.custom_op(
|
||||
"?&", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
HAS_ANY = operators.custom_op(
|
||||
"?|", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
CONTAINS = operators.custom_op(
|
||||
"@>", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
CONTAINED_BY = operators.custom_op(
|
||||
"<@", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
|
||||
class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
"""Represent the PostgreSQL HSTORE type.
|
||||
|
||||
The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
|
||||
|
||||
data_table = Table('data_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', HSTORE)
|
||||
)
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
data_table.insert(),
|
||||
data = {"key1": "value1", "key2": "value2"}
|
||||
)
|
||||
|
||||
:class:`.HSTORE` provides for a wide range of operations, including:
|
||||
|
||||
* Index operations::
|
||||
|
||||
data_table.c.data['some key'] == 'some value'
|
||||
|
||||
* Containment operations::
|
||||
|
||||
data_table.c.data.has_key('some key')
|
||||
|
||||
data_table.c.data.has_all(['one', 'two', 'three'])
|
||||
|
||||
* Concatenation::
|
||||
|
||||
data_table.c.data + {"k1": "v1"}
|
||||
|
||||
For a full list of special methods see
|
||||
:class:`.HSTORE.comparator_factory`.
|
||||
|
||||
For usage with the SQLAlchemy ORM, it may be desirable to combine
|
||||
the usage of :class:`.HSTORE` with :class:`.MutableDict` dictionary
|
||||
now part of the :mod:`sqlalchemy.ext.mutable`
|
||||
extension. This extension will allow "in-place" changes to the
|
||||
dictionary, e.g. addition of new keys or replacement/removal of existing
|
||||
keys to/from the current dictionary, to produce events which will be
|
||||
detected by the unit of work::
|
||||
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
|
||||
class MyClass(Base):
|
||||
__tablename__ = 'data_table'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(MutableDict.as_mutable(HSTORE))
|
||||
|
||||
my_object = session.query(MyClass).one()
|
||||
|
||||
# in-place mutation, requires Mutable extension
|
||||
# in order for the ORM to detect
|
||||
my_object.data['some_key'] = 'some value'
|
||||
|
||||
session.commit()
|
||||
|
||||
When the :mod:`sqlalchemy.ext.mutable` extension is not used, the ORM
|
||||
will not be alerted to any changes to the contents of an existing
|
||||
dictionary, unless that dictionary value is re-assigned to the
|
||||
HSTORE-attribute itself, thus generating a change event.
|
||||
|
||||
.. versionadded:: 0.8
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.hstore` - render the PostgreSQL ``hstore()`` function.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'HSTORE'
|
||||
hashable = False
|
||||
text_type = sqltypes.Text()
|
||||
|
||||
def __init__(self, text_type=None):
|
||||
"""Construct a new :class:`.HSTORE`.
|
||||
|
||||
:param text_type: the type that should be used for indexed values.
|
||||
Defaults to :class:`.types.Text`.
|
||||
|
||||
.. versionadded:: 1.1.0
|
||||
|
||||
"""
|
||||
if text_type is not None:
|
||||
self.text_type = text_type
|
||||
|
||||
class Comparator(
|
||||
sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator):
|
||||
"""Define comparison operations for :class:`.HSTORE`."""
|
||||
|
||||
def has_key(self, other):
|
||||
"""Boolean expression. Test for presence of a key. Note that the
|
||||
key may be a SQLA expression.
|
||||
"""
|
||||
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_all(self, other):
|
||||
"""Boolean expression. Test for presence of all keys in jsonb
|
||||
"""
|
||||
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_any(self, other):
|
||||
"""Boolean expression. Test for presence of any key in jsonb
|
||||
"""
|
||||
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contains(self, other, **kwargs):
|
||||
"""Boolean expression. Test if keys (or array) are a superset
|
||||
of/contained the keys of the argument jsonb expression.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other):
|
||||
"""Boolean expression. Test if keys are a proper subset of the
|
||||
keys of the argument jsonb expression.
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def _setup_getitem(self, index):
|
||||
return GETITEM, index, self.type.text_type
|
||||
|
||||
def defined(self, key):
|
||||
"""Boolean expression. Test for presence of a non-NULL value for
|
||||
the key. Note that the key may be a SQLA expression.
|
||||
"""
|
||||
return _HStoreDefinedFunction(self.expr, key)
|
||||
|
||||
def delete(self, key):
|
||||
"""HStore expression. Returns the contents of this hstore with the
|
||||
given key deleted. Note that the key may be a SQLA expression.
|
||||
"""
|
||||
if isinstance(key, dict):
|
||||
key = _serialize_hstore(key)
|
||||
return _HStoreDeleteFunction(self.expr, key)
|
||||
|
||||
def slice(self, array):
|
||||
"""HStore expression. Returns a subset of an hstore defined by
|
||||
array of keys.
|
||||
"""
|
||||
return _HStoreSliceFunction(self.expr, array)
|
||||
|
||||
def keys(self):
|
||||
"""Text array expression. Returns array of keys."""
|
||||
return _HStoreKeysFunction(self.expr)
|
||||
|
||||
def vals(self):
|
||||
"""Text array expression. Returns array of values."""
|
||||
return _HStoreValsFunction(self.expr)
|
||||
|
||||
def array(self):
|
||||
"""Text array expression. Returns array of alternating keys and
|
||||
values.
|
||||
"""
|
||||
return _HStoreArrayFunction(self.expr)
|
||||
|
||||
def matrix(self):
|
||||
"""Text array expression. Returns array of [key, value] pairs."""
|
||||
return _HStoreMatrixFunction(self.expr)
|
||||
|
||||
comparator_factory = Comparator
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
if util.py2k:
|
||||
encoding = dialect.encoding
|
||||
|
||||
def process(value):
|
||||
if isinstance(value, dict):
|
||||
return _serialize_hstore(value).encode(encoding)
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
def process(value):
|
||||
if isinstance(value, dict):
|
||||
return _serialize_hstore(value)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if util.py2k:
|
||||
encoding = dialect.encoding
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return _parse_hstore(value.decode(encoding))
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return _parse_hstore(value)
|
||||
else:
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
ischema_names['hstore'] = HSTORE
|
||||
|
||||
|
||||
class hstore(sqlfunc.GenericFunction):
|
||||
"""Construct an hstore value within a SQL expression using the
|
||||
PostgreSQL ``hstore()`` function.
|
||||
|
||||
The :class:`.hstore` function accepts one or two arguments as described
|
||||
in the PostgreSQL documentation.
|
||||
|
||||
E.g.::
|
||||
|
||||
from sqlalchemy.dialects.postgresql import array, hstore
|
||||
|
||||
select([hstore('key1', 'value1')])
|
||||
|
||||
select([
|
||||
hstore(
|
||||
array(['key1', 'key2', 'key3']),
|
||||
array(['value1', 'value2', 'value3'])
|
||||
)
|
||||
])
|
||||
|
||||
.. versionadded:: 0.8
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype.
|
||||
|
||||
"""
|
||||
type = HSTORE
|
||||
name = 'hstore'
|
||||
|
||||
|
||||
class _HStoreDefinedFunction(sqlfunc.GenericFunction):
|
||||
type = sqltypes.Boolean
|
||||
name = 'defined'
|
||||
|
||||
|
||||
class _HStoreDeleteFunction(sqlfunc.GenericFunction):
|
||||
type = HSTORE
|
||||
name = 'delete'
|
||||
|
||||
|
||||
class _HStoreSliceFunction(sqlfunc.GenericFunction):
|
||||
type = HSTORE
|
||||
name = 'slice'
|
||||
|
||||
|
||||
class _HStoreKeysFunction(sqlfunc.GenericFunction):
|
||||
type = ARRAY(sqltypes.Text)
|
||||
name = 'akeys'
|
||||
|
||||
|
||||
class _HStoreValsFunction(sqlfunc.GenericFunction):
|
||||
type = ARRAY(sqltypes.Text)
|
||||
name = 'avals'
|
||||
|
||||
|
||||
class _HStoreArrayFunction(sqlfunc.GenericFunction):
|
||||
type = ARRAY(sqltypes.Text)
|
||||
name = 'hstore_to_array'
|
||||
|
||||
|
||||
class _HStoreMatrixFunction(sqlfunc.GenericFunction):
|
||||
type = ARRAY(sqltypes.Text)
|
||||
name = 'hstore_to_matrix'
|
||||
|
||||
|
||||
#
|
||||
# parsing. note that none of this is used with the psycopg2 backend,
|
||||
# which provides its own native extensions.
|
||||
#
|
||||
|
||||
# My best guess at the parsing rules of hstore literals, since no formal
|
||||
# grammar is given. This is mostly reverse engineered from PG's input parser
|
||||
# behavior.
|
||||
HSTORE_PAIR_RE = re.compile(r"""
|
||||
(
|
||||
"(?P<key> (\\ . | [^"])* )" # Quoted key
|
||||
)
|
||||
[ ]* => [ ]* # Pair operator, optional adjoining whitespace
|
||||
(
|
||||
(?P<value_null> NULL ) # NULL value
|
||||
| "(?P<value> (\\ . | [^"])* )" # Quoted value
|
||||
)
|
||||
""", re.VERBOSE)
|
||||
|
||||
HSTORE_DELIMITER_RE = re.compile(r"""
|
||||
[ ]* , [ ]*
|
||||
""", re.VERBOSE)
|
||||
|
||||
|
||||
def _parse_error(hstore_str, pos):
|
||||
"""format an unmarshalling error."""
|
||||
|
||||
ctx = 20
|
||||
hslen = len(hstore_str)
|
||||
|
||||
parsed_tail = hstore_str[max(pos - ctx - 1, 0):min(pos, hslen)]
|
||||
residual = hstore_str[min(pos, hslen):min(pos + ctx + 1, hslen)]
|
||||
|
||||
if len(parsed_tail) > ctx:
|
||||
parsed_tail = '[...]' + parsed_tail[1:]
|
||||
if len(residual) > ctx:
|
||||
residual = residual[:-1] + '[...]'
|
||||
|
||||
return "After %r, could not parse residual at position %d: %r" % (
|
||||
parsed_tail, pos, residual)
|
||||
|
||||
|
||||
def _parse_hstore(hstore_str):
|
||||
"""Parse an hstore from its literal string representation.
|
||||
|
||||
Attempts to approximate PG's hstore input parsing rules as closely as
|
||||
possible. Although currently this is not strictly necessary, since the
|
||||
current implementation of hstore's output syntax is stricter than what it
|
||||
accepts as input, the documentation makes no guarantees that will always
|
||||
be the case.
|
||||
|
||||
|
||||
|
||||
"""
|
||||
result = {}
|
||||
pos = 0
|
||||
pair_match = HSTORE_PAIR_RE.match(hstore_str)
|
||||
|
||||
while pair_match is not None:
|
||||
key = pair_match.group('key').replace(r'\"', '"').replace(
|
||||
"\\\\", "\\")
|
||||
if pair_match.group('value_null'):
|
||||
value = None
|
||||
else:
|
||||
value = pair_match.group('value').replace(
|
||||
r'\"', '"').replace("\\\\", "\\")
|
||||
result[key] = value
|
||||
|
||||
pos += pair_match.end()
|
||||
|
||||
delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:])
|
||||
if delim_match is not None:
|
||||
pos += delim_match.end()
|
||||
|
||||
pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:])
|
||||
|
||||
if pos != len(hstore_str):
|
||||
raise ValueError(_parse_error(hstore_str, pos))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _serialize_hstore(val):
|
||||
"""Serialize a dictionary into an hstore literal. Keys and values must
|
||||
both be strings (except None for values).
|
||||
|
||||
"""
|
||||
def esc(s, position):
|
||||
if position == 'value' and s is None:
|
||||
return 'NULL'
|
||||
elif isinstance(s, util.string_types):
|
||||
return '"%s"' % s.replace("\\", "\\\\").replace('"', r'\"')
|
||||
else:
|
||||
raise ValueError("%r in %s position is not a string." %
|
||||
(s, position))
|
||||
|
||||
return ', '.join('%s=>%s' % (esc(k, 'key'), esc(v, 'value'))
|
||||
for k, v in val.items())
|
||||
|
||||
|
||||
301
sqlalchemy/dialects/postgresql/json.py
Normal file
301
sqlalchemy/dialects/postgresql/json.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# postgresql/json.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import absolute_import
|
||||
|
||||
import json
|
||||
import collections
|
||||
|
||||
from .base import ischema_names, colspecs
|
||||
from ... import types as sqltypes
|
||||
from ...sql import operators
|
||||
from ...sql import elements
|
||||
from ... import util
|
||||
|
||||
__all__ = ('JSON', 'JSONB')
|
||||
|
||||
idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
|
||||
|
||||
ASTEXT = operators.custom_op(
|
||||
"->>", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
JSONPATH_ASTEXT = operators.custom_op(
|
||||
"#>>", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
|
||||
HAS_KEY = operators.custom_op(
|
||||
"?", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
HAS_ALL = operators.custom_op(
|
||||
"?&", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
HAS_ANY = operators.custom_op(
|
||||
"?|", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
CONTAINS = operators.custom_op(
|
||||
"@>", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
CONTAINED_BY = operators.custom_op(
|
||||
"<@", precedence=idx_precedence, natural_self_precedent=True,
|
||||
eager_grouping=True
|
||||
)
|
||||
|
||||
|
||||
class JSONPathType(sqltypes.JSON.JSONPathType):
|
||||
def bind_processor(self, dialect):
|
||||
super_proc = self.string_bind_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
assert isinstance(value, collections.Sequence)
|
||||
tokens = [util.text_type(elem)for elem in value]
|
||||
value = "{%s}" % (", ".join(tokens))
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
super_proc = self.string_literal_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
assert isinstance(value, collections.Sequence)
|
||||
tokens = [util.text_type(elem)for elem in value]
|
||||
value = "{%s}" % (", ".join(tokens))
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
colspecs[sqltypes.JSON.JSONPathType] = JSONPathType
|
||||
|
||||
|
||||
class JSON(sqltypes.JSON):
|
||||
"""Represent the PostgreSQL JSON type.
|
||||
|
||||
This type is a specialization of the Core-level :class:`.types.JSON`
|
||||
type. Be sure to read the documentation for :class:`.types.JSON` for
|
||||
important tips regarding treatment of NULL values and ORM use.
|
||||
|
||||
.. versionchanged:: 1.1 :class:`.postgresql.JSON` is now a PostgreSQL-
|
||||
specific specialization of the new :class:`.types.JSON` type.
|
||||
|
||||
The operators provided by the PostgreSQL version of :class:`.JSON`
|
||||
include:
|
||||
|
||||
* Index operations (the ``->`` operator)::
|
||||
|
||||
data_table.c.data['some key']
|
||||
|
||||
data_table.c.data[5]
|
||||
|
||||
|
||||
* Index operations returning text (the ``->>`` operator)::
|
||||
|
||||
data_table.c.data['some key'].astext == 'some value'
|
||||
|
||||
* Index operations with CAST
|
||||
(equivalent to ``CAST(col ->> ['some key'] AS <type>)``)::
|
||||
|
||||
data_table.c.data['some key'].astext.cast(Integer) == 5
|
||||
|
||||
* Path index operations (the ``#>`` operator)::
|
||||
|
||||
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
|
||||
|
||||
* Path index operations returning text (the ``#>>`` operator)::
|
||||
|
||||
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == \
|
||||
'some value'
|
||||
|
||||
.. versionchanged:: 1.1 The :meth:`.ColumnElement.cast` operator on
|
||||
JSON objects now requires that the :attr:`.JSON.Comparator.astext`
|
||||
modifier be called explicitly, if the cast works only from a textual
|
||||
string.
|
||||
|
||||
Index operations return an expression object whose type defaults to
|
||||
:class:`.JSON` by default, so that further JSON-oriented instructions
|
||||
may be called upon the result type.
|
||||
|
||||
Custom serializers and deserializers are specified at the dialect level,
|
||||
that is using :func:`.create_engine`. The reason for this is that when
|
||||
using psycopg2, the DBAPI only allows serializers at the per-cursor
|
||||
or per-connection level. E.g.::
|
||||
|
||||
engine = create_engine("postgresql://scott:tiger@localhost/test",
|
||||
json_serializer=my_serialize_fn,
|
||||
json_deserializer=my_deserialize_fn
|
||||
)
|
||||
|
||||
When using the psycopg2 dialect, the json_deserializer is registered
|
||||
against the database using ``psycopg2.extras.register_default_json``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.types.JSON` - Core level JSON type
|
||||
|
||||
:class:`.JSONB`
|
||||
|
||||
"""
|
||||
|
||||
astext_type = sqltypes.Text()
|
||||
|
||||
def __init__(self, none_as_null=False, astext_type=None):
|
||||
"""Construct a :class:`.JSON` type.
|
||||
|
||||
:param none_as_null: if True, persist the value ``None`` as a
|
||||
SQL NULL value, not the JSON encoding of ``null``. Note that
|
||||
when this flag is False, the :func:`.null` construct can still
|
||||
be used to persist a NULL value::
|
||||
|
||||
from sqlalchemy import null
|
||||
conn.execute(table.insert(), data=null())
|
||||
|
||||
.. versionchanged:: 0.9.8 - Added ``none_as_null``, and :func:`.null`
|
||||
is now supported in order to persist a NULL value.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:attr:`.JSON.NULL`
|
||||
|
||||
:param astext_type: the type to use for the
|
||||
:attr:`.JSON.Comparator.astext`
|
||||
accessor on indexed attributes. Defaults to :class:`.types.Text`.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
"""
|
||||
super(JSON, self).__init__(none_as_null=none_as_null)
|
||||
if astext_type is not None:
|
||||
self.astext_type = astext_type
|
||||
|
||||
class Comparator(sqltypes.JSON.Comparator):
|
||||
"""Define comparison operations for :class:`.JSON`."""
|
||||
|
||||
@property
|
||||
def astext(self):
|
||||
"""On an indexed expression, use the "astext" (e.g. "->>")
|
||||
conversion when rendered in SQL.
|
||||
|
||||
E.g.::
|
||||
|
||||
select([data_table.c.data['some key'].astext])
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`.ColumnElement.cast`
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
|
||||
return self.expr.left.operate(
|
||||
JSONPATH_ASTEXT,
|
||||
self.expr.right, result_type=self.type.astext_type)
|
||||
else:
|
||||
return self.expr.left.operate(
|
||||
ASTEXT, self.expr.right, result_type=self.type.astext_type)
|
||||
|
||||
comparator_factory = Comparator
|
||||
|
||||
|
||||
colspecs[sqltypes.JSON] = JSON
|
||||
ischema_names['json'] = JSON
|
||||
|
||||
|
||||
class JSONB(JSON):
|
||||
"""Represent the PostgreSQL JSONB type.
|
||||
|
||||
The :class:`.JSONB` type stores arbitrary JSONB format data, e.g.::
|
||||
|
||||
data_table = Table('data_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', JSONB)
|
||||
)
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
data_table.insert(),
|
||||
data = {"key1": "value1", "key2": "value2"}
|
||||
)
|
||||
|
||||
The :class:`.JSONB` type includes all operations provided by
|
||||
:class:`.JSON`, including the same behaviors for indexing operations.
|
||||
It also adds additional operators specific to JSONB, including
|
||||
:meth:`.JSONB.Comparator.has_key`, :meth:`.JSONB.Comparator.has_all`,
|
||||
:meth:`.JSONB.Comparator.has_any`, :meth:`.JSONB.Comparator.contains`,
|
||||
and :meth:`.JSONB.Comparator.contained_by`.
|
||||
|
||||
Like the :class:`.JSON` type, the :class:`.JSONB` type does not detect
|
||||
in-place changes when used with the ORM, unless the
|
||||
:mod:`sqlalchemy.ext.mutable` extension is used.
|
||||
|
||||
Custom serializers and deserializers
|
||||
are shared with the :class:`.JSON` class, using the ``json_serializer``
|
||||
and ``json_deserializer`` keyword arguments. These must be specified
|
||||
at the dialect level using :func:`.create_engine`. When using
|
||||
psycopg2, the serializers are associated with the jsonb type using
|
||||
``psycopg2.extras.register_default_jsonb`` on a per-connection basis,
|
||||
in the same way that ``psycopg2.extras.register_default_json`` is used
|
||||
to register these handlers with the json type.
|
||||
|
||||
.. versionadded:: 0.9.7
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.JSON`
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'JSONB'
|
||||
|
||||
class Comparator(JSON.Comparator):
|
||||
"""Define comparison operations for :class:`.JSON`."""
|
||||
|
||||
def has_key(self, other):
|
||||
"""Boolean expression. Test for presence of a key. Note that the
|
||||
key may be a SQLA expression.
|
||||
"""
|
||||
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_all(self, other):
|
||||
"""Boolean expression. Test for presence of all keys in jsonb
|
||||
"""
|
||||
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_any(self, other):
|
||||
"""Boolean expression. Test for presence of any key in jsonb
|
||||
"""
|
||||
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contains(self, other, **kwargs):
|
||||
"""Boolean expression. Test if keys (or array) are a superset
|
||||
of/contained the keys of the argument jsonb expression.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other):
|
||||
"""Boolean expression. Test if keys are a proper subset of the
|
||||
keys of the argument jsonb expression.
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
comparator_factory = Comparator
|
||||
|
||||
ischema_names['jsonb'] = JSONB
|
||||
61
sqlalchemy/dialects/postgresql/psycopg2cffi.py
Normal file
61
sqlalchemy/dialects/postgresql/psycopg2cffi.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# testing/engines.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""
|
||||
.. dialect:: postgresql+psycopg2cffi
|
||||
:name: psycopg2cffi
|
||||
:dbapi: psycopg2cffi
|
||||
:connectstring: \
|
||||
postgresql+psycopg2cffi://user:password@host:port/dbname\
|
||||
[?key=value&key=value...]
|
||||
:url: http://pypi.python.org/pypi/psycopg2cffi/
|
||||
|
||||
``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C
|
||||
layer. This makes it suitable for use in e.g. PyPy. Documentation
|
||||
is as per ``psycopg2``.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
|
||||
.. seealso::
|
||||
|
||||
:mod:`sqlalchemy.dialects.postgresql.psycopg2`
|
||||
|
||||
"""
|
||||
from .psycopg2 import PGDialect_psycopg2
|
||||
|
||||
|
||||
class PGDialect_psycopg2cffi(PGDialect_psycopg2):
|
||||
driver = 'psycopg2cffi'
|
||||
supports_unicode_statements = True
|
||||
|
||||
# psycopg2cffi's first release is 2.5.0, but reports
|
||||
# __version__ as 2.4.4. Subsequent releases seem to have
|
||||
# fixed this.
|
||||
|
||||
FEATURE_VERSION_MAP = dict(
|
||||
native_json=(2, 4, 4),
|
||||
native_jsonb=(2, 7, 1),
|
||||
sane_multi_rowcount=(2, 4, 4),
|
||||
array_oid=(2, 4, 4),
|
||||
hstore_adapter=(2, 4, 4)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
return __import__('psycopg2cffi')
|
||||
|
||||
@classmethod
|
||||
def _psycopg2_extensions(cls):
|
||||
root = __import__('psycopg2cffi', fromlist=['extensions'])
|
||||
return root.extensions
|
||||
|
||||
@classmethod
|
||||
def _psycopg2_extras(cls):
|
||||
root = __import__('psycopg2cffi', fromlist=['extras'])
|
||||
return root.extras
|
||||
|
||||
|
||||
dialect = PGDialect_psycopg2cffi
|
||||
243
sqlalchemy/dialects/postgresql/pygresql.py
Normal file
243
sqlalchemy/dialects/postgresql/pygresql.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# postgresql/pygresql.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: postgresql+pygresql
|
||||
:name: pygresql
|
||||
:dbapi: pgdb
|
||||
:connectstring: postgresql+pygresql://user:password@host:port/dbname\
|
||||
[?key=value&key=value...]
|
||||
:url: http://www.pygresql.org/
|
||||
"""
|
||||
|
||||
import decimal
|
||||
import re
|
||||
|
||||
from ... import exc, processors, util
|
||||
from ...types import Numeric, JSON as Json
|
||||
from ...sql.elements import Null
|
||||
from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \
|
||||
_DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID
|
||||
from .hstore import HSTORE
|
||||
from .json import JSON, JSONB
|
||||
|
||||
|
||||
class _PGNumeric(Numeric):
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if not isinstance(coltype, int):
|
||||
coltype = coltype.oid
|
||||
if self.asdecimal:
|
||||
if coltype in _FLOAT_TYPES:
|
||||
return processors.to_decimal_processor_factory(
|
||||
decimal.Decimal,
|
||||
self._effective_decimal_return_scale)
|
||||
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
|
||||
# PyGreSQL returns Decimal natively for 1700 (numeric)
|
||||
return None
|
||||
else:
|
||||
raise exc.InvalidRequestError(
|
||||
"Unknown PG numeric type: %d" % coltype)
|
||||
else:
|
||||
if coltype in _FLOAT_TYPES:
|
||||
# PyGreSQL returns float natively for 701 (float8)
|
||||
return None
|
||||
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
|
||||
return processors.to_float
|
||||
else:
|
||||
raise exc.InvalidRequestError(
|
||||
"Unknown PG numeric type: %d" % coltype)
|
||||
|
||||
|
||||
class _PGHStore(HSTORE):
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
if not dialect.has_native_hstore:
|
||||
return super(_PGHStore, self).bind_processor(dialect)
|
||||
hstore = dialect.dbapi.Hstore
|
||||
def process(value):
|
||||
if isinstance(value, dict):
|
||||
return hstore(value)
|
||||
return value
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if not dialect.has_native_hstore:
|
||||
return super(_PGHStore, self).result_processor(dialect, coltype)
|
||||
|
||||
|
||||
class _PGJSON(JSON):
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
if not dialect.has_native_json:
|
||||
return super(_PGJSON, self).bind_processor(dialect)
|
||||
json = dialect.dbapi.Json
|
||||
|
||||
def process(value):
|
||||
if value is self.NULL:
|
||||
value = None
|
||||
elif isinstance(value, Null) or (
|
||||
value is None and self.none_as_null):
|
||||
return None
|
||||
if value is None or isinstance(value, (dict, list)):
|
||||
return json(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if not dialect.has_native_json:
|
||||
return super(_PGJSON, self).result_processor(dialect, coltype)
|
||||
|
||||
|
||||
class _PGJSONB(JSONB):
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
if not dialect.has_native_json:
|
||||
return super(_PGJSONB, self).bind_processor(dialect)
|
||||
json = dialect.dbapi.Json
|
||||
|
||||
def process(value):
|
||||
if value is self.NULL:
|
||||
value = None
|
||||
elif isinstance(value, Null) or (
|
||||
value is None and self.none_as_null):
|
||||
return None
|
||||
if value is None or isinstance(value, (dict, list)):
|
||||
return json(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if not dialect.has_native_json:
|
||||
return super(_PGJSONB, self).result_processor(dialect, coltype)
|
||||
|
||||
|
||||
class _PGUUID(UUID):
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
if not dialect.has_native_uuid:
|
||||
return super(_PGUUID, self).bind_processor(dialect)
|
||||
uuid = dialect.dbapi.Uuid
|
||||
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (str, bytes)):
|
||||
if len(value) == 16:
|
||||
return uuid(bytes=value)
|
||||
return uuid(value)
|
||||
if isinstance(value, int):
|
||||
return uuid(int=value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if not dialect.has_native_uuid:
|
||||
return super(_PGUUID, self).result_processor(dialect, coltype)
|
||||
if not self.as_uuid:
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return str(value)
|
||||
return process
|
||||
|
||||
|
||||
class _PGCompiler(PGCompiler):
|
||||
|
||||
def visit_mod_binary(self, binary, operator, **kw):
|
||||
return self.process(binary.left, **kw) + " %% " + \
|
||||
self.process(binary.right, **kw)
|
||||
|
||||
def post_process_text(self, text):
|
||||
return text.replace('%', '%%')
|
||||
|
||||
|
||||
class _PGIdentifierPreparer(PGIdentifierPreparer):
|
||||
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value.replace('%', '%%')
|
||||
|
||||
|
||||
class PGDialect_pygresql(PGDialect):
|
||||
|
||||
driver = 'pygresql'
|
||||
|
||||
statement_compiler = _PGCompiler
|
||||
preparer = _PGIdentifierPreparer
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
import pgdb
|
||||
return pgdb
|
||||
|
||||
colspecs = util.update_copy(
|
||||
PGDialect.colspecs,
|
||||
{
|
||||
Numeric: _PGNumeric,
|
||||
HSTORE: _PGHStore,
|
||||
Json: _PGJSON,
|
||||
JSON: _PGJSON,
|
||||
JSONB: _PGJSONB,
|
||||
UUID: _PGUUID,
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(PGDialect_pygresql, self).__init__(**kwargs)
|
||||
try:
|
||||
version = self.dbapi.version
|
||||
m = re.match(r'(\d+)\.(\d+)', version)
|
||||
version = (int(m.group(1)), int(m.group(2)))
|
||||
except (AttributeError, ValueError, TypeError):
|
||||
version = (0, 0)
|
||||
self.dbapi_version = version
|
||||
if version < (5, 0):
|
||||
has_native_hstore = has_native_json = has_native_uuid = False
|
||||
if version != (0, 0):
|
||||
util.warn("PyGreSQL is only fully supported by SQLAlchemy"
|
||||
" since version 5.0.")
|
||||
else:
|
||||
self.supports_unicode_statements = True
|
||||
self.supports_unicode_binds = True
|
||||
has_native_hstore = has_native_json = has_native_uuid = True
|
||||
self.has_native_hstore = has_native_hstore
|
||||
self.has_native_json = has_native_json
|
||||
self.has_native_uuid = has_native_uuid
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username='user')
|
||||
if 'port' in opts:
|
||||
opts['host'] = '%s:%s' % (
|
||||
opts.get('host', '').rsplit(':', 1)[0], opts.pop('port'))
|
||||
opts.update(url.query)
|
||||
return [], opts
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.Error):
|
||||
if not connection:
|
||||
return False
|
||||
try:
|
||||
connection = connection.connection
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
if not connection:
|
||||
return False
|
||||
try:
|
||||
return connection.closed
|
||||
except AttributeError: # PyGreSQL < 5.0
|
||||
return connection._cnx is None
|
||||
return False
|
||||
|
||||
|
||||
dialect = PGDialect_pygresql
|
||||
168
sqlalchemy/dialects/postgresql/ranges.py
Normal file
168
sqlalchemy/dialects/postgresql/ranges.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Copyright (C) 2013-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .base import ischema_names
|
||||
from ... import types as sqltypes
|
||||
|
||||
__all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE')
|
||||
|
||||
|
||||
class RangeOperators(object):
|
||||
"""
|
||||
This mixin provides functionality for the Range Operators
|
||||
listed in Table 9-44 of the `postgres documentation`__ for Range
|
||||
Functions and Operators. It is used by all the range types
|
||||
provided in the ``postgres`` dialect and can likely be used for
|
||||
any range types you create yourself.
|
||||
|
||||
__ http://www.postgresql.org/docs/devel/static/functions-range.html
|
||||
|
||||
No extra support is provided for the Range Functions listed in
|
||||
Table 9-45 of the postgres documentation. For these, the normal
|
||||
:func:`~sqlalchemy.sql.expression.func` object should be used.
|
||||
|
||||
.. versionadded:: 0.8.2 Support for PostgreSQL RANGE operations.
|
||||
|
||||
"""
|
||||
|
||||
class comparator_factory(sqltypes.Concatenable.Comparator):
|
||||
"""Define comparison operations for range types."""
|
||||
|
||||
def __ne__(self, other):
|
||||
"Boolean expression. Returns true if two ranges are not equal"
|
||||
return self.expr.op('<>')(other)
|
||||
|
||||
def contains(self, other, **kw):
|
||||
"""Boolean expression. Returns true if the right hand operand,
|
||||
which can be an element or a range, is contained within the
|
||||
column.
|
||||
"""
|
||||
return self.expr.op('@>')(other)
|
||||
|
||||
def contained_by(self, other):
|
||||
"""Boolean expression. Returns true if the column is contained
|
||||
within the right hand operand.
|
||||
"""
|
||||
return self.expr.op('<@')(other)
|
||||
|
||||
def overlaps(self, other):
|
||||
"""Boolean expression. Returns true if the column overlaps
|
||||
(has points in common with) the right hand operand.
|
||||
"""
|
||||
return self.expr.op('&&')(other)
|
||||
|
||||
def strictly_left_of(self, other):
|
||||
"""Boolean expression. Returns true if the column is strictly
|
||||
left of the right hand operand.
|
||||
"""
|
||||
return self.expr.op('<<')(other)
|
||||
|
||||
__lshift__ = strictly_left_of
|
||||
|
||||
def strictly_right_of(self, other):
|
||||
"""Boolean expression. Returns true if the column is strictly
|
||||
right of the right hand operand.
|
||||
"""
|
||||
return self.expr.op('>>')(other)
|
||||
|
||||
__rshift__ = strictly_right_of
|
||||
|
||||
def not_extend_right_of(self, other):
|
||||
"""Boolean expression. Returns true if the range in the column
|
||||
does not extend right of the range in the operand.
|
||||
"""
|
||||
return self.expr.op('&<')(other)
|
||||
|
||||
def not_extend_left_of(self, other):
|
||||
"""Boolean expression. Returns true if the range in the column
|
||||
does not extend left of the range in the operand.
|
||||
"""
|
||||
return self.expr.op('&>')(other)
|
||||
|
||||
def adjacent_to(self, other):
|
||||
"""Boolean expression. Returns true if the range in the column
|
||||
is adjacent to the range in the operand.
|
||||
"""
|
||||
return self.expr.op('-|-')(other)
|
||||
|
||||
def __add__(self, other):
|
||||
"""Range expression. Returns the union of the two ranges.
|
||||
Will raise an exception if the resulting range is not
|
||||
contigous.
|
||||
"""
|
||||
return self.expr.op('+')(other)
|
||||
|
||||
|
||||
class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
|
||||
"""Represent the PostgreSQL INT4RANGE type.
|
||||
|
||||
.. versionadded:: 0.8.2
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'INT4RANGE'
|
||||
|
||||
ischema_names['int4range'] = INT4RANGE
|
||||
|
||||
|
||||
class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
|
||||
"""Represent the PostgreSQL INT8RANGE type.
|
||||
|
||||
.. versionadded:: 0.8.2
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'INT8RANGE'
|
||||
|
||||
ischema_names['int8range'] = INT8RANGE
|
||||
|
||||
|
||||
class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
|
||||
"""Represent the PostgreSQL NUMRANGE type.
|
||||
|
||||
.. versionadded:: 0.8.2
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'NUMRANGE'
|
||||
|
||||
ischema_names['numrange'] = NUMRANGE
|
||||
|
||||
|
||||
class DATERANGE(RangeOperators, sqltypes.TypeEngine):
|
||||
"""Represent the PostgreSQL DATERANGE type.
|
||||
|
||||
.. versionadded:: 0.8.2
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'DATERANGE'
|
||||
|
||||
ischema_names['daterange'] = DATERANGE
|
||||
|
||||
|
||||
class TSRANGE(RangeOperators, sqltypes.TypeEngine):
|
||||
"""Represent the PostgreSQL TSRANGE type.
|
||||
|
||||
.. versionadded:: 0.8.2
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'TSRANGE'
|
||||
|
||||
ischema_names['tsrange'] = TSRANGE
|
||||
|
||||
|
||||
class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
|
||||
"""Represent the PostgreSQL TSTZRANGE type.
|
||||
|
||||
.. versionadded:: 0.8.2
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'TSTZRANGE'
|
||||
|
||||
ischema_names['tstzrange'] = TSTZRANGE
|
||||
@@ -1,14 +1,20 @@
|
||||
from sqlalchemy.dialects.sqlite import base, pysqlite
|
||||
# sqlite/__init__.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from sqlalchemy.dialects.sqlite import base, pysqlite, pysqlcipher
|
||||
|
||||
# default dialect
|
||||
base.dialect = pysqlite.dialect
|
||||
|
||||
from sqlalchemy.dialects.sqlite.base import (
|
||||
BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, FLOAT, INTEGER, REAL,
|
||||
NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, dialect,
|
||||
)
|
||||
|
||||
from sqlalchemy.dialects.sqlite.base import \
|
||||
BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, FLOAT, INTEGER,\
|
||||
NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, dialect
|
||||
|
||||
__all__ = (
|
||||
'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'FLOAT', 'INTEGER',
|
||||
'NUMERIC', 'SMALLINT', 'TEXT', 'TIME', 'TIMESTAMP', 'VARCHAR', 'dialect'
|
||||
)
|
||||
__all__ = ('BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL',
|
||||
'FLOAT', 'INTEGER', 'NUMERIC', 'SMALLINT', 'TEXT', 'TIME',
|
||||
'TIMESTAMP', 'VARCHAR', 'REAL', 'dialect')
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
130
sqlalchemy/dialects/sqlite/pysqlcipher.py
Normal file
130
sqlalchemy/dialects/sqlite/pysqlcipher.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# sqlite/pysqlcipher.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
.. dialect:: sqlite+pysqlcipher
|
||||
:name: pysqlcipher
|
||||
:dbapi: pysqlcipher
|
||||
:connectstring: sqlite+pysqlcipher://:passphrase/file_path[?kdf_iter=<iter>]
|
||||
:url: https://pypi.python.org/pypi/pysqlcipher
|
||||
|
||||
``pysqlcipher`` is a fork of the standard ``pysqlite`` driver to make
|
||||
use of the `SQLCipher <https://www.zetetic.net/sqlcipher>`_ backend.
|
||||
|
||||
``pysqlcipher3`` is a fork of ``pysqlcipher`` for Python 3. This dialect
|
||||
will attempt to import it if ``pysqlcipher`` is non-present.
|
||||
|
||||
.. versionadded:: 1.1.4 - added fallback import for pysqlcipher3
|
||||
|
||||
.. versionadded:: 0.9.9 - added pysqlcipher dialect
|
||||
|
||||
Driver
|
||||
------
|
||||
|
||||
The driver here is the `pysqlcipher <https://pypi.python.org/pypi/pysqlcipher>`_
|
||||
driver, which makes use of the SQLCipher engine. This system essentially
|
||||
introduces new PRAGMA commands to SQLite which allows the setting of a
|
||||
passphrase and other encryption parameters, allowing the database
|
||||
file to be encrypted.
|
||||
|
||||
`pysqlcipher3` is a fork of `pysqlcipher` with support for Python 3,
|
||||
the driver is the same.
|
||||
|
||||
Connect Strings
|
||||
---------------
|
||||
|
||||
The format of the connect string is in every way the same as that
|
||||
of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the
|
||||
"password" field is now accepted, which should contain a passphrase::
|
||||
|
||||
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db')
|
||||
|
||||
For an absolute file path, two leading slashes should be used for the
|
||||
database name::
|
||||
|
||||
e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db')
|
||||
|
||||
A selection of additional encryption-related pragmas supported by SQLCipher
|
||||
as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed
|
||||
in the query string, and will result in that PRAGMA being called for each
|
||||
new connection. Currently, ``cipher``, ``kdf_iter``
|
||||
``cipher_page_size`` and ``cipher_use_hmac`` are supported::
|
||||
|
||||
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000')
|
||||
|
||||
|
||||
Pooling Behavior
|
||||
----------------
|
||||
|
||||
The driver makes a change to the default pool behavior of pysqlite
|
||||
as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver
|
||||
has been observed to be significantly slower on connection than the
|
||||
pysqlite driver, most likely due to the encryption overhead, so the
|
||||
dialect here defaults to using the :class:`.SingletonThreadPool`
|
||||
implementation,
|
||||
instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool
|
||||
implementation is entirely configurable using the
|
||||
:paramref:`.create_engine.poolclass` parameter; the :class:`.StaticPool` may
|
||||
be more feasible for single-threaded use, or :class:`.NullPool` may be used
|
||||
to prevent unencrypted connections from being held open for long periods of
|
||||
time, at the expense of slower startup time for new connections.
|
||||
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from .pysqlite import SQLiteDialect_pysqlite
|
||||
from ...engine import url as _url
|
||||
from ... import pool
|
||||
|
||||
|
||||
class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
|
||||
driver = 'pysqlcipher'
|
||||
|
||||
pragmas = ('kdf_iter', 'cipher', 'cipher_page_size', 'cipher_use_hmac')
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
try:
|
||||
from pysqlcipher import dbapi2 as sqlcipher
|
||||
except ImportError as e:
|
||||
try:
|
||||
from pysqlcipher3 import dbapi2 as sqlcipher
|
||||
except ImportError:
|
||||
raise e
|
||||
return sqlcipher
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
return pool.SingletonThreadPool
|
||||
|
||||
def connect(self, *cargs, **cparams):
|
||||
passphrase = cparams.pop('passphrase', '')
|
||||
|
||||
pragmas = dict(
|
||||
(key, cparams.pop(key, None)) for key in
|
||||
self.pragmas
|
||||
)
|
||||
|
||||
conn = super(SQLiteDialect_pysqlcipher, self).\
|
||||
connect(*cargs, **cparams)
|
||||
conn.execute('pragma key="%s"' % passphrase)
|
||||
for prag, value in pragmas.items():
|
||||
if value is not None:
|
||||
conn.execute('pragma %s="%s"' % (prag, value))
|
||||
|
||||
return conn
|
||||
|
||||
def create_connect_args(self, url):
|
||||
super_url = _url.URL(
|
||||
url.drivername, username=url.username,
|
||||
host=url.host, database=url.database, query=url.query)
|
||||
c_args, opts = super(SQLiteDialect_pysqlcipher, self).\
|
||||
create_connect_args(super_url)
|
||||
opts['passphrase'] = url.password
|
||||
return c_args, opts
|
||||
|
||||
dialect = SQLiteDialect_pysqlcipher
|
||||
@@ -1,56 +1,67 @@
|
||||
"""Support for the SQLite database via pysqlite.
|
||||
# sqlite/pysqlite.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
Note that pysqlite is the same driver as the ``sqlite3``
|
||||
module included with the Python distribution.
|
||||
r"""
|
||||
.. dialect:: sqlite+pysqlite
|
||||
:name: pysqlite
|
||||
:dbapi: sqlite3
|
||||
:connectstring: sqlite+pysqlite:///file_path
|
||||
:url: http://docs.python.org/library/sqlite3.html
|
||||
|
||||
Note that ``pysqlite`` is the same driver as the ``sqlite3``
|
||||
module included with the Python distribution.
|
||||
|
||||
Driver
|
||||
------
|
||||
|
||||
When using Python 2.5 and above, the built in ``sqlite3`` driver is
|
||||
When using Python 2.5 and above, the built in ``sqlite3`` driver is
|
||||
already installed and no additional installation is needed. Otherwise,
|
||||
the ``pysqlite2`` driver needs to be present. This is the same driver as
|
||||
``sqlite3``, just with a different name.
|
||||
|
||||
The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3``
|
||||
is loaded. This allows an explicitly installed pysqlite driver to take
|
||||
precedence over the built in one. As with all dialects, a specific
|
||||
DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control
|
||||
precedence over the built in one. As with all dialects, a specific
|
||||
DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control
|
||||
this explicitly::
|
||||
|
||||
from sqlite3 import dbapi2 as sqlite
|
||||
e = create_engine('sqlite+pysqlite:///file.db', module=sqlite)
|
||||
|
||||
Full documentation on pysqlite is available at:
|
||||
`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
|
||||
|
||||
Connect Strings
|
||||
---------------
|
||||
|
||||
The file specification for the SQLite database is taken as the "database" portion of
|
||||
the URL. Note that the format of a url is::
|
||||
The file specification for the SQLite database is taken as the "database"
|
||||
portion of the URL. Note that the format of a SQLAlchemy url is::
|
||||
|
||||
driver://user:pass@host/database
|
||||
|
||||
This means that the actual filename to be used starts with the characters to the
|
||||
**right** of the third slash. So connecting to a relative filepath looks like::
|
||||
|
||||
This means that the actual filename to be used starts with the characters to
|
||||
the **right** of the third slash. So connecting to a relative filepath
|
||||
looks like::
|
||||
|
||||
# relative path
|
||||
e = create_engine('sqlite:///path/to/database.db')
|
||||
|
||||
An absolute path, which is denoted by starting with a slash, means you need **four**
|
||||
slashes::
|
||||
|
||||
An absolute path, which is denoted by starting with a slash, means you
|
||||
need **four** slashes::
|
||||
|
||||
# absolute path
|
||||
e = create_engine('sqlite:////path/to/database.db')
|
||||
|
||||
To use a Windows path, regular drive specifications and backslashes can be used.
|
||||
Double backslashes are probably needed::
|
||||
To use a Windows path, regular drive specifications and backslashes can be
|
||||
used. Double backslashes are probably needed::
|
||||
|
||||
# absolute path on Windows
|
||||
e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db')
|
||||
e = create_engine('sqlite:///C:\\path\\to\\database.db')
|
||||
|
||||
The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify
|
||||
``sqlite://`` and nothing else::
|
||||
The sqlite ``:memory:`` identifier is the default if no filepath is
|
||||
present. Specify ``sqlite://`` and nothing else::
|
||||
|
||||
# in-memory database
|
||||
e = create_engine('sqlite://')
|
||||
@@ -58,89 +69,208 @@ The sqlite ``:memory:`` identifier is the default if no filepath is present. Sp
|
||||
Compatibility with sqlite3 "native" date and datetime types
|
||||
-----------------------------------------------------------
|
||||
|
||||
The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and
|
||||
The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and
|
||||
sqlite3.PARSE_COLNAMES options, which have the effect of any column
|
||||
or expression explicitly cast as "date" or "timestamp" will be converted
|
||||
to a Python date or datetime object. The date and datetime types provided
|
||||
with the pysqlite dialect are not currently compatible with these options,
|
||||
since they render the ISO date/datetime including microseconds, which
|
||||
to a Python date or datetime object. The date and datetime types provided
|
||||
with the pysqlite dialect are not currently compatible with these options,
|
||||
since they render the ISO date/datetime including microseconds, which
|
||||
pysqlite's driver does not. Additionally, SQLAlchemy does not at
|
||||
this time automatically render the "cast" syntax required for the
|
||||
this time automatically render the "cast" syntax required for the
|
||||
freestanding functions "current_timestamp" and "current_date" to return
|
||||
datetime/date types natively. Unfortunately, pysqlite
|
||||
does not provide the standard DBAPI types in `cursor.description`,
|
||||
leaving SQLAlchemy with no way to detect these types on the fly
|
||||
datetime/date types natively. Unfortunately, pysqlite
|
||||
does not provide the standard DBAPI types in ``cursor.description``,
|
||||
leaving SQLAlchemy with no way to detect these types on the fly
|
||||
without expensive per-row type checks.
|
||||
|
||||
Usage of PARSE_DECLTYPES can be forced if one configures
|
||||
"native_datetime=True" on create_engine()::
|
||||
Keeping in mind that pysqlite's parsing option is not recommended,
|
||||
nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES
|
||||
can be forced if one configures "native_datetime=True" on create_engine()::
|
||||
|
||||
engine = create_engine('sqlite://',
|
||||
connect_args={'detect_types': sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
|
||||
native_datetime=True
|
||||
)
|
||||
engine = create_engine('sqlite://',
|
||||
connect_args={'detect_types':
|
||||
sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
|
||||
native_datetime=True
|
||||
)
|
||||
|
||||
With this flag enabled, the DATE and TIMESTAMP types (but note - not the DATETIME
|
||||
or TIME types...confused yet ?) will not perform any bind parameter or result
|
||||
processing. Execution of "func.current_date()" will return a string.
|
||||
With this flag enabled, the DATE and TIMESTAMP types (but note - not the
|
||||
DATETIME or TIME types...confused yet ?) will not perform any bind parameter
|
||||
or result processing. Execution of "func.current_date()" will return a string.
|
||||
"func.current_timestamp()" is registered as returning a DATETIME type in
|
||||
SQLAlchemy, so this function still receives SQLAlchemy-level result processing.
|
||||
SQLAlchemy, so this function still receives SQLAlchemy-level result
|
||||
processing.
|
||||
|
||||
Threading Behavior
|
||||
------------------
|
||||
.. _pysqlite_threading_pooling:
|
||||
|
||||
Pysqlite connections do not support being moved between threads, unless
|
||||
the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition,
|
||||
when using an in-memory SQLite database, the full database exists only within
|
||||
the scope of a single connection. It is reported that an in-memory
|
||||
database does not support being shared between threads regardless of the
|
||||
``check_same_thread`` flag - which means that a multithreaded
|
||||
application **cannot** share data from a ``:memory:`` database across threads
|
||||
unless access to the connection is limited to a single worker thread which communicates
|
||||
through a queueing mechanism to concurrent threads.
|
||||
Threading/Pooling Behavior
|
||||
---------------------------
|
||||
|
||||
To provide a default which accomodates SQLite's default threading capabilities
|
||||
somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool`
|
||||
be used by default. This pool maintains a single SQLite connection per thread
|
||||
that is held open up to a count of five concurrent threads. When more than five threads
|
||||
are used, a cleanup mechanism will dispose of excess unused connections.
|
||||
Pysqlite's default behavior is to prohibit the usage of a single connection
|
||||
in more than one thread. This is originally intended to work with older
|
||||
versions of SQLite that did not support multithreaded operation under
|
||||
various circumstances. In particular, older SQLite versions
|
||||
did not allow a ``:memory:`` database to be used in multiple threads
|
||||
under any circumstances.
|
||||
|
||||
Two optional pool implementations that may be appropriate for particular SQLite usage scenarios:
|
||||
Pysqlite does include a now-undocumented flag known as
|
||||
``check_same_thread`` which will disable this check, however note that
|
||||
pysqlite connections are still not safe to use in concurrently in multiple
|
||||
threads. In particular, any statement execution calls would need to be
|
||||
externally mutexed, as Pysqlite does not provide for thread-safe propagation
|
||||
of error messages among other things. So while even ``:memory:`` databases
|
||||
can be shared among threads in modern SQLite, Pysqlite doesn't provide enough
|
||||
thread-safety to make this usage worth it.
|
||||
|
||||
* the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded
|
||||
application using an in-memory database, assuming the threading issues inherent in
|
||||
pysqlite are somehow accomodated for. This pool holds persistently onto a single connection
|
||||
which is never closed, and is returned for all requests.
|
||||
|
||||
* the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that
|
||||
makes use of a file-based sqlite database. This pool disables any actual "pooling"
|
||||
behavior, and simply opens and closes real connections corresonding to the :func:`connect()`
|
||||
and :func:`close()` methods. SQLite can "connect" to a particular file with very high
|
||||
efficiency, so this option may actually perform better without the extra overhead
|
||||
of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection
|
||||
useless since the database would be lost as soon as the connection is "returned" to the pool.
|
||||
SQLAlchemy sets up pooling to work with Pysqlite's default behavior:
|
||||
|
||||
* When a ``:memory:`` SQLite database is specified, the dialect by default
|
||||
will use :class:`.SingletonThreadPool`. This pool maintains a single
|
||||
connection per thread, so that all access to the engine within the current
|
||||
thread use the same ``:memory:`` database - other threads would access a
|
||||
different ``:memory:`` database.
|
||||
* When a file-based database is specified, the dialect will use
|
||||
:class:`.NullPool` as the source of connections. This pool closes and
|
||||
discards connections which are returned to the pool immediately. SQLite
|
||||
file-based connections have extremely low overhead, so pooling is not
|
||||
necessary. The scheme also prevents a connection from being used again in
|
||||
a different thread and works best with SQLite's coarse-grained file locking.
|
||||
|
||||
.. versionchanged:: 0.7
|
||||
Default selection of :class:`.NullPool` for SQLite file-based databases.
|
||||
Previous versions select :class:`.SingletonThreadPool` by
|
||||
default for all SQLite databases.
|
||||
|
||||
|
||||
Using a Memory Database in Multiple Threads
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
To use a ``:memory:`` database in a multithreaded scenario, the same
|
||||
connection object must be shared among threads, since the database exists
|
||||
only within the scope of that connection. The
|
||||
:class:`.StaticPool` implementation will maintain a single connection
|
||||
globally, and the ``check_same_thread`` flag can be passed to Pysqlite
|
||||
as ``False``::
|
||||
|
||||
from sqlalchemy.pool import StaticPool
|
||||
engine = create_engine('sqlite://',
|
||||
connect_args={'check_same_thread':False},
|
||||
poolclass=StaticPool)
|
||||
|
||||
Note that using a ``:memory:`` database in multiple threads requires a recent
|
||||
version of SQLite.
|
||||
|
||||
Using Temporary Tables with SQLite
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Due to the way SQLite deals with temporary tables, if you wish to use a
|
||||
temporary table in a file-based SQLite database across multiple checkouts
|
||||
from the connection pool, such as when using an ORM :class:`.Session` where
|
||||
the temporary table should continue to remain after :meth:`.Session.commit` or
|
||||
:meth:`.Session.rollback` is called, a pool which maintains a single
|
||||
connection must be used. Use :class:`.SingletonThreadPool` if the scope is
|
||||
only needed within the current thread, or :class:`.StaticPool` is scope is
|
||||
needed within multiple threads for this case::
|
||||
|
||||
# maintain the same connection per thread
|
||||
from sqlalchemy.pool import SingletonThreadPool
|
||||
engine = create_engine('sqlite:///mydb.db',
|
||||
poolclass=SingletonThreadPool)
|
||||
|
||||
|
||||
# maintain the same connection across all threads
|
||||
from sqlalchemy.pool import StaticPool
|
||||
engine = create_engine('sqlite:///mydb.db',
|
||||
poolclass=StaticPool)
|
||||
|
||||
Note that :class:`.SingletonThreadPool` should be configured for the number
|
||||
of threads that are to be used; beyond that number, connections will be
|
||||
closed out in a non deterministic way.
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's
|
||||
default behavior regarding Unicode is that all strings are returned as Python unicode objects
|
||||
in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is
|
||||
*not* used, you will still always receive unicode data back from a result set. It is
|
||||
**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type
|
||||
to represent strings, since it will raise a warning if a non-unicode Python string is
|
||||
passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can
|
||||
quickly create confusion, particularly when using the ORM as internal data is not
|
||||
always represented by an actual database result string.
|
||||
The pysqlite driver only returns Python ``unicode`` objects in result sets,
|
||||
never plain strings, and accommodates ``unicode`` objects within bound
|
||||
parameter values in all cases. Regardless of the SQLAlchemy string type in
|
||||
use, string-based result values will by Python ``unicode`` in Python 2.
|
||||
The :class:`.Unicode` type should still be used to indicate those columns that
|
||||
require unicode, however, so that non-``unicode`` values passed inadvertently
|
||||
will emit a warning. Pysqlite will emit an error if a non-``unicode`` string
|
||||
is passed containing non-ASCII characters.
|
||||
|
||||
.. _pysqlite_serializable:
|
||||
|
||||
Serializable isolation / Savepoints / Transactional DDL
|
||||
-------------------------------------------------------
|
||||
|
||||
In the section :ref:`sqlite_concurrency`, we refer to the pysqlite
|
||||
driver's assortment of issues that prevent several features of SQLite
|
||||
from working correctly. The pysqlite DBAPI driver has several
|
||||
long-standing bugs which impact the correctness of its transactional
|
||||
behavior. In its default mode of operation, SQLite features such as
|
||||
SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are
|
||||
non-functional, and in order to use these features, workarounds must
|
||||
be taken.
|
||||
|
||||
The issue is essentially that the driver attempts to second-guess the user's
|
||||
intent, failing to start transactions and sometimes ending them prematurely, in
|
||||
an effort to minimize the SQLite databases's file locking behavior, even
|
||||
though SQLite itself uses "shared" locks for read-only activities.
|
||||
|
||||
SQLAlchemy chooses to not alter this behavior by default, as it is the
|
||||
long-expected behavior of the pysqlite driver; if and when the pysqlite
|
||||
driver attempts to repair these issues, that will be more of a driver towards
|
||||
defaults for SQLAlchemy.
|
||||
|
||||
The good news is that with a few events, we can implement transactional
|
||||
support fully, by disabling pysqlite's feature entirely and emitting BEGIN
|
||||
ourselves. This is achieved using two event listeners::
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
|
||||
engine = create_engine("sqlite:///myfile.db")
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def do_connect(dbapi_connection, connection_record):
|
||||
# disable pysqlite's emitting of the BEGIN statement entirely.
|
||||
# also stops it from emitting COMMIT before any DDL.
|
||||
dbapi_connection.isolation_level = None
|
||||
|
||||
@event.listens_for(engine, "begin")
|
||||
def do_begin(conn):
|
||||
# emit our own BEGIN
|
||||
conn.execute("BEGIN")
|
||||
|
||||
Above, we intercept a new pysqlite connection and disable any transactional
|
||||
integration. Then, at the point at which SQLAlchemy knows that transaction
|
||||
scope is to begin, we emit ``"BEGIN"`` ourselves.
|
||||
|
||||
When we take control of ``"BEGIN"``, we can also control directly SQLite's
|
||||
locking modes, introduced at `BEGIN TRANSACTION <http://sqlite.org/lang_transaction.html>`_,
|
||||
by adding the desired locking mode to our ``"BEGIN"``::
|
||||
|
||||
@event.listens_for(engine, "begin")
|
||||
def do_begin(conn):
|
||||
conn.execute("BEGIN EXCLUSIVE")
|
||||
|
||||
.. seealso::
|
||||
|
||||
`BEGIN TRANSACTION <http://sqlite.org/lang_transaction.html>`_ - on the SQLite site
|
||||
|
||||
`sqlite3 SELECT does not BEGIN a transaction <http://bugs.python.org/issue9924>`_ - on the Python bug tracker
|
||||
|
||||
`sqlite3 module breaks transactions and potentially corrupts data <http://bugs.python.org/issue10740>`_ - on the Python bug tracker
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.sqlite.base import SQLiteDialect, DATETIME, DATE
|
||||
from sqlalchemy import schema, exc, pool
|
||||
from sqlalchemy.engine import default
|
||||
from sqlalchemy import exc, pool
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy import util
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class _SQLite_pysqliteTimeStamp(DATETIME):
|
||||
def bind_processor(self, dialect):
|
||||
@@ -148,43 +278,44 @@ class _SQLite_pysqliteTimeStamp(DATETIME):
|
||||
return None
|
||||
else:
|
||||
return DATETIME.bind_processor(self, dialect)
|
||||
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if dialect.native_datetime:
|
||||
return None
|
||||
else:
|
||||
return DATETIME.result_processor(self, dialect, coltype)
|
||||
|
||||
|
||||
class _SQLite_pysqliteDate(DATE):
|
||||
def bind_processor(self, dialect):
|
||||
if dialect.native_datetime:
|
||||
return None
|
||||
else:
|
||||
return DATE.bind_processor(self, dialect)
|
||||
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if dialect.native_datetime:
|
||||
return None
|
||||
else:
|
||||
return DATE.result_processor(self, dialect, coltype)
|
||||
|
||||
|
||||
class SQLiteDialect_pysqlite(SQLiteDialect):
|
||||
default_paramstyle = 'qmark'
|
||||
poolclass = pool.SingletonThreadPool
|
||||
|
||||
colspecs = util.update_copy(
|
||||
SQLiteDialect.colspecs,
|
||||
{
|
||||
sqltypes.Date:_SQLite_pysqliteDate,
|
||||
sqltypes.TIMESTAMP:_SQLite_pysqliteTimeStamp,
|
||||
sqltypes.Date: _SQLite_pysqliteDate,
|
||||
sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp,
|
||||
}
|
||||
)
|
||||
|
||||
# Py3K
|
||||
#description_encoding = None
|
||||
|
||||
|
||||
if not util.py2k:
|
||||
description_encoding = None
|
||||
|
||||
driver = 'pysqlite'
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
SQLiteDialect.__init__(self, **kwargs)
|
||||
|
||||
@@ -201,13 +332,20 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
|
||||
def dbapi(cls):
|
||||
try:
|
||||
from pysqlite2 import dbapi2 as sqlite
|
||||
except ImportError, e:
|
||||
except ImportError as e:
|
||||
try:
|
||||
from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
|
||||
from sqlite3 import dbapi2 as sqlite # try 2.5+ stdlib name.
|
||||
except ImportError:
|
||||
raise e
|
||||
return sqlite
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
if url.database and url.database != ':memory:':
|
||||
return pool.NullPool
|
||||
else:
|
||||
return pool.SingletonThreadPool
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
return self.dbapi.sqlite_version_info
|
||||
|
||||
@@ -220,6 +358,8 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
|
||||
" sqlite:///relative/path/to/file.db\n"
|
||||
" sqlite:////absolute/path/to/file.db" % (url,))
|
||||
filename = url.database or ':memory:'
|
||||
if filename != ':memory:':
|
||||
filename = os.path.abspath(filename)
|
||||
|
||||
opts = url.query.copy()
|
||||
util.coerce_kw_type(opts, 'timeout', float)
|
||||
@@ -230,7 +370,8 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
|
||||
|
||||
return ([filename], opts)
|
||||
|
||||
def is_disconnect(self, e):
|
||||
return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
return isinstance(e, self.dbapi.ProgrammingError) and \
|
||||
"Cannot operate on a closed database." in str(e)
|
||||
|
||||
dialect = SQLiteDialect_pysqlite
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
# sybase/__init__.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from sqlalchemy.dialects.sybase import base, pysybase, pyodbc
|
||||
|
||||
|
||||
from base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
|
||||
TEXT,DATE,DATETIME, FLOAT, NUMERIC,\
|
||||
BIGINT,INT, INTEGER, SMALLINT, BINARY,\
|
||||
VARBINARY,UNITEXT,UNICHAR,UNIVARCHAR,\
|
||||
IMAGE,BIT,MONEY,SMALLMONEY,TINYINT
|
||||
|
||||
# default dialect
|
||||
base.dialect = pyodbc.dialect
|
||||
|
||||
from .base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
|
||||
TEXT, DATE, DATETIME, FLOAT, NUMERIC,\
|
||||
BIGINT, INT, INTEGER, SMALLINT, BINARY,\
|
||||
VARBINARY, UNITEXT, UNICHAR, UNIVARCHAR,\
|
||||
IMAGE, BIT, MONEY, SMALLMONEY, TINYINT,\
|
||||
dialect
|
||||
|
||||
|
||||
__all__ = (
|
||||
'CHAR', 'VARCHAR', 'TIME', 'NCHAR', 'NVARCHAR',
|
||||
'TEXT','DATE','DATETIME', 'FLOAT', 'NUMERIC',
|
||||
'BIGINT','INT', 'INTEGER', 'SMALLINT', 'BINARY',
|
||||
'VARBINARY','UNITEXT','UNICHAR','UNIVARCHAR',
|
||||
'IMAGE','BIT','MONEY','SMALLMONEY','TINYINT',
|
||||
'dialect'
|
||||
'CHAR', 'VARCHAR', 'TIME', 'NCHAR', 'NVARCHAR',
|
||||
'TEXT', 'DATE', 'DATETIME', 'FLOAT', 'NUMERIC',
|
||||
'BIGINT', 'INT', 'INTEGER', 'SMALLINT', 'BINARY',
|
||||
'VARBINARY', 'UNITEXT', 'UNICHAR', 'UNIVARCHAR',
|
||||
'IMAGE', 'BIT', 'MONEY', 'SMALLMONEY', 'TINYINT',
|
||||
'dialect'
|
||||
)
|
||||
|
||||
@@ -1,21 +1,29 @@
|
||||
# sybase/base.py
|
||||
# Copyright (C) 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
# Copyright (C) 2010-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
# get_select_precolumns(), limit_clause() implementation
|
||||
# copyright (C) 2007 Fisch Asset Management
|
||||
# AG http://www.fam.ch, with coding by Alexander Houben
|
||||
# copyright (C) 2007 Fisch Asset Management
|
||||
# AG http://www.fam.ch, with coding by Alexander Houben
|
||||
# alexander.houben@thor-solutions.ch
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Support for Sybase Adaptive Server Enterprise (ASE).
|
||||
|
||||
Note that this dialect is no longer specific to Sybase iAnywhere.
|
||||
ASE is the primary support platform.
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: sybase
|
||||
:name: Sybase
|
||||
|
||||
.. note::
|
||||
|
||||
The Sybase dialect functions on current SQLAlchemy versions
|
||||
but is not regularly tested, and may have many issues and
|
||||
caveats not currently handled.
|
||||
|
||||
"""
|
||||
import operator
|
||||
import re
|
||||
|
||||
from sqlalchemy.sql import compiler, expression, text, bindparam
|
||||
from sqlalchemy.engine import default, base, reflection
|
||||
from sqlalchemy import types as sqltypes
|
||||
@@ -24,10 +32,10 @@ from sqlalchemy import schema as sa_schema
|
||||
from sqlalchemy import util, sql, exc
|
||||
|
||||
from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
|
||||
TEXT,DATE,DATETIME, FLOAT, NUMERIC,\
|
||||
BIGINT,INT, INTEGER, SMALLINT, BINARY,\
|
||||
VARBINARY, DECIMAL, TIMESTAMP, Unicode,\
|
||||
UnicodeText
|
||||
TEXT, DATE, DATETIME, FLOAT, NUMERIC,\
|
||||
BIGINT, INT, INTEGER, SMALLINT, BINARY,\
|
||||
VARBINARY, DECIMAL, TIMESTAMP, Unicode,\
|
||||
UnicodeText, REAL
|
||||
|
||||
RESERVED_WORDS = set([
|
||||
"add", "all", "alter", "and",
|
||||
@@ -86,165 +94,215 @@ RESERVED_WORDS = set([
|
||||
"when", "where", "while", "window",
|
||||
"with", "with_cube", "with_lparen", "with_rollup",
|
||||
"within", "work", "writetext",
|
||||
])
|
||||
])
|
||||
|
||||
|
||||
|
||||
class _SybaseUnitypeMixin(object):
|
||||
"""these types appear to return a buffer object."""
|
||||
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return str(value) #.decode("ucs-2")
|
||||
return str(value) # decode("ucs-2")
|
||||
else:
|
||||
return None
|
||||
return process
|
||||
|
||||
|
||||
|
||||
class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
|
||||
__visit_name__ = 'UNICHAR'
|
||||
|
||||
|
||||
class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
|
||||
__visit_name__ = 'UNIVARCHAR'
|
||||
|
||||
|
||||
class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText):
|
||||
__visit_name__ = 'UNITEXT'
|
||||
|
||||
|
||||
class TINYINT(sqltypes.Integer):
|
||||
__visit_name__ = 'TINYINT'
|
||||
|
||||
|
||||
class BIT(sqltypes.TypeEngine):
|
||||
__visit_name__ = 'BIT'
|
||||
|
||||
|
||||
|
||||
class MONEY(sqltypes.TypeEngine):
|
||||
__visit_name__ = "MONEY"
|
||||
|
||||
|
||||
class SMALLMONEY(sqltypes.TypeEngine):
|
||||
__visit_name__ = "SMALLMONEY"
|
||||
|
||||
|
||||
class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
|
||||
__visit_name__ = "UNIQUEIDENTIFIER"
|
||||
|
||||
|
||||
|
||||
class IMAGE(sqltypes.LargeBinary):
|
||||
__visit_name__ = 'IMAGE'
|
||||
|
||||
|
||||
|
||||
class SybaseTypeCompiler(compiler.GenericTypeCompiler):
|
||||
def visit_large_binary(self, type_):
|
||||
def visit_large_binary(self, type_, **kw):
|
||||
return self.visit_IMAGE(type_)
|
||||
|
||||
def visit_boolean(self, type_):
|
||||
|
||||
def visit_boolean(self, type_, **kw):
|
||||
return self.visit_BIT(type_)
|
||||
|
||||
def visit_unicode(self, type_):
|
||||
def visit_unicode(self, type_, **kw):
|
||||
return self.visit_NVARCHAR(type_)
|
||||
|
||||
def visit_UNICHAR(self, type_):
|
||||
def visit_UNICHAR(self, type_, **kw):
|
||||
return "UNICHAR(%d)" % type_.length
|
||||
|
||||
def visit_UNIVARCHAR(self, type_):
|
||||
def visit_UNIVARCHAR(self, type_, **kw):
|
||||
return "UNIVARCHAR(%d)" % type_.length
|
||||
|
||||
def visit_UNITEXT(self, type_):
|
||||
def visit_UNITEXT(self, type_, **kw):
|
||||
return "UNITEXT"
|
||||
|
||||
def visit_TINYINT(self, type_):
|
||||
def visit_TINYINT(self, type_, **kw):
|
||||
return "TINYINT"
|
||||
|
||||
def visit_IMAGE(self, type_):
|
||||
|
||||
def visit_IMAGE(self, type_, **kw):
|
||||
return "IMAGE"
|
||||
|
||||
def visit_BIT(self, type_):
|
||||
def visit_BIT(self, type_, **kw):
|
||||
return "BIT"
|
||||
|
||||
def visit_MONEY(self, type_):
|
||||
def visit_MONEY(self, type_, **kw):
|
||||
return "MONEY"
|
||||
|
||||
def visit_SMALLMONEY(self, type_):
|
||||
|
||||
def visit_SMALLMONEY(self, type_, **kw):
|
||||
return "SMALLMONEY"
|
||||
|
||||
def visit_UNIQUEIDENTIFIER(self, type_):
|
||||
|
||||
def visit_UNIQUEIDENTIFIER(self, type_, **kw):
|
||||
return "UNIQUEIDENTIFIER"
|
||||
|
||||
|
||||
ischema_names = {
|
||||
'integer' : INTEGER,
|
||||
'unsigned int' : INTEGER, # TODO: unsigned flags
|
||||
'unsigned smallint' : SMALLINT, # TODO: unsigned flags
|
||||
'unsigned bigint' : BIGINT, # TODO: unsigned flags
|
||||
'bigint': BIGINT,
|
||||
'smallint' : SMALLINT,
|
||||
'tinyint' : TINYINT,
|
||||
'varchar' : VARCHAR,
|
||||
'long varchar' : TEXT, # TODO
|
||||
'char' : CHAR,
|
||||
'decimal' : DECIMAL,
|
||||
'numeric' : NUMERIC,
|
||||
'float' : FLOAT,
|
||||
'double' : NUMERIC, # TODO
|
||||
'binary' : BINARY,
|
||||
'varbinary' : VARBINARY,
|
||||
'bit': BIT,
|
||||
'image' : IMAGE,
|
||||
'timestamp': TIMESTAMP,
|
||||
'int': INTEGER,
|
||||
'integer': INTEGER,
|
||||
'smallint': SMALLINT,
|
||||
'tinyint': TINYINT,
|
||||
'unsigned bigint': BIGINT, # TODO: unsigned flags
|
||||
'unsigned int': INTEGER, # TODO: unsigned flags
|
||||
'unsigned smallint': SMALLINT, # TODO: unsigned flags
|
||||
'numeric': NUMERIC,
|
||||
'decimal': DECIMAL,
|
||||
'dec': DECIMAL,
|
||||
'float': FLOAT,
|
||||
'double': NUMERIC, # TODO
|
||||
'double precision': NUMERIC, # TODO
|
||||
'real': REAL,
|
||||
'smallmoney': SMALLMONEY,
|
||||
'money': MONEY,
|
||||
'smallmoney': MONEY,
|
||||
'smalldatetime': DATETIME,
|
||||
'datetime': DATETIME,
|
||||
'date': DATE,
|
||||
'time': TIME,
|
||||
'char': CHAR,
|
||||
'character': CHAR,
|
||||
'varchar': VARCHAR,
|
||||
'character varying': VARCHAR,
|
||||
'char varying': VARCHAR,
|
||||
'unichar': UNICHAR,
|
||||
'unicode character': UNIVARCHAR,
|
||||
'nchar': NCHAR,
|
||||
'national char': NCHAR,
|
||||
'national character': NCHAR,
|
||||
'nvarchar': NVARCHAR,
|
||||
'nchar varying': NVARCHAR,
|
||||
'national char varying': NVARCHAR,
|
||||
'national character varying': NVARCHAR,
|
||||
'text': TEXT,
|
||||
'unitext': UNITEXT,
|
||||
'binary': BINARY,
|
||||
'varbinary': VARBINARY,
|
||||
'image': IMAGE,
|
||||
'bit': BIT,
|
||||
|
||||
# not in documentation for ASE 15.7
|
||||
'long varchar': TEXT, # TODO
|
||||
'timestamp': TIMESTAMP,
|
||||
'uniqueidentifier': UNIQUEIDENTIFIER,
|
||||
|
||||
}
|
||||
|
||||
|
||||
class SybaseInspector(reflection.Inspector):
|
||||
|
||||
def __init__(self, conn):
|
||||
reflection.Inspector.__init__(self, conn)
|
||||
|
||||
def get_table_id(self, table_name, schema=None):
|
||||
"""Return the table id from `table_name` and `schema`."""
|
||||
|
||||
return self.dialect.get_table_id(self.bind, table_name, schema,
|
||||
info_cache=self.info_cache)
|
||||
|
||||
|
||||
class SybaseExecutionContext(default.DefaultExecutionContext):
|
||||
_enable_identity_insert = False
|
||||
|
||||
|
||||
def set_ddl_autocommit(self, connection, value):
|
||||
"""Must be implemented by subclasses to accommodate DDL executions.
|
||||
|
||||
|
||||
"connection" is the raw unwrapped DBAPI connection. "value"
|
||||
is True or False. when True, the connection should be configured
|
||||
such that a DDL can take place subsequently. when False,
|
||||
a DDL has taken place and the connection should be resumed
|
||||
into non-autocommit mode.
|
||||
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def pre_exec(self):
|
||||
if self.isinsert:
|
||||
tbl = self.compiled.statement.table
|
||||
seq_column = tbl._autoincrement_column
|
||||
insert_has_sequence = seq_column is not None
|
||||
|
||||
|
||||
if insert_has_sequence:
|
||||
self._enable_identity_insert = seq_column.key in self.compiled_parameters[0]
|
||||
self._enable_identity_insert = \
|
||||
seq_column.key in self.compiled_parameters[0]
|
||||
else:
|
||||
self._enable_identity_insert = False
|
||||
|
||||
|
||||
if self._enable_identity_insert:
|
||||
self.cursor.execute("SET IDENTITY_INSERT %s ON" %
|
||||
self.cursor.execute(
|
||||
"SET IDENTITY_INSERT %s ON" %
|
||||
self.dialect.identifier_preparer.format_table(tbl))
|
||||
|
||||
if self.isddl:
|
||||
# TODO: to enhance this, we can detect "ddl in tran" on the
|
||||
# database settings. this error message should be improved to
|
||||
# database settings. this error message should be improved to
|
||||
# include a note about that.
|
||||
if not self.should_autocommit:
|
||||
raise exc.InvalidRequestError("The Sybase dialect only supports "
|
||||
"DDL in 'autocommit' mode at this time.")
|
||||
raise exc.InvalidRequestError(
|
||||
"The Sybase dialect only supports "
|
||||
"DDL in 'autocommit' mode at this time.")
|
||||
|
||||
self.root_connection.engine.logger.info("AUTOCOMMIT (Assuming no Sybase 'ddl in tran')")
|
||||
self.root_connection.engine.logger.info(
|
||||
"AUTOCOMMIT (Assuming no Sybase 'ddl in tran')")
|
||||
|
||||
self.set_ddl_autocommit(self.root_connection.connection.connection, True)
|
||||
|
||||
self.set_ddl_autocommit(
|
||||
self.root_connection.connection.connection,
|
||||
True)
|
||||
|
||||
def post_exec(self):
|
||||
if self.isddl:
|
||||
if self.isddl:
|
||||
self.set_ddl_autocommit(self.root_connection, False)
|
||||
|
||||
if self._enable_identity_insert:
|
||||
|
||||
if self._enable_identity_insert:
|
||||
self.cursor.execute(
|
||||
"SET IDENTITY_INSERT %s OFF" %
|
||||
self.dialect.identifier_preparer.
|
||||
format_table(self.compiled.statement.table)
|
||||
)
|
||||
"SET IDENTITY_INSERT %s OFF" %
|
||||
self.dialect.identifier_preparer.
|
||||
format_table(self.compiled.statement.table)
|
||||
)
|
||||
|
||||
def get_lastrowid(self):
|
||||
cursor = self.create_cursor()
|
||||
@@ -253,46 +311,52 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
|
||||
cursor.close()
|
||||
return lastrowid
|
||||
|
||||
|
||||
class SybaseSQLCompiler(compiler.SQLCompiler):
|
||||
ansi_bind_rules = True
|
||||
|
||||
extract_map = util.update_copy(
|
||||
compiler.SQLCompiler.extract_map,
|
||||
{
|
||||
'doy': 'dayofyear',
|
||||
'dow': 'weekday',
|
||||
'milliseconds': 'millisecond'
|
||||
})
|
||||
'doy': 'dayofyear',
|
||||
'dow': 'weekday',
|
||||
'milliseconds': 'millisecond'
|
||||
})
|
||||
|
||||
def get_select_precolumns(self, select):
|
||||
def get_select_precolumns(self, select, **kw):
|
||||
s = select._distinct and "DISTINCT " or ""
|
||||
if select._limit:
|
||||
#if select._limit == 1:
|
||||
#s += "FIRST "
|
||||
#else:
|
||||
#s += "TOP %s " % (select._limit,)
|
||||
s += "TOP %s " % (select._limit,)
|
||||
if select._offset:
|
||||
if not select._limit:
|
||||
# FIXME: sybase doesn't allow an offset without a limit
|
||||
# so use a huge value for TOP here
|
||||
s += "TOP 1000000 "
|
||||
s += "START AT %s " % (select._offset+1,)
|
||||
# TODO: don't think Sybase supports
|
||||
# bind params for FIRST / TOP
|
||||
limit = select._limit
|
||||
if limit:
|
||||
# if select._limit == 1:
|
||||
# s += "FIRST "
|
||||
# else:
|
||||
# s += "TOP %s " % (select._limit,)
|
||||
s += "TOP %s " % (limit,)
|
||||
offset = select._offset
|
||||
if offset:
|
||||
raise NotImplementedError("Sybase ASE does not support OFFSET")
|
||||
return s
|
||||
|
||||
def get_from_hint_text(self, table, text):
|
||||
return text
|
||||
|
||||
def limit_clause(self, select):
|
||||
def limit_clause(self, select, **kw):
|
||||
# Limit in sybase is after the select keyword
|
||||
return ""
|
||||
|
||||
def visit_extract(self, extract, **kw):
|
||||
field = self.extract_map.get(extract.field, extract.field)
|
||||
return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
|
||||
return 'DATEPART("%s", %s)' % (
|
||||
field, self.process(extract.expr, **kw))
|
||||
|
||||
def visit_now_func(self, fn, **kw):
|
||||
return "GETDATE()"
|
||||
|
||||
def for_update_clause(self, select):
|
||||
# "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
|
||||
# "FOR UPDATE" is only allowed on "DECLARE CURSOR"
|
||||
# which SQLAlchemy doesn't use
|
||||
return ''
|
||||
|
||||
def order_by_clause(self, select, **kw):
|
||||
@@ -309,18 +373,22 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
|
||||
class SybaseDDLCompiler(compiler.DDLCompiler):
|
||||
def get_column_specification(self, column, **kwargs):
|
||||
colspec = self.preparer.format_column(column) + " " + \
|
||||
self.dialect.type_compiler.process(column.type)
|
||||
self.dialect.type_compiler.process(
|
||||
column.type, type_expression=column)
|
||||
|
||||
if column.table is None:
|
||||
raise exc.InvalidRequestError("The Sybase dialect requires Table-bound "\
|
||||
"columns in order to generate DDL")
|
||||
raise exc.CompileError(
|
||||
"The Sybase dialect requires Table-bound "
|
||||
"columns in order to generate DDL")
|
||||
seq_col = column.table._autoincrement_column
|
||||
|
||||
# install a IDENTITY Sequence if we have an implicit IDENTITY column
|
||||
if seq_col is column:
|
||||
sequence = isinstance(column.default, sa_schema.Sequence) and column.default
|
||||
sequence = isinstance(column.default, sa_schema.Sequence) \
|
||||
and column.default
|
||||
if sequence:
|
||||
start, increment = sequence.start or 1, sequence.increment or 1
|
||||
start, increment = sequence.start or 1, \
|
||||
sequence.increment or 1
|
||||
else:
|
||||
start, increment = 1, 1
|
||||
if (start, increment) == (1, 1):
|
||||
@@ -329,28 +397,31 @@ class SybaseDDLCompiler(compiler.DDLCompiler):
|
||||
# TODO: need correct syntax for this
|
||||
colspec += " IDENTITY(%s,%s)" % (start, increment)
|
||||
else:
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
if column.nullable is not None:
|
||||
if not column.nullable or column.primary_key:
|
||||
colspec += " NOT NULL"
|
||||
else:
|
||||
colspec += " NULL"
|
||||
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
return colspec
|
||||
|
||||
def visit_drop_index(self, drop):
|
||||
index = drop.element
|
||||
return "\nDROP INDEX %s.%s" % (
|
||||
self.preparer.quote_identifier(index.table.name),
|
||||
self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
|
||||
)
|
||||
self._prepared_index_name(drop.element,
|
||||
include_schema=False)
|
||||
)
|
||||
|
||||
|
||||
class SybaseIdentifierPreparer(compiler.IdentifierPreparer):
|
||||
reserved_words = RESERVED_WORDS
|
||||
|
||||
|
||||
class SybaseDialect(default.DefaultDialect):
|
||||
name = 'sybase'
|
||||
supports_unicode_statements = False
|
||||
@@ -368,53 +439,383 @@ class SybaseDialect(default.DefaultDialect):
|
||||
statement_compiler = SybaseSQLCompiler
|
||||
ddl_compiler = SybaseDDLCompiler
|
||||
preparer = SybaseIdentifierPreparer
|
||||
inspector = SybaseInspector
|
||||
|
||||
construct_arguments = []
|
||||
|
||||
def _get_default_schema_name(self, connection):
|
||||
return connection.scalar(
|
||||
text("SELECT user_name() as user_name", typemap={'user_name':Unicode})
|
||||
)
|
||||
text("SELECT user_name() as user_name",
|
||||
typemap={'user_name': Unicode})
|
||||
)
|
||||
|
||||
def initialize(self, connection):
|
||||
super(SybaseDialect, self).initialize(connection)
|
||||
if self.server_version_info is not None and\
|
||||
self.server_version_info < (15, ):
|
||||
self.server_version_info < (15, ):
|
||||
self.max_identifier_length = 30
|
||||
else:
|
||||
self.max_identifier_length = 255
|
||||
|
||||
|
||||
def get_table_id(self, connection, table_name, schema=None, **kw):
|
||||
"""Fetch the id for schema.table_name.
|
||||
|
||||
Several reflection methods require the table id. The idea for using
|
||||
this method is that it can be fetched one time and cached for
|
||||
subsequent calls.
|
||||
|
||||
"""
|
||||
|
||||
table_id = None
|
||||
if schema is None:
|
||||
schema = self.default_schema_name
|
||||
|
||||
TABLEID_SQL = text("""
|
||||
SELECT o.id AS id
|
||||
FROM sysobjects o JOIN sysusers u ON o.uid=u.uid
|
||||
WHERE u.name = :schema_name
|
||||
AND o.name = :table_name
|
||||
AND o.type in ('U', 'V')
|
||||
""")
|
||||
|
||||
if util.py2k:
|
||||
if isinstance(schema, unicode):
|
||||
schema = schema.encode("ascii")
|
||||
if isinstance(table_name, unicode):
|
||||
table_name = table_name.encode("ascii")
|
||||
result = connection.execute(TABLEID_SQL,
|
||||
schema_name=schema,
|
||||
table_name=table_name)
|
||||
table_id = result.scalar()
|
||||
if table_id is None:
|
||||
raise exc.NoSuchTableError(table_name)
|
||||
return table_id
|
||||
|
||||
@reflection.cache
|
||||
def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
table_id = self.get_table_id(connection, table_name, schema,
|
||||
info_cache=kw.get("info_cache"))
|
||||
|
||||
COLUMN_SQL = text("""
|
||||
SELECT col.name AS name,
|
||||
t.name AS type,
|
||||
(col.status & 8) AS nullable,
|
||||
(col.status & 128) AS autoincrement,
|
||||
com.text AS 'default',
|
||||
col.prec AS precision,
|
||||
col.scale AS scale,
|
||||
col.length AS length
|
||||
FROM systypes t, syscolumns col LEFT OUTER JOIN syscomments com ON
|
||||
col.cdefault = com.id
|
||||
WHERE col.usertype = t.usertype
|
||||
AND col.id = :table_id
|
||||
ORDER BY col.colid
|
||||
""")
|
||||
|
||||
results = connection.execute(COLUMN_SQL, table_id=table_id)
|
||||
|
||||
columns = []
|
||||
for (name, type_, nullable, autoincrement, default, precision, scale,
|
||||
length) in results:
|
||||
col_info = self._get_column_info(name, type_, bool(nullable),
|
||||
bool(autoincrement),
|
||||
default, precision, scale,
|
||||
length)
|
||||
columns.append(col_info)
|
||||
|
||||
return columns
|
||||
|
||||
def _get_column_info(self, name, type_, nullable, autoincrement, default,
|
||||
precision, scale, length):
|
||||
|
||||
coltype = self.ischema_names.get(type_, None)
|
||||
|
||||
kwargs = {}
|
||||
|
||||
if coltype in (NUMERIC, DECIMAL):
|
||||
args = (precision, scale)
|
||||
elif coltype == FLOAT:
|
||||
args = (precision,)
|
||||
elif coltype in (CHAR, VARCHAR, UNICHAR, UNIVARCHAR, NCHAR, NVARCHAR):
|
||||
args = (length,)
|
||||
else:
|
||||
args = ()
|
||||
|
||||
if coltype:
|
||||
coltype = coltype(*args, **kwargs)
|
||||
# is this necessary
|
||||
# if is_array:
|
||||
# coltype = ARRAY(coltype)
|
||||
else:
|
||||
util.warn("Did not recognize type '%s' of column '%s'" %
|
||||
(type_, name))
|
||||
coltype = sqltypes.NULLTYPE
|
||||
|
||||
if default:
|
||||
default = default.replace("DEFAULT", "").strip()
|
||||
default = re.sub("^'(.*)'$", lambda m: m.group(1), default)
|
||||
else:
|
||||
default = None
|
||||
|
||||
column_info = dict(name=name, type=coltype, nullable=nullable,
|
||||
default=default, autoincrement=autoincrement)
|
||||
return column_info
|
||||
|
||||
@reflection.cache
|
||||
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
|
||||
|
||||
table_id = self.get_table_id(connection, table_name, schema,
|
||||
info_cache=kw.get("info_cache"))
|
||||
|
||||
table_cache = {}
|
||||
column_cache = {}
|
||||
foreign_keys = []
|
||||
|
||||
table_cache[table_id] = {"name": table_name, "schema": schema}
|
||||
|
||||
COLUMN_SQL = text("""
|
||||
SELECT c.colid AS id, c.name AS name
|
||||
FROM syscolumns c
|
||||
WHERE c.id = :table_id
|
||||
""")
|
||||
|
||||
results = connection.execute(COLUMN_SQL, table_id=table_id)
|
||||
columns = {}
|
||||
for col in results:
|
||||
columns[col["id"]] = col["name"]
|
||||
column_cache[table_id] = columns
|
||||
|
||||
REFCONSTRAINT_SQL = text("""
|
||||
SELECT o.name AS name, r.reftabid AS reftable_id,
|
||||
r.keycnt AS 'count',
|
||||
r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3,
|
||||
r.fokey4 AS fokey4, r.fokey5 AS fokey5, r.fokey6 AS fokey6,
|
||||
r.fokey7 AS fokey7, r.fokey1 AS fokey8, r.fokey9 AS fokey9,
|
||||
r.fokey10 AS fokey10, r.fokey11 AS fokey11, r.fokey12 AS fokey12,
|
||||
r.fokey13 AS fokey13, r.fokey14 AS fokey14, r.fokey15 AS fokey15,
|
||||
r.fokey16 AS fokey16,
|
||||
r.refkey1 AS refkey1, r.refkey2 AS refkey2, r.refkey3 AS refkey3,
|
||||
r.refkey4 AS refkey4, r.refkey5 AS refkey5, r.refkey6 AS refkey6,
|
||||
r.refkey7 AS refkey7, r.refkey1 AS refkey8, r.refkey9 AS refkey9,
|
||||
r.refkey10 AS refkey10, r.refkey11 AS refkey11,
|
||||
r.refkey12 AS refkey12, r.refkey13 AS refkey13,
|
||||
r.refkey14 AS refkey14, r.refkey15 AS refkey15,
|
||||
r.refkey16 AS refkey16
|
||||
FROM sysreferences r JOIN sysobjects o on r.tableid = o.id
|
||||
WHERE r.tableid = :table_id
|
||||
""")
|
||||
referential_constraints = connection.execute(
|
||||
REFCONSTRAINT_SQL, table_id=table_id).fetchall()
|
||||
|
||||
REFTABLE_SQL = text("""
|
||||
SELECT o.name AS name, u.name AS 'schema'
|
||||
FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
|
||||
WHERE o.id = :table_id
|
||||
""")
|
||||
|
||||
for r in referential_constraints:
|
||||
reftable_id = r["reftable_id"]
|
||||
|
||||
if reftable_id not in table_cache:
|
||||
c = connection.execute(REFTABLE_SQL, table_id=reftable_id)
|
||||
reftable = c.fetchone()
|
||||
c.close()
|
||||
table_info = {"name": reftable["name"], "schema": None}
|
||||
if (schema is not None or
|
||||
reftable["schema"] != self.default_schema_name):
|
||||
table_info["schema"] = reftable["schema"]
|
||||
|
||||
table_cache[reftable_id] = table_info
|
||||
results = connection.execute(COLUMN_SQL, table_id=reftable_id)
|
||||
reftable_columns = {}
|
||||
for col in results:
|
||||
reftable_columns[col["id"]] = col["name"]
|
||||
column_cache[reftable_id] = reftable_columns
|
||||
|
||||
reftable = table_cache[reftable_id]
|
||||
reftable_columns = column_cache[reftable_id]
|
||||
|
||||
constrained_columns = []
|
||||
referred_columns = []
|
||||
for i in range(1, r["count"] + 1):
|
||||
constrained_columns.append(columns[r["fokey%i" % i]])
|
||||
referred_columns.append(reftable_columns[r["refkey%i" % i]])
|
||||
|
||||
fk_info = {
|
||||
"constrained_columns": constrained_columns,
|
||||
"referred_schema": reftable["schema"],
|
||||
"referred_table": reftable["name"],
|
||||
"referred_columns": referred_columns,
|
||||
"name": r["name"]
|
||||
}
|
||||
|
||||
foreign_keys.append(fk_info)
|
||||
|
||||
return foreign_keys
|
||||
|
||||
@reflection.cache
|
||||
def get_indexes(self, connection, table_name, schema=None, **kw):
|
||||
table_id = self.get_table_id(connection, table_name, schema,
|
||||
info_cache=kw.get("info_cache"))
|
||||
|
||||
INDEX_SQL = text("""
|
||||
SELECT object_name(i.id) AS table_name,
|
||||
i.keycnt AS 'count',
|
||||
i.name AS name,
|
||||
(i.status & 0x2) AS 'unique',
|
||||
index_col(object_name(i.id), i.indid, 1) AS col_1,
|
||||
index_col(object_name(i.id), i.indid, 2) AS col_2,
|
||||
index_col(object_name(i.id), i.indid, 3) AS col_3,
|
||||
index_col(object_name(i.id), i.indid, 4) AS col_4,
|
||||
index_col(object_name(i.id), i.indid, 5) AS col_5,
|
||||
index_col(object_name(i.id), i.indid, 6) AS col_6,
|
||||
index_col(object_name(i.id), i.indid, 7) AS col_7,
|
||||
index_col(object_name(i.id), i.indid, 8) AS col_8,
|
||||
index_col(object_name(i.id), i.indid, 9) AS col_9,
|
||||
index_col(object_name(i.id), i.indid, 10) AS col_10,
|
||||
index_col(object_name(i.id), i.indid, 11) AS col_11,
|
||||
index_col(object_name(i.id), i.indid, 12) AS col_12,
|
||||
index_col(object_name(i.id), i.indid, 13) AS col_13,
|
||||
index_col(object_name(i.id), i.indid, 14) AS col_14,
|
||||
index_col(object_name(i.id), i.indid, 15) AS col_15,
|
||||
index_col(object_name(i.id), i.indid, 16) AS col_16
|
||||
FROM sysindexes i, sysobjects o
|
||||
WHERE o.id = i.id
|
||||
AND o.id = :table_id
|
||||
AND (i.status & 2048) = 0
|
||||
AND i.indid BETWEEN 1 AND 254
|
||||
""")
|
||||
|
||||
results = connection.execute(INDEX_SQL, table_id=table_id)
|
||||
indexes = []
|
||||
for r in results:
|
||||
column_names = []
|
||||
for i in range(1, r["count"]):
|
||||
column_names.append(r["col_%i" % (i,)])
|
||||
index_info = {"name": r["name"],
|
||||
"unique": bool(r["unique"]),
|
||||
"column_names": column_names}
|
||||
indexes.append(index_info)
|
||||
|
||||
return indexes
|
||||
|
||||
@reflection.cache
|
||||
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
|
||||
table_id = self.get_table_id(connection, table_name, schema,
|
||||
info_cache=kw.get("info_cache"))
|
||||
|
||||
PK_SQL = text("""
|
||||
SELECT object_name(i.id) AS table_name,
|
||||
i.keycnt AS 'count',
|
||||
i.name AS name,
|
||||
index_col(object_name(i.id), i.indid, 1) AS pk_1,
|
||||
index_col(object_name(i.id), i.indid, 2) AS pk_2,
|
||||
index_col(object_name(i.id), i.indid, 3) AS pk_3,
|
||||
index_col(object_name(i.id), i.indid, 4) AS pk_4,
|
||||
index_col(object_name(i.id), i.indid, 5) AS pk_5,
|
||||
index_col(object_name(i.id), i.indid, 6) AS pk_6,
|
||||
index_col(object_name(i.id), i.indid, 7) AS pk_7,
|
||||
index_col(object_name(i.id), i.indid, 8) AS pk_8,
|
||||
index_col(object_name(i.id), i.indid, 9) AS pk_9,
|
||||
index_col(object_name(i.id), i.indid, 10) AS pk_10,
|
||||
index_col(object_name(i.id), i.indid, 11) AS pk_11,
|
||||
index_col(object_name(i.id), i.indid, 12) AS pk_12,
|
||||
index_col(object_name(i.id), i.indid, 13) AS pk_13,
|
||||
index_col(object_name(i.id), i.indid, 14) AS pk_14,
|
||||
index_col(object_name(i.id), i.indid, 15) AS pk_15,
|
||||
index_col(object_name(i.id), i.indid, 16) AS pk_16
|
||||
FROM sysindexes i, sysobjects o
|
||||
WHERE o.id = i.id
|
||||
AND o.id = :table_id
|
||||
AND (i.status & 2048) = 2048
|
||||
AND i.indid BETWEEN 1 AND 254
|
||||
""")
|
||||
|
||||
results = connection.execute(PK_SQL, table_id=table_id)
|
||||
pks = results.fetchone()
|
||||
results.close()
|
||||
|
||||
constrained_columns = []
|
||||
if pks:
|
||||
for i in range(1, pks["count"] + 1):
|
||||
constrained_columns.append(pks["pk_%i" % (i,)])
|
||||
return {"constrained_columns": constrained_columns,
|
||||
"name": pks["name"]}
|
||||
else:
|
||||
return {"constrained_columns": [], "name": None}
|
||||
|
||||
@reflection.cache
|
||||
def get_schema_names(self, connection, **kw):
|
||||
|
||||
SCHEMA_SQL = text("SELECT u.name AS name FROM sysusers u")
|
||||
|
||||
schemas = connection.execute(SCHEMA_SQL)
|
||||
|
||||
return [s["name"] for s in schemas]
|
||||
|
||||
@reflection.cache
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
if schema is None:
|
||||
schema = self.default_schema_name
|
||||
|
||||
result = connection.execute(
|
||||
text("select sysobjects.name from sysobjects, sysusers "
|
||||
"where sysobjects.uid=sysusers.uid and "
|
||||
"sysusers.name=:schemaname and "
|
||||
"sysobjects.type='U'",
|
||||
bindparams=[
|
||||
bindparam('schemaname', schema)
|
||||
])
|
||||
)
|
||||
return [r[0] for r in result]
|
||||
TABLE_SQL = text("""
|
||||
SELECT o.name AS name
|
||||
FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
|
||||
WHERE u.name = :schema_name
|
||||
AND o.type = 'U'
|
||||
""")
|
||||
|
||||
def has_table(self, connection, tablename, schema=None):
|
||||
if util.py2k:
|
||||
if isinstance(schema, unicode):
|
||||
schema = schema.encode("ascii")
|
||||
|
||||
tables = connection.execute(TABLE_SQL, schema_name=schema)
|
||||
|
||||
return [t["name"] for t in tables]
|
||||
|
||||
@reflection.cache
|
||||
def get_view_definition(self, connection, view_name, schema=None, **kw):
|
||||
if schema is None:
|
||||
schema = self.default_schema_name
|
||||
|
||||
result = connection.execute(
|
||||
text("select sysobjects.name from sysobjects, sysusers "
|
||||
"where sysobjects.uid=sysusers.uid and "
|
||||
"sysobjects.name=:tablename and "
|
||||
"sysusers.name=:schemaname and "
|
||||
"sysobjects.type='U'",
|
||||
bindparams=[
|
||||
bindparam('tablename', tablename),
|
||||
bindparam('schemaname', schema)
|
||||
])
|
||||
)
|
||||
return result.scalar() is not None
|
||||
VIEW_DEF_SQL = text("""
|
||||
SELECT c.text
|
||||
FROM syscomments c JOIN sysobjects o ON c.id = o.id
|
||||
WHERE o.name = :view_name
|
||||
AND o.type = 'V'
|
||||
""")
|
||||
|
||||
def reflecttable(self, connection, table, include_columns):
|
||||
raise NotImplementedError()
|
||||
if util.py2k:
|
||||
if isinstance(view_name, unicode):
|
||||
view_name = view_name.encode("ascii")
|
||||
|
||||
view = connection.execute(VIEW_DEF_SQL, view_name=view_name)
|
||||
|
||||
return view.scalar()
|
||||
|
||||
@reflection.cache
|
||||
def get_view_names(self, connection, schema=None, **kw):
|
||||
if schema is None:
|
||||
schema = self.default_schema_name
|
||||
|
||||
VIEW_SQL = text("""
|
||||
SELECT o.name AS name
|
||||
FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
|
||||
WHERE u.name = :schema_name
|
||||
AND o.type = 'V'
|
||||
""")
|
||||
|
||||
if util.py2k:
|
||||
if isinstance(schema, unicode):
|
||||
schema = schema.encode("ascii")
|
||||
views = connection.execute(VIEW_SQL, schema_name=schema)
|
||||
|
||||
return [v["name"] for v in views]
|
||||
|
||||
def has_table(self, connection, table_name, schema=None):
|
||||
try:
|
||||
self.get_table_id(connection, table_name, schema)
|
||||
except exc.NoSuchTableError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@@ -1,16 +1,32 @@
|
||||
# sybase/mxodbc.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""
|
||||
Support for Sybase via mxodbc.
|
||||
|
||||
This dialect is a stub only and is likely non functional at this time.
|
||||
.. dialect:: sybase+mxodbc
|
||||
:name: mxODBC
|
||||
:dbapi: mxodbc
|
||||
:connectstring: sybase+mxodbc://<username>:<password>@<dsnname>
|
||||
:url: http://www.egenix.com/
|
||||
|
||||
.. note::
|
||||
|
||||
This dialect is a stub only and is likely non functional at this time.
|
||||
|
||||
|
||||
"""
|
||||
from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
|
||||
from sqlalchemy.dialects.sybase.base import SybaseDialect
|
||||
from sqlalchemy.dialects.sybase.base import SybaseExecutionContext
|
||||
from sqlalchemy.connectors.mxodbc import MxODBCConnector
|
||||
|
||||
|
||||
class SybaseExecutionContext_mxodbc(SybaseExecutionContext):
|
||||
pass
|
||||
|
||||
|
||||
class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect):
|
||||
execution_ctx_cls = SybaseExecutionContext_mxodbc
|
||||
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
# sybase/pyodbc.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
Support for Sybase via pyodbc.
|
||||
.. dialect:: sybase+pyodbc
|
||||
:name: PyODBC
|
||||
:dbapi: pyodbc
|
||||
:connectstring: sybase+pyodbc://<username>:<password>@<dsnname>\
|
||||
[/<database>]
|
||||
:url: http://pypi.python.org/pypi/pyodbc/
|
||||
|
||||
http://pypi.python.org/pypi/pyodbc/
|
||||
|
||||
Connect strings are of the form::
|
||||
|
||||
sybase+pyodbc://<username>:<password>@<dsn>/
|
||||
sybase+pyodbc://<username>:<password>@<host>/<database>
|
||||
|
||||
Unicode Support
|
||||
---------------
|
||||
|
||||
The pyodbc driver currently supports usage of these Sybase types with
|
||||
The pyodbc driver currently supports usage of these Sybase types with
|
||||
Unicode or multibyte strings::
|
||||
|
||||
CHAR
|
||||
@@ -25,25 +31,28 @@ Currently *not* supported are::
|
||||
UNICHAR
|
||||
UNITEXT
|
||||
UNIVARCHAR
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
|
||||
from sqlalchemy.dialects.sybase.base import SybaseDialect,\
|
||||
SybaseExecutionContext
|
||||
from sqlalchemy.connectors.pyodbc import PyODBCConnector
|
||||
from sqlalchemy import types as sqltypes, processors
|
||||
import decimal
|
||||
from sqlalchemy import types as sqltypes, util, processors
|
||||
|
||||
|
||||
class _SybNumeric_pyodbc(sqltypes.Numeric):
|
||||
"""Turns Decimals with adjusted() < -6 into floats.
|
||||
|
||||
It's not yet known how to get decimals with many
|
||||
|
||||
It's not yet known how to get decimals with many
|
||||
significant digits or very large adjusted() into Sybase
|
||||
via pyodbc.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect)
|
||||
super_process = super(_SybNumeric_pyodbc, self).\
|
||||
bind_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
if self.asdecimal and \
|
||||
@@ -58,6 +67,7 @@ class _SybNumeric_pyodbc(sqltypes.Numeric):
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
|
||||
def set_ddl_autocommit(self, connection, value):
|
||||
if value:
|
||||
@@ -65,11 +75,12 @@ class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
|
||||
else:
|
||||
connection.autocommit = False
|
||||
|
||||
|
||||
class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect):
|
||||
execution_ctx_cls = SybaseExecutionContext_pyodbc
|
||||
|
||||
colspecs = {
|
||||
sqltypes.Numeric:_SybNumeric_pyodbc,
|
||||
sqltypes.Numeric: _SybNumeric_pyodbc,
|
||||
}
|
||||
|
||||
dialect = SybaseDialect_pyodbc
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
# pysybase.py
|
||||
# Copyright (C) 2010 Michael Bayer mike_mp@zzzcomputing.com
|
||||
# sybase/pysybase.py
|
||||
# Copyright (C) 2010-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
Support for Sybase via the python-sybase driver.
|
||||
|
||||
http://python-sybase.sourceforge.net/
|
||||
|
||||
Connect strings are of the form::
|
||||
|
||||
sybase+pysybase://<username>:<password>@<dsn>/[database name]
|
||||
.. dialect:: sybase+pysybase
|
||||
:name: Python-Sybase
|
||||
:dbapi: Sybase
|
||||
:connectstring: sybase+pysybase://<username>:<password>@<dsn>/\
|
||||
[database name]
|
||||
:url: http://python-sybase.sourceforge.net/
|
||||
|
||||
Unicode Support
|
||||
---------------
|
||||
@@ -23,7 +23,7 @@ kind at this time.
|
||||
|
||||
from sqlalchemy import types as sqltypes, processors
|
||||
from sqlalchemy.dialects.sybase.base import SybaseDialect, \
|
||||
SybaseExecutionContext, SybaseSQLCompiler
|
||||
SybaseExecutionContext, SybaseSQLCompiler
|
||||
|
||||
|
||||
class _SybNumeric(sqltypes.Numeric):
|
||||
@@ -33,12 +33,13 @@ class _SybNumeric(sqltypes.Numeric):
|
||||
else:
|
||||
return sqltypes.Numeric.result_processor(self, dialect, type_)
|
||||
|
||||
|
||||
class SybaseExecutionContext_pysybase(SybaseExecutionContext):
|
||||
|
||||
def set_ddl_autocommit(self, dbapi_connection, value):
|
||||
if value:
|
||||
# call commit() on the Sybase connection directly,
|
||||
# to avoid any side effects of calling a Connection
|
||||
# to avoid any side effects of calling a Connection
|
||||
# transactional method inside of pre_exec()
|
||||
dbapi_connection.commit()
|
||||
|
||||
@@ -52,17 +53,18 @@ class SybaseExecutionContext_pysybase(SybaseExecutionContext):
|
||||
|
||||
|
||||
class SybaseSQLCompiler_pysybase(SybaseSQLCompiler):
|
||||
def bindparam_string(self, name):
|
||||
def bindparam_string(self, name, **kw):
|
||||
return "@" + name
|
||||
|
||||
|
||||
|
||||
class SybaseDialect_pysybase(SybaseDialect):
|
||||
driver = 'pysybase'
|
||||
execution_ctx_cls = SybaseExecutionContext_pysybase
|
||||
statement_compiler = SybaseSQLCompiler_pysybase
|
||||
|
||||
colspecs={
|
||||
sqltypes.Numeric:_SybNumeric,
|
||||
sqltypes.Float:sqltypes.Float
|
||||
colspecs = {
|
||||
sqltypes.Numeric: _SybNumeric,
|
||||
sqltypes.Float: sqltypes.Float
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -82,12 +84,14 @@ class SybaseDialect_pysybase(SybaseDialect):
|
||||
cursor.execute(statement, param)
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
vers = connection.scalar("select @@version_number")
|
||||
# i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0), (12, 5, 0, 0)
|
||||
return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10)
|
||||
vers = connection.scalar("select @@version_number")
|
||||
# i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0),
|
||||
# (12, 5, 0, 0)
|
||||
return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10)
|
||||
|
||||
def is_disconnect(self, e):
|
||||
if isinstance(e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)):
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, (self.dbapi.OperationalError,
|
||||
self.dbapi.ProgrammingError)):
|
||||
msg = str(e)
|
||||
return ('Unable to complete network request to host' in msg or
|
||||
'Invalid connection state' in msg or
|
||||
|
||||
1286
sqlalchemy/engine/interfaces.py
Normal file
1286
sqlalchemy/engine/interfaces.py
Normal file
File diff suppressed because it is too large
Load Diff
1435
sqlalchemy/engine/result.py
Normal file
1435
sqlalchemy/engine/result.py
Normal file
File diff suppressed because it is too large
Load Diff
74
sqlalchemy/engine/util.py
Normal file
74
sqlalchemy/engine/util.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# engine/util.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .. import util
|
||||
|
||||
|
||||
def connection_memoize(key):
|
||||
"""Decorator, memoize a function in a connection.info stash.
|
||||
|
||||
Only applicable to functions which take no arguments other than a
|
||||
connection. The memo will be stored in ``connection.info[key]``.
|
||||
"""
|
||||
|
||||
@util.decorator
|
||||
def decorated(fn, self, connection):
|
||||
connection = connection.connect()
|
||||
try:
|
||||
return connection.info[key]
|
||||
except KeyError:
|
||||
connection.info[key] = val = fn(self, connection)
|
||||
return val
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def py_fallback():
|
||||
def _distill_params(multiparams, params):
|
||||
"""Given arguments from the calling form *multiparams, **params,
|
||||
return a list of bind parameter structures, usually a list of
|
||||
dictionaries.
|
||||
|
||||
In the case of 'raw' execution which accepts positional parameters,
|
||||
it may be a list of tuples or lists.
|
||||
|
||||
"""
|
||||
|
||||
if not multiparams:
|
||||
if params:
|
||||
return [params]
|
||||
else:
|
||||
return []
|
||||
elif len(multiparams) == 1:
|
||||
zero = multiparams[0]
|
||||
if isinstance(zero, (list, tuple)):
|
||||
if not zero or hasattr(zero[0], '__iter__') and \
|
||||
not hasattr(zero[0], 'strip'):
|
||||
# execute(stmt, [{}, {}, {}, ...])
|
||||
# execute(stmt, [(), (), (), ...])
|
||||
return zero
|
||||
else:
|
||||
# execute(stmt, ("value", "value"))
|
||||
return [zero]
|
||||
elif hasattr(zero, 'keys'):
|
||||
# execute(stmt, {"key":"value"})
|
||||
return [zero]
|
||||
else:
|
||||
# execute(stmt, "value")
|
||||
return [[zero]]
|
||||
else:
|
||||
if hasattr(multiparams[0], '__iter__') and \
|
||||
not hasattr(multiparams[0], 'strip'):
|
||||
return multiparams
|
||||
else:
|
||||
return [multiparams]
|
||||
|
||||
return locals()
|
||||
try:
|
||||
from sqlalchemy.cutils import _distill_params
|
||||
except ImportError:
|
||||
globals().update(py_fallback())
|
||||
11
sqlalchemy/event/__init__.py
Normal file
11
sqlalchemy/event/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# event/__init__.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .api import CANCEL, NO_RETVAL, listen, listens_for, remove, contains
|
||||
from .base import Events, dispatcher
|
||||
from .attr import RefCollection
|
||||
from .legacy import _legacy_signature
|
||||
188
sqlalchemy/event/api.py
Normal file
188
sqlalchemy/event/api.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# event/api.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Public API functions for the event system.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .. import util, exc
|
||||
from .base import _registrars
|
||||
from .registry import _EventKey
|
||||
|
||||
CANCEL = util.symbol('CANCEL')
|
||||
NO_RETVAL = util.symbol('NO_RETVAL')
|
||||
|
||||
|
||||
def _event_key(target, identifier, fn):
|
||||
for evt_cls in _registrars[identifier]:
|
||||
tgt = evt_cls._accept_with(target)
|
||||
if tgt is not None:
|
||||
return _EventKey(target, identifier, fn, tgt)
|
||||
else:
|
||||
raise exc.InvalidRequestError("No such event '%s' for target '%s'" %
|
||||
(identifier, target))
|
||||
|
||||
|
||||
def listen(target, identifier, fn, *args, **kw):
|
||||
"""Register a listener function for the given target.
|
||||
|
||||
e.g.::
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
|
||||
def unique_constraint_name(const, table):
|
||||
const.name = "uq_%s_%s" % (
|
||||
table.name,
|
||||
list(const.columns)[0].name
|
||||
)
|
||||
event.listen(
|
||||
UniqueConstraint,
|
||||
"after_parent_attach",
|
||||
unique_constraint_name)
|
||||
|
||||
|
||||
A given function can also be invoked for only the first invocation
|
||||
of the event using the ``once`` argument::
|
||||
|
||||
def on_config():
|
||||
do_config()
|
||||
|
||||
event.listen(Mapper, "before_configure", on_config, once=True)
|
||||
|
||||
.. versionadded:: 0.9.4 Added ``once=True`` to :func:`.event.listen`
|
||||
and :func:`.event.listens_for`.
|
||||
|
||||
.. note::
|
||||
|
||||
The :func:`.listen` function cannot be called at the same time
|
||||
that the target event is being run. This has implications
|
||||
for thread safety, and also means an event cannot be added
|
||||
from inside the listener function for itself. The list of
|
||||
events to be run are present inside of a mutable collection
|
||||
that can't be changed during iteration.
|
||||
|
||||
Event registration and removal is not intended to be a "high
|
||||
velocity" operation; it is a configurational operation. For
|
||||
systems that need to quickly associate and deassociate with
|
||||
events at high scale, use a mutable structure that is handled
|
||||
from inside of a single listener.
|
||||
|
||||
.. versionchanged:: 1.0.0 - a ``collections.deque()`` object is now
|
||||
used as the container for the list of events, which explicitly
|
||||
disallows collection mutation while the collection is being
|
||||
iterated.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`.listens_for`
|
||||
|
||||
:func:`.remove`
|
||||
|
||||
"""
|
||||
|
||||
_event_key(target, identifier, fn).listen(*args, **kw)
|
||||
|
||||
|
||||
def listens_for(target, identifier, *args, **kw):
|
||||
"""Decorate a function as a listener for the given target + identifier.
|
||||
|
||||
e.g.::
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
|
||||
@event.listens_for(UniqueConstraint, "after_parent_attach")
|
||||
def unique_constraint_name(const, table):
|
||||
const.name = "uq_%s_%s" % (
|
||||
table.name,
|
||||
list(const.columns)[0].name
|
||||
)
|
||||
|
||||
A given function can also be invoked for only the first invocation
|
||||
of the event using the ``once`` argument::
|
||||
|
||||
@event.listens_for(Mapper, "before_configure", once=True)
|
||||
def on_config():
|
||||
do_config()
|
||||
|
||||
|
||||
.. versionadded:: 0.9.4 Added ``once=True`` to :func:`.event.listen`
|
||||
and :func:`.event.listens_for`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`.listen` - general description of event listening
|
||||
|
||||
"""
|
||||
def decorate(fn):
|
||||
listen(target, identifier, fn, *args, **kw)
|
||||
return fn
|
||||
return decorate
|
||||
|
||||
|
||||
def remove(target, identifier, fn):
|
||||
"""Remove an event listener.
|
||||
|
||||
The arguments here should match exactly those which were sent to
|
||||
:func:`.listen`; all the event registration which proceeded as a result
|
||||
of this call will be reverted by calling :func:`.remove` with the same
|
||||
arguments.
|
||||
|
||||
e.g.::
|
||||
|
||||
# if a function was registered like this...
|
||||
@event.listens_for(SomeMappedClass, "before_insert", propagate=True)
|
||||
def my_listener_function(*arg):
|
||||
pass
|
||||
|
||||
# ... it's removed like this
|
||||
event.remove(SomeMappedClass, "before_insert", my_listener_function)
|
||||
|
||||
Above, the listener function associated with ``SomeMappedClass`` was also
|
||||
propagated to subclasses of ``SomeMappedClass``; the :func:`.remove`
|
||||
function will revert all of these operations.
|
||||
|
||||
.. versionadded:: 0.9.0
|
||||
|
||||
.. note::
|
||||
|
||||
The :func:`.remove` function cannot be called at the same time
|
||||
that the target event is being run. This has implications
|
||||
for thread safety, and also means an event cannot be removed
|
||||
from inside the listener function for itself. The list of
|
||||
events to be run are present inside of a mutable collection
|
||||
that can't be changed during iteration.
|
||||
|
||||
Event registration and removal is not intended to be a "high
|
||||
velocity" operation; it is a configurational operation. For
|
||||
systems that need to quickly associate and deassociate with
|
||||
events at high scale, use a mutable structure that is handled
|
||||
from inside of a single listener.
|
||||
|
||||
.. versionchanged:: 1.0.0 - a ``collections.deque()`` object is now
|
||||
used as the container for the list of events, which explicitly
|
||||
disallows collection mutation while the collection is being
|
||||
iterated.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`.listen`
|
||||
|
||||
"""
|
||||
_event_key(target, identifier, fn).remove()
|
||||
|
||||
|
||||
def contains(target, identifier, fn):
|
||||
"""Return True if the given target/ident/fn is set up to listen.
|
||||
|
||||
.. versionadded:: 0.9.0
|
||||
|
||||
"""
|
||||
|
||||
return _event_key(target, identifier, fn).contains()
|
||||
373
sqlalchemy/event/attr.py
Normal file
373
sqlalchemy/event/attr.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# event/attr.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Attribute implementation for _Dispatch classes.
|
||||
|
||||
The various listener targets for a particular event class are represented
|
||||
as attributes, which refer to collections of listeners to be fired off.
|
||||
These collections can exist at the class level as well as at the instance
|
||||
level. An event is fired off using code like this::
|
||||
|
||||
some_object.dispatch.first_connect(arg1, arg2)
|
||||
|
||||
Above, ``some_object.dispatch`` would be an instance of ``_Dispatch`` and
|
||||
``first_connect`` is typically an instance of ``_ListenerCollection``
|
||||
if event listeners are present, or ``_EmptyListener`` if none are present.
|
||||
|
||||
The attribute mechanics here spend effort trying to ensure listener functions
|
||||
are available with a minimum of function call overhead, that unnecessary
|
||||
objects aren't created (i.e. many empty per-instance listener collections),
|
||||
as well as that everything is garbage collectable when owning references are
|
||||
lost. Other features such as "propagation" of listener functions across
|
||||
many ``_Dispatch`` instances, "joining" of multiple ``_Dispatch`` instances,
|
||||
as well as support for subclass propagation (e.g. events assigned to
|
||||
``Pool`` vs. ``QueuePool``) are all implemented here.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, with_statement
|
||||
|
||||
from .. import util
|
||||
from ..util import threading
|
||||
from . import registry
|
||||
from . import legacy
|
||||
from itertools import chain
|
||||
import weakref
|
||||
import collections
|
||||
|
||||
|
||||
class RefCollection(util.MemoizedSlots):
|
||||
__slots__ = 'ref',
|
||||
|
||||
def _memoized_attr_ref(self):
|
||||
return weakref.ref(self, registry._collection_gced)
|
||||
|
||||
|
||||
class _ClsLevelDispatch(RefCollection):
|
||||
"""Class-level events on :class:`._Dispatch` classes."""
|
||||
|
||||
__slots__ = ('name', 'arg_names', 'has_kw',
|
||||
'legacy_signatures', '_clslevel', '__weakref__')
|
||||
|
||||
def __init__(self, parent_dispatch_cls, fn):
|
||||
self.name = fn.__name__
|
||||
argspec = util.inspect_getargspec(fn)
|
||||
self.arg_names = argspec.args[1:]
|
||||
self.has_kw = bool(argspec.keywords)
|
||||
self.legacy_signatures = list(reversed(
|
||||
sorted(
|
||||
getattr(fn, '_legacy_signatures', []),
|
||||
key=lambda s: s[0]
|
||||
)
|
||||
))
|
||||
fn.__doc__ = legacy._augment_fn_docs(self, parent_dispatch_cls, fn)
|
||||
|
||||
self._clslevel = weakref.WeakKeyDictionary()
|
||||
|
||||
def _adjust_fn_spec(self, fn, named):
|
||||
if named:
|
||||
fn = self._wrap_fn_for_kw(fn)
|
||||
if self.legacy_signatures:
|
||||
try:
|
||||
argspec = util.get_callable_argspec(fn, no_self=True)
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
fn = legacy._wrap_fn_for_legacy(self, fn, argspec)
|
||||
return fn
|
||||
|
||||
def _wrap_fn_for_kw(self, fn):
|
||||
def wrap_kw(*args, **kw):
|
||||
argdict = dict(zip(self.arg_names, args))
|
||||
argdict.update(kw)
|
||||
return fn(**argdict)
|
||||
return wrap_kw
|
||||
|
||||
def insert(self, event_key, propagate):
|
||||
target = event_key.dispatch_target
|
||||
assert isinstance(target, type), \
|
||||
"Class-level Event targets must be classes."
|
||||
stack = [target]
|
||||
while stack:
|
||||
cls = stack.pop(0)
|
||||
stack.extend(cls.__subclasses__())
|
||||
if cls is not target and cls not in self._clslevel:
|
||||
self.update_subclass(cls)
|
||||
else:
|
||||
if cls not in self._clslevel:
|
||||
self._clslevel[cls] = collections.deque()
|
||||
self._clslevel[cls].appendleft(event_key._listen_fn)
|
||||
registry._stored_in_collection(event_key, self)
|
||||
|
||||
def append(self, event_key, propagate):
|
||||
target = event_key.dispatch_target
|
||||
assert isinstance(target, type), \
|
||||
"Class-level Event targets must be classes."
|
||||
|
||||
stack = [target]
|
||||
while stack:
|
||||
cls = stack.pop(0)
|
||||
stack.extend(cls.__subclasses__())
|
||||
if cls is not target and cls not in self._clslevel:
|
||||
self.update_subclass(cls)
|
||||
else:
|
||||
if cls not in self._clslevel:
|
||||
self._clslevel[cls] = collections.deque()
|
||||
self._clslevel[cls].append(event_key._listen_fn)
|
||||
registry._stored_in_collection(event_key, self)
|
||||
|
||||
def update_subclass(self, target):
|
||||
if target not in self._clslevel:
|
||||
self._clslevel[target] = collections.deque()
|
||||
clslevel = self._clslevel[target]
|
||||
for cls in target.__mro__[1:]:
|
||||
if cls in self._clslevel:
|
||||
clslevel.extend([
|
||||
fn for fn
|
||||
in self._clslevel[cls]
|
||||
if fn not in clslevel
|
||||
])
|
||||
|
||||
def remove(self, event_key):
|
||||
target = event_key.dispatch_target
|
||||
stack = [target]
|
||||
while stack:
|
||||
cls = stack.pop(0)
|
||||
stack.extend(cls.__subclasses__())
|
||||
if cls in self._clslevel:
|
||||
self._clslevel[cls].remove(event_key._listen_fn)
|
||||
registry._removed_from_collection(event_key, self)
|
||||
|
||||
def clear(self):
|
||||
"""Clear all class level listeners"""
|
||||
|
||||
to_clear = set()
|
||||
for dispatcher in self._clslevel.values():
|
||||
to_clear.update(dispatcher)
|
||||
dispatcher.clear()
|
||||
registry._clear(self, to_clear)
|
||||
|
||||
def for_modify(self, obj):
|
||||
"""Return an event collection which can be modified.
|
||||
|
||||
For _ClsLevelDispatch at the class level of
|
||||
a dispatcher, this returns self.
|
||||
|
||||
"""
|
||||
return self
|
||||
|
||||
|
||||
class _InstanceLevelDispatch(RefCollection):
|
||||
__slots__ = ()
|
||||
|
||||
def _adjust_fn_spec(self, fn, named):
|
||||
return self.parent._adjust_fn_spec(fn, named)
|
||||
|
||||
|
||||
class _EmptyListener(_InstanceLevelDispatch):
|
||||
"""Serves as a proxy interface to the events
|
||||
served by a _ClsLevelDispatch, when there are no
|
||||
instance-level events present.
|
||||
|
||||
Is replaced by _ListenerCollection when instance-level
|
||||
events are added.
|
||||
|
||||
"""
|
||||
|
||||
propagate = frozenset()
|
||||
listeners = ()
|
||||
|
||||
__slots__ = 'parent', 'parent_listeners', 'name'
|
||||
|
||||
def __init__(self, parent, target_cls):
|
||||
if target_cls not in parent._clslevel:
|
||||
parent.update_subclass(target_cls)
|
||||
self.parent = parent # _ClsLevelDispatch
|
||||
self.parent_listeners = parent._clslevel[target_cls]
|
||||
self.name = parent.name
|
||||
|
||||
def for_modify(self, obj):
|
||||
"""Return an event collection which can be modified.
|
||||
|
||||
For _EmptyListener at the instance level of
|
||||
a dispatcher, this generates a new
|
||||
_ListenerCollection, applies it to the instance,
|
||||
and returns it.
|
||||
|
||||
"""
|
||||
result = _ListenerCollection(self.parent, obj._instance_cls)
|
||||
if getattr(obj, self.name) is self:
|
||||
setattr(obj, self.name, result)
|
||||
else:
|
||||
assert isinstance(getattr(obj, self.name), _JoinedListener)
|
||||
return result
|
||||
|
||||
def _needs_modify(self, *args, **kw):
|
||||
raise NotImplementedError("need to call for_modify()")
|
||||
|
||||
exec_once = insert = append = remove = clear = _needs_modify
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
"""Execute this event."""
|
||||
|
||||
for fn in self.parent_listeners:
|
||||
fn(*args, **kw)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.parent_listeners)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.parent_listeners)
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.parent_listeners)
|
||||
|
||||
__nonzero__ = __bool__
|
||||
|
||||
|
||||
class _CompoundListener(_InstanceLevelDispatch):
|
||||
__slots__ = '_exec_once_mutex', '_exec_once'
|
||||
|
||||
def _memoized_attr__exec_once_mutex(self):
|
||||
return threading.Lock()
|
||||
|
||||
def exec_once(self, *args, **kw):
|
||||
"""Execute this event, but only if it has not been
|
||||
executed already for this collection."""
|
||||
|
||||
if not self._exec_once:
|
||||
with self._exec_once_mutex:
|
||||
if not self._exec_once:
|
||||
try:
|
||||
self(*args, **kw)
|
||||
finally:
|
||||
self._exec_once = True
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
"""Execute this event."""
|
||||
|
||||
for fn in self.parent_listeners:
|
||||
fn(*args, **kw)
|
||||
for fn in self.listeners:
|
||||
fn(*args, **kw)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.parent_listeners) + len(self.listeners)
|
||||
|
||||
def __iter__(self):
|
||||
return chain(self.parent_listeners, self.listeners)
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.listeners or self.parent_listeners)
|
||||
|
||||
__nonzero__ = __bool__
|
||||
|
||||
|
||||
class _ListenerCollection(_CompoundListener):
|
||||
"""Instance-level attributes on instances of :class:`._Dispatch`.
|
||||
|
||||
Represents a collection of listeners.
|
||||
|
||||
As of 0.7.9, _ListenerCollection is only first
|
||||
created via the _EmptyListener.for_modify() method.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'parent_listeners', 'parent', 'name', 'listeners',
|
||||
'propagate', '__weakref__')
|
||||
|
||||
def __init__(self, parent, target_cls):
|
||||
if target_cls not in parent._clslevel:
|
||||
parent.update_subclass(target_cls)
|
||||
self._exec_once = False
|
||||
self.parent_listeners = parent._clslevel[target_cls]
|
||||
self.parent = parent
|
||||
self.name = parent.name
|
||||
self.listeners = collections.deque()
|
||||
self.propagate = set()
|
||||
|
||||
def for_modify(self, obj):
|
||||
"""Return an event collection which can be modified.
|
||||
|
||||
For _ListenerCollection at the instance level of
|
||||
a dispatcher, this returns self.
|
||||
|
||||
"""
|
||||
return self
|
||||
|
||||
def _update(self, other, only_propagate=True):
|
||||
"""Populate from the listeners in another :class:`_Dispatch`
|
||||
object."""
|
||||
|
||||
existing_listeners = self.listeners
|
||||
existing_listener_set = set(existing_listeners)
|
||||
self.propagate.update(other.propagate)
|
||||
other_listeners = [l for l
|
||||
in other.listeners
|
||||
if l not in existing_listener_set
|
||||
and not only_propagate or l in self.propagate
|
||||
]
|
||||
|
||||
existing_listeners.extend(other_listeners)
|
||||
|
||||
to_associate = other.propagate.union(other_listeners)
|
||||
registry._stored_in_collection_multi(self, other, to_associate)
|
||||
|
||||
def insert(self, event_key, propagate):
|
||||
if event_key.prepend_to_list(self, self.listeners):
|
||||
if propagate:
|
||||
self.propagate.add(event_key._listen_fn)
|
||||
|
||||
def append(self, event_key, propagate):
|
||||
if event_key.append_to_list(self, self.listeners):
|
||||
if propagate:
|
||||
self.propagate.add(event_key._listen_fn)
|
||||
|
||||
def remove(self, event_key):
|
||||
self.listeners.remove(event_key._listen_fn)
|
||||
self.propagate.discard(event_key._listen_fn)
|
||||
registry._removed_from_collection(event_key, self)
|
||||
|
||||
def clear(self):
|
||||
registry._clear(self, self.listeners)
|
||||
self.propagate.clear()
|
||||
self.listeners.clear()
|
||||
|
||||
|
||||
class _JoinedListener(_CompoundListener):
|
||||
__slots__ = 'parent', 'name', 'local', 'parent_listeners'
|
||||
|
||||
def __init__(self, parent, name, local):
|
||||
self._exec_once = False
|
||||
self.parent = parent
|
||||
self.name = name
|
||||
self.local = local
|
||||
self.parent_listeners = self.local
|
||||
|
||||
@property
|
||||
def listeners(self):
|
||||
return getattr(self.parent, self.name)
|
||||
|
||||
def _adjust_fn_spec(self, fn, named):
|
||||
return self.local._adjust_fn_spec(fn, named)
|
||||
|
||||
def for_modify(self, obj):
|
||||
self.local = self.parent_listeners = self.local.for_modify(obj)
|
||||
return self
|
||||
|
||||
def insert(self, event_key, propagate):
|
||||
self.local.insert(event_key, propagate)
|
||||
|
||||
def append(self, event_key, propagate):
|
||||
self.local.append(event_key, propagate)
|
||||
|
||||
def remove(self, event_key):
|
||||
self.local.remove(event_key)
|
||||
|
||||
def clear(self):
|
||||
raise NotImplementedError()
|
||||
289
sqlalchemy/event/base.py
Normal file
289
sqlalchemy/event/base.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# event/base.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Base implementation classes.
|
||||
|
||||
The public-facing ``Events`` serves as the base class for an event interface;
|
||||
its public attributes represent different kinds of events. These attributes
|
||||
are mirrored onto a ``_Dispatch`` class, which serves as a container for
|
||||
collections of listener functions. These collections are represented both
|
||||
at the class level of a particular ``_Dispatch`` class as well as within
|
||||
instances of ``_Dispatch``.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import weakref
|
||||
|
||||
from .. import util
|
||||
from .attr import _JoinedListener, \
|
||||
_EmptyListener, _ClsLevelDispatch
|
||||
|
||||
_registrars = util.defaultdict(list)
|
||||
|
||||
|
||||
def _is_event_name(name):
|
||||
return not name.startswith('_') and name != 'dispatch'
|
||||
|
||||
|
||||
class _UnpickleDispatch(object):
|
||||
"""Serializable callable that re-generates an instance of
|
||||
:class:`_Dispatch` given a particular :class:`.Events` subclass.
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, _instance_cls):
|
||||
for cls in _instance_cls.__mro__:
|
||||
if 'dispatch' in cls.__dict__:
|
||||
return cls.__dict__['dispatch'].\
|
||||
dispatch_cls._for_class(_instance_cls)
|
||||
else:
|
||||
raise AttributeError("No class with a 'dispatch' member present.")
|
||||
|
||||
|
||||
class _Dispatch(object):
|
||||
"""Mirror the event listening definitions of an Events class with
|
||||
listener collections.
|
||||
|
||||
Classes which define a "dispatch" member will return a
|
||||
non-instantiated :class:`._Dispatch` subclass when the member
|
||||
is accessed at the class level. When the "dispatch" member is
|
||||
accessed at the instance level of its owner, an instance
|
||||
of the :class:`._Dispatch` class is returned.
|
||||
|
||||
A :class:`._Dispatch` class is generated for each :class:`.Events`
|
||||
class defined, by the :func:`._create_dispatcher_class` function.
|
||||
The original :class:`.Events` classes remain untouched.
|
||||
This decouples the construction of :class:`.Events` subclasses from
|
||||
the implementation used by the event internals, and allows
|
||||
inspecting tools like Sphinx to work in an unsurprising
|
||||
way against the public API.
|
||||
|
||||
"""
|
||||
|
||||
# in one ORM edge case, an attribute is added to _Dispatch,
|
||||
# so __dict__ is used in just that case and potentially others.
|
||||
__slots__ = '_parent', '_instance_cls', '__dict__', '_empty_listeners'
|
||||
|
||||
_empty_listener_reg = weakref.WeakKeyDictionary()
|
||||
|
||||
def __init__(self, parent, instance_cls=None):
|
||||
self._parent = parent
|
||||
self._instance_cls = instance_cls
|
||||
if instance_cls:
|
||||
try:
|
||||
self._empty_listeners = self._empty_listener_reg[instance_cls]
|
||||
except KeyError:
|
||||
self._empty_listeners = \
|
||||
self._empty_listener_reg[instance_cls] = dict(
|
||||
(ls.name, _EmptyListener(ls, instance_cls))
|
||||
for ls in parent._event_descriptors
|
||||
)
|
||||
else:
|
||||
self._empty_listeners = {}
|
||||
|
||||
def __getattr__(self, name):
|
||||
# assign EmptyListeners as attributes on demand
|
||||
# to reduce startup time for new dispatch objects
|
||||
try:
|
||||
ls = self._empty_listeners[name]
|
||||
except KeyError:
|
||||
raise AttributeError(name)
|
||||
else:
|
||||
setattr(self, ls.name, ls)
|
||||
return ls
|
||||
|
||||
@property
|
||||
def _event_descriptors(self):
|
||||
for k in self._event_names:
|
||||
yield getattr(self, k)
|
||||
|
||||
def _for_class(self, instance_cls):
|
||||
return self.__class__(self, instance_cls)
|
||||
|
||||
def _for_instance(self, instance):
|
||||
instance_cls = instance.__class__
|
||||
return self._for_class(instance_cls)
|
||||
|
||||
@property
|
||||
def _listen(self):
|
||||
return self._events._listen
|
||||
|
||||
def _join(self, other):
|
||||
"""Create a 'join' of this :class:`._Dispatch` and another.
|
||||
|
||||
This new dispatcher will dispatch events to both
|
||||
:class:`._Dispatch` objects.
|
||||
|
||||
"""
|
||||
if '_joined_dispatch_cls' not in self.__class__.__dict__:
|
||||
cls = type(
|
||||
"Joined%s" % self.__class__.__name__,
|
||||
(_JoinedDispatcher, ), {'__slots__': self._event_names}
|
||||
)
|
||||
|
||||
self.__class__._joined_dispatch_cls = cls
|
||||
return self._joined_dispatch_cls(self, other)
|
||||
|
||||
def __reduce__(self):
|
||||
return _UnpickleDispatch(), (self._instance_cls, )
|
||||
|
||||
def _update(self, other, only_propagate=True):
|
||||
"""Populate from the listeners in another :class:`_Dispatch`
|
||||
object."""
|
||||
for ls in other._event_descriptors:
|
||||
if isinstance(ls, _EmptyListener):
|
||||
continue
|
||||
getattr(self, ls.name).\
|
||||
for_modify(self)._update(ls, only_propagate=only_propagate)
|
||||
|
||||
def _clear(self):
|
||||
for ls in self._event_descriptors:
|
||||
ls.for_modify(self).clear()
|
||||
|
||||
|
||||
class _EventMeta(type):
|
||||
"""Intercept new Event subclasses and create
|
||||
associated _Dispatch classes."""
|
||||
|
||||
def __init__(cls, classname, bases, dict_):
|
||||
_create_dispatcher_class(cls, classname, bases, dict_)
|
||||
return type.__init__(cls, classname, bases, dict_)
|
||||
|
||||
|
||||
def _create_dispatcher_class(cls, classname, bases, dict_):
|
||||
"""Create a :class:`._Dispatch` class corresponding to an
|
||||
:class:`.Events` class."""
|
||||
|
||||
# there's all kinds of ways to do this,
|
||||
# i.e. make a Dispatch class that shares the '_listen' method
|
||||
# of the Event class, this is the straight monkeypatch.
|
||||
if hasattr(cls, 'dispatch'):
|
||||
dispatch_base = cls.dispatch.__class__
|
||||
else:
|
||||
dispatch_base = _Dispatch
|
||||
|
||||
event_names = [k for k in dict_ if _is_event_name(k)]
|
||||
dispatch_cls = type("%sDispatch" % classname,
|
||||
(dispatch_base, ), {'__slots__': event_names})
|
||||
|
||||
dispatch_cls._event_names = event_names
|
||||
|
||||
dispatch_inst = cls._set_dispatch(cls, dispatch_cls)
|
||||
for k in dispatch_cls._event_names:
|
||||
setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k]))
|
||||
_registrars[k].append(cls)
|
||||
|
||||
for super_ in dispatch_cls.__bases__:
|
||||
if issubclass(super_, _Dispatch) and super_ is not _Dispatch:
|
||||
for ls in super_._events.dispatch._event_descriptors:
|
||||
setattr(dispatch_inst, ls.name, ls)
|
||||
dispatch_cls._event_names.append(ls.name)
|
||||
|
||||
if getattr(cls, '_dispatch_target', None):
|
||||
cls._dispatch_target.dispatch = dispatcher(cls)
|
||||
|
||||
|
||||
def _remove_dispatcher(cls):
|
||||
for k in cls.dispatch._event_names:
|
||||
_registrars[k].remove(cls)
|
||||
if not _registrars[k]:
|
||||
del _registrars[k]
|
||||
|
||||
|
||||
class Events(util.with_metaclass(_EventMeta, object)):
|
||||
"""Define event listening functions for a particular target type."""
|
||||
|
||||
@staticmethod
|
||||
def _set_dispatch(cls, dispatch_cls):
|
||||
# this allows an Events subclass to define additional utility
|
||||
# methods made available to the target via
|
||||
# "self.dispatch._events.<utilitymethod>"
|
||||
# @staticemethod to allow easy "super" calls while in a metaclass
|
||||
# constructor.
|
||||
cls.dispatch = dispatch_cls(None)
|
||||
dispatch_cls._events = cls
|
||||
return cls.dispatch
|
||||
|
||||
@classmethod
|
||||
def _accept_with(cls, target):
|
||||
# Mapper, ClassManager, Session override this to
|
||||
# also accept classes, scoped_sessions, sessionmakers, etc.
|
||||
if hasattr(target, 'dispatch') and (
|
||||
|
||||
isinstance(target.dispatch, cls.dispatch.__class__) or
|
||||
|
||||
|
||||
(
|
||||
isinstance(target.dispatch, type) and
|
||||
isinstance(target.dispatch, cls.dispatch.__class__)
|
||||
) or
|
||||
|
||||
(
|
||||
isinstance(target.dispatch, _JoinedDispatcher) and
|
||||
isinstance(target.dispatch.parent, cls.dispatch.__class__)
|
||||
)
|
||||
|
||||
|
||||
):
|
||||
return target
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _listen(cls, event_key, propagate=False, insert=False, named=False):
|
||||
event_key.base_listen(propagate=propagate, insert=insert, named=named)
|
||||
|
||||
@classmethod
|
||||
def _remove(cls, event_key):
|
||||
event_key.remove()
|
||||
|
||||
@classmethod
|
||||
def _clear(cls):
|
||||
cls.dispatch._clear()
|
||||
|
||||
|
||||
class _JoinedDispatcher(object):
|
||||
"""Represent a connection between two _Dispatch objects."""
|
||||
|
||||
__slots__ = 'local', 'parent', '_instance_cls'
|
||||
|
||||
def __init__(self, local, parent):
|
||||
self.local = local
|
||||
self.parent = parent
|
||||
self._instance_cls = self.local._instance_cls
|
||||
|
||||
def __getattr__(self, name):
|
||||
# assign _JoinedListeners as attributes on demand
|
||||
# to reduce startup time for new dispatch objects
|
||||
ls = getattr(self.local, name)
|
||||
jl = _JoinedListener(self.parent, ls.name, ls)
|
||||
setattr(self, ls.name, jl)
|
||||
return jl
|
||||
|
||||
@property
|
||||
def _listen(self):
|
||||
return self.parent._listen
|
||||
|
||||
|
||||
class dispatcher(object):
|
||||
"""Descriptor used by target classes to
|
||||
deliver the _Dispatch class at the class level
|
||||
and produce new _Dispatch instances for target
|
||||
instances.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, events):
|
||||
self.dispatch_cls = events.dispatch
|
||||
self.events = events
|
||||
|
||||
def __get__(self, obj, cls):
|
||||
if obj is None:
|
||||
return self.dispatch_cls
|
||||
obj.__dict__['dispatch'] = disp = self.dispatch_cls._for_instance(obj)
|
||||
return disp
|
||||
169
sqlalchemy/event/legacy.py
Normal file
169
sqlalchemy/event/legacy.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# event/legacy.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Routines to handle adaption of legacy call signatures,
|
||||
generation of deprecation notes and docstrings.
|
||||
|
||||
"""
|
||||
|
||||
from .. import util
|
||||
|
||||
|
||||
def _legacy_signature(since, argnames, converter=None):
|
||||
def leg(fn):
|
||||
if not hasattr(fn, '_legacy_signatures'):
|
||||
fn._legacy_signatures = []
|
||||
fn._legacy_signatures.append((since, argnames, converter))
|
||||
return fn
|
||||
return leg
|
||||
|
||||
|
||||
def _wrap_fn_for_legacy(dispatch_collection, fn, argspec):
|
||||
for since, argnames, conv in dispatch_collection.legacy_signatures:
|
||||
if argnames[-1] == "**kw":
|
||||
has_kw = True
|
||||
argnames = argnames[0:-1]
|
||||
else:
|
||||
has_kw = False
|
||||
|
||||
if len(argnames) == len(argspec.args) \
|
||||
and has_kw is bool(argspec.keywords):
|
||||
|
||||
if conv:
|
||||
assert not has_kw
|
||||
|
||||
def wrap_leg(*args):
|
||||
return fn(*conv(*args))
|
||||
else:
|
||||
def wrap_leg(*args, **kw):
|
||||
argdict = dict(zip(dispatch_collection.arg_names, args))
|
||||
args = [argdict[name] for name in argnames]
|
||||
if has_kw:
|
||||
return fn(*args, **kw)
|
||||
else:
|
||||
return fn(*args)
|
||||
return wrap_leg
|
||||
else:
|
||||
return fn
|
||||
|
||||
|
||||
def _indent(text, indent):
|
||||
return "\n".join(
|
||||
indent + line
|
||||
for line in text.split("\n")
|
||||
)
|
||||
|
||||
|
||||
def _standard_listen_example(dispatch_collection, sample_target, fn):
|
||||
example_kw_arg = _indent(
|
||||
"\n".join(
|
||||
"%(arg)s = kw['%(arg)s']" % {"arg": arg}
|
||||
for arg in dispatch_collection.arg_names[0:2]
|
||||
),
|
||||
" ")
|
||||
if dispatch_collection.legacy_signatures:
|
||||
current_since = max(since for since, args, conv
|
||||
in dispatch_collection.legacy_signatures)
|
||||
else:
|
||||
current_since = None
|
||||
text = (
|
||||
"from sqlalchemy import event\n\n"
|
||||
"# standard decorator style%(current_since)s\n"
|
||||
"@event.listens_for(%(sample_target)s, '%(event_name)s')\n"
|
||||
"def receive_%(event_name)s("
|
||||
"%(named_event_arguments)s%(has_kw_arguments)s):\n"
|
||||
" \"listen for the '%(event_name)s' event\"\n"
|
||||
"\n # ... (event handling logic) ...\n"
|
||||
)
|
||||
|
||||
if len(dispatch_collection.arg_names) > 3:
|
||||
text += (
|
||||
|
||||
"\n# named argument style (new in 0.9)\n"
|
||||
"@event.listens_for("
|
||||
"%(sample_target)s, '%(event_name)s', named=True)\n"
|
||||
"def receive_%(event_name)s(**kw):\n"
|
||||
" \"listen for the '%(event_name)s' event\"\n"
|
||||
"%(example_kw_arg)s\n"
|
||||
"\n # ... (event handling logic) ...\n"
|
||||
)
|
||||
|
||||
text %= {
|
||||
"current_since": " (arguments as of %s)" %
|
||||
current_since if current_since else "",
|
||||
"event_name": fn.__name__,
|
||||
"has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "",
|
||||
"named_event_arguments": ", ".join(dispatch_collection.arg_names),
|
||||
"example_kw_arg": example_kw_arg,
|
||||
"sample_target": sample_target
|
||||
}
|
||||
return text
|
||||
|
||||
|
||||
def _legacy_listen_examples(dispatch_collection, sample_target, fn):
|
||||
text = ""
|
||||
for since, args, conv in dispatch_collection.legacy_signatures:
|
||||
text += (
|
||||
"\n# legacy calling style (pre-%(since)s)\n"
|
||||
"@event.listens_for(%(sample_target)s, '%(event_name)s')\n"
|
||||
"def receive_%(event_name)s("
|
||||
"%(named_event_arguments)s%(has_kw_arguments)s):\n"
|
||||
" \"listen for the '%(event_name)s' event\"\n"
|
||||
"\n # ... (event handling logic) ...\n" % {
|
||||
"since": since,
|
||||
"event_name": fn.__name__,
|
||||
"has_kw_arguments": " **kw"
|
||||
if dispatch_collection.has_kw else "",
|
||||
"named_event_arguments": ", ".join(args),
|
||||
"sample_target": sample_target
|
||||
}
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
def _version_signature_changes(dispatch_collection):
|
||||
since, args, conv = dispatch_collection.legacy_signatures[0]
|
||||
return (
|
||||
"\n.. versionchanged:: %(since)s\n"
|
||||
" The ``%(event_name)s`` event now accepts the \n"
|
||||
" arguments ``%(named_event_arguments)s%(has_kw_arguments)s``.\n"
|
||||
" Listener functions which accept the previous argument \n"
|
||||
" signature(s) listed above will be automatically \n"
|
||||
" adapted to the new signature." % {
|
||||
"since": since,
|
||||
"event_name": dispatch_collection.name,
|
||||
"named_event_arguments": ", ".join(dispatch_collection.arg_names),
|
||||
"has_kw_arguments": ", **kw" if dispatch_collection.has_kw else ""
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _augment_fn_docs(dispatch_collection, parent_dispatch_cls, fn):
|
||||
header = ".. container:: event_signatures\n\n"\
|
||||
" Example argument forms::\n"\
|
||||
"\n"
|
||||
|
||||
sample_target = getattr(parent_dispatch_cls, "_target_class_doc", "obj")
|
||||
text = (
|
||||
header +
|
||||
_indent(
|
||||
_standard_listen_example(
|
||||
dispatch_collection, sample_target, fn),
|
||||
" " * 8)
|
||||
)
|
||||
if dispatch_collection.legacy_signatures:
|
||||
text += _indent(
|
||||
_legacy_listen_examples(
|
||||
dispatch_collection, sample_target, fn),
|
||||
" " * 8)
|
||||
|
||||
text += _version_signature_changes(dispatch_collection)
|
||||
|
||||
return util.inject_docstring_text(fn.__doc__,
|
||||
text,
|
||||
1
|
||||
)
|
||||
262
sqlalchemy/event/registry.py
Normal file
262
sqlalchemy/event/registry.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# event/registry.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Provides managed registration services on behalf of :func:`.listen`
|
||||
arguments.
|
||||
|
||||
By "managed registration", we mean that event listening functions and
|
||||
other objects can be added to various collections in such a way that their
|
||||
membership in all those collections can be revoked at once, based on
|
||||
an equivalent :class:`._EventKey`.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import weakref
|
||||
import collections
|
||||
import types
|
||||
from .. import exc, util
|
||||
|
||||
|
||||
_key_to_collection = collections.defaultdict(dict)
|
||||
"""
|
||||
Given an original listen() argument, can locate all
|
||||
listener collections and the listener fn contained
|
||||
|
||||
(target, identifier, fn) -> {
|
||||
ref(listenercollection) -> ref(listener_fn)
|
||||
ref(listenercollection) -> ref(listener_fn)
|
||||
ref(listenercollection) -> ref(listener_fn)
|
||||
}
|
||||
"""
|
||||
|
||||
_collection_to_key = collections.defaultdict(dict)
|
||||
"""
|
||||
Given a _ListenerCollection or _ClsLevelListener, can locate
|
||||
all the original listen() arguments and the listener fn contained
|
||||
|
||||
ref(listenercollection) -> {
|
||||
ref(listener_fn) -> (target, identifier, fn),
|
||||
ref(listener_fn) -> (target, identifier, fn),
|
||||
ref(listener_fn) -> (target, identifier, fn),
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _collection_gced(ref):
|
||||
# defaultdict, so can't get a KeyError
|
||||
if not _collection_to_key or ref not in _collection_to_key:
|
||||
return
|
||||
listener_to_key = _collection_to_key.pop(ref)
|
||||
for key in listener_to_key.values():
|
||||
if key in _key_to_collection:
|
||||
# defaultdict, so can't get a KeyError
|
||||
dispatch_reg = _key_to_collection[key]
|
||||
dispatch_reg.pop(ref)
|
||||
if not dispatch_reg:
|
||||
_key_to_collection.pop(key)
|
||||
|
||||
|
||||
def _stored_in_collection(event_key, owner):
|
||||
key = event_key._key
|
||||
|
||||
dispatch_reg = _key_to_collection[key]
|
||||
|
||||
owner_ref = owner.ref
|
||||
listen_ref = weakref.ref(event_key._listen_fn)
|
||||
|
||||
if owner_ref in dispatch_reg:
|
||||
return False
|
||||
|
||||
dispatch_reg[owner_ref] = listen_ref
|
||||
|
||||
listener_to_key = _collection_to_key[owner_ref]
|
||||
listener_to_key[listen_ref] = key
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _removed_from_collection(event_key, owner):
|
||||
key = event_key._key
|
||||
|
||||
dispatch_reg = _key_to_collection[key]
|
||||
|
||||
listen_ref = weakref.ref(event_key._listen_fn)
|
||||
|
||||
owner_ref = owner.ref
|
||||
dispatch_reg.pop(owner_ref, None)
|
||||
if not dispatch_reg:
|
||||
del _key_to_collection[key]
|
||||
|
||||
if owner_ref in _collection_to_key:
|
||||
listener_to_key = _collection_to_key[owner_ref]
|
||||
listener_to_key.pop(listen_ref)
|
||||
|
||||
|
||||
def _stored_in_collection_multi(newowner, oldowner, elements):
|
||||
if not elements:
|
||||
return
|
||||
|
||||
oldowner = oldowner.ref
|
||||
newowner = newowner.ref
|
||||
|
||||
old_listener_to_key = _collection_to_key[oldowner]
|
||||
new_listener_to_key = _collection_to_key[newowner]
|
||||
|
||||
for listen_fn in elements:
|
||||
listen_ref = weakref.ref(listen_fn)
|
||||
key = old_listener_to_key[listen_ref]
|
||||
dispatch_reg = _key_to_collection[key]
|
||||
if newowner in dispatch_reg:
|
||||
assert dispatch_reg[newowner] == listen_ref
|
||||
else:
|
||||
dispatch_reg[newowner] = listen_ref
|
||||
|
||||
new_listener_to_key[listen_ref] = key
|
||||
|
||||
|
||||
def _clear(owner, elements):
|
||||
if not elements:
|
||||
return
|
||||
|
||||
owner = owner.ref
|
||||
listener_to_key = _collection_to_key[owner]
|
||||
for listen_fn in elements:
|
||||
listen_ref = weakref.ref(listen_fn)
|
||||
key = listener_to_key[listen_ref]
|
||||
dispatch_reg = _key_to_collection[key]
|
||||
dispatch_reg.pop(owner, None)
|
||||
|
||||
if not dispatch_reg:
|
||||
del _key_to_collection[key]
|
||||
|
||||
|
||||
class _EventKey(object):
|
||||
"""Represent :func:`.listen` arguments.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'target', 'identifier', 'fn', 'fn_key', 'fn_wrap', 'dispatch_target'
|
||||
)
|
||||
|
||||
def __init__(self, target, identifier,
|
||||
fn, dispatch_target, _fn_wrap=None):
|
||||
self.target = target
|
||||
self.identifier = identifier
|
||||
self.fn = fn
|
||||
if isinstance(fn, types.MethodType):
|
||||
self.fn_key = id(fn.__func__), id(fn.__self__)
|
||||
else:
|
||||
self.fn_key = id(fn)
|
||||
self.fn_wrap = _fn_wrap
|
||||
self.dispatch_target = dispatch_target
|
||||
|
||||
@property
|
||||
def _key(self):
|
||||
return (id(self.target), self.identifier, self.fn_key)
|
||||
|
||||
def with_wrapper(self, fn_wrap):
|
||||
if fn_wrap is self._listen_fn:
|
||||
return self
|
||||
else:
|
||||
return _EventKey(
|
||||
self.target,
|
||||
self.identifier,
|
||||
self.fn,
|
||||
self.dispatch_target,
|
||||
_fn_wrap=fn_wrap
|
||||
)
|
||||
|
||||
def with_dispatch_target(self, dispatch_target):
|
||||
if dispatch_target is self.dispatch_target:
|
||||
return self
|
||||
else:
|
||||
return _EventKey(
|
||||
self.target,
|
||||
self.identifier,
|
||||
self.fn,
|
||||
dispatch_target,
|
||||
_fn_wrap=self.fn_wrap
|
||||
)
|
||||
|
||||
def listen(self, *args, **kw):
|
||||
once = kw.pop("once", False)
|
||||
named = kw.pop("named", False)
|
||||
|
||||
target, identifier, fn = \
|
||||
self.dispatch_target, self.identifier, self._listen_fn
|
||||
|
||||
dispatch_collection = getattr(target.dispatch, identifier)
|
||||
|
||||
adjusted_fn = dispatch_collection._adjust_fn_spec(fn, named)
|
||||
|
||||
self = self.with_wrapper(adjusted_fn)
|
||||
|
||||
if once:
|
||||
self.with_wrapper(
|
||||
util.only_once(self._listen_fn)).listen(*args, **kw)
|
||||
else:
|
||||
self.dispatch_target.dispatch._listen(self, *args, **kw)
|
||||
|
||||
def remove(self):
|
||||
key = self._key
|
||||
|
||||
if key not in _key_to_collection:
|
||||
raise exc.InvalidRequestError(
|
||||
"No listeners found for event %s / %r / %s " %
|
||||
(self.target, self.identifier, self.fn)
|
||||
)
|
||||
dispatch_reg = _key_to_collection.pop(key)
|
||||
|
||||
for collection_ref, listener_ref in dispatch_reg.items():
|
||||
collection = collection_ref()
|
||||
listener_fn = listener_ref()
|
||||
if collection is not None and listener_fn is not None:
|
||||
collection.remove(self.with_wrapper(listener_fn))
|
||||
|
||||
def contains(self):
|
||||
"""Return True if this event key is registered to listen.
|
||||
"""
|
||||
return self._key in _key_to_collection
|
||||
|
||||
def base_listen(self, propagate=False, insert=False,
|
||||
named=False):
|
||||
|
||||
target, identifier, fn = \
|
||||
self.dispatch_target, self.identifier, self._listen_fn
|
||||
|
||||
dispatch_collection = getattr(target.dispatch, identifier)
|
||||
|
||||
if insert:
|
||||
dispatch_collection.\
|
||||
for_modify(target.dispatch).insert(self, propagate)
|
||||
else:
|
||||
dispatch_collection.\
|
||||
for_modify(target.dispatch).append(self, propagate)
|
||||
|
||||
@property
|
||||
def _listen_fn(self):
|
||||
return self.fn_wrap or self.fn
|
||||
|
||||
def append_to_list(self, owner, list_):
|
||||
if _stored_in_collection(self, owner):
|
||||
list_.append(self._listen_fn)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def remove_from_list(self, owner, list_):
|
||||
_removed_from_collection(self, owner)
|
||||
list_.remove(self._listen_fn)
|
||||
|
||||
def prepend_to_list(self, owner, list_):
|
||||
if _stored_in_collection(self, owner):
|
||||
list_.appendleft(self._listen_fn)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
1173
sqlalchemy/events.py
Normal file
1173
sqlalchemy/events.py
Normal file
File diff suppressed because it is too large
Load Diff
1048
sqlalchemy/ext/automap.py
Normal file
1048
sqlalchemy/ext/automap.py
Normal file
File diff suppressed because it is too large
Load Diff
559
sqlalchemy/ext/baked.py
Normal file
559
sqlalchemy/ext/baked.py
Normal file
@@ -0,0 +1,559 @@
|
||||
# sqlalchemy/ext/baked.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""Baked query extension.
|
||||
|
||||
Provides a creational pattern for the :class:`.query.Query` object which
|
||||
allows the fully constructed object, Core select statement, and string
|
||||
compiled result to be fully cached.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from ..orm.query import Query
|
||||
from ..orm import strategies, attributes, properties, \
|
||||
strategy_options, util as orm_util, interfaces
|
||||
from .. import log as sqla_log
|
||||
from ..sql import util as sql_util, func, literal_column
|
||||
from ..orm import exc as orm_exc
|
||||
from .. import exc as sa_exc
|
||||
from .. import util
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BakedQuery(object):
|
||||
"""A builder object for :class:`.query.Query` objects."""
|
||||
|
||||
__slots__ = 'steps', '_bakery', '_cache_key', '_spoiled'
|
||||
|
||||
def __init__(self, bakery, initial_fn, args=()):
|
||||
self._cache_key = ()
|
||||
self._update_cache_key(initial_fn, args)
|
||||
self.steps = [initial_fn]
|
||||
self._spoiled = False
|
||||
self._bakery = bakery
|
||||
|
||||
@classmethod
|
||||
def bakery(cls, size=200):
|
||||
"""Construct a new bakery."""
|
||||
|
||||
_bakery = util.LRUCache(size)
|
||||
|
||||
def call(initial_fn, *args):
|
||||
return cls(_bakery, initial_fn, args)
|
||||
|
||||
return call
|
||||
|
||||
def _clone(self):
|
||||
b1 = BakedQuery.__new__(BakedQuery)
|
||||
b1._cache_key = self._cache_key
|
||||
b1.steps = list(self.steps)
|
||||
b1._bakery = self._bakery
|
||||
b1._spoiled = self._spoiled
|
||||
return b1
|
||||
|
||||
def _update_cache_key(self, fn, args=()):
|
||||
self._cache_key += (fn.__code__,) + args
|
||||
|
||||
def __iadd__(self, other):
|
||||
if isinstance(other, tuple):
|
||||
self.add_criteria(*other)
|
||||
else:
|
||||
self.add_criteria(other)
|
||||
return self
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, tuple):
|
||||
return self.with_criteria(*other)
|
||||
else:
|
||||
return self.with_criteria(other)
|
||||
|
||||
def add_criteria(self, fn, *args):
|
||||
"""Add a criteria function to this :class:`.BakedQuery`.
|
||||
|
||||
This is equivalent to using the ``+=`` operator to
|
||||
modify a :class:`.BakedQuery` in-place.
|
||||
|
||||
"""
|
||||
self._update_cache_key(fn, args)
|
||||
self.steps.append(fn)
|
||||
return self
|
||||
|
||||
def with_criteria(self, fn, *args):
|
||||
"""Add a criteria function to a :class:`.BakedQuery` cloned from this one.
|
||||
|
||||
This is equivalent to using the ``+`` operator to
|
||||
produce a new :class:`.BakedQuery` with modifications.
|
||||
|
||||
"""
|
||||
return self._clone().add_criteria(fn, *args)
|
||||
|
||||
def for_session(self, session):
|
||||
"""Return a :class:`.Result` object for this :class:`.BakedQuery`.
|
||||
|
||||
This is equivalent to calling the :class:`.BakedQuery` as a
|
||||
Python callable, e.g. ``result = my_baked_query(session)``.
|
||||
|
||||
"""
|
||||
return Result(self, session)
|
||||
|
||||
def __call__(self, session):
|
||||
return self.for_session(session)
|
||||
|
||||
def spoil(self, full=False):
|
||||
"""Cancel any query caching that will occur on this BakedQuery object.
|
||||
|
||||
The BakedQuery can continue to be used normally, however additional
|
||||
creational functions will not be cached; they will be called
|
||||
on every invocation.
|
||||
|
||||
This is to support the case where a particular step in constructing
|
||||
a baked query disqualifies the query from being cacheable, such
|
||||
as a variant that relies upon some uncacheable value.
|
||||
|
||||
:param full: if False, only functions added to this
|
||||
:class:`.BakedQuery` object subsequent to the spoil step will be
|
||||
non-cached; the state of the :class:`.BakedQuery` up until
|
||||
this point will be pulled from the cache. If True, then the
|
||||
entire :class:`.Query` object is built from scratch each
|
||||
time, with all creational functions being called on each
|
||||
invocation.
|
||||
|
||||
"""
|
||||
if not full:
|
||||
_spoil_point = self._clone()
|
||||
_spoil_point._cache_key += ('_query_only', )
|
||||
self.steps = [_spoil_point._retrieve_baked_query]
|
||||
self._spoiled = True
|
||||
return self
|
||||
|
||||
def _retrieve_baked_query(self, session):
|
||||
query = self._bakery.get(self._cache_key, None)
|
||||
if query is None:
|
||||
query = self._as_query(session)
|
||||
self._bakery[self._cache_key] = query.with_session(None)
|
||||
return query.with_session(session)
|
||||
|
||||
def _bake(self, session):
|
||||
query = self._as_query(session)
|
||||
|
||||
context = query._compile_context()
|
||||
self._bake_subquery_loaders(session, context)
|
||||
context.session = None
|
||||
context.query = query = context.query.with_session(None)
|
||||
query._execution_options = query._execution_options.union(
|
||||
{"compiled_cache": self._bakery}
|
||||
)
|
||||
# we'll be holding onto the query for some of its state,
|
||||
# so delete some compilation-use-only attributes that can take up
|
||||
# space
|
||||
for attr in (
|
||||
'_correlate', '_from_obj', '_mapper_adapter_map',
|
||||
'_joinpath', '_joinpoint'):
|
||||
query.__dict__.pop(attr, None)
|
||||
self._bakery[self._cache_key] = context
|
||||
return context
|
||||
|
||||
def _as_query(self, session):
|
||||
query = self.steps[0](session)
|
||||
|
||||
for step in self.steps[1:]:
|
||||
query = step(query)
|
||||
return query
|
||||
|
||||
def _bake_subquery_loaders(self, session, context):
|
||||
"""convert subquery eager loaders in the cache into baked queries.
|
||||
|
||||
For subquery eager loading to work, all we need here is that the
|
||||
Query point to the correct session when it is run. However, since
|
||||
we are "baking" anyway, we may as well also turn the query into
|
||||
a "baked" query so that we save on performance too.
|
||||
|
||||
"""
|
||||
context.attributes['baked_queries'] = baked_queries = []
|
||||
for k, v in list(context.attributes.items()):
|
||||
if isinstance(v, Query):
|
||||
if 'subquery' in k:
|
||||
bk = BakedQuery(self._bakery, lambda *args: v)
|
||||
bk._cache_key = self._cache_key + k
|
||||
bk._bake(session)
|
||||
baked_queries.append((k, bk._cache_key, v))
|
||||
del context.attributes[k]
|
||||
|
||||
def _unbake_subquery_loaders(self, session, context, params):
|
||||
"""Retrieve subquery eager loaders stored by _bake_subquery_loaders
|
||||
and turn them back into Result objects that will iterate just
|
||||
like a Query object.
|
||||
|
||||
"""
|
||||
for k, cache_key, query in context.attributes["baked_queries"]:
|
||||
bk = BakedQuery(self._bakery,
|
||||
lambda sess, q=query: q.with_session(sess))
|
||||
bk._cache_key = cache_key
|
||||
context.attributes[k] = bk.for_session(session).params(**params)
|
||||
|
||||
|
||||
class Result(object):
|
||||
"""Invokes a :class:`.BakedQuery` against a :class:`.Session`.
|
||||
|
||||
The :class:`.Result` object is where the actual :class:`.query.Query`
|
||||
object gets created, or retrieved from the cache,
|
||||
against a target :class:`.Session`, and is then invoked for results.
|
||||
|
||||
"""
|
||||
__slots__ = 'bq', 'session', '_params'
|
||||
|
||||
def __init__(self, bq, session):
|
||||
self.bq = bq
|
||||
self.session = session
|
||||
self._params = {}
|
||||
|
||||
def params(self, *args, **kw):
|
||||
"""Specify parameters to be replaced into the string SQL statement."""
|
||||
|
||||
if len(args) == 1:
|
||||
kw.update(args[0])
|
||||
elif len(args) > 0:
|
||||
raise sa_exc.ArgumentError(
|
||||
"params() takes zero or one positional argument, "
|
||||
"which is a dictionary.")
|
||||
self._params.update(kw)
|
||||
return self
|
||||
|
||||
def _as_query(self):
|
||||
return self.bq._as_query(self.session).params(self._params)
|
||||
|
||||
def __str__(self):
|
||||
return str(self._as_query())
|
||||
|
||||
def __iter__(self):
|
||||
bq = self.bq
|
||||
if bq._spoiled:
|
||||
return iter(self._as_query())
|
||||
|
||||
baked_context = bq._bakery.get(bq._cache_key, None)
|
||||
if baked_context is None:
|
||||
baked_context = bq._bake(self.session)
|
||||
|
||||
context = copy.copy(baked_context)
|
||||
context.session = self.session
|
||||
context.attributes = context.attributes.copy()
|
||||
|
||||
bq._unbake_subquery_loaders(self.session, context, self._params)
|
||||
|
||||
context.statement.use_labels = True
|
||||
if context.autoflush and not context.populate_existing:
|
||||
self.session._autoflush()
|
||||
return context.query.params(self._params).\
|
||||
with_session(self.session)._execute_and_instances(context)
|
||||
|
||||
def count(self):
|
||||
"""return the 'count'.
|
||||
|
||||
Equivalent to :meth:`.Query.count`.
|
||||
|
||||
Note this uses a subquery to ensure an accurate count regardless
|
||||
of the structure of the original statement.
|
||||
|
||||
.. versionadded:: 1.1.6
|
||||
|
||||
"""
|
||||
|
||||
col = func.count(literal_column('*'))
|
||||
bq = self.bq.with_criteria(lambda q: q.from_self(col))
|
||||
return bq.for_session(self.session).params(self._params).scalar()
|
||||
|
||||
def scalar(self):
|
||||
"""Return the first element of the first result or None
|
||||
if no rows present. If multiple rows are returned,
|
||||
raises MultipleResultsFound.
|
||||
|
||||
Equivalent to :meth:`.Query.scalar`.
|
||||
|
||||
.. versionadded:: 1.1.6
|
||||
|
||||
"""
|
||||
try:
|
||||
ret = self.one()
|
||||
if not isinstance(ret, tuple):
|
||||
return ret
|
||||
return ret[0]
|
||||
except orm_exc.NoResultFound:
|
||||
return None
|
||||
|
||||
def first(self):
|
||||
"""Return the first row.
|
||||
|
||||
Equivalent to :meth:`.Query.first`.
|
||||
|
||||
"""
|
||||
bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
|
||||
ret = list(bq.for_session(self.session).params(self._params))
|
||||
if len(ret) > 0:
|
||||
return ret[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def one(self):
|
||||
"""Return exactly one result or raise an exception.
|
||||
|
||||
Equivalent to :meth:`.Query.one`.
|
||||
|
||||
"""
|
||||
try:
|
||||
ret = self.one_or_none()
|
||||
except orm_exc.MultipleResultsFound:
|
||||
raise orm_exc.MultipleResultsFound(
|
||||
"Multiple rows were found for one()")
|
||||
else:
|
||||
if ret is None:
|
||||
raise orm_exc.NoResultFound("No row was found for one()")
|
||||
return ret
|
||||
|
||||
def one_or_none(self):
|
||||
"""Return one or zero results, or raise an exception for multiple
|
||||
rows.
|
||||
|
||||
Equivalent to :meth:`.Query.one_or_none`.
|
||||
|
||||
.. versionadded:: 1.0.9
|
||||
|
||||
"""
|
||||
ret = list(self)
|
||||
|
||||
l = len(ret)
|
||||
if l == 1:
|
||||
return ret[0]
|
||||
elif l == 0:
|
||||
return None
|
||||
else:
|
||||
raise orm_exc.MultipleResultsFound(
|
||||
"Multiple rows were found for one_or_none()")
|
||||
|
||||
def all(self):
|
||||
"""Return all rows.
|
||||
|
||||
Equivalent to :meth:`.Query.all`.
|
||||
|
||||
"""
|
||||
return list(self)
|
||||
|
||||
def get(self, ident):
|
||||
"""Retrieve an object based on identity.
|
||||
|
||||
Equivalent to :meth:`.Query.get`.
|
||||
|
||||
"""
|
||||
|
||||
query = self.bq.steps[0](self.session)
|
||||
return query._get_impl(ident, self._load_on_ident)
|
||||
|
||||
def _load_on_ident(self, query, key):
|
||||
"""Load the given identity key from the database."""
|
||||
|
||||
ident = key[1]
|
||||
|
||||
mapper = query._mapper_zero()
|
||||
|
||||
_get_clause, _get_params = mapper._get_clause
|
||||
|
||||
def setup(query):
|
||||
_lcl_get_clause = _get_clause
|
||||
q = query._clone()
|
||||
q._get_condition()
|
||||
q._order_by = None
|
||||
|
||||
# None present in ident - turn those comparisons
|
||||
# into "IS NULL"
|
||||
if None in ident:
|
||||
nones = set([
|
||||
_get_params[col].key for col, value in
|
||||
zip(mapper.primary_key, ident) if value is None
|
||||
])
|
||||
_lcl_get_clause = sql_util.adapt_criterion_to_null(
|
||||
_lcl_get_clause, nones)
|
||||
|
||||
_lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False)
|
||||
q._criterion = _lcl_get_clause
|
||||
return q
|
||||
|
||||
# cache the query against a key that includes
|
||||
# which positions in the primary key are NULL
|
||||
# (remember, we can map to an OUTER JOIN)
|
||||
bq = self.bq
|
||||
|
||||
# add the clause we got from mapper._get_clause to the cache
|
||||
# key so that if a race causes multiple calls to _get_clause,
|
||||
# we've cached on ours
|
||||
bq = bq._clone()
|
||||
bq._cache_key += (_get_clause, )
|
||||
|
||||
bq = bq.with_criteria(setup, tuple(elem is None for elem in ident))
|
||||
|
||||
params = dict([
|
||||
(_get_params[primary_key].key, id_val)
|
||||
for id_val, primary_key in zip(ident, mapper.primary_key)
|
||||
])
|
||||
|
||||
result = list(bq.for_session(self.session).params(**params))
|
||||
l = len(result)
|
||||
if l > 1:
|
||||
raise orm_exc.MultipleResultsFound()
|
||||
elif l:
|
||||
return result[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def bake_lazy_loaders():
|
||||
"""Enable the use of baked queries for all lazyloaders systemwide.
|
||||
|
||||
This operation should be safe for all lazy loaders, and will reduce
|
||||
Python overhead for these operations.
|
||||
|
||||
"""
|
||||
BakedLazyLoader._strategy_keys[:] = []
|
||||
|
||||
properties.RelationshipProperty.strategy_for(
|
||||
lazy="select")(BakedLazyLoader)
|
||||
properties.RelationshipProperty.strategy_for(
|
||||
lazy=True)(BakedLazyLoader)
|
||||
properties.RelationshipProperty.strategy_for(
|
||||
lazy="baked_select")(BakedLazyLoader)
|
||||
|
||||
strategies.LazyLoader._strategy_keys[:] = BakedLazyLoader._strategy_keys[:]
|
||||
|
||||
|
||||
def unbake_lazy_loaders():
|
||||
"""Disable the use of baked queries for all lazyloaders systemwide.
|
||||
|
||||
This operation reverts the changes produced by :func:`.bake_lazy_loaders`.
|
||||
|
||||
"""
|
||||
strategies.LazyLoader._strategy_keys[:] = []
|
||||
BakedLazyLoader._strategy_keys[:] = []
|
||||
|
||||
properties.RelationshipProperty.strategy_for(
|
||||
lazy="select")(strategies.LazyLoader)
|
||||
properties.RelationshipProperty.strategy_for(
|
||||
lazy=True)(strategies.LazyLoader)
|
||||
properties.RelationshipProperty.strategy_for(
|
||||
lazy="baked_select")(BakedLazyLoader)
|
||||
assert strategies.LazyLoader._strategy_keys
|
||||
|
||||
|
||||
@sqla_log.class_logger
|
||||
@properties.RelationshipProperty.strategy_for(lazy="baked_select")
|
||||
class BakedLazyLoader(strategies.LazyLoader):
|
||||
|
||||
def _emit_lazyload(self, session, state, ident_key, passive):
|
||||
q = BakedQuery(
|
||||
self.mapper._compiled_cache,
|
||||
lambda session: session.query(self.mapper))
|
||||
q.add_criteria(
|
||||
lambda q: q._adapt_all_clauses()._with_invoke_all_eagers(False),
|
||||
self.parent_property)
|
||||
|
||||
if not self.parent_property.bake_queries:
|
||||
q.spoil(full=True)
|
||||
|
||||
if self.parent_property.secondary is not None:
|
||||
q.add_criteria(
|
||||
lambda q:
|
||||
q.select_from(self.mapper, self.parent_property.secondary))
|
||||
|
||||
pending = not state.key
|
||||
|
||||
# don't autoflush on pending
|
||||
if pending or passive & attributes.NO_AUTOFLUSH:
|
||||
q.add_criteria(lambda q: q.autoflush(False))
|
||||
|
||||
if state.load_options:
|
||||
q.spoil()
|
||||
args = state.load_path[self.parent_property]
|
||||
q.add_criteria(
|
||||
lambda q:
|
||||
q._with_current_path(args), args)
|
||||
q.add_criteria(
|
||||
lambda q: q._conditional_options(*state.load_options))
|
||||
|
||||
if self.use_get:
|
||||
return q(session)._load_on_ident(
|
||||
session.query(self.mapper), ident_key)
|
||||
|
||||
if self.parent_property.order_by:
|
||||
q.add_criteria(
|
||||
lambda q:
|
||||
q.order_by(*util.to_list(self.parent_property.order_by)))
|
||||
|
||||
for rev in self.parent_property._reverse_property:
|
||||
# reverse props that are MANYTOONE are loading *this*
|
||||
# object from get(), so don't need to eager out to those.
|
||||
if rev.direction is interfaces.MANYTOONE and \
|
||||
rev._use_get and \
|
||||
not isinstance(rev.strategy, strategies.LazyLoader):
|
||||
|
||||
q.add_criteria(
|
||||
lambda q:
|
||||
q.options(
|
||||
strategy_options.Load.for_existing_path(
|
||||
q._current_path[rev.parent]
|
||||
).baked_lazyload(rev.key)
|
||||
)
|
||||
)
|
||||
|
||||
lazy_clause, params = self._generate_lazy_clause(state, passive)
|
||||
|
||||
if pending:
|
||||
if orm_util._none_set.intersection(params.values()):
|
||||
return None
|
||||
|
||||
q.add_criteria(lambda q: q.filter(lazy_clause))
|
||||
result = q(session).params(**params).all()
|
||||
if self.uselist:
|
||||
return result
|
||||
else:
|
||||
l = len(result)
|
||||
if l:
|
||||
if l > 1:
|
||||
util.warn(
|
||||
"Multiple rows returned with "
|
||||
"uselist=False for lazily-loaded attribute '%s' "
|
||||
% self.parent_property)
|
||||
|
||||
return result[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@strategy_options.loader_option()
|
||||
def baked_lazyload(loadopt, attr):
|
||||
"""Indicate that the given attribute should be loaded using "lazy"
|
||||
loading with a "baked" query used in the load.
|
||||
|
||||
"""
|
||||
return loadopt.set_relationship_strategy(attr, {"lazy": "baked_select"})
|
||||
|
||||
|
||||
@baked_lazyload._add_unbound_fn
|
||||
def baked_lazyload(*keys):
|
||||
return strategy_options._UnboundLoad._from_keys(
|
||||
strategy_options._UnboundLoad.baked_lazyload, keys, False, {})
|
||||
|
||||
|
||||
@baked_lazyload._add_unbound_all_fn
|
||||
def baked_lazyload_all(*keys):
|
||||
return strategy_options._UnboundLoad._from_keys(
|
||||
strategy_options._UnboundLoad.baked_lazyload, keys, True, {})
|
||||
|
||||
baked_lazyload = baked_lazyload._unbound_fn
|
||||
baked_lazyload_all = baked_lazyload_all._unbound_all_fn
|
||||
|
||||
bakery = BakedQuery.bakery
|
||||
18
sqlalchemy/ext/declarative/__init__.py
Normal file
18
sqlalchemy/ext/declarative/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# ext/declarative/__init__.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .api import declarative_base, synonym_for, comparable_using, \
|
||||
instrument_declarative, ConcreteBase, AbstractConcreteBase, \
|
||||
DeclarativeMeta, DeferredReflection, has_inherited_table,\
|
||||
declared_attr, as_declarative
|
||||
|
||||
|
||||
__all__ = ['declarative_base', 'synonym_for', 'has_inherited_table',
|
||||
'comparable_using', 'instrument_declarative', 'declared_attr',
|
||||
'as_declarative',
|
||||
'ConcreteBase', 'AbstractConcreteBase', 'DeclarativeMeta',
|
||||
'DeferredReflection']
|
||||
696
sqlalchemy/ext/declarative/api.py
Normal file
696
sqlalchemy/ext/declarative/api.py
Normal file
@@ -0,0 +1,696 @@
|
||||
# ext/declarative/api.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""Public API functions and helpers for declarative."""
|
||||
|
||||
|
||||
from ...schema import Table, MetaData, Column
|
||||
from ...orm import synonym as _orm_synonym, \
|
||||
comparable_property,\
|
||||
interfaces, properties, attributes
|
||||
from ...orm.util import polymorphic_union
|
||||
from ...orm.base import _mapper_or_none
|
||||
from ...util import OrderedDict, hybridmethod, hybridproperty
|
||||
from ... import util
|
||||
from ... import exc
|
||||
import weakref
|
||||
|
||||
from .base import _as_declarative, \
|
||||
_declarative_constructor,\
|
||||
_DeferredMapperConfig, _add_attribute
|
||||
from .clsregistry import _class_resolver
|
||||
|
||||
|
||||
def instrument_declarative(cls, registry, metadata):
|
||||
"""Given a class, configure the class declaratively,
|
||||
using the given registry, which can be any dictionary, and
|
||||
MetaData object.
|
||||
|
||||
"""
|
||||
if '_decl_class_registry' in cls.__dict__:
|
||||
raise exc.InvalidRequestError(
|
||||
"Class %r already has been "
|
||||
"instrumented declaratively" % cls)
|
||||
cls._decl_class_registry = registry
|
||||
cls.metadata = metadata
|
||||
_as_declarative(cls, cls.__name__, cls.__dict__)
|
||||
|
||||
|
||||
def has_inherited_table(cls):
|
||||
"""Given a class, return True if any of the classes it inherits from has a
|
||||
mapped table, otherwise return False.
|
||||
|
||||
This is used in declarative mixins to build attributes that behave
|
||||
differently for the base class vs. a subclass in an inheritance
|
||||
hierarchy.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`decl_mixin_inheritance`
|
||||
|
||||
"""
|
||||
for class_ in cls.__mro__[1:]:
|
||||
if getattr(class_, '__table__', None) is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class DeclarativeMeta(type):
|
||||
def __init__(cls, classname, bases, dict_):
|
||||
if '_decl_class_registry' not in cls.__dict__:
|
||||
_as_declarative(cls, classname, cls.__dict__)
|
||||
type.__init__(cls, classname, bases, dict_)
|
||||
|
||||
def __setattr__(cls, key, value):
|
||||
_add_attribute(cls, key, value)
|
||||
|
||||
|
||||
def synonym_for(name, map_column=False):
|
||||
"""Decorator, make a Python @property a query synonym for a column.
|
||||
|
||||
A decorator version of :func:`~sqlalchemy.orm.synonym`. The function being
|
||||
decorated is the 'descriptor', otherwise passes its arguments through to
|
||||
synonym()::
|
||||
|
||||
@synonym_for('col')
|
||||
@property
|
||||
def prop(self):
|
||||
return 'special sauce'
|
||||
|
||||
The regular ``synonym()`` is also usable directly in a declarative setting
|
||||
and may be convenient for read/write properties::
|
||||
|
||||
prop = synonym('col', descriptor=property(_read_prop, _write_prop))
|
||||
|
||||
"""
|
||||
def decorate(fn):
|
||||
return _orm_synonym(name, map_column=map_column, descriptor=fn)
|
||||
return decorate
|
||||
|
||||
|
||||
def comparable_using(comparator_factory):
|
||||
"""Decorator, allow a Python @property to be used in query criteria.
|
||||
|
||||
This is a decorator front end to
|
||||
:func:`~sqlalchemy.orm.comparable_property` that passes
|
||||
through the comparator_factory and the function being decorated::
|
||||
|
||||
@comparable_using(MyComparatorType)
|
||||
@property
|
||||
def prop(self):
|
||||
return 'special sauce'
|
||||
|
||||
The regular ``comparable_property()`` is also usable directly in a
|
||||
declarative setting and may be convenient for read/write properties::
|
||||
|
||||
prop = comparable_property(MyComparatorType)
|
||||
|
||||
"""
|
||||
def decorate(fn):
|
||||
return comparable_property(comparator_factory, fn)
|
||||
return decorate
|
||||
|
||||
|
||||
class declared_attr(interfaces._MappedAttribute, property):
|
||||
"""Mark a class-level method as representing the definition of
|
||||
a mapped property or special declarative member name.
|
||||
|
||||
@declared_attr turns the attribute into a scalar-like
|
||||
property that can be invoked from the uninstantiated class.
|
||||
Declarative treats attributes specifically marked with
|
||||
@declared_attr as returning a construct that is specific
|
||||
to mapping or declarative table configuration. The name
|
||||
of the attribute is that of what the non-dynamic version
|
||||
of the attribute would be.
|
||||
|
||||
@declared_attr is more often than not applicable to mixins,
|
||||
to define relationships that are to be applied to different
|
||||
implementors of the class::
|
||||
|
||||
class ProvidesUser(object):
|
||||
"A mixin that adds a 'user' relationship to classes."
|
||||
|
||||
@declared_attr
|
||||
def user(self):
|
||||
return relationship("User")
|
||||
|
||||
It also can be applied to mapped classes, such as to provide
|
||||
a "polymorphic" scheme for inheritance::
|
||||
|
||||
class Employee(Base):
|
||||
id = Column(Integer, primary_key=True)
|
||||
type = Column(String(50), nullable=False)
|
||||
|
||||
@declared_attr
|
||||
def __tablename__(cls):
|
||||
return cls.__name__.lower()
|
||||
|
||||
@declared_attr
|
||||
def __mapper_args__(cls):
|
||||
if cls.__name__ == 'Employee':
|
||||
return {
|
||||
"polymorphic_on":cls.type,
|
||||
"polymorphic_identity":"Employee"
|
||||
}
|
||||
else:
|
||||
return {"polymorphic_identity":cls.__name__}
|
||||
|
||||
.. versionchanged:: 0.8 :class:`.declared_attr` can be used with
|
||||
non-ORM or extension attributes, such as user-defined attributes
|
||||
or :func:`.association_proxy` objects, which will be assigned
|
||||
to the class at class construction time.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, fget, cascading=False):
|
||||
super(declared_attr, self).__init__(fget)
|
||||
self.__doc__ = fget.__doc__
|
||||
self._cascading = cascading
|
||||
|
||||
def __get__(desc, self, cls):
|
||||
reg = cls.__dict__.get('_sa_declared_attr_reg', None)
|
||||
if reg is None:
|
||||
manager = attributes.manager_of_class(cls)
|
||||
if manager is None:
|
||||
util.warn(
|
||||
"Unmanaged access of declarative attribute %s from "
|
||||
"non-mapped class %s" %
|
||||
(desc.fget.__name__, cls.__name__))
|
||||
return desc.fget(cls)
|
||||
elif desc in reg:
|
||||
return reg[desc]
|
||||
else:
|
||||
reg[desc] = obj = desc.fget(cls)
|
||||
return obj
|
||||
|
||||
@hybridmethod
|
||||
def _stateful(cls, **kw):
|
||||
return _stateful_declared_attr(**kw)
|
||||
|
||||
@hybridproperty
|
||||
def cascading(cls):
|
||||
"""Mark a :class:`.declared_attr` as cascading.
|
||||
|
||||
This is a special-use modifier which indicates that a column
|
||||
or MapperProperty-based declared attribute should be configured
|
||||
distinctly per mapped subclass, within a mapped-inheritance scenario.
|
||||
|
||||
Below, both MyClass as well as MySubClass will have a distinct
|
||||
``id`` Column object established::
|
||||
|
||||
class HasIdMixin(object):
|
||||
@declared_attr.cascading
|
||||
def id(cls):
|
||||
if has_inherited_table(cls):
|
||||
return Column(ForeignKey('myclass.id'), primary_key=True)
|
||||
else:
|
||||
return Column(Integer, primary_key=True)
|
||||
|
||||
class MyClass(HasIdMixin, Base):
|
||||
__tablename__ = 'myclass'
|
||||
# ...
|
||||
|
||||
class MySubClass(MyClass):
|
||||
""
|
||||
# ...
|
||||
|
||||
The behavior of the above configuration is that ``MySubClass``
|
||||
will refer to both its own ``id`` column as well as that of
|
||||
``MyClass`` underneath the attribute named ``some_id``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`declarative_inheritance`
|
||||
|
||||
:ref:`mixin_inheritance_columns`
|
||||
|
||||
|
||||
"""
|
||||
return cls._stateful(cascading=True)
|
||||
|
||||
|
||||
class _stateful_declared_attr(declared_attr):
|
||||
def __init__(self, **kw):
|
||||
self.kw = kw
|
||||
|
||||
def _stateful(self, **kw):
|
||||
new_kw = self.kw.copy()
|
||||
new_kw.update(kw)
|
||||
return _stateful_declared_attr(**new_kw)
|
||||
|
||||
def __call__(self, fn):
|
||||
return declared_attr(fn, **self.kw)
|
||||
|
||||
|
||||
def declarative_base(bind=None, metadata=None, mapper=None, cls=object,
|
||||
name='Base', constructor=_declarative_constructor,
|
||||
class_registry=None,
|
||||
metaclass=DeclarativeMeta):
|
||||
r"""Construct a base class for declarative class definitions.
|
||||
|
||||
The new base class will be given a metaclass that produces
|
||||
appropriate :class:`~sqlalchemy.schema.Table` objects and makes
|
||||
the appropriate :func:`~sqlalchemy.orm.mapper` calls based on the
|
||||
information provided declaratively in the class and any subclasses
|
||||
of the class.
|
||||
|
||||
:param bind: An optional
|
||||
:class:`~sqlalchemy.engine.Connectable`, will be assigned
|
||||
the ``bind`` attribute on the :class:`~sqlalchemy.schema.MetaData`
|
||||
instance.
|
||||
|
||||
:param metadata:
|
||||
An optional :class:`~sqlalchemy.schema.MetaData` instance. All
|
||||
:class:`~sqlalchemy.schema.Table` objects implicitly declared by
|
||||
subclasses of the base will share this MetaData. A MetaData instance
|
||||
will be created if none is provided. The
|
||||
:class:`~sqlalchemy.schema.MetaData` instance will be available via the
|
||||
`metadata` attribute of the generated declarative base class.
|
||||
|
||||
:param mapper:
|
||||
An optional callable, defaults to :func:`~sqlalchemy.orm.mapper`. Will
|
||||
be used to map subclasses to their Tables.
|
||||
|
||||
:param cls:
|
||||
Defaults to :class:`object`. A type to use as the base for the generated
|
||||
declarative base class. May be a class or tuple of classes.
|
||||
|
||||
:param name:
|
||||
Defaults to ``Base``. The display name for the generated
|
||||
class. Customizing this is not required, but can improve clarity in
|
||||
tracebacks and debugging.
|
||||
|
||||
:param constructor:
|
||||
Defaults to
|
||||
:func:`~sqlalchemy.ext.declarative.base._declarative_constructor`, an
|
||||
__init__ implementation that assigns \**kwargs for declared
|
||||
fields and relationships to an instance. If ``None`` is supplied,
|
||||
no __init__ will be provided and construction will fall back to
|
||||
cls.__init__ by way of the normal Python semantics.
|
||||
|
||||
:param class_registry: optional dictionary that will serve as the
|
||||
registry of class names-> mapped classes when string names
|
||||
are used to identify classes inside of :func:`.relationship`
|
||||
and others. Allows two or more declarative base classes
|
||||
to share the same registry of class names for simplified
|
||||
inter-base relationships.
|
||||
|
||||
:param metaclass:
|
||||
Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__
|
||||
compatible callable to use as the meta type of the generated
|
||||
declarative base class.
|
||||
|
||||
.. versionchanged:: 1.1 if :paramref:`.declarative_base.cls` is a single class (rather
|
||||
than a tuple), the constructed base class will inherit its docstring.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`.as_declarative`
|
||||
|
||||
"""
|
||||
lcl_metadata = metadata or MetaData()
|
||||
if bind:
|
||||
lcl_metadata.bind = bind
|
||||
|
||||
if class_registry is None:
|
||||
class_registry = weakref.WeakValueDictionary()
|
||||
|
||||
bases = not isinstance(cls, tuple) and (cls,) or cls
|
||||
class_dict = dict(_decl_class_registry=class_registry,
|
||||
metadata=lcl_metadata)
|
||||
|
||||
if isinstance(cls, type):
|
||||
class_dict['__doc__'] = cls.__doc__
|
||||
|
||||
if constructor:
|
||||
class_dict['__init__'] = constructor
|
||||
if mapper:
|
||||
class_dict['__mapper_cls__'] = mapper
|
||||
|
||||
return metaclass(name, bases, class_dict)
|
||||
|
||||
|
||||
def as_declarative(**kw):
|
||||
"""
|
||||
Class decorator for :func:`.declarative_base`.
|
||||
|
||||
Provides a syntactical shortcut to the ``cls`` argument
|
||||
sent to :func:`.declarative_base`, allowing the base class
|
||||
to be converted in-place to a "declarative" base::
|
||||
|
||||
from sqlalchemy.ext.declarative import as_declarative
|
||||
|
||||
@as_declarative()
|
||||
class Base(object):
|
||||
@declared_attr
|
||||
def __tablename__(cls):
|
||||
return cls.__name__.lower()
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
class MyMappedClass(Base):
|
||||
# ...
|
||||
|
||||
All keyword arguments passed to :func:`.as_declarative` are passed
|
||||
along to :func:`.declarative_base`.
|
||||
|
||||
.. versionadded:: 0.8.3
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`.declarative_base`
|
||||
|
||||
"""
|
||||
def decorate(cls):
|
||||
kw['cls'] = cls
|
||||
kw['name'] = cls.__name__
|
||||
return declarative_base(**kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
class ConcreteBase(object):
|
||||
"""A helper class for 'concrete' declarative mappings.
|
||||
|
||||
:class:`.ConcreteBase` will use the :func:`.polymorphic_union`
|
||||
function automatically, against all tables mapped as a subclass
|
||||
to this class. The function is called via the
|
||||
``__declare_last__()`` function, which is essentially
|
||||
a hook for the :meth:`.after_configured` event.
|
||||
|
||||
:class:`.ConcreteBase` produces a mapped
|
||||
table for the class itself. Compare to :class:`.AbstractConcreteBase`,
|
||||
which does not.
|
||||
|
||||
Example::
|
||||
|
||||
from sqlalchemy.ext.declarative import ConcreteBase
|
||||
|
||||
class Employee(ConcreteBase, Base):
|
||||
__tablename__ = 'employee'
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'employee',
|
||||
'concrete':True}
|
||||
|
||||
class Manager(Employee):
|
||||
__tablename__ = 'manager'
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
manager_data = Column(String(40))
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'manager',
|
||||
'concrete':True}
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.AbstractConcreteBase`
|
||||
|
||||
:ref:`concrete_inheritance`
|
||||
|
||||
:ref:`inheritance_concrete_helpers`
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _create_polymorphic_union(cls, mappers):
|
||||
return polymorphic_union(OrderedDict(
|
||||
(mp.polymorphic_identity, mp.local_table)
|
||||
for mp in mappers
|
||||
), 'type', 'pjoin')
|
||||
|
||||
@classmethod
|
||||
def __declare_first__(cls):
|
||||
m = cls.__mapper__
|
||||
if m.with_polymorphic:
|
||||
return
|
||||
|
||||
mappers = list(m.self_and_descendants)
|
||||
pjoin = cls._create_polymorphic_union(mappers)
|
||||
m._set_with_polymorphic(("*", pjoin))
|
||||
m._set_polymorphic_on(pjoin.c.type)
|
||||
|
||||
|
||||
class AbstractConcreteBase(ConcreteBase):
|
||||
"""A helper class for 'concrete' declarative mappings.
|
||||
|
||||
:class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union`
|
||||
function automatically, against all tables mapped as a subclass
|
||||
to this class. The function is called via the
|
||||
``__declare_last__()`` function, which is essentially
|
||||
a hook for the :meth:`.after_configured` event.
|
||||
|
||||
:class:`.AbstractConcreteBase` does produce a mapped class
|
||||
for the base class, however it is not persisted to any table; it
|
||||
is instead mapped directly to the "polymorphic" selectable directly
|
||||
and is only used for selecting. Compare to :class:`.ConcreteBase`,
|
||||
which does create a persisted table for the base class.
|
||||
|
||||
Example::
|
||||
|
||||
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||
|
||||
class Employee(AbstractConcreteBase, Base):
|
||||
pass
|
||||
|
||||
class Manager(Employee):
|
||||
__tablename__ = 'manager'
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
manager_data = Column(String(40))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'manager',
|
||||
'concrete':True}
|
||||
|
||||
The abstract base class is handled by declarative in a special way;
|
||||
at class configuration time, it behaves like a declarative mixin
|
||||
or an ``__abstract__`` base class. Once classes are configured
|
||||
and mappings are produced, it then gets mapped itself, but
|
||||
after all of its decscendants. This is a very unique system of mapping
|
||||
not found in any other SQLAlchemy system.
|
||||
|
||||
Using this approach, we can specify columns and properties
|
||||
that will take place on mapped subclasses, in the way that
|
||||
we normally do as in :ref:`declarative_mixins`::
|
||||
|
||||
class Company(Base):
|
||||
__tablename__ = 'company'
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
class Employee(AbstractConcreteBase, Base):
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
|
||||
@declared_attr
|
||||
def company_id(cls):
|
||||
return Column(ForeignKey('company.id'))
|
||||
|
||||
@declared_attr
|
||||
def company(cls):
|
||||
return relationship("Company")
|
||||
|
||||
class Manager(Employee):
|
||||
__tablename__ = 'manager'
|
||||
|
||||
name = Column(String(50))
|
||||
manager_data = Column(String(40))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'manager',
|
||||
'concrete':True}
|
||||
|
||||
When we make use of our mappings however, both ``Manager`` and
|
||||
``Employee`` will have an independently usable ``.company`` attribute::
|
||||
|
||||
session.query(Employee).filter(Employee.company.has(id=5))
|
||||
|
||||
.. versionchanged:: 1.0.0 - The mechanics of :class:`.AbstractConcreteBase`
|
||||
have been reworked to support relationships established directly
|
||||
on the abstract base, without any special configurational steps.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.ConcreteBase`
|
||||
|
||||
:ref:`concrete_inheritance`
|
||||
|
||||
:ref:`inheritance_concrete_helpers`
|
||||
|
||||
"""
|
||||
|
||||
__no_table__ = True
|
||||
|
||||
@classmethod
|
||||
def __declare_first__(cls):
|
||||
cls._sa_decl_prepare_nocascade()
|
||||
|
||||
@classmethod
|
||||
def _sa_decl_prepare_nocascade(cls):
|
||||
if getattr(cls, '__mapper__', None):
|
||||
return
|
||||
|
||||
to_map = _DeferredMapperConfig.config_for_cls(cls)
|
||||
|
||||
# can't rely on 'self_and_descendants' here
|
||||
# since technically an immediate subclass
|
||||
# might not be mapped, but a subclass
|
||||
# may be.
|
||||
mappers = []
|
||||
stack = list(cls.__subclasses__())
|
||||
while stack:
|
||||
klass = stack.pop()
|
||||
stack.extend(klass.__subclasses__())
|
||||
mn = _mapper_or_none(klass)
|
||||
if mn is not None:
|
||||
mappers.append(mn)
|
||||
pjoin = cls._create_polymorphic_union(mappers)
|
||||
|
||||
# For columns that were declared on the class, these
|
||||
# are normally ignored with the "__no_table__" mapping,
|
||||
# unless they have a different attribute key vs. col name
|
||||
# and are in the properties argument.
|
||||
# In that case, ensure we update the properties entry
|
||||
# to the correct column from the pjoin target table.
|
||||
declared_cols = set(to_map.declared_columns)
|
||||
for k, v in list(to_map.properties.items()):
|
||||
if v in declared_cols:
|
||||
to_map.properties[k] = pjoin.c[v.key]
|
||||
|
||||
to_map.local_table = pjoin
|
||||
|
||||
m_args = to_map.mapper_args_fn or dict
|
||||
|
||||
def mapper_args():
|
||||
args = m_args()
|
||||
args['polymorphic_on'] = pjoin.c.type
|
||||
return args
|
||||
to_map.mapper_args_fn = mapper_args
|
||||
|
||||
m = to_map.map()
|
||||
|
||||
for scls in cls.__subclasses__():
|
||||
sm = _mapper_or_none(scls)
|
||||
if sm and sm.concrete and cls in scls.__bases__:
|
||||
sm._set_concrete_base(m)
|
||||
|
||||
|
||||
class DeferredReflection(object):
|
||||
"""A helper class for construction of mappings based on
|
||||
a deferred reflection step.
|
||||
|
||||
Normally, declarative can be used with reflection by
|
||||
setting a :class:`.Table` object using autoload=True
|
||||
as the ``__table__`` attribute on a declarative class.
|
||||
The caveat is that the :class:`.Table` must be fully
|
||||
reflected, or at the very least have a primary key column,
|
||||
at the point at which a normal declarative mapping is
|
||||
constructed, meaning the :class:`.Engine` must be available
|
||||
at class declaration time.
|
||||
|
||||
The :class:`.DeferredReflection` mixin moves the construction
|
||||
of mappers to be at a later point, after a specific
|
||||
method is called which first reflects all :class:`.Table`
|
||||
objects created so far. Classes can define it as such::
|
||||
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.declarative import DeferredReflection
|
||||
Base = declarative_base()
|
||||
|
||||
class MyClass(DeferredReflection, Base):
|
||||
__tablename__ = 'mytable'
|
||||
|
||||
Above, ``MyClass`` is not yet mapped. After a series of
|
||||
classes have been defined in the above fashion, all tables
|
||||
can be reflected and mappings created using
|
||||
:meth:`.prepare`::
|
||||
|
||||
engine = create_engine("someengine://...")
|
||||
DeferredReflection.prepare(engine)
|
||||
|
||||
The :class:`.DeferredReflection` mixin can be applied to individual
|
||||
classes, used as the base for the declarative base itself,
|
||||
or used in a custom abstract class. Using an abstract base
|
||||
allows that only a subset of classes to be prepared for a
|
||||
particular prepare step, which is necessary for applications
|
||||
that use more than one engine. For example, if an application
|
||||
has two engines, you might use two bases, and prepare each
|
||||
separately, e.g.::
|
||||
|
||||
class ReflectedOne(DeferredReflection, Base):
|
||||
__abstract__ = True
|
||||
|
||||
class ReflectedTwo(DeferredReflection, Base):
|
||||
__abstract__ = True
|
||||
|
||||
class MyClass(ReflectedOne):
|
||||
__tablename__ = 'mytable'
|
||||
|
||||
class MyOtherClass(ReflectedOne):
|
||||
__tablename__ = 'myothertable'
|
||||
|
||||
class YetAnotherClass(ReflectedTwo):
|
||||
__tablename__ = 'yetanothertable'
|
||||
|
||||
# ... etc.
|
||||
|
||||
Above, the class hierarchies for ``ReflectedOne`` and
|
||||
``ReflectedTwo`` can be configured separately::
|
||||
|
||||
ReflectedOne.prepare(engine_one)
|
||||
ReflectedTwo.prepare(engine_two)
|
||||
|
||||
.. versionadded:: 0.8
|
||||
|
||||
"""
|
||||
@classmethod
|
||||
def prepare(cls, engine):
|
||||
"""Reflect all :class:`.Table` objects for all current
|
||||
:class:`.DeferredReflection` subclasses"""
|
||||
|
||||
to_map = _DeferredMapperConfig.classes_for_base(cls)
|
||||
for thingy in to_map:
|
||||
cls._sa_decl_prepare(thingy.local_table, engine)
|
||||
thingy.map()
|
||||
mapper = thingy.cls.__mapper__
|
||||
metadata = mapper.class_.metadata
|
||||
for rel in mapper._props.values():
|
||||
if isinstance(rel, properties.RelationshipProperty) and \
|
||||
rel.secondary is not None:
|
||||
if isinstance(rel.secondary, Table):
|
||||
cls._reflect_table(rel.secondary, engine)
|
||||
elif isinstance(rel.secondary, _class_resolver):
|
||||
rel.secondary._resolvers += (
|
||||
cls._sa_deferred_table_resolver(engine, metadata),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _sa_deferred_table_resolver(cls, engine, metadata):
|
||||
def _resolve(key):
|
||||
t1 = Table(key, metadata)
|
||||
cls._reflect_table(t1, engine)
|
||||
return t1
|
||||
return _resolve
|
||||
|
||||
@classmethod
|
||||
def _sa_decl_prepare(cls, local_table, engine):
|
||||
# autoload Table, which is already
|
||||
# present in the metadata. This
|
||||
# will fill in db-loaded columns
|
||||
# into the existing Table object.
|
||||
if local_table is not None:
|
||||
cls._reflect_table(local_table, engine)
|
||||
|
||||
@classmethod
|
||||
def _reflect_table(cls, table, engine):
|
||||
Table(table.name,
|
||||
table.metadata,
|
||||
extend_existing=True,
|
||||
autoload_replace=False,
|
||||
autoload=True,
|
||||
autoload_with=engine,
|
||||
schema=table.schema)
|
||||
662
sqlalchemy/ext/declarative/base.py
Normal file
662
sqlalchemy/ext/declarative/base.py
Normal file
@@ -0,0 +1,662 @@
|
||||
# ext/declarative/base.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""Internal implementation for declarative."""
|
||||
|
||||
from ...schema import Table, Column
|
||||
from ...orm import mapper, class_mapper, synonym
|
||||
from ...orm.interfaces import MapperProperty
|
||||
from ...orm.properties import ColumnProperty, CompositeProperty
|
||||
from ...orm.attributes import QueryableAttribute
|
||||
from ...orm.base import _is_mapped_class
|
||||
from ... import util, exc
|
||||
from ...util import topological
|
||||
from ...sql import expression
|
||||
from ... import event
|
||||
from . import clsregistry
|
||||
import collections
|
||||
import weakref
|
||||
from sqlalchemy.orm import instrumentation
|
||||
|
||||
declared_attr = declarative_props = None
|
||||
|
||||
|
||||
def _declared_mapping_info(cls):
|
||||
# deferred mapping
|
||||
if _DeferredMapperConfig.has_cls(cls):
|
||||
return _DeferredMapperConfig.config_for_cls(cls)
|
||||
# regular mapping
|
||||
elif _is_mapped_class(cls):
|
||||
return class_mapper(cls, configure=False)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_for_abstract(cls):
|
||||
if cls is object:
|
||||
return None
|
||||
|
||||
if _get_immediate_cls_attr(cls, '__abstract__', strict=True):
|
||||
for sup in cls.__bases__:
|
||||
sup = _resolve_for_abstract(sup)
|
||||
if sup is not None:
|
||||
return sup
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return cls
|
||||
|
||||
|
||||
def _get_immediate_cls_attr(cls, attrname, strict=False):
|
||||
"""return an attribute of the class that is either present directly
|
||||
on the class, e.g. not on a superclass, or is from a superclass but
|
||||
this superclass is a mixin, that is, not a descendant of
|
||||
the declarative base.
|
||||
|
||||
This is used to detect attributes that indicate something about
|
||||
a mapped class independently from any mapped classes that it may
|
||||
inherit from.
|
||||
|
||||
"""
|
||||
if not issubclass(cls, object):
|
||||
return None
|
||||
|
||||
for base in cls.__mro__:
|
||||
_is_declarative_inherits = hasattr(base, '_decl_class_registry')
|
||||
if attrname in base.__dict__ and (
|
||||
base is cls or
|
||||
((base in cls.__bases__ if strict else True)
|
||||
and not _is_declarative_inherits)
|
||||
):
|
||||
return getattr(base, attrname)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _as_declarative(cls, classname, dict_):
|
||||
global declared_attr, declarative_props
|
||||
if declared_attr is None:
|
||||
from .api import declared_attr
|
||||
declarative_props = (declared_attr, util.classproperty)
|
||||
|
||||
if _get_immediate_cls_attr(cls, '__abstract__', strict=True):
|
||||
return
|
||||
|
||||
_MapperConfig.setup_mapping(cls, classname, dict_)
|
||||
|
||||
|
||||
class _MapperConfig(object):
|
||||
|
||||
@classmethod
|
||||
def setup_mapping(cls, cls_, classname, dict_):
|
||||
defer_map = _get_immediate_cls_attr(
|
||||
cls_, '_sa_decl_prepare_nocascade', strict=True) or \
|
||||
hasattr(cls_, '_sa_decl_prepare')
|
||||
|
||||
if defer_map:
|
||||
cfg_cls = _DeferredMapperConfig
|
||||
else:
|
||||
cfg_cls = _MapperConfig
|
||||
cfg_cls(cls_, classname, dict_)
|
||||
|
||||
def __init__(self, cls_, classname, dict_):
|
||||
|
||||
self.cls = cls_
|
||||
|
||||
# dict_ will be a dictproxy, which we can't write to, and we need to!
|
||||
self.dict_ = dict(dict_)
|
||||
self.classname = classname
|
||||
self.mapped_table = None
|
||||
self.properties = util.OrderedDict()
|
||||
self.declared_columns = set()
|
||||
self.column_copies = {}
|
||||
self._setup_declared_events()
|
||||
|
||||
# temporary registry. While early 1.0 versions
|
||||
# set up the ClassManager here, by API contract
|
||||
# we can't do that until there's a mapper.
|
||||
self.cls._sa_declared_attr_reg = {}
|
||||
|
||||
self._scan_attributes()
|
||||
|
||||
clsregistry.add_class(self.classname, self.cls)
|
||||
|
||||
self._extract_mappable_attributes()
|
||||
|
||||
self._extract_declared_columns()
|
||||
|
||||
self._setup_table()
|
||||
|
||||
self._setup_inheritance()
|
||||
|
||||
self._early_mapping()
|
||||
|
||||
def _early_mapping(self):
|
||||
self.map()
|
||||
|
||||
def _setup_declared_events(self):
|
||||
if _get_immediate_cls_attr(self.cls, '__declare_last__'):
|
||||
@event.listens_for(mapper, "after_configured")
|
||||
def after_configured():
|
||||
self.cls.__declare_last__()
|
||||
|
||||
if _get_immediate_cls_attr(self.cls, '__declare_first__'):
|
||||
@event.listens_for(mapper, "before_configured")
|
||||
def before_configured():
|
||||
self.cls.__declare_first__()
|
||||
|
||||
def _scan_attributes(self):
|
||||
cls = self.cls
|
||||
dict_ = self.dict_
|
||||
column_copies = self.column_copies
|
||||
mapper_args_fn = None
|
||||
table_args = inherited_table_args = None
|
||||
tablename = None
|
||||
|
||||
for base in cls.__mro__:
|
||||
class_mapped = base is not cls and \
|
||||
_declared_mapping_info(base) is not None and \
|
||||
not _get_immediate_cls_attr(
|
||||
base, '_sa_decl_prepare_nocascade', strict=True)
|
||||
|
||||
if not class_mapped and base is not cls:
|
||||
self._produce_column_copies(base)
|
||||
|
||||
for name, obj in vars(base).items():
|
||||
if name == '__mapper_args__':
|
||||
if not mapper_args_fn and (
|
||||
not class_mapped or
|
||||
isinstance(obj, declarative_props)
|
||||
):
|
||||
# don't even invoke __mapper_args__ until
|
||||
# after we've determined everything about the
|
||||
# mapped table.
|
||||
# make a copy of it so a class-level dictionary
|
||||
# is not overwritten when we update column-based
|
||||
# arguments.
|
||||
mapper_args_fn = lambda: dict(cls.__mapper_args__)
|
||||
elif name == '__tablename__':
|
||||
if not tablename and (
|
||||
not class_mapped or
|
||||
isinstance(obj, declarative_props)
|
||||
):
|
||||
tablename = cls.__tablename__
|
||||
elif name == '__table_args__':
|
||||
if not table_args and (
|
||||
not class_mapped or
|
||||
isinstance(obj, declarative_props)
|
||||
):
|
||||
table_args = cls.__table_args__
|
||||
if not isinstance(
|
||||
table_args, (tuple, dict, type(None))):
|
||||
raise exc.ArgumentError(
|
||||
"__table_args__ value must be a tuple, "
|
||||
"dict, or None")
|
||||
if base is not cls:
|
||||
inherited_table_args = True
|
||||
elif class_mapped:
|
||||
if isinstance(obj, declarative_props):
|
||||
util.warn("Regular (i.e. not __special__) "
|
||||
"attribute '%s.%s' uses @declared_attr, "
|
||||
"but owning class %s is mapped - "
|
||||
"not applying to subclass %s."
|
||||
% (base.__name__, name, base, cls))
|
||||
continue
|
||||
elif base is not cls:
|
||||
# we're a mixin, abstract base, or something that is
|
||||
# acting like that for now.
|
||||
if isinstance(obj, Column):
|
||||
# already copied columns to the mapped class.
|
||||
continue
|
||||
elif isinstance(obj, MapperProperty):
|
||||
raise exc.InvalidRequestError(
|
||||
"Mapper properties (i.e. deferred,"
|
||||
"column_property(), relationship(), etc.) must "
|
||||
"be declared as @declared_attr callables "
|
||||
"on declarative mixin classes.")
|
||||
elif isinstance(obj, declarative_props):
|
||||
oldclassprop = isinstance(obj, util.classproperty)
|
||||
if not oldclassprop and obj._cascading:
|
||||
dict_[name] = column_copies[obj] = \
|
||||
ret = obj.__get__(obj, cls)
|
||||
setattr(cls, name, ret)
|
||||
else:
|
||||
if oldclassprop:
|
||||
util.warn_deprecated(
|
||||
"Use of sqlalchemy.util.classproperty on "
|
||||
"declarative classes is deprecated.")
|
||||
dict_[name] = column_copies[obj] = \
|
||||
ret = getattr(cls, name)
|
||||
if isinstance(ret, (Column, MapperProperty)) and \
|
||||
ret.doc is None:
|
||||
ret.doc = obj.__doc__
|
||||
|
||||
if inherited_table_args and not tablename:
|
||||
table_args = None
|
||||
|
||||
self.table_args = table_args
|
||||
self.tablename = tablename
|
||||
self.mapper_args_fn = mapper_args_fn
|
||||
|
||||
def _produce_column_copies(self, base):
|
||||
cls = self.cls
|
||||
dict_ = self.dict_
|
||||
column_copies = self.column_copies
|
||||
# copy mixin columns to the mapped class
|
||||
for name, obj in vars(base).items():
|
||||
if isinstance(obj, Column):
|
||||
if getattr(cls, name) is not obj:
|
||||
# if column has been overridden
|
||||
# (like by the InstrumentedAttribute of the
|
||||
# superclass), skip
|
||||
continue
|
||||
elif obj.foreign_keys:
|
||||
raise exc.InvalidRequestError(
|
||||
"Columns with foreign keys to other columns "
|
||||
"must be declared as @declared_attr callables "
|
||||
"on declarative mixin classes. ")
|
||||
elif name not in dict_ and not (
|
||||
'__table__' in dict_ and
|
||||
(obj.name or name) in dict_['__table__'].c
|
||||
):
|
||||
column_copies[obj] = copy_ = obj.copy()
|
||||
copy_._creation_order = obj._creation_order
|
||||
setattr(cls, name, copy_)
|
||||
dict_[name] = copy_
|
||||
|
||||
def _extract_mappable_attributes(self):
|
||||
cls = self.cls
|
||||
dict_ = self.dict_
|
||||
|
||||
our_stuff = self.properties
|
||||
|
||||
for k in list(dict_):
|
||||
|
||||
if k in ('__table__', '__tablename__', '__mapper_args__'):
|
||||
continue
|
||||
|
||||
value = dict_[k]
|
||||
if isinstance(value, declarative_props):
|
||||
value = getattr(cls, k)
|
||||
|
||||
elif isinstance(value, QueryableAttribute) and \
|
||||
value.class_ is not cls and \
|
||||
value.key != k:
|
||||
# detect a QueryableAttribute that's already mapped being
|
||||
# assigned elsewhere in userland, turn into a synonym()
|
||||
value = synonym(value.key)
|
||||
setattr(cls, k, value)
|
||||
|
||||
if (isinstance(value, tuple) and len(value) == 1 and
|
||||
isinstance(value[0], (Column, MapperProperty))):
|
||||
util.warn("Ignoring declarative-like tuple value of attribute "
|
||||
"%s: possibly a copy-and-paste error with a comma "
|
||||
"left at the end of the line?" % k)
|
||||
continue
|
||||
elif not isinstance(value, (Column, MapperProperty)):
|
||||
# using @declared_attr for some object that
|
||||
# isn't Column/MapperProperty; remove from the dict_
|
||||
# and place the evaluated value onto the class.
|
||||
if not k.startswith('__'):
|
||||
dict_.pop(k)
|
||||
setattr(cls, k, value)
|
||||
continue
|
||||
# we expect to see the name 'metadata' in some valid cases;
|
||||
# however at this point we see it's assigned to something trying
|
||||
# to be mapped, so raise for that.
|
||||
elif k == 'metadata':
|
||||
raise exc.InvalidRequestError(
|
||||
"Attribute name 'metadata' is reserved "
|
||||
"for the MetaData instance when using a "
|
||||
"declarative base class."
|
||||
)
|
||||
prop = clsregistry._deferred_relationship(cls, value)
|
||||
our_stuff[k] = prop
|
||||
|
||||
def _extract_declared_columns(self):
|
||||
our_stuff = self.properties
|
||||
|
||||
# set up attributes in the order they were created
|
||||
our_stuff.sort(key=lambda key: our_stuff[key]._creation_order)
|
||||
|
||||
# extract columns from the class dict
|
||||
declared_columns = self.declared_columns
|
||||
name_to_prop_key = collections.defaultdict(set)
|
||||
for key, c in list(our_stuff.items()):
|
||||
if isinstance(c, (ColumnProperty, CompositeProperty)):
|
||||
for col in c.columns:
|
||||
if isinstance(col, Column) and \
|
||||
col.table is None:
|
||||
_undefer_column_name(key, col)
|
||||
if not isinstance(c, CompositeProperty):
|
||||
name_to_prop_key[col.name].add(key)
|
||||
declared_columns.add(col)
|
||||
elif isinstance(c, Column):
|
||||
_undefer_column_name(key, c)
|
||||
name_to_prop_key[c.name].add(key)
|
||||
declared_columns.add(c)
|
||||
# if the column is the same name as the key,
|
||||
# remove it from the explicit properties dict.
|
||||
# the normal rules for assigning column-based properties
|
||||
# will take over, including precedence of columns
|
||||
# in multi-column ColumnProperties.
|
||||
if key == c.key:
|
||||
del our_stuff[key]
|
||||
|
||||
for name, keys in name_to_prop_key.items():
|
||||
if len(keys) > 1:
|
||||
util.warn(
|
||||
"On class %r, Column object %r named "
|
||||
"directly multiple times, "
|
||||
"only one will be used: %s. "
|
||||
"Consider using orm.synonym instead" %
|
||||
(self.classname, name, (", ".join(sorted(keys))))
|
||||
)
|
||||
|
||||
def _setup_table(self):
|
||||
cls = self.cls
|
||||
tablename = self.tablename
|
||||
table_args = self.table_args
|
||||
dict_ = self.dict_
|
||||
declared_columns = self.declared_columns
|
||||
|
||||
declared_columns = self.declared_columns = sorted(
|
||||
declared_columns, key=lambda c: c._creation_order)
|
||||
table = None
|
||||
|
||||
if hasattr(cls, '__table_cls__'):
|
||||
table_cls = util.unbound_method_to_callable(cls.__table_cls__)
|
||||
else:
|
||||
table_cls = Table
|
||||
|
||||
if '__table__' not in dict_:
|
||||
if tablename is not None:
|
||||
|
||||
args, table_kw = (), {}
|
||||
if table_args:
|
||||
if isinstance(table_args, dict):
|
||||
table_kw = table_args
|
||||
elif isinstance(table_args, tuple):
|
||||
if isinstance(table_args[-1], dict):
|
||||
args, table_kw = table_args[0:-1], table_args[-1]
|
||||
else:
|
||||
args = table_args
|
||||
|
||||
autoload = dict_.get('__autoload__')
|
||||
if autoload:
|
||||
table_kw['autoload'] = True
|
||||
|
||||
cls.__table__ = table = table_cls(
|
||||
tablename, cls.metadata,
|
||||
*(tuple(declared_columns) + tuple(args)),
|
||||
**table_kw)
|
||||
else:
|
||||
table = cls.__table__
|
||||
if declared_columns:
|
||||
for c in declared_columns:
|
||||
if not table.c.contains_column(c):
|
||||
raise exc.ArgumentError(
|
||||
"Can't add additional column %r when "
|
||||
"specifying __table__" % c.key
|
||||
)
|
||||
self.local_table = table
|
||||
|
||||
def _setup_inheritance(self):
|
||||
table = self.local_table
|
||||
cls = self.cls
|
||||
table_args = self.table_args
|
||||
declared_columns = self.declared_columns
|
||||
for c in cls.__bases__:
|
||||
c = _resolve_for_abstract(c)
|
||||
if c is None:
|
||||
continue
|
||||
if _declared_mapping_info(c) is not None and \
|
||||
not _get_immediate_cls_attr(
|
||||
c, '_sa_decl_prepare_nocascade', strict=True):
|
||||
self.inherits = c
|
||||
break
|
||||
else:
|
||||
self.inherits = None
|
||||
|
||||
if table is None and self.inherits is None and \
|
||||
not _get_immediate_cls_attr(cls, '__no_table__'):
|
||||
|
||||
raise exc.InvalidRequestError(
|
||||
"Class %r does not have a __table__ or __tablename__ "
|
||||
"specified and does not inherit from an existing "
|
||||
"table-mapped class." % cls
|
||||
)
|
||||
elif self.inherits:
|
||||
inherited_mapper = _declared_mapping_info(self.inherits)
|
||||
inherited_table = inherited_mapper.local_table
|
||||
inherited_mapped_table = inherited_mapper.mapped_table
|
||||
|
||||
if table is None:
|
||||
# single table inheritance.
|
||||
# ensure no table args
|
||||
if table_args:
|
||||
raise exc.ArgumentError(
|
||||
"Can't place __table_args__ on an inherited class "
|
||||
"with no table."
|
||||
)
|
||||
# add any columns declared here to the inherited table.
|
||||
for c in declared_columns:
|
||||
if c.primary_key:
|
||||
raise exc.ArgumentError(
|
||||
"Can't place primary key columns on an inherited "
|
||||
"class with no table."
|
||||
)
|
||||
if c.name in inherited_table.c:
|
||||
if inherited_table.c[c.name] is c:
|
||||
continue
|
||||
raise exc.ArgumentError(
|
||||
"Column '%s' on class %s conflicts with "
|
||||
"existing column '%s'" %
|
||||
(c, cls, inherited_table.c[c.name])
|
||||
)
|
||||
inherited_table.append_column(c)
|
||||
if inherited_mapped_table is not None and \
|
||||
inherited_mapped_table is not inherited_table:
|
||||
inherited_mapped_table._refresh_for_new_column(c)
|
||||
|
||||
def _prepare_mapper_arguments(self):
|
||||
properties = self.properties
|
||||
if self.mapper_args_fn:
|
||||
mapper_args = self.mapper_args_fn()
|
||||
else:
|
||||
mapper_args = {}
|
||||
|
||||
# make sure that column copies are used rather
|
||||
# than the original columns from any mixins
|
||||
for k in ('version_id_col', 'polymorphic_on',):
|
||||
if k in mapper_args:
|
||||
v = mapper_args[k]
|
||||
mapper_args[k] = self.column_copies.get(v, v)
|
||||
|
||||
assert 'inherits' not in mapper_args, \
|
||||
"Can't specify 'inherits' explicitly with declarative mappings"
|
||||
|
||||
if self.inherits:
|
||||
mapper_args['inherits'] = self.inherits
|
||||
|
||||
if self.inherits and not mapper_args.get('concrete', False):
|
||||
# single or joined inheritance
|
||||
# exclude any cols on the inherited table which are
|
||||
# not mapped on the parent class, to avoid
|
||||
# mapping columns specific to sibling/nephew classes
|
||||
inherited_mapper = _declared_mapping_info(self.inherits)
|
||||
inherited_table = inherited_mapper.local_table
|
||||
|
||||
if 'exclude_properties' not in mapper_args:
|
||||
mapper_args['exclude_properties'] = exclude_properties = \
|
||||
set(
|
||||
[c.key for c in inherited_table.c
|
||||
if c not in inherited_mapper._columntoproperty]
|
||||
).union(
|
||||
inherited_mapper.exclude_properties or ()
|
||||
)
|
||||
exclude_properties.difference_update(
|
||||
[c.key for c in self.declared_columns])
|
||||
|
||||
# look through columns in the current mapper that
|
||||
# are keyed to a propname different than the colname
|
||||
# (if names were the same, we'd have popped it out above,
|
||||
# in which case the mapper makes this combination).
|
||||
# See if the superclass has a similar column property.
|
||||
# If so, join them together.
|
||||
for k, col in list(properties.items()):
|
||||
if not isinstance(col, expression.ColumnElement):
|
||||
continue
|
||||
if k in inherited_mapper._props:
|
||||
p = inherited_mapper._props[k]
|
||||
if isinstance(p, ColumnProperty):
|
||||
# note here we place the subclass column
|
||||
# first. See [ticket:1892] for background.
|
||||
properties[k] = [col] + p.columns
|
||||
result_mapper_args = mapper_args.copy()
|
||||
result_mapper_args['properties'] = properties
|
||||
self.mapper_args = result_mapper_args
|
||||
|
||||
def map(self):
|
||||
self._prepare_mapper_arguments()
|
||||
if hasattr(self.cls, '__mapper_cls__'):
|
||||
mapper_cls = util.unbound_method_to_callable(
|
||||
self.cls.__mapper_cls__)
|
||||
else:
|
||||
mapper_cls = mapper
|
||||
|
||||
self.cls.__mapper__ = mp_ = mapper_cls(
|
||||
self.cls,
|
||||
self.local_table,
|
||||
**self.mapper_args
|
||||
)
|
||||
del self.cls._sa_declared_attr_reg
|
||||
return mp_
|
||||
|
||||
|
||||
class _DeferredMapperConfig(_MapperConfig):
|
||||
_configs = util.OrderedDict()
|
||||
|
||||
def _early_mapping(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
return self._cls()
|
||||
|
||||
@cls.setter
|
||||
def cls(self, class_):
|
||||
self._cls = weakref.ref(class_, self._remove_config_cls)
|
||||
self._configs[self._cls] = self
|
||||
|
||||
@classmethod
|
||||
def _remove_config_cls(cls, ref):
|
||||
cls._configs.pop(ref, None)
|
||||
|
||||
@classmethod
|
||||
def has_cls(cls, class_):
|
||||
# 2.6 fails on weakref if class_ is an old style class
|
||||
return isinstance(class_, type) and \
|
||||
weakref.ref(class_) in cls._configs
|
||||
|
||||
@classmethod
|
||||
def config_for_cls(cls, class_):
|
||||
return cls._configs[weakref.ref(class_)]
|
||||
|
||||
@classmethod
|
||||
def classes_for_base(cls, base_cls, sort=True):
|
||||
classes_for_base = [m for m in cls._configs.values()
|
||||
if issubclass(m.cls, base_cls)]
|
||||
if not sort:
|
||||
return classes_for_base
|
||||
|
||||
all_m_by_cls = dict(
|
||||
(m.cls, m)
|
||||
for m in classes_for_base
|
||||
)
|
||||
|
||||
tuples = []
|
||||
for m_cls in all_m_by_cls:
|
||||
tuples.extend(
|
||||
(all_m_by_cls[base_cls], all_m_by_cls[m_cls])
|
||||
for base_cls in m_cls.__bases__
|
||||
if base_cls in all_m_by_cls
|
||||
)
|
||||
return list(
|
||||
topological.sort(
|
||||
tuples,
|
||||
classes_for_base
|
||||
)
|
||||
)
|
||||
|
||||
def map(self):
|
||||
self._configs.pop(self._cls, None)
|
||||
return super(_DeferredMapperConfig, self).map()
|
||||
|
||||
|
||||
def _add_attribute(cls, key, value):
|
||||
"""add an attribute to an existing declarative class.
|
||||
|
||||
This runs through the logic to determine MapperProperty,
|
||||
adds it to the Mapper, adds a column to the mapped Table, etc.
|
||||
|
||||
"""
|
||||
|
||||
if '__mapper__' in cls.__dict__:
|
||||
if isinstance(value, Column):
|
||||
_undefer_column_name(key, value)
|
||||
cls.__table__.append_column(value)
|
||||
cls.__mapper__.add_property(key, value)
|
||||
elif isinstance(value, ColumnProperty):
|
||||
for col in value.columns:
|
||||
if isinstance(col, Column) and col.table is None:
|
||||
_undefer_column_name(key, col)
|
||||
cls.__table__.append_column(col)
|
||||
cls.__mapper__.add_property(key, value)
|
||||
elif isinstance(value, MapperProperty):
|
||||
cls.__mapper__.add_property(
|
||||
key,
|
||||
clsregistry._deferred_relationship(cls, value)
|
||||
)
|
||||
elif isinstance(value, QueryableAttribute) and value.key != key:
|
||||
# detect a QueryableAttribute that's already mapped being
|
||||
# assigned elsewhere in userland, turn into a synonym()
|
||||
value = synonym(value.key)
|
||||
cls.__mapper__.add_property(
|
||||
key,
|
||||
clsregistry._deferred_relationship(cls, value)
|
||||
)
|
||||
else:
|
||||
type.__setattr__(cls, key, value)
|
||||
else:
|
||||
type.__setattr__(cls, key, value)
|
||||
|
||||
|
||||
def _declarative_constructor(self, **kwargs):
|
||||
"""A simple constructor that allows initialization from kwargs.
|
||||
|
||||
Sets attributes on the constructed instance using the names and
|
||||
values in ``kwargs``.
|
||||
|
||||
Only keys that are present as
|
||||
attributes of the instance's class are allowed. These could be,
|
||||
for example, any mapped columns or relationships.
|
||||
"""
|
||||
cls_ = type(self)
|
||||
for k in kwargs:
|
||||
if not hasattr(cls_, k):
|
||||
raise TypeError(
|
||||
"%r is an invalid keyword argument for %s" %
|
||||
(k, cls_.__name__))
|
||||
setattr(self, k, kwargs[k])
|
||||
_declarative_constructor.__name__ = '__init__'
|
||||
|
||||
|
||||
def _undefer_column_name(key, column):
|
||||
if column.key is None:
|
||||
column.key = key
|
||||
if column.name is None:
|
||||
column.name = key
|
||||
328
sqlalchemy/ext/declarative/clsregistry.py
Normal file
328
sqlalchemy/ext/declarative/clsregistry.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# ext/declarative/clsregistry.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""Routines to handle the string class registry used by declarative.
|
||||
|
||||
This system allows specification of classes and expressions used in
|
||||
:func:`.relationship` using strings.
|
||||
|
||||
"""
|
||||
from ...orm.properties import ColumnProperty, RelationshipProperty, \
|
||||
SynonymProperty
|
||||
from ...schema import _get_table_key
|
||||
from ...orm import class_mapper, interfaces
|
||||
from ... import util
|
||||
from ... import inspection
|
||||
from ... import exc
|
||||
import weakref
|
||||
|
||||
# strong references to registries which we place in
|
||||
# the _decl_class_registry, which is usually weak referencing.
|
||||
# the internal registries here link to classes with weakrefs and remove
|
||||
# themselves when all references to contained classes are removed.
|
||||
_registries = set()
|
||||
|
||||
|
||||
def add_class(classname, cls):
|
||||
"""Add a class to the _decl_class_registry associated with the
|
||||
given declarative class.
|
||||
|
||||
"""
|
||||
if classname in cls._decl_class_registry:
|
||||
# class already exists.
|
||||
existing = cls._decl_class_registry[classname]
|
||||
if not isinstance(existing, _MultipleClassMarker):
|
||||
existing = \
|
||||
cls._decl_class_registry[classname] = \
|
||||
_MultipleClassMarker([cls, existing])
|
||||
else:
|
||||
cls._decl_class_registry[classname] = cls
|
||||
|
||||
try:
|
||||
root_module = cls._decl_class_registry['_sa_module_registry']
|
||||
except KeyError:
|
||||
cls._decl_class_registry['_sa_module_registry'] = \
|
||||
root_module = _ModuleMarker('_sa_module_registry', None)
|
||||
|
||||
tokens = cls.__module__.split(".")
|
||||
|
||||
# build up a tree like this:
|
||||
# modulename: myapp.snacks.nuts
|
||||
#
|
||||
# myapp->snack->nuts->(classes)
|
||||
# snack->nuts->(classes)
|
||||
# nuts->(classes)
|
||||
#
|
||||
# this allows partial token paths to be used.
|
||||
while tokens:
|
||||
token = tokens.pop(0)
|
||||
module = root_module.get_module(token)
|
||||
for token in tokens:
|
||||
module = module.get_module(token)
|
||||
module.add_class(classname, cls)
|
||||
|
||||
|
||||
class _MultipleClassMarker(object):
|
||||
"""refers to multiple classes of the same name
|
||||
within _decl_class_registry.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = 'on_remove', 'contents', '__weakref__'
|
||||
|
||||
def __init__(self, classes, on_remove=None):
|
||||
self.on_remove = on_remove
|
||||
self.contents = set([
|
||||
weakref.ref(item, self._remove_item) for item in classes])
|
||||
_registries.add(self)
|
||||
|
||||
def __iter__(self):
|
||||
return (ref() for ref in self.contents)
|
||||
|
||||
def attempt_get(self, path, key):
|
||||
if len(self.contents) > 1:
|
||||
raise exc.InvalidRequestError(
|
||||
"Multiple classes found for path \"%s\" "
|
||||
"in the registry of this declarative "
|
||||
"base. Please use a fully module-qualified path." %
|
||||
(".".join(path + [key]))
|
||||
)
|
||||
else:
|
||||
ref = list(self.contents)[0]
|
||||
cls = ref()
|
||||
if cls is None:
|
||||
raise NameError(key)
|
||||
return cls
|
||||
|
||||
def _remove_item(self, ref):
|
||||
self.contents.remove(ref)
|
||||
if not self.contents:
|
||||
_registries.discard(self)
|
||||
if self.on_remove:
|
||||
self.on_remove()
|
||||
|
||||
def add_item(self, item):
|
||||
# protect against class registration race condition against
|
||||
# asynchronous garbage collection calling _remove_item,
|
||||
# [ticket:3208]
|
||||
modules = set([
|
||||
cls.__module__ for cls in
|
||||
[ref() for ref in self.contents] if cls is not None])
|
||||
if item.__module__ in modules:
|
||||
util.warn(
|
||||
"This declarative base already contains a class with the "
|
||||
"same class name and module name as %s.%s, and will "
|
||||
"be replaced in the string-lookup table." % (
|
||||
item.__module__,
|
||||
item.__name__
|
||||
)
|
||||
)
|
||||
self.contents.add(weakref.ref(item, self._remove_item))
|
||||
|
||||
|
||||
class _ModuleMarker(object):
|
||||
""""refers to a module name within
|
||||
_decl_class_registry.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = 'parent', 'name', 'contents', 'mod_ns', 'path', '__weakref__'
|
||||
|
||||
def __init__(self, name, parent):
|
||||
self.parent = parent
|
||||
self.name = name
|
||||
self.contents = {}
|
||||
self.mod_ns = _ModNS(self)
|
||||
if self.parent:
|
||||
self.path = self.parent.path + [self.name]
|
||||
else:
|
||||
self.path = []
|
||||
_registries.add(self)
|
||||
|
||||
def __contains__(self, name):
|
||||
return name in self.contents
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.contents[name]
|
||||
|
||||
def _remove_item(self, name):
|
||||
self.contents.pop(name, None)
|
||||
if not self.contents and self.parent is not None:
|
||||
self.parent._remove_item(self.name)
|
||||
_registries.discard(self)
|
||||
|
||||
def resolve_attr(self, key):
|
||||
return getattr(self.mod_ns, key)
|
||||
|
||||
def get_module(self, name):
|
||||
if name not in self.contents:
|
||||
marker = _ModuleMarker(name, self)
|
||||
self.contents[name] = marker
|
||||
else:
|
||||
marker = self.contents[name]
|
||||
return marker
|
||||
|
||||
def add_class(self, name, cls):
|
||||
if name in self.contents:
|
||||
existing = self.contents[name]
|
||||
existing.add_item(cls)
|
||||
else:
|
||||
existing = self.contents[name] = \
|
||||
_MultipleClassMarker([cls],
|
||||
on_remove=lambda: self._remove_item(name))
|
||||
|
||||
|
||||
class _ModNS(object):
|
||||
__slots__ = '__parent',
|
||||
|
||||
def __init__(self, parent):
|
||||
self.__parent = parent
|
||||
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
value = self.__parent.contents[key]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if value is not None:
|
||||
if isinstance(value, _ModuleMarker):
|
||||
return value.mod_ns
|
||||
else:
|
||||
assert isinstance(value, _MultipleClassMarker)
|
||||
return value.attempt_get(self.__parent.path, key)
|
||||
raise AttributeError("Module %r has no mapped classes "
|
||||
"registered under the name %r" % (
|
||||
self.__parent.name, key))
|
||||
|
||||
|
||||
class _GetColumns(object):
|
||||
__slots__ = 'cls',
|
||||
|
||||
def __init__(self, cls):
|
||||
self.cls = cls
|
||||
|
||||
def __getattr__(self, key):
|
||||
mp = class_mapper(self.cls, configure=False)
|
||||
if mp:
|
||||
if key not in mp.all_orm_descriptors:
|
||||
raise exc.InvalidRequestError(
|
||||
"Class %r does not have a mapped column named %r"
|
||||
% (self.cls, key))
|
||||
|
||||
desc = mp.all_orm_descriptors[key]
|
||||
if desc.extension_type is interfaces.NOT_EXTENSION:
|
||||
prop = desc.property
|
||||
if isinstance(prop, SynonymProperty):
|
||||
key = prop.name
|
||||
elif not isinstance(prop, ColumnProperty):
|
||||
raise exc.InvalidRequestError(
|
||||
"Property %r is not an instance of"
|
||||
" ColumnProperty (i.e. does not correspond"
|
||||
" directly to a Column)." % key)
|
||||
return getattr(self.cls, key)
|
||||
|
||||
inspection._inspects(_GetColumns)(
|
||||
lambda target: inspection.inspect(target.cls))
|
||||
|
||||
|
||||
class _GetTable(object):
|
||||
__slots__ = 'key', 'metadata'
|
||||
|
||||
def __init__(self, key, metadata):
|
||||
self.key = key
|
||||
self.metadata = metadata
|
||||
|
||||
def __getattr__(self, key):
|
||||
return self.metadata.tables[
|
||||
_get_table_key(key, self.key)
|
||||
]
|
||||
|
||||
|
||||
def _determine_container(key, value):
|
||||
if isinstance(value, _MultipleClassMarker):
|
||||
value = value.attempt_get([], key)
|
||||
return _GetColumns(value)
|
||||
|
||||
|
||||
class _class_resolver(object):
|
||||
def __init__(self, cls, prop, fallback, arg):
|
||||
self.cls = cls
|
||||
self.prop = prop
|
||||
self.arg = self._declarative_arg = arg
|
||||
self.fallback = fallback
|
||||
self._dict = util.PopulateDict(self._access_cls)
|
||||
self._resolvers = ()
|
||||
|
||||
def _access_cls(self, key):
|
||||
cls = self.cls
|
||||
if key in cls._decl_class_registry:
|
||||
return _determine_container(key, cls._decl_class_registry[key])
|
||||
elif key in cls.metadata.tables:
|
||||
return cls.metadata.tables[key]
|
||||
elif key in cls.metadata._schemas:
|
||||
return _GetTable(key, cls.metadata)
|
||||
elif '_sa_module_registry' in cls._decl_class_registry and \
|
||||
key in cls._decl_class_registry['_sa_module_registry']:
|
||||
registry = cls._decl_class_registry['_sa_module_registry']
|
||||
return registry.resolve_attr(key)
|
||||
elif self._resolvers:
|
||||
for resolv in self._resolvers:
|
||||
value = resolv(key)
|
||||
if value is not None:
|
||||
return value
|
||||
|
||||
return self.fallback[key]
|
||||
|
||||
def __call__(self):
|
||||
try:
|
||||
x = eval(self.arg, globals(), self._dict)
|
||||
|
||||
if isinstance(x, _GetColumns):
|
||||
return x.cls
|
||||
else:
|
||||
return x
|
||||
except NameError as n:
|
||||
raise exc.InvalidRequestError(
|
||||
"When initializing mapper %s, expression %r failed to "
|
||||
"locate a name (%r). If this is a class name, consider "
|
||||
"adding this relationship() to the %r class after "
|
||||
"both dependent classes have been defined." %
|
||||
(self.prop.parent, self.arg, n.args[0], self.cls)
|
||||
)
|
||||
|
||||
|
||||
def _resolver(cls, prop):
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import foreign, remote
|
||||
|
||||
fallback = sqlalchemy.__dict__.copy()
|
||||
fallback.update({'foreign': foreign, 'remote': remote})
|
||||
|
||||
def resolve_arg(arg):
|
||||
return _class_resolver(cls, prop, fallback, arg)
|
||||
return resolve_arg
|
||||
|
||||
|
||||
def _deferred_relationship(cls, prop):
|
||||
|
||||
if isinstance(prop, RelationshipProperty):
|
||||
resolve_arg = _resolver(cls, prop)
|
||||
|
||||
for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin',
|
||||
'secondary', '_user_defined_foreign_keys', 'remote_side'):
|
||||
v = getattr(prop, attr)
|
||||
if isinstance(v, util.string_types):
|
||||
setattr(prop, attr, resolve_arg(v))
|
||||
|
||||
if prop.backref and isinstance(prop.backref, tuple):
|
||||
key, kwargs = prop.backref
|
||||
for attr in ('primaryjoin', 'secondaryjoin', 'secondary',
|
||||
'foreign_keys', 'remote_side', 'order_by'):
|
||||
if attr in kwargs and isinstance(kwargs[attr],
|
||||
util.string_types):
|
||||
kwargs[attr] = resolve_arg(kwargs[attr])
|
||||
|
||||
return prop
|
||||
841
sqlalchemy/ext/hybrid.py
Normal file
841
sqlalchemy/ext/hybrid.py
Normal file
@@ -0,0 +1,841 @@
|
||||
# ext/hybrid.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
r"""Define attributes on ORM-mapped classes that have "hybrid" behavior.
|
||||
|
||||
"hybrid" means the attribute has distinct behaviors defined at the
|
||||
class level and at the instance level.
|
||||
|
||||
The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of
|
||||
method decorator, is around 50 lines of code and has almost no
|
||||
dependencies on the rest of SQLAlchemy. It can, in theory, work with
|
||||
any descriptor-based expression system.
|
||||
|
||||
Consider a mapping ``Interval``, representing integer ``start`` and ``end``
|
||||
values. We can define higher level functions on mapped classes that produce
|
||||
SQL expressions at the class level, and Python expression evaluation at the
|
||||
instance level. Below, each function decorated with :class:`.hybrid_method` or
|
||||
:class:`.hybrid_property` may receive ``self`` as an instance of the class, or
|
||||
as the class itself::
|
||||
|
||||
from sqlalchemy import Column, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, aliased
|
||||
from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Interval(Base):
|
||||
__tablename__ = 'interval'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
start = Column(Integer, nullable=False)
|
||||
end = Column(Integer, nullable=False)
|
||||
|
||||
def __init__(self, start, end):
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
@hybrid_property
|
||||
def length(self):
|
||||
return self.end - self.start
|
||||
|
||||
@hybrid_method
|
||||
def contains(self, point):
|
||||
return (self.start <= point) & (point <= self.end)
|
||||
|
||||
@hybrid_method
|
||||
def intersects(self, other):
|
||||
return self.contains(other.start) | self.contains(other.end)
|
||||
|
||||
Above, the ``length`` property returns the difference between the
|
||||
``end`` and ``start`` attributes. With an instance of ``Interval``,
|
||||
this subtraction occurs in Python, using normal Python descriptor
|
||||
mechanics::
|
||||
|
||||
>>> i1 = Interval(5, 10)
|
||||
>>> i1.length
|
||||
5
|
||||
|
||||
When dealing with the ``Interval`` class itself, the :class:`.hybrid_property`
|
||||
descriptor evaluates the function body given the ``Interval`` class as
|
||||
the argument, which when evaluated with SQLAlchemy expression mechanics
|
||||
returns a new SQL expression::
|
||||
|
||||
>>> print Interval.length
|
||||
interval."end" - interval.start
|
||||
|
||||
>>> print Session().query(Interval).filter(Interval.length > 10)
|
||||
SELECT interval.id AS interval_id, interval.start AS interval_start,
|
||||
interval."end" AS interval_end
|
||||
FROM interval
|
||||
WHERE interval."end" - interval.start > :param_1
|
||||
|
||||
ORM methods such as :meth:`~.Query.filter_by` generally use ``getattr()`` to
|
||||
locate attributes, so can also be used with hybrid attributes::
|
||||
|
||||
>>> print Session().query(Interval).filter_by(length=5)
|
||||
SELECT interval.id AS interval_id, interval.start AS interval_start,
|
||||
interval."end" AS interval_end
|
||||
FROM interval
|
||||
WHERE interval."end" - interval.start = :param_1
|
||||
|
||||
The ``Interval`` class example also illustrates two methods,
|
||||
``contains()`` and ``intersects()``, decorated with
|
||||
:class:`.hybrid_method`. This decorator applies the same idea to
|
||||
methods that :class:`.hybrid_property` applies to attributes. The
|
||||
methods return boolean values, and take advantage of the Python ``|``
|
||||
and ``&`` bitwise operators to produce equivalent instance-level and
|
||||
SQL expression-level boolean behavior::
|
||||
|
||||
>>> i1.contains(6)
|
||||
True
|
||||
>>> i1.contains(15)
|
||||
False
|
||||
>>> i1.intersects(Interval(7, 18))
|
||||
True
|
||||
>>> i1.intersects(Interval(25, 29))
|
||||
False
|
||||
|
||||
>>> print Session().query(Interval).filter(Interval.contains(15))
|
||||
SELECT interval.id AS interval_id, interval.start AS interval_start,
|
||||
interval."end" AS interval_end
|
||||
FROM interval
|
||||
WHERE interval.start <= :start_1 AND interval."end" > :end_1
|
||||
|
||||
>>> ia = aliased(Interval)
|
||||
>>> print Session().query(Interval, ia).filter(Interval.intersects(ia))
|
||||
SELECT interval.id AS interval_id, interval.start AS interval_start,
|
||||
interval."end" AS interval_end, interval_1.id AS interval_1_id,
|
||||
interval_1.start AS interval_1_start, interval_1."end" AS interval_1_end
|
||||
FROM interval, interval AS interval_1
|
||||
WHERE interval.start <= interval_1.start
|
||||
AND interval."end" > interval_1.start
|
||||
OR interval.start <= interval_1."end"
|
||||
AND interval."end" > interval_1."end"
|
||||
|
||||
Defining Expression Behavior Distinct from Attribute Behavior
|
||||
--------------------------------------------------------------
|
||||
|
||||
Our usage of the ``&`` and ``|`` bitwise operators above was
|
||||
fortunate, considering our functions operated on two boolean values to
|
||||
return a new one. In many cases, the construction of an in-Python
|
||||
function and a SQLAlchemy SQL expression have enough differences that
|
||||
two separate Python expressions should be defined. The
|
||||
:mod:`~sqlalchemy.ext.hybrid` decorators define the
|
||||
:meth:`.hybrid_property.expression` modifier for this purpose. As an
|
||||
example we'll define the radius of the interval, which requires the
|
||||
usage of the absolute value function::
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
class Interval(object):
|
||||
# ...
|
||||
|
||||
@hybrid_property
|
||||
def radius(self):
|
||||
return abs(self.length) / 2
|
||||
|
||||
@radius.expression
|
||||
def radius(cls):
|
||||
return func.abs(cls.length) / 2
|
||||
|
||||
Above the Python function ``abs()`` is used for instance-level
|
||||
operations, the SQL function ``ABS()`` is used via the :data:`.func`
|
||||
object for class-level expressions::
|
||||
|
||||
>>> i1.radius
|
||||
2
|
||||
|
||||
>>> print Session().query(Interval).filter(Interval.radius > 5)
|
||||
SELECT interval.id AS interval_id, interval.start AS interval_start,
|
||||
interval."end" AS interval_end
|
||||
FROM interval
|
||||
WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1
|
||||
|
||||
Defining Setters
|
||||
----------------
|
||||
|
||||
Hybrid properties can also define setter methods. If we wanted
|
||||
``length`` above, when set, to modify the endpoint value::
|
||||
|
||||
class Interval(object):
|
||||
# ...
|
||||
|
||||
@hybrid_property
|
||||
def length(self):
|
||||
return self.end - self.start
|
||||
|
||||
@length.setter
|
||||
def length(self, value):
|
||||
self.end = self.start + value
|
||||
|
||||
The ``length(self, value)`` method is now called upon set::
|
||||
|
||||
>>> i1 = Interval(5, 10)
|
||||
>>> i1.length
|
||||
5
|
||||
>>> i1.length = 12
|
||||
>>> i1.end
|
||||
17
|
||||
|
||||
Working with Relationships
|
||||
--------------------------
|
||||
|
||||
There's no essential difference when creating hybrids that work with
|
||||
related objects as opposed to column-based data. The need for distinct
|
||||
expressions tends to be greater. The two variants we'll illustrate
|
||||
are the "join-dependent" hybrid, and the "correlated subquery" hybrid.
|
||||
|
||||
Join-Dependent Relationship Hybrid
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Consider the following declarative
|
||||
mapping which relates a ``User`` to a ``SavingsAccount``::
|
||||
|
||||
from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class SavingsAccount(Base):
|
||||
__tablename__ = 'account'
|
||||
id = Column(Integer, primary_key=True)
|
||||
user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
|
||||
balance = Column(Numeric(15, 5))
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
|
||||
accounts = relationship("SavingsAccount", backref="owner")
|
||||
|
||||
@hybrid_property
|
||||
def balance(self):
|
||||
if self.accounts:
|
||||
return self.accounts[0].balance
|
||||
else:
|
||||
return None
|
||||
|
||||
@balance.setter
|
||||
def balance(self, value):
|
||||
if not self.accounts:
|
||||
account = Account(owner=self)
|
||||
else:
|
||||
account = self.accounts[0]
|
||||
account.balance = value
|
||||
|
||||
@balance.expression
|
||||
def balance(cls):
|
||||
return SavingsAccount.balance
|
||||
|
||||
The above hybrid property ``balance`` works with the first
|
||||
``SavingsAccount`` entry in the list of accounts for this user. The
|
||||
in-Python getter/setter methods can treat ``accounts`` as a Python
|
||||
list available on ``self``.
|
||||
|
||||
However, at the expression level, it's expected that the ``User`` class will
|
||||
be used in an appropriate context such that an appropriate join to
|
||||
``SavingsAccount`` will be present::
|
||||
|
||||
>>> print Session().query(User, User.balance).\
|
||||
... join(User.accounts).filter(User.balance > 5000)
|
||||
SELECT "user".id AS user_id, "user".name AS user_name,
|
||||
account.balance AS account_balance
|
||||
FROM "user" JOIN account ON "user".id = account.user_id
|
||||
WHERE account.balance > :balance_1
|
||||
|
||||
Note however, that while the instance level accessors need to worry
|
||||
about whether ``self.accounts`` is even present, this issue expresses
|
||||
itself differently at the SQL expression level, where we basically
|
||||
would use an outer join::
|
||||
|
||||
>>> from sqlalchemy import or_
|
||||
>>> print (Session().query(User, User.balance).outerjoin(User.accounts).
|
||||
... filter(or_(User.balance < 5000, User.balance == None)))
|
||||
SELECT "user".id AS user_id, "user".name AS user_name,
|
||||
account.balance AS account_balance
|
||||
FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id
|
||||
WHERE account.balance < :balance_1 OR account.balance IS NULL
|
||||
|
||||
Correlated Subquery Relationship Hybrid
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We can, of course, forego being dependent on the enclosing query's usage
|
||||
of joins in favor of the correlated subquery, which can portably be packed
|
||||
into a single column expression. A correlated subquery is more portable, but
|
||||
often performs more poorly at the SQL level. Using the same technique
|
||||
illustrated at :ref:`mapper_column_property_sql_expressions`,
|
||||
we can adjust our ``SavingsAccount`` example to aggregate the balances for
|
||||
*all* accounts, and use a correlated subquery for the column expression::
|
||||
|
||||
from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy import select, func
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class SavingsAccount(Base):
|
||||
__tablename__ = 'account'
|
||||
id = Column(Integer, primary_key=True)
|
||||
user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
|
||||
balance = Column(Numeric(15, 5))
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
|
||||
accounts = relationship("SavingsAccount", backref="owner")
|
||||
|
||||
@hybrid_property
|
||||
def balance(self):
|
||||
return sum(acc.balance for acc in self.accounts)
|
||||
|
||||
@balance.expression
|
||||
def balance(cls):
|
||||
return select([func.sum(SavingsAccount.balance)]).\
|
||||
where(SavingsAccount.user_id==cls.id).\
|
||||
label('total_balance')
|
||||
|
||||
The above recipe will give us the ``balance`` column which renders
|
||||
a correlated SELECT::
|
||||
|
||||
>>> print s.query(User).filter(User.balance > 400)
|
||||
SELECT "user".id AS user_id, "user".name AS user_name
|
||||
FROM "user"
|
||||
WHERE (SELECT sum(account.balance) AS sum_1
|
||||
FROM account
|
||||
WHERE account.user_id = "user".id) > :param_1
|
||||
|
||||
.. _hybrid_custom_comparators:
|
||||
|
||||
Building Custom Comparators
|
||||
---------------------------
|
||||
|
||||
The hybrid property also includes a helper that allows construction of
|
||||
custom comparators. A comparator object allows one to customize the
|
||||
behavior of each SQLAlchemy expression operator individually. They
|
||||
are useful when creating custom types that have some highly
|
||||
idiosyncratic behavior on the SQL side.
|
||||
|
||||
The example class below allows case-insensitive comparisons on the attribute
|
||||
named ``word_insensitive``::
|
||||
|
||||
from sqlalchemy.ext.hybrid import Comparator, hybrid_property
|
||||
from sqlalchemy import func, Column, Integer, String
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class CaseInsensitiveComparator(Comparator):
|
||||
def __eq__(self, other):
|
||||
return func.lower(self.__clause_element__()) == func.lower(other)
|
||||
|
||||
class SearchWord(Base):
|
||||
__tablename__ = 'searchword'
|
||||
id = Column(Integer, primary_key=True)
|
||||
word = Column(String(255), nullable=False)
|
||||
|
||||
@hybrid_property
|
||||
def word_insensitive(self):
|
||||
return self.word.lower()
|
||||
|
||||
@word_insensitive.comparator
|
||||
def word_insensitive(cls):
|
||||
return CaseInsensitiveComparator(cls.word)
|
||||
|
||||
Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()``
|
||||
SQL function to both sides::
|
||||
|
||||
>>> print Session().query(SearchWord).filter_by(word_insensitive="Trucks")
|
||||
SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
|
||||
FROM searchword
|
||||
WHERE lower(searchword.word) = lower(:lower_1)
|
||||
|
||||
The ``CaseInsensitiveComparator`` above implements part of the
|
||||
:class:`.ColumnOperators` interface. A "coercion" operation like
|
||||
lowercasing can be applied to all comparison operations (i.e. ``eq``,
|
||||
``lt``, ``gt``, etc.) using :meth:`.Operators.operate`::
|
||||
|
||||
class CaseInsensitiveComparator(Comparator):
|
||||
def operate(self, op, other):
|
||||
return op(func.lower(self.__clause_element__()), func.lower(other))
|
||||
|
||||
Hybrid Value Objects
|
||||
--------------------
|
||||
|
||||
Note in our previous example, if we were to compare the
|
||||
``word_insensitive`` attribute of a ``SearchWord`` instance to a plain
|
||||
Python string, the plain Python string would not be coerced to lower
|
||||
case - the ``CaseInsensitiveComparator`` we built, being returned by
|
||||
``@word_insensitive.comparator``, only applies to the SQL side.
|
||||
|
||||
A more comprehensive form of the custom comparator is to construct a
|
||||
*Hybrid Value Object*. This technique applies the target value or
|
||||
expression to a value object which is then returned by the accessor in
|
||||
all cases. The value object allows control of all operations upon
|
||||
the value as well as how compared values are treated, both on the SQL
|
||||
expression side as well as the Python value side. Replacing the
|
||||
previous ``CaseInsensitiveComparator`` class with a new
|
||||
``CaseInsensitiveWord`` class::
|
||||
|
||||
class CaseInsensitiveWord(Comparator):
|
||||
"Hybrid value representing a lower case representation of a word."
|
||||
|
||||
def __init__(self, word):
|
||||
if isinstance(word, basestring):
|
||||
self.word = word.lower()
|
||||
elif isinstance(word, CaseInsensitiveWord):
|
||||
self.word = word.word
|
||||
else:
|
||||
self.word = func.lower(word)
|
||||
|
||||
def operate(self, op, other):
|
||||
if not isinstance(other, CaseInsensitiveWord):
|
||||
other = CaseInsensitiveWord(other)
|
||||
return op(self.word, other.word)
|
||||
|
||||
def __clause_element__(self):
|
||||
return self.word
|
||||
|
||||
def __str__(self):
|
||||
return self.word
|
||||
|
||||
key = 'word'
|
||||
"Label to apply to Query tuple results"
|
||||
|
||||
Above, the ``CaseInsensitiveWord`` object represents ``self.word``,
|
||||
which may be a SQL function, or may be a Python native. By
|
||||
overriding ``operate()`` and ``__clause_element__()`` to work in terms
|
||||
of ``self.word``, all comparison operations will work against the
|
||||
"converted" form of ``word``, whether it be SQL side or Python side.
|
||||
Our ``SearchWord`` class can now deliver the ``CaseInsensitiveWord``
|
||||
object unconditionally from a single hybrid call::
|
||||
|
||||
class SearchWord(Base):
|
||||
__tablename__ = 'searchword'
|
||||
id = Column(Integer, primary_key=True)
|
||||
word = Column(String(255), nullable=False)
|
||||
|
||||
@hybrid_property
|
||||
def word_insensitive(self):
|
||||
return CaseInsensitiveWord(self.word)
|
||||
|
||||
The ``word_insensitive`` attribute now has case-insensitive comparison
|
||||
behavior universally, including SQL expression vs. Python expression
|
||||
(note the Python value is converted to lower case on the Python side
|
||||
here)::
|
||||
|
||||
>>> print Session().query(SearchWord).filter_by(word_insensitive="Trucks")
|
||||
SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
|
||||
FROM searchword
|
||||
WHERE lower(searchword.word) = :lower_1
|
||||
|
||||
SQL expression versus SQL expression::
|
||||
|
||||
>>> sw1 = aliased(SearchWord)
|
||||
>>> sw2 = aliased(SearchWord)
|
||||
>>> print Session().query(
|
||||
... sw1.word_insensitive,
|
||||
... sw2.word_insensitive).\
|
||||
... filter(
|
||||
... sw1.word_insensitive > sw2.word_insensitive
|
||||
... )
|
||||
SELECT lower(searchword_1.word) AS lower_1,
|
||||
lower(searchword_2.word) AS lower_2
|
||||
FROM searchword AS searchword_1, searchword AS searchword_2
|
||||
WHERE lower(searchword_1.word) > lower(searchword_2.word)
|
||||
|
||||
Python only expression::
|
||||
|
||||
>>> ws1 = SearchWord(word="SomeWord")
|
||||
>>> ws1.word_insensitive == "sOmEwOrD"
|
||||
True
|
||||
>>> ws1.word_insensitive == "XOmEwOrX"
|
||||
False
|
||||
>>> print ws1.word_insensitive
|
||||
someword
|
||||
|
||||
The Hybrid Value pattern is very useful for any kind of value that may
|
||||
have multiple representations, such as timestamps, time deltas, units
|
||||
of measurement, currencies and encrypted passwords.
|
||||
|
||||
.. seealso::
|
||||
|
||||
`Hybrids and Value Agnostic Types
|
||||
<http://techspot.zzzeek.org/2011/10/21/hybrids-and-value-agnostic-types/>`_
|
||||
- on the techspot.zzzeek.org blog
|
||||
|
||||
`Value Agnostic Types, Part II
|
||||
<http://techspot.zzzeek.org/2011/10/29/value-agnostic-types-part-ii/>`_ -
|
||||
on the techspot.zzzeek.org blog
|
||||
|
||||
.. _hybrid_transformers:
|
||||
|
||||
Building Transformers
|
||||
----------------------
|
||||
|
||||
A *transformer* is an object which can receive a :class:`.Query`
|
||||
object and return a new one. The :class:`.Query` object includes a
|
||||
method :meth:`.with_transformation` that returns a new :class:`.Query`
|
||||
transformed by the given function.
|
||||
|
||||
We can combine this with the :class:`.Comparator` class to produce one type
|
||||
of recipe which can both set up the FROM clause of a query as well as assign
|
||||
filtering criterion.
|
||||
|
||||
Consider a mapped class ``Node``, which assembles using adjacency list
|
||||
into a hierarchical tree pattern::
|
||||
|
||||
from sqlalchemy import Column, Integer, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
Base = declarative_base()
|
||||
|
||||
class Node(Base):
|
||||
__tablename__ = 'node'
|
||||
id = Column(Integer, primary_key=True)
|
||||
parent_id = Column(Integer, ForeignKey('node.id'))
|
||||
parent = relationship("Node", remote_side=id)
|
||||
|
||||
Suppose we wanted to add an accessor ``grandparent``. This would
|
||||
return the ``parent`` of ``Node.parent``. When we have an instance of
|
||||
``Node``, this is simple::
|
||||
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
|
||||
class Node(Base):
|
||||
# ...
|
||||
|
||||
@hybrid_property
|
||||
def grandparent(self):
|
||||
return self.parent.parent
|
||||
|
||||
For the expression, things are not so clear. We'd need to construct
|
||||
a :class:`.Query` where we :meth:`~.Query.join` twice along
|
||||
``Node.parent`` to get to the ``grandparent``. We can instead return
|
||||
a transforming callable that we'll combine with the
|
||||
:class:`.Comparator` class to receive any :class:`.Query` object, and
|
||||
return a new one that's joined to the ``Node.parent`` attribute and
|
||||
filtered based on the given criterion::
|
||||
|
||||
from sqlalchemy.ext.hybrid import Comparator
|
||||
|
||||
class GrandparentTransformer(Comparator):
|
||||
def operate(self, op, other):
|
||||
def transform(q):
|
||||
cls = self.__clause_element__()
|
||||
parent_alias = aliased(cls)
|
||||
return q.join(parent_alias, cls.parent).\
|
||||
filter(op(parent_alias.parent, other))
|
||||
return transform
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Node(Base):
|
||||
__tablename__ = 'node'
|
||||
id =Column(Integer, primary_key=True)
|
||||
parent_id = Column(Integer, ForeignKey('node.id'))
|
||||
parent = relationship("Node", remote_side=id)
|
||||
|
||||
@hybrid_property
|
||||
def grandparent(self):
|
||||
return self.parent.parent
|
||||
|
||||
@grandparent.comparator
|
||||
def grandparent(cls):
|
||||
return GrandparentTransformer(cls)
|
||||
|
||||
The ``GrandparentTransformer`` overrides the core
|
||||
:meth:`.Operators.operate` method at the base of the
|
||||
:class:`.Comparator` hierarchy to return a query-transforming
|
||||
callable, which then runs the given comparison operation in a
|
||||
particular context. Such as, in the example above, the ``operate``
|
||||
method is called, given the :attr:`.Operators.eq` callable as well as
|
||||
the right side of the comparison ``Node(id=5)``. A function
|
||||
``transform`` is then returned which will transform a :class:`.Query`
|
||||
first to join to ``Node.parent``, then to compare ``parent_alias``
|
||||
using :attr:`.Operators.eq` against the left and right sides, passing
|
||||
into :class:`.Query.filter`:
|
||||
|
||||
.. sourcecode:: pycon+sql
|
||||
|
||||
>>> from sqlalchemy.orm import Session
|
||||
>>> session = Session()
|
||||
{sql}>>> session.query(Node).\
|
||||
... with_transformation(Node.grandparent==Node(id=5)).\
|
||||
... all()
|
||||
SELECT node.id AS node_id, node.parent_id AS node_parent_id
|
||||
FROM node JOIN node AS node_1 ON node_1.id = node.parent_id
|
||||
WHERE :param_1 = node_1.parent_id
|
||||
{stop}
|
||||
|
||||
We can modify the pattern to be more verbose but flexible by separating
|
||||
the "join" step from the "filter" step. The tricky part here is ensuring
|
||||
that successive instances of ``GrandparentTransformer`` use the same
|
||||
:class:`.AliasedClass` object against ``Node``. Below we use a simple
|
||||
memoizing approach that associates a ``GrandparentTransformer``
|
||||
with each class::
|
||||
|
||||
class Node(Base):
|
||||
|
||||
# ...
|
||||
|
||||
@grandparent.comparator
|
||||
def grandparent(cls):
|
||||
# memoize a GrandparentTransformer
|
||||
# per class
|
||||
if '_gp' not in cls.__dict__:
|
||||
cls._gp = GrandparentTransformer(cls)
|
||||
return cls._gp
|
||||
|
||||
class GrandparentTransformer(Comparator):
|
||||
|
||||
def __init__(self, cls):
|
||||
self.parent_alias = aliased(cls)
|
||||
|
||||
@property
|
||||
def join(self):
|
||||
def go(q):
|
||||
return q.join(self.parent_alias, Node.parent)
|
||||
return go
|
||||
|
||||
def operate(self, op, other):
|
||||
return op(self.parent_alias.parent, other)
|
||||
|
||||
.. sourcecode:: pycon+sql
|
||||
|
||||
{sql}>>> session.query(Node).\
|
||||
... with_transformation(Node.grandparent.join).\
|
||||
... filter(Node.grandparent==Node(id=5))
|
||||
SELECT node.id AS node_id, node.parent_id AS node_parent_id
|
||||
FROM node JOIN node AS node_1 ON node_1.id = node.parent_id
|
||||
WHERE :param_1 = node_1.parent_id
|
||||
{stop}
|
||||
|
||||
The "transformer" pattern is an experimental pattern that starts
|
||||
to make usage of some functional programming paradigms.
|
||||
While it's only recommended for advanced and/or patient developers,
|
||||
there's probably a whole lot of amazing things it can be used for.
|
||||
|
||||
"""
|
||||
from .. import util
|
||||
from ..orm import attributes, interfaces
|
||||
|
||||
HYBRID_METHOD = util.symbol('HYBRID_METHOD')
|
||||
"""Symbol indicating an :class:`InspectionAttr` that's
|
||||
of type :class:`.hybrid_method`.
|
||||
|
||||
Is assigned to the :attr:`.InspectionAttr.extension_type`
|
||||
attibute.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:attr:`.Mapper.all_orm_attributes`
|
||||
|
||||
"""
|
||||
|
||||
HYBRID_PROPERTY = util.symbol('HYBRID_PROPERTY')
|
||||
"""Symbol indicating an :class:`InspectionAttr` that's
|
||||
of type :class:`.hybrid_method`.
|
||||
|
||||
Is assigned to the :attr:`.InspectionAttr.extension_type`
|
||||
attibute.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:attr:`.Mapper.all_orm_attributes`
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class hybrid_method(interfaces.InspectionAttrInfo):
|
||||
"""A decorator which allows definition of a Python object method with both
|
||||
instance-level and class-level behavior.
|
||||
|
||||
"""
|
||||
|
||||
is_attribute = True
|
||||
extension_type = HYBRID_METHOD
|
||||
|
||||
def __init__(self, func, expr=None):
|
||||
"""Create a new :class:`.hybrid_method`.
|
||||
|
||||
Usage is typically via decorator::
|
||||
|
||||
from sqlalchemy.ext.hybrid import hybrid_method
|
||||
|
||||
class SomeClass(object):
|
||||
@hybrid_method
|
||||
def value(self, x, y):
|
||||
return self._value + x + y
|
||||
|
||||
@value.expression
|
||||
def value(self, x, y):
|
||||
return func.some_function(self._value, x, y)
|
||||
|
||||
"""
|
||||
self.func = func
|
||||
self.expression(expr or func)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self.expr.__get__(owner, owner.__class__)
|
||||
else:
|
||||
return self.func.__get__(instance, owner)
|
||||
|
||||
def expression(self, expr):
|
||||
"""Provide a modifying decorator that defines a
|
||||
SQL-expression producing method."""
|
||||
|
||||
self.expr = expr
|
||||
if not self.expr.__doc__:
|
||||
self.expr.__doc__ = self.func.__doc__
|
||||
return self
|
||||
|
||||
|
||||
class hybrid_property(interfaces.InspectionAttrInfo):
|
||||
"""A decorator which allows definition of a Python descriptor with both
|
||||
instance-level and class-level behavior.
|
||||
|
||||
"""
|
||||
|
||||
is_attribute = True
|
||||
extension_type = HYBRID_PROPERTY
|
||||
|
||||
def __init__(self, fget, fset=None, fdel=None, expr=None):
|
||||
"""Create a new :class:`.hybrid_property`.
|
||||
|
||||
Usage is typically via decorator::
|
||||
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
|
||||
class SomeClass(object):
|
||||
@hybrid_property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
@value.setter
|
||||
def value(self, value):
|
||||
self._value = value
|
||||
|
||||
"""
|
||||
self.fget = fget
|
||||
self.fset = fset
|
||||
self.fdel = fdel
|
||||
self.expression(expr or fget)
|
||||
util.update_wrapper(self, fget)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self.expr(owner)
|
||||
else:
|
||||
return self.fget(instance)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
if self.fset is None:
|
||||
raise AttributeError("can't set attribute")
|
||||
self.fset(instance, value)
|
||||
|
||||
def __delete__(self, instance):
|
||||
if self.fdel is None:
|
||||
raise AttributeError("can't delete attribute")
|
||||
self.fdel(instance)
|
||||
|
||||
def setter(self, fset):
|
||||
"""Provide a modifying decorator that defines a value-setter method."""
|
||||
|
||||
self.fset = fset
|
||||
return self
|
||||
|
||||
def deleter(self, fdel):
|
||||
"""Provide a modifying decorator that defines a
|
||||
value-deletion method."""
|
||||
|
||||
self.fdel = fdel
|
||||
return self
|
||||
|
||||
def expression(self, expr):
|
||||
"""Provide a modifying decorator that defines a SQL-expression
|
||||
producing method."""
|
||||
|
||||
def _expr(cls):
|
||||
return ExprComparator(expr(cls), self)
|
||||
util.update_wrapper(_expr, expr)
|
||||
|
||||
self.expr = _expr
|
||||
return self.comparator(_expr)
|
||||
|
||||
def comparator(self, comparator):
|
||||
"""Provide a modifying decorator that defines a custom
|
||||
comparator producing method.
|
||||
|
||||
The return value of the decorated method should be an instance of
|
||||
:class:`~.hybrid.Comparator`.
|
||||
|
||||
"""
|
||||
|
||||
proxy_attr = attributes.\
|
||||
create_proxied_attribute(self)
|
||||
|
||||
def expr(owner):
|
||||
return proxy_attr(
|
||||
owner, self.__name__, self, comparator(owner),
|
||||
doc=comparator.__doc__ or self.__doc__)
|
||||
self.expr = expr
|
||||
return self
|
||||
|
||||
|
||||
class Comparator(interfaces.PropComparator):
|
||||
"""A helper class that allows easy construction of custom
|
||||
:class:`~.orm.interfaces.PropComparator`
|
||||
classes for usage with hybrids."""
|
||||
|
||||
property = None
|
||||
|
||||
def __init__(self, expression):
|
||||
self.expression = expression
|
||||
|
||||
def __clause_element__(self):
|
||||
expr = self.expression
|
||||
if hasattr(expr, '__clause_element__'):
|
||||
expr = expr.__clause_element__()
|
||||
return expr
|
||||
|
||||
def adapt_to_entity(self, adapt_to_entity):
|
||||
# interesting....
|
||||
return self
|
||||
|
||||
|
||||
class ExprComparator(Comparator):
|
||||
def __init__(self, expression, hybrid):
|
||||
self.expression = expression
|
||||
self.hybrid = hybrid
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.expression, key)
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
return self.hybrid.info
|
||||
|
||||
@property
|
||||
def property(self):
|
||||
return self.expression.property
|
||||
|
||||
def operate(self, op, *other, **kwargs):
|
||||
return op(self.expression, *other, **kwargs)
|
||||
|
||||
def reverse_operate(self, op, other, **kwargs):
|
||||
return op(other, self.expression, **kwargs)
|
||||
349
sqlalchemy/ext/indexable.py
Normal file
349
sqlalchemy/ext/indexable.py
Normal file
@@ -0,0 +1,349 @@
|
||||
# ext/index.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Define attributes on ORM-mapped classes that have "index" attributes for
|
||||
columns with :class:`~.types.Indexable` types.
|
||||
|
||||
"index" means the attribute is associated with an element of an
|
||||
:class:`~.types.Indexable` column with the predefined index to access it.
|
||||
The :class:`~.types.Indexable` types include types such as
|
||||
:class:`~.types.ARRAY`, :class:`~.types.JSON` and
|
||||
:class:`~.postgresql.HSTORE`.
|
||||
|
||||
|
||||
|
||||
The :mod:`~sqlalchemy.ext.indexable` extension provides
|
||||
:class:`~.schema.Column`-like interface for any element of an
|
||||
:class:`~.types.Indexable` typed column. In simple cases, it can be
|
||||
treated as a :class:`~.schema.Column` - mapped attribute.
|
||||
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
Synopsis
|
||||
========
|
||||
|
||||
Given ``Person`` as a model with a primary key and JSON data field.
|
||||
While this field may have any number of elements encoded within it,
|
||||
we would like to refer to the element called ``name`` individually
|
||||
as a dedicated attribute which behaves like a standalone column::
|
||||
|
||||
from sqlalchemy import Column, JSON, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.indexable import index_property
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'person'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSON)
|
||||
|
||||
name = index_property('data', 'name')
|
||||
|
||||
|
||||
Above, the ``name`` attribute now behaves like a mapped column. We
|
||||
can compose a new ``Person`` and set the value of ``name``::
|
||||
|
||||
>>> person = Person(name='Alchemist')
|
||||
|
||||
The value is now accessible::
|
||||
|
||||
>>> person.name
|
||||
'Alchemist'
|
||||
|
||||
Behind the scenes, the JSON field was initialized to a new blank dictionary
|
||||
and the field was set::
|
||||
|
||||
>>> person.data
|
||||
{"name": "Alchemist'}
|
||||
|
||||
The field is mutable in place::
|
||||
|
||||
>>> person.name = 'Renamed'
|
||||
>>> person.name
|
||||
'Renamed'
|
||||
>>> person.data
|
||||
{'name': 'Renamed'}
|
||||
|
||||
When using :class:`.index_property`, the change that we make to the indexable
|
||||
structure is also automatically tracked as history; we no longer need
|
||||
to use :class:`~.mutable.MutableDict` in order to track this change
|
||||
for the unit of work.
|
||||
|
||||
Deletions work normally as well::
|
||||
|
||||
>>> del person.name
|
||||
>>> person.data
|
||||
{}
|
||||
|
||||
Above, deletion of ``person.name`` deletes the value from the dictionary,
|
||||
but not the dictionary itself.
|
||||
|
||||
A missing key will produce ``AttributeError``::
|
||||
|
||||
>>> person = Person()
|
||||
>>> person.name
|
||||
...
|
||||
AttributeError: 'name'
|
||||
|
||||
Unless you set a default value::
|
||||
|
||||
>>> class Person(Base):
|
||||
>>> __tablename__ = 'person'
|
||||
>>>
|
||||
>>> id = Column(Integer, primary_key=True)
|
||||
>>> data = Column(JSON)
|
||||
>>>
|
||||
>>> name = index_property('data', 'name', default=None) # See default
|
||||
|
||||
>>> person = Person()
|
||||
>>> print(person.name)
|
||||
None
|
||||
|
||||
|
||||
The attributes are also accessible at the class level.
|
||||
Below, we illustrate ``Person.name`` used to generate
|
||||
an indexed SQL criteria::
|
||||
|
||||
>>> from sqlalchemy.orm import Session
|
||||
>>> session = Session()
|
||||
>>> query = session.query(Person).filter(Person.name == 'Alchemist')
|
||||
|
||||
The above query is equivalent to::
|
||||
|
||||
>>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist')
|
||||
|
||||
Multiple :class:`.index_property` objects can be chained to produce
|
||||
multiple levels of indexing::
|
||||
|
||||
from sqlalchemy import Column, JSON, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.indexable import index_property
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'person'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSON)
|
||||
|
||||
birthday = index_property('data', 'birthday')
|
||||
year = index_property('birthday', 'year')
|
||||
month = index_property('birthday', 'month')
|
||||
day = index_property('birthday', 'day')
|
||||
|
||||
Above, a query such as::
|
||||
|
||||
q = session.query(Person).filter(Person.year == '1980')
|
||||
|
||||
On a PostgreSQL backend, the above query will render as::
|
||||
|
||||
SELECT person.id, person.data
|
||||
FROM person
|
||||
WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s
|
||||
|
||||
Default Values
|
||||
==============
|
||||
|
||||
:class:`.index_property` includes special behaviors for when the indexed
|
||||
data structure does not exist, and a set operation is called:
|
||||
|
||||
* For an :class:`.index_property` that is given an integer index value,
|
||||
the default data structure will be a Python list of ``None`` values,
|
||||
at least as long as the index value; the value is then set at its
|
||||
place in the list. This means for an index value of zero, the list
|
||||
will be initialized to ``[None]`` before setting the given value,
|
||||
and for an index value of five, the list will be initialized to
|
||||
``[None, None, None, None, None]`` before setting the fifth element
|
||||
to the given value. Note that an existing list is **not** extended
|
||||
in place to receive a value.
|
||||
|
||||
* for an :class:`.index_property` that is given any other kind of index
|
||||
value (e.g. strings usually), a Python dictionary is used as the
|
||||
default data structure.
|
||||
|
||||
* The default data structure can be set to any Python callable using the
|
||||
:paramref:`.index_property.datatype` parameter, overriding the previous
|
||||
rules.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Subclassing
|
||||
===========
|
||||
|
||||
:class:`.index_property` can be subclassed, in particular for the common
|
||||
use case of providing coercion of values or SQL expressions as they are
|
||||
accessed. Below is a common recipe for use with a PostgreSQL JSON type,
|
||||
where we want to also include automatic casting plus ``astext()``::
|
||||
|
||||
class pg_json_property(index_property):
|
||||
def __init__(self, attr_name, index, cast_type):
|
||||
super(pg_json_property, self).__init__(attr_name, index)
|
||||
self.cast_type = cast_type
|
||||
|
||||
def expr(self, model):
|
||||
expr = super(pg_json_property, self).expr(model)
|
||||
return expr.astext.cast(self.cast_type)
|
||||
|
||||
The above subclass can be used with the PostgreSQL-specific
|
||||
version of :class:`.postgresql.JSON`::
|
||||
|
||||
from sqlalchemy import Column, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.dialects.postgresql import JSON
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'person'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSON)
|
||||
|
||||
age = pg_json_property('data', 'age', Integer)
|
||||
|
||||
The ``age`` attribute at the instance level works as before; however
|
||||
when rendering SQL, PostgreSQL's ``->>`` operator will be used
|
||||
for indexed access, instead of the usual index opearator of ``->``::
|
||||
|
||||
>>> query = session.query(Person).filter(Person.age < 20)
|
||||
|
||||
The above query will render::
|
||||
|
||||
SELECT person.id, person.data
|
||||
FROM person
|
||||
WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from ..orm.attributes import flag_modified
|
||||
from ..ext.hybrid import hybrid_property
|
||||
|
||||
|
||||
__all__ = ['index_property']
|
||||
|
||||
|
||||
class index_property(hybrid_property): # noqa
|
||||
"""A property generator. The generated property describes an object
|
||||
attribute that corresponds to an :class:`~.types.Indexable`
|
||||
column.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
.. seealso::
|
||||
|
||||
:mod:`sqlalchemy.ext.indexable`
|
||||
|
||||
"""
|
||||
|
||||
_NO_DEFAULT_ARGUMENT = object()
|
||||
|
||||
def __init__(
|
||||
self, attr_name, index, default=_NO_DEFAULT_ARGUMENT,
|
||||
datatype=None, mutable=True, onebased=True):
|
||||
"""Create a new :class:`.index_property`.
|
||||
|
||||
:param attr_name:
|
||||
An attribute name of an `Indexable` typed column, or other
|
||||
attribute that returns an indexable structure.
|
||||
:param index:
|
||||
The index to be used for getting and setting this value. This
|
||||
should be the Python-side index value for integers.
|
||||
:param default:
|
||||
A value which will be returned instead of `AttributeError`
|
||||
when there is not a value at given index.
|
||||
:param datatype: default datatype to use when the field is empty.
|
||||
By default, this is derived from the type of index used; a
|
||||
Python list for an integer index, or a Python dictionary for
|
||||
any other style of index. For a list, the list will be
|
||||
initialized to a list of None values that is at least
|
||||
``index`` elements long.
|
||||
:param mutable: if False, writes and deletes to the attribute will
|
||||
be disallowed.
|
||||
:param onebased: assume the SQL representation of this value is
|
||||
one-based; that is, the first index in SQL is 1, not zero.
|
||||
"""
|
||||
|
||||
if mutable:
|
||||
super(index_property, self).__init__(
|
||||
self.fget, self.fset, self.fdel, self.expr
|
||||
)
|
||||
else:
|
||||
super(index_property, self).__init__(
|
||||
self.fget, None, None, self.expr
|
||||
)
|
||||
self.attr_name = attr_name
|
||||
self.index = index
|
||||
self.default = default
|
||||
is_numeric = isinstance(index, int)
|
||||
onebased = is_numeric and onebased
|
||||
|
||||
if datatype is not None:
|
||||
self.datatype = datatype
|
||||
else:
|
||||
if is_numeric:
|
||||
self.datatype = lambda: [None for x in range(index + 1)]
|
||||
else:
|
||||
self.datatype = dict
|
||||
self.onebased = onebased
|
||||
|
||||
def _fget_default(self):
|
||||
if self.default == self._NO_DEFAULT_ARGUMENT:
|
||||
raise AttributeError(self.attr_name)
|
||||
else:
|
||||
return self.default
|
||||
|
||||
def fget(self, instance):
|
||||
attr_name = self.attr_name
|
||||
column_value = getattr(instance, attr_name)
|
||||
if column_value is None:
|
||||
return self._fget_default()
|
||||
try:
|
||||
value = column_value[self.index]
|
||||
except (KeyError, IndexError):
|
||||
return self._fget_default()
|
||||
else:
|
||||
return value
|
||||
|
||||
def fset(self, instance, value):
|
||||
attr_name = self.attr_name
|
||||
column_value = getattr(instance, attr_name, None)
|
||||
if column_value is None:
|
||||
column_value = self.datatype()
|
||||
setattr(instance, attr_name, column_value)
|
||||
column_value[self.index] = value
|
||||
setattr(instance, attr_name, column_value)
|
||||
if attr_name in inspect(instance).mapper.attrs:
|
||||
flag_modified(instance, attr_name)
|
||||
|
||||
def fdel(self, instance):
|
||||
attr_name = self.attr_name
|
||||
column_value = getattr(instance, attr_name)
|
||||
if column_value is None:
|
||||
raise AttributeError(self.attr_name)
|
||||
try:
|
||||
del column_value[self.index]
|
||||
except KeyError:
|
||||
raise AttributeError(self.attr_name)
|
||||
else:
|
||||
setattr(instance, attr_name, column_value)
|
||||
flag_modified(instance, attr_name)
|
||||
|
||||
def expr(self, model):
|
||||
column = getattr(model, self.attr_name)
|
||||
index = self.index
|
||||
if self.onebased:
|
||||
index += 1
|
||||
return column[index]
|
||||
414
sqlalchemy/ext/instrumentation.py
Normal file
414
sqlalchemy/ext/instrumentation.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""Extensible class instrumentation.
|
||||
|
||||
The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate
|
||||
systems of class instrumentation within the ORM. Class instrumentation
|
||||
refers to how the ORM places attributes on the class which maintain
|
||||
data and track changes to that data, as well as event hooks installed
|
||||
on the class.
|
||||
|
||||
.. note::
|
||||
The extension package is provided for the benefit of integration
|
||||
with other object management packages, which already perform
|
||||
their own instrumentation. It is not intended for general use.
|
||||
|
||||
For examples of how the instrumentation extension is used,
|
||||
see the example :ref:`examples_instrumentation`.
|
||||
|
||||
.. versionchanged:: 0.8
|
||||
The :mod:`sqlalchemy.orm.instrumentation` was split out so
|
||||
that all functionality having to do with non-standard
|
||||
instrumentation was moved out to :mod:`sqlalchemy.ext.instrumentation`.
|
||||
When imported, the module installs itself within
|
||||
:mod:`sqlalchemy.orm.instrumentation` so that it
|
||||
takes effect, including recognition of
|
||||
``__sa_instrumentation_manager__`` on mapped classes, as
|
||||
well :data:`.instrumentation_finders`
|
||||
being used to determine class instrumentation resolution.
|
||||
|
||||
"""
|
||||
from ..orm import instrumentation as orm_instrumentation
|
||||
from ..orm.instrumentation import (
|
||||
ClassManager, InstrumentationFactory, _default_state_getter,
|
||||
_default_dict_getter, _default_manager_getter
|
||||
)
|
||||
from ..orm import attributes, collections, base as orm_base
|
||||
from .. import util
|
||||
from ..orm import exc as orm_exc
|
||||
import weakref
|
||||
|
||||
INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__'
|
||||
"""Attribute, elects custom instrumentation when present on a mapped class.
|
||||
|
||||
Allows a class to specify a slightly or wildly different technique for
|
||||
tracking changes made to mapped attributes and collections.
|
||||
|
||||
Only one instrumentation implementation is allowed in a given object
|
||||
inheritance hierarchy.
|
||||
|
||||
The value of this attribute must be a callable and will be passed a class
|
||||
object. The callable must return one of:
|
||||
|
||||
- An instance of an InstrumentationManager or subclass
|
||||
- An object implementing all or some of InstrumentationManager (TODO)
|
||||
- A dictionary of callables, implementing all or some of the above (TODO)
|
||||
- An instance of a ClassManager or subclass
|
||||
|
||||
This attribute is consulted by SQLAlchemy instrumentation
|
||||
resolution, once the :mod:`sqlalchemy.ext.instrumentation` module
|
||||
has been imported. If custom finders are installed in the global
|
||||
instrumentation_finders list, they may or may not choose to honor this
|
||||
attribute.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def find_native_user_instrumentation_hook(cls):
|
||||
"""Find user-specified instrumentation management for a class."""
|
||||
return getattr(cls, INSTRUMENTATION_MANAGER, None)
|
||||
|
||||
instrumentation_finders = [find_native_user_instrumentation_hook]
|
||||
"""An extensible sequence of callables which return instrumentation
|
||||
implementations
|
||||
|
||||
When a class is registered, each callable will be passed a class object.
|
||||
If None is returned, the
|
||||
next finder in the sequence is consulted. Otherwise the return must be an
|
||||
instrumentation factory that follows the same guidelines as
|
||||
sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER.
|
||||
|
||||
By default, the only finder is find_native_user_instrumentation_hook, which
|
||||
searches for INSTRUMENTATION_MANAGER. If all finders return None, standard
|
||||
ClassManager instrumentation is used.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ExtendedInstrumentationRegistry(InstrumentationFactory):
|
||||
"""Extends :class:`.InstrumentationFactory` with additional
|
||||
bookkeeping, to accommodate multiple types of
|
||||
class managers.
|
||||
|
||||
"""
|
||||
_manager_finders = weakref.WeakKeyDictionary()
|
||||
_state_finders = weakref.WeakKeyDictionary()
|
||||
_dict_finders = weakref.WeakKeyDictionary()
|
||||
_extended = False
|
||||
|
||||
def _locate_extended_factory(self, class_):
|
||||
for finder in instrumentation_finders:
|
||||
factory = finder(class_)
|
||||
if factory is not None:
|
||||
manager = self._extended_class_manager(class_, factory)
|
||||
return manager, factory
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def _check_conflicts(self, class_, factory):
|
||||
existing_factories = self._collect_management_factories_for(class_).\
|
||||
difference([factory])
|
||||
if existing_factories:
|
||||
raise TypeError(
|
||||
"multiple instrumentation implementations specified "
|
||||
"in %s inheritance hierarchy: %r" % (
|
||||
class_.__name__, list(existing_factories)))
|
||||
|
||||
def _extended_class_manager(self, class_, factory):
|
||||
manager = factory(class_)
|
||||
if not isinstance(manager, ClassManager):
|
||||
manager = _ClassInstrumentationAdapter(class_, manager)
|
||||
|
||||
if factory != ClassManager and not self._extended:
|
||||
# somebody invoked a custom ClassManager.
|
||||
# reinstall global "getter" functions with the more
|
||||
# expensive ones.
|
||||
self._extended = True
|
||||
_install_instrumented_lookups()
|
||||
|
||||
self._manager_finders[class_] = manager.manager_getter()
|
||||
self._state_finders[class_] = manager.state_getter()
|
||||
self._dict_finders[class_] = manager.dict_getter()
|
||||
return manager
|
||||
|
||||
def _collect_management_factories_for(self, cls):
|
||||
"""Return a collection of factories in play or specified for a
|
||||
hierarchy.
|
||||
|
||||
Traverses the entire inheritance graph of a cls and returns a
|
||||
collection of instrumentation factories for those classes. Factories
|
||||
are extracted from active ClassManagers, if available, otherwise
|
||||
instrumentation_finders is consulted.
|
||||
|
||||
"""
|
||||
hierarchy = util.class_hierarchy(cls)
|
||||
factories = set()
|
||||
for member in hierarchy:
|
||||
manager = self.manager_of_class(member)
|
||||
if manager is not None:
|
||||
factories.add(manager.factory)
|
||||
else:
|
||||
for finder in instrumentation_finders:
|
||||
factory = finder(member)
|
||||
if factory is not None:
|
||||
break
|
||||
else:
|
||||
factory = None
|
||||
factories.add(factory)
|
||||
factories.discard(None)
|
||||
return factories
|
||||
|
||||
def unregister(self, class_):
|
||||
if class_ in self._manager_finders:
|
||||
del self._manager_finders[class_]
|
||||
del self._state_finders[class_]
|
||||
del self._dict_finders[class_]
|
||||
super(ExtendedInstrumentationRegistry, self).unregister(class_)
|
||||
|
||||
def manager_of_class(self, cls):
|
||||
if cls is None:
|
||||
return None
|
||||
try:
|
||||
finder = self._manager_finders.get(cls, _default_manager_getter)
|
||||
except TypeError:
|
||||
# due to weakref lookup on invalid object
|
||||
return None
|
||||
else:
|
||||
return finder(cls)
|
||||
|
||||
def state_of(self, instance):
|
||||
if instance is None:
|
||||
raise AttributeError("None has no persistent state.")
|
||||
return self._state_finders.get(
|
||||
instance.__class__, _default_state_getter)(instance)
|
||||
|
||||
def dict_of(self, instance):
|
||||
if instance is None:
|
||||
raise AttributeError("None has no persistent state.")
|
||||
return self._dict_finders.get(
|
||||
instance.__class__, _default_dict_getter)(instance)
|
||||
|
||||
|
||||
orm_instrumentation._instrumentation_factory = \
|
||||
_instrumentation_factory = ExtendedInstrumentationRegistry()
|
||||
orm_instrumentation.instrumentation_finders = instrumentation_finders
|
||||
|
||||
|
||||
class InstrumentationManager(object):
|
||||
"""User-defined class instrumentation extension.
|
||||
|
||||
:class:`.InstrumentationManager` can be subclassed in order
|
||||
to change
|
||||
how class instrumentation proceeds. This class exists for
|
||||
the purposes of integration with other object management
|
||||
frameworks which would like to entirely modify the
|
||||
instrumentation methodology of the ORM, and is not intended
|
||||
for regular usage. For interception of class instrumentation
|
||||
events, see :class:`.InstrumentationEvents`.
|
||||
|
||||
The API for this class should be considered as semi-stable,
|
||||
and may change slightly with new releases.
|
||||
|
||||
.. versionchanged:: 0.8
|
||||
:class:`.InstrumentationManager` was moved from
|
||||
:mod:`sqlalchemy.orm.instrumentation` to
|
||||
:mod:`sqlalchemy.ext.instrumentation`.
|
||||
|
||||
"""
|
||||
|
||||
# r4361 added a mandatory (cls) constructor to this interface.
|
||||
# given that, perhaps class_ should be dropped from all of these
|
||||
# signatures.
|
||||
|
||||
def __init__(self, class_):
|
||||
pass
|
||||
|
||||
def manage(self, class_, manager):
|
||||
setattr(class_, '_default_class_manager', manager)
|
||||
|
||||
def dispose(self, class_, manager):
|
||||
delattr(class_, '_default_class_manager')
|
||||
|
||||
def manager_getter(self, class_):
|
||||
def get(cls):
|
||||
return cls._default_class_manager
|
||||
return get
|
||||
|
||||
def instrument_attribute(self, class_, key, inst):
|
||||
pass
|
||||
|
||||
def post_configure_attribute(self, class_, key, inst):
|
||||
pass
|
||||
|
||||
def install_descriptor(self, class_, key, inst):
|
||||
setattr(class_, key, inst)
|
||||
|
||||
def uninstall_descriptor(self, class_, key):
|
||||
delattr(class_, key)
|
||||
|
||||
def install_member(self, class_, key, implementation):
|
||||
setattr(class_, key, implementation)
|
||||
|
||||
def uninstall_member(self, class_, key):
|
||||
delattr(class_, key)
|
||||
|
||||
def instrument_collection_class(self, class_, key, collection_class):
|
||||
return collections.prepare_instrumentation(collection_class)
|
||||
|
||||
def get_instance_dict(self, class_, instance):
|
||||
return instance.__dict__
|
||||
|
||||
def initialize_instance_dict(self, class_, instance):
|
||||
pass
|
||||
|
||||
def install_state(self, class_, instance, state):
|
||||
setattr(instance, '_default_state', state)
|
||||
|
||||
def remove_state(self, class_, instance):
|
||||
delattr(instance, '_default_state')
|
||||
|
||||
def state_getter(self, class_):
|
||||
return lambda instance: getattr(instance, '_default_state')
|
||||
|
||||
def dict_getter(self, class_):
|
||||
return lambda inst: self.get_instance_dict(class_, inst)
|
||||
|
||||
|
||||
class _ClassInstrumentationAdapter(ClassManager):
|
||||
"""Adapts a user-defined InstrumentationManager to a ClassManager."""
|
||||
|
||||
def __init__(self, class_, override):
|
||||
self._adapted = override
|
||||
self._get_state = self._adapted.state_getter(class_)
|
||||
self._get_dict = self._adapted.dict_getter(class_)
|
||||
|
||||
ClassManager.__init__(self, class_)
|
||||
|
||||
def manage(self):
|
||||
self._adapted.manage(self.class_, self)
|
||||
|
||||
def dispose(self):
|
||||
self._adapted.dispose(self.class_)
|
||||
|
||||
def manager_getter(self):
|
||||
return self._adapted.manager_getter(self.class_)
|
||||
|
||||
def instrument_attribute(self, key, inst, propagated=False):
|
||||
ClassManager.instrument_attribute(self, key, inst, propagated)
|
||||
if not propagated:
|
||||
self._adapted.instrument_attribute(self.class_, key, inst)
|
||||
|
||||
def post_configure_attribute(self, key):
|
||||
super(_ClassInstrumentationAdapter, self).post_configure_attribute(key)
|
||||
self._adapted.post_configure_attribute(self.class_, key, self[key])
|
||||
|
||||
def install_descriptor(self, key, inst):
|
||||
self._adapted.install_descriptor(self.class_, key, inst)
|
||||
|
||||
def uninstall_descriptor(self, key):
|
||||
self._adapted.uninstall_descriptor(self.class_, key)
|
||||
|
||||
def install_member(self, key, implementation):
|
||||
self._adapted.install_member(self.class_, key, implementation)
|
||||
|
||||
def uninstall_member(self, key):
|
||||
self._adapted.uninstall_member(self.class_, key)
|
||||
|
||||
def instrument_collection_class(self, key, collection_class):
|
||||
return self._adapted.instrument_collection_class(
|
||||
self.class_, key, collection_class)
|
||||
|
||||
def initialize_collection(self, key, state, factory):
|
||||
delegate = getattr(self._adapted, 'initialize_collection', None)
|
||||
if delegate:
|
||||
return delegate(key, state, factory)
|
||||
else:
|
||||
return ClassManager.initialize_collection(self, key,
|
||||
state, factory)
|
||||
|
||||
def new_instance(self, state=None):
|
||||
instance = self.class_.__new__(self.class_)
|
||||
self.setup_instance(instance, state)
|
||||
return instance
|
||||
|
||||
def _new_state_if_none(self, instance):
|
||||
"""Install a default InstanceState if none is present.
|
||||
|
||||
A private convenience method used by the __init__ decorator.
|
||||
"""
|
||||
if self.has_state(instance):
|
||||
return False
|
||||
else:
|
||||
return self.setup_instance(instance)
|
||||
|
||||
def setup_instance(self, instance, state=None):
|
||||
self._adapted.initialize_instance_dict(self.class_, instance)
|
||||
|
||||
if state is None:
|
||||
state = self._state_constructor(instance, self)
|
||||
|
||||
# the given instance is assumed to have no state
|
||||
self._adapted.install_state(self.class_, instance, state)
|
||||
return state
|
||||
|
||||
def teardown_instance(self, instance):
|
||||
self._adapted.remove_state(self.class_, instance)
|
||||
|
||||
def has_state(self, instance):
|
||||
try:
|
||||
self._get_state(instance)
|
||||
except orm_exc.NO_STATE:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def state_getter(self):
|
||||
return self._get_state
|
||||
|
||||
def dict_getter(self):
|
||||
return self._get_dict
|
||||
|
||||
|
||||
def _install_instrumented_lookups():
|
||||
"""Replace global class/object management functions
|
||||
with ExtendedInstrumentationRegistry implementations, which
|
||||
allow multiple types of class managers to be present,
|
||||
at the cost of performance.
|
||||
|
||||
This function is called only by ExtendedInstrumentationRegistry
|
||||
and unit tests specific to this behavior.
|
||||
|
||||
The _reinstall_default_lookups() function can be called
|
||||
after this one to re-establish the default functions.
|
||||
|
||||
"""
|
||||
_install_lookups(
|
||||
dict(
|
||||
instance_state=_instrumentation_factory.state_of,
|
||||
instance_dict=_instrumentation_factory.dict_of,
|
||||
manager_of_class=_instrumentation_factory.manager_of_class
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _reinstall_default_lookups():
|
||||
"""Restore simplified lookups."""
|
||||
_install_lookups(
|
||||
dict(
|
||||
instance_state=_default_state_getter,
|
||||
instance_dict=_default_dict_getter,
|
||||
manager_of_class=_default_manager_getter
|
||||
)
|
||||
)
|
||||
_instrumentation_factory._extended = False
|
||||
|
||||
|
||||
def _install_lookups(lookups):
|
||||
global instance_state, instance_dict, manager_of_class
|
||||
instance_state = lookups['instance_state']
|
||||
instance_dict = lookups['instance_dict']
|
||||
manager_of_class = lookups['manager_of_class']
|
||||
orm_base.instance_state = attributes.instance_state = \
|
||||
orm_instrumentation.instance_state = instance_state
|
||||
orm_base.instance_dict = attributes.instance_dict = \
|
||||
orm_instrumentation.instance_dict = instance_dict
|
||||
orm_base.manager_of_class = attributes.manager_of_class = \
|
||||
orm_instrumentation.manager_of_class = manager_of_class
|
||||
904
sqlalchemy/ext/mutable.py
Normal file
904
sqlalchemy/ext/mutable.py
Normal file
@@ -0,0 +1,904 @@
|
||||
# ext/mutable.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
r"""Provide support for tracking of in-place changes to scalar values,
|
||||
which are propagated into ORM change events on owning parent objects.
|
||||
|
||||
.. versionadded:: 0.7 :mod:`sqlalchemy.ext.mutable` replaces SQLAlchemy's
|
||||
legacy approach to in-place mutations of scalar values; see
|
||||
:ref:`07_migration_mutation_extension`.
|
||||
|
||||
.. _mutable_scalars:
|
||||
|
||||
Establishing Mutability on Scalar Column Values
|
||||
===============================================
|
||||
|
||||
A typical example of a "mutable" structure is a Python dictionary.
|
||||
Following the example introduced in :ref:`types_toplevel`, we
|
||||
begin with a custom type that marshals Python dictionaries into
|
||||
JSON strings before being persisted::
|
||||
|
||||
from sqlalchemy.types import TypeDecorator, VARCHAR
|
||||
import json
|
||||
|
||||
class JSONEncodedDict(TypeDecorator):
|
||||
"Represents an immutable structure as a json-encoded string."
|
||||
|
||||
impl = VARCHAR
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is not None:
|
||||
value = json.dumps(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is not None:
|
||||
value = json.loads(value)
|
||||
return value
|
||||
|
||||
The usage of ``json`` is only for the purposes of example. The
|
||||
:mod:`sqlalchemy.ext.mutable` extension can be used
|
||||
with any type whose target Python type may be mutable, including
|
||||
:class:`.PickleType`, :class:`.postgresql.ARRAY`, etc.
|
||||
|
||||
When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself
|
||||
tracks all parents which reference it. Below, we illustrate a simple
|
||||
version of the :class:`.MutableDict` dictionary object, which applies
|
||||
the :class:`.Mutable` mixin to a plain Python dictionary::
|
||||
|
||||
from sqlalchemy.ext.mutable import Mutable
|
||||
|
||||
class MutableDict(Mutable, dict):
|
||||
@classmethod
|
||||
def coerce(cls, key, value):
|
||||
"Convert plain dictionaries to MutableDict."
|
||||
|
||||
if not isinstance(value, MutableDict):
|
||||
if isinstance(value, dict):
|
||||
return MutableDict(value)
|
||||
|
||||
# this call will raise ValueError
|
||||
return Mutable.coerce(key, value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"Detect dictionary set events and emit change events."
|
||||
|
||||
dict.__setitem__(self, key, value)
|
||||
self.changed()
|
||||
|
||||
def __delitem__(self, key):
|
||||
"Detect dictionary del events and emit change events."
|
||||
|
||||
dict.__delitem__(self, key)
|
||||
self.changed()
|
||||
|
||||
The above dictionary class takes the approach of subclassing the Python
|
||||
built-in ``dict`` to produce a dict
|
||||
subclass which routes all mutation events through ``__setitem__``. There are
|
||||
variants on this approach, such as subclassing ``UserDict.UserDict`` or
|
||||
``collections.MutableMapping``; the part that's important to this example is
|
||||
that the :meth:`.Mutable.changed` method is called whenever an in-place
|
||||
change to the datastructure takes place.
|
||||
|
||||
We also redefine the :meth:`.Mutable.coerce` method which will be used to
|
||||
convert any values that are not instances of ``MutableDict``, such
|
||||
as the plain dictionaries returned by the ``json`` module, into the
|
||||
appropriate type. Defining this method is optional; we could just as well
|
||||
created our ``JSONEncodedDict`` such that it always returns an instance
|
||||
of ``MutableDict``, and additionally ensured that all calling code
|
||||
uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not
|
||||
overridden, any values applied to a parent object which are not instances
|
||||
of the mutable type will raise a ``ValueError``.
|
||||
|
||||
Our new ``MutableDict`` type offers a class method
|
||||
:meth:`~.Mutable.as_mutable` which we can use within column metadata
|
||||
to associate with types. This method grabs the given type object or
|
||||
class and associates a listener that will detect all future mappings
|
||||
of this type, applying event listening instrumentation to the mapped
|
||||
attribute. Such as, with classical table metadata::
|
||||
|
||||
from sqlalchemy import Table, Column, Integer
|
||||
|
||||
my_data = Table('my_data', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', MutableDict.as_mutable(JSONEncodedDict))
|
||||
)
|
||||
|
||||
Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict``
|
||||
(if the type object was not an instance already), which will intercept any
|
||||
attributes which are mapped against this type. Below we establish a simple
|
||||
mapping against the ``my_data`` table::
|
||||
|
||||
from sqlalchemy import mapper
|
||||
|
||||
class MyDataClass(object):
|
||||
pass
|
||||
|
||||
# associates mutation listeners with MyDataClass.data
|
||||
mapper(MyDataClass, my_data)
|
||||
|
||||
The ``MyDataClass.data`` member will now be notified of in place changes
|
||||
to its value.
|
||||
|
||||
There's no difference in usage when using declarative::
|
||||
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class MyDataClass(Base):
|
||||
__tablename__ = 'my_data'
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(MutableDict.as_mutable(JSONEncodedDict))
|
||||
|
||||
Any in-place changes to the ``MyDataClass.data`` member
|
||||
will flag the attribute as "dirty" on the parent object::
|
||||
|
||||
>>> from sqlalchemy.orm import Session
|
||||
|
||||
>>> sess = Session()
|
||||
>>> m1 = MyDataClass(data={'value1':'foo'})
|
||||
>>> sess.add(m1)
|
||||
>>> sess.commit()
|
||||
|
||||
>>> m1.data['value1'] = 'bar'
|
||||
>>> assert m1 in sess.dirty
|
||||
True
|
||||
|
||||
The ``MutableDict`` can be associated with all future instances
|
||||
of ``JSONEncodedDict`` in one step, using
|
||||
:meth:`~.Mutable.associate_with`. This is similar to
|
||||
:meth:`~.Mutable.as_mutable` except it will intercept all occurrences
|
||||
of ``MutableDict`` in all mappings unconditionally, without
|
||||
the need to declare it individually::
|
||||
|
||||
MutableDict.associate_with(JSONEncodedDict)
|
||||
|
||||
class MyDataClass(Base):
|
||||
__tablename__ = 'my_data'
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSONEncodedDict)
|
||||
|
||||
|
||||
Supporting Pickling
|
||||
--------------------
|
||||
|
||||
The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the
|
||||
placement of a ``weakref.WeakKeyDictionary`` upon the value object, which
|
||||
stores a mapping of parent mapped objects keyed to the attribute name under
|
||||
which they are associated with this value. ``WeakKeyDictionary`` objects are
|
||||
not picklable, due to the fact that they contain weakrefs and function
|
||||
callbacks. In our case, this is a good thing, since if this dictionary were
|
||||
picklable, it could lead to an excessively large pickle size for our value
|
||||
objects that are pickled by themselves outside of the context of the parent.
|
||||
The developer responsibility here is only to provide a ``__getstate__`` method
|
||||
that excludes the :meth:`~MutableBase._parents` collection from the pickle
|
||||
stream::
|
||||
|
||||
class MyMutableType(Mutable):
|
||||
def __getstate__(self):
|
||||
d = self.__dict__.copy()
|
||||
d.pop('_parents', None)
|
||||
return d
|
||||
|
||||
With our dictionary example, we need to return the contents of the dict itself
|
||||
(and also restore them on __setstate__)::
|
||||
|
||||
class MutableDict(Mutable, dict):
|
||||
# ....
|
||||
|
||||
def __getstate__(self):
|
||||
return dict(self)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.update(state)
|
||||
|
||||
In the case that our mutable value object is pickled as it is attached to one
|
||||
or more parent objects that are also part of the pickle, the :class:`.Mutable`
|
||||
mixin will re-establish the :attr:`.Mutable._parents` collection on each value
|
||||
object as the owning parents themselves are unpickled.
|
||||
|
||||
.. _mutable_composites:
|
||||
|
||||
Establishing Mutability on Composites
|
||||
=====================================
|
||||
|
||||
Composites are a special ORM feature which allow a single scalar attribute to
|
||||
be assigned an object value which represents information "composed" from one
|
||||
or more columns from the underlying mapped table. The usual example is that of
|
||||
a geometric "point", and is introduced in :ref:`mapper_composite`.
|
||||
|
||||
.. versionchanged:: 0.7
|
||||
The internals of :func:`.orm.composite` have been
|
||||
greatly simplified and in-place mutation detection is no longer enabled by
|
||||
default; instead, the user-defined value must detect changes on its own and
|
||||
propagate them to all owning parents. The :mod:`sqlalchemy.ext.mutable`
|
||||
extension provides the helper class :class:`.MutableComposite`, which is a
|
||||
slight variant on the :class:`.Mutable` class.
|
||||
|
||||
As is the case with :class:`.Mutable`, the user-defined composite class
|
||||
subclasses :class:`.MutableComposite` as a mixin, and detects and delivers
|
||||
change events to its parents via the :meth:`.MutableComposite.changed` method.
|
||||
In the case of a composite class, the detection is usually via the usage of
|
||||
Python descriptors (i.e. ``@property``), or alternatively via the special
|
||||
Python method ``__setattr__()``. Below we expand upon the ``Point`` class
|
||||
introduced in :ref:`mapper_composite` to subclass :class:`.MutableComposite`
|
||||
and to also route attribute set events via ``__setattr__`` to the
|
||||
:meth:`.MutableComposite.changed` method::
|
||||
|
||||
from sqlalchemy.ext.mutable import MutableComposite
|
||||
|
||||
class Point(MutableComposite):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
"Intercept set events"
|
||||
|
||||
# set the attribute
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
# alert all parents to the change
|
||||
self.changed()
|
||||
|
||||
def __composite_values__(self):
|
||||
return self.x, self.y
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Point) and \
|
||||
other.x == self.x and \
|
||||
other.y == self.y
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
The :class:`.MutableComposite` class uses a Python metaclass to automatically
|
||||
establish listeners for any usage of :func:`.orm.composite` that specifies our
|
||||
``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class,
|
||||
listeners are established which will route change events from ``Point``
|
||||
objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes::
|
||||
|
||||
from sqlalchemy.orm import composite, mapper
|
||||
from sqlalchemy import Table, Column
|
||||
|
||||
vertices = Table('vertices', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('x1', Integer),
|
||||
Column('y1', Integer),
|
||||
Column('x2', Integer),
|
||||
Column('y2', Integer),
|
||||
)
|
||||
|
||||
class Vertex(object):
|
||||
pass
|
||||
|
||||
mapper(Vertex, vertices, properties={
|
||||
'start': composite(Point, vertices.c.x1, vertices.c.y1),
|
||||
'end': composite(Point, vertices.c.x2, vertices.c.y2)
|
||||
})
|
||||
|
||||
Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members
|
||||
will flag the attribute as "dirty" on the parent object::
|
||||
|
||||
>>> from sqlalchemy.orm import Session
|
||||
|
||||
>>> sess = Session()
|
||||
>>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15))
|
||||
>>> sess.add(v1)
|
||||
>>> sess.commit()
|
||||
|
||||
>>> v1.end.x = 8
|
||||
>>> assert v1 in sess.dirty
|
||||
True
|
||||
|
||||
Coercing Mutable Composites
|
||||
---------------------------
|
||||
|
||||
The :meth:`.MutableBase.coerce` method is also supported on composite types.
|
||||
In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce`
|
||||
method is only called for attribute set operations, not load operations.
|
||||
Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent
|
||||
to using a :func:`.validates` validation routine for all attributes which
|
||||
make use of the custom composite type::
|
||||
|
||||
class Point(MutableComposite):
|
||||
# other Point methods
|
||||
# ...
|
||||
|
||||
def coerce(cls, key, value):
|
||||
if isinstance(value, tuple):
|
||||
value = Point(*value)
|
||||
elif not isinstance(value, Point):
|
||||
raise ValueError("tuple or Point expected")
|
||||
return value
|
||||
|
||||
.. versionadded:: 0.7.10,0.8.0b2
|
||||
Support for the :meth:`.MutableBase.coerce` method in conjunction with
|
||||
objects of type :class:`.MutableComposite`.
|
||||
|
||||
Supporting Pickling
|
||||
--------------------
|
||||
|
||||
As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper
|
||||
class uses a ``weakref.WeakKeyDictionary`` available via the
|
||||
:meth:`MutableBase._parents` attribute which isn't picklable. If we need to
|
||||
pickle instances of ``Point`` or its owning class ``Vertex``, we at least need
|
||||
to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary.
|
||||
Below we define both a ``__getstate__`` and a ``__setstate__`` that package up
|
||||
the minimal form of our ``Point`` class::
|
||||
|
||||
class Point(MutableComposite):
|
||||
# ...
|
||||
|
||||
def __getstate__(self):
|
||||
return self.x, self.y
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.x, self.y = state
|
||||
|
||||
As with :class:`.Mutable`, the :class:`.MutableComposite` augments the
|
||||
pickling process of the parent's object-relational state so that the
|
||||
:meth:`MutableBase._parents` collection is restored to all ``Point`` objects.
|
||||
|
||||
"""
|
||||
from ..orm.attributes import flag_modified
|
||||
from .. import event, types
|
||||
from ..orm import mapper, object_mapper, Mapper
|
||||
from ..util import memoized_property
|
||||
from ..sql.base import SchemaEventTarget
|
||||
import weakref
|
||||
|
||||
|
||||
class MutableBase(object):
|
||||
"""Common base class to :class:`.Mutable`
|
||||
and :class:`.MutableComposite`.
|
||||
|
||||
"""
|
||||
|
||||
@memoized_property
|
||||
def _parents(self):
|
||||
"""Dictionary of parent object->attribute name on the parent.
|
||||
|
||||
This attribute is a so-called "memoized" property. It initializes
|
||||
itself with a new ``weakref.WeakKeyDictionary`` the first time
|
||||
it is accessed, returning the same object upon subsequent access.
|
||||
|
||||
"""
|
||||
|
||||
return weakref.WeakKeyDictionary()
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, key, value):
|
||||
"""Given a value, coerce it into the target type.
|
||||
|
||||
Can be overridden by custom subclasses to coerce incoming
|
||||
data into a particular type.
|
||||
|
||||
By default, raises ``ValueError``.
|
||||
|
||||
This method is called in different scenarios depending on if
|
||||
the parent class is of type :class:`.Mutable` or of type
|
||||
:class:`.MutableComposite`. In the case of the former, it is called
|
||||
for both attribute-set operations as well as during ORM loading
|
||||
operations. For the latter, it is only called during attribute-set
|
||||
operations; the mechanics of the :func:`.composite` construct
|
||||
handle coercion during load operations.
|
||||
|
||||
|
||||
:param key: string name of the ORM-mapped attribute being set.
|
||||
:param value: the incoming value.
|
||||
:return: the method should return the coerced value, or raise
|
||||
``ValueError`` if the coercion cannot be completed.
|
||||
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
msg = "Attribute '%s' does not accept objects of type %s"
|
||||
raise ValueError(msg % (key, type(value)))
|
||||
|
||||
@classmethod
|
||||
def _get_listen_keys(cls, attribute):
|
||||
"""Given a descriptor attribute, return a ``set()`` of the attribute
|
||||
keys which indicate a change in the state of this attribute.
|
||||
|
||||
This is normally just ``set([attribute.key])``, but can be overridden
|
||||
to provide for additional keys. E.g. a :class:`.MutableComposite`
|
||||
augments this set with the attribute keys associated with the columns
|
||||
that comprise the composite value.
|
||||
|
||||
This collection is consulted in the case of intercepting the
|
||||
:meth:`.InstanceEvents.refresh` and
|
||||
:meth:`.InstanceEvents.refresh_flush` events, which pass along a list
|
||||
of attribute names that have been refreshed; the list is compared
|
||||
against this set to determine if action needs to be taken.
|
||||
|
||||
.. versionadded:: 1.0.5
|
||||
|
||||
"""
|
||||
return set([attribute.key])
|
||||
|
||||
@classmethod
|
||||
def _listen_on_attribute(cls, attribute, coerce, parent_cls):
|
||||
"""Establish this type as a mutation listener for the given
|
||||
mapped descriptor.
|
||||
|
||||
"""
|
||||
key = attribute.key
|
||||
if parent_cls is not attribute.class_:
|
||||
return
|
||||
|
||||
# rely on "propagate" here
|
||||
parent_cls = attribute.class_
|
||||
|
||||
listen_keys = cls._get_listen_keys(attribute)
|
||||
|
||||
def load(state, *args):
|
||||
"""Listen for objects loaded or refreshed.
|
||||
|
||||
Wrap the target data member's value with
|
||||
``Mutable``.
|
||||
|
||||
"""
|
||||
val = state.dict.get(key, None)
|
||||
if val is not None:
|
||||
if coerce:
|
||||
val = cls.coerce(key, val)
|
||||
state.dict[key] = val
|
||||
val._parents[state.obj()] = key
|
||||
|
||||
def load_attrs(state, ctx, attrs):
|
||||
if not attrs or listen_keys.intersection(attrs):
|
||||
load(state)
|
||||
|
||||
def set(target, value, oldvalue, initiator):
|
||||
"""Listen for set/replace events on the target
|
||||
data member.
|
||||
|
||||
Establish a weak reference to the parent object
|
||||
on the incoming value, remove it for the one
|
||||
outgoing.
|
||||
|
||||
"""
|
||||
if value is oldvalue:
|
||||
return value
|
||||
|
||||
if not isinstance(value, cls):
|
||||
value = cls.coerce(key, value)
|
||||
if value is not None:
|
||||
value._parents[target.obj()] = key
|
||||
if isinstance(oldvalue, cls):
|
||||
oldvalue._parents.pop(target.obj(), None)
|
||||
return value
|
||||
|
||||
def pickle(state, state_dict):
|
||||
val = state.dict.get(key, None)
|
||||
if val is not None:
|
||||
if 'ext.mutable.values' not in state_dict:
|
||||
state_dict['ext.mutable.values'] = []
|
||||
state_dict['ext.mutable.values'].append(val)
|
||||
|
||||
def unpickle(state, state_dict):
|
||||
if 'ext.mutable.values' in state_dict:
|
||||
for val in state_dict['ext.mutable.values']:
|
||||
val._parents[state.obj()] = key
|
||||
|
||||
event.listen(parent_cls, 'load', load,
|
||||
raw=True, propagate=True)
|
||||
event.listen(parent_cls, 'refresh', load_attrs,
|
||||
raw=True, propagate=True)
|
||||
event.listen(parent_cls, 'refresh_flush', load_attrs,
|
||||
raw=True, propagate=True)
|
||||
event.listen(attribute, 'set', set,
|
||||
raw=True, retval=True, propagate=True)
|
||||
event.listen(parent_cls, 'pickle', pickle,
|
||||
raw=True, propagate=True)
|
||||
event.listen(parent_cls, 'unpickle', unpickle,
|
||||
raw=True, propagate=True)
|
||||
|
||||
|
||||
class Mutable(MutableBase):
|
||||
"""Mixin that defines transparent propagation of change
|
||||
events to a parent object.
|
||||
|
||||
See the example in :ref:`mutable_scalars` for usage information.
|
||||
|
||||
"""
|
||||
|
||||
def changed(self):
|
||||
"""Subclasses should call this method whenever change events occur."""
|
||||
|
||||
for parent, key in self._parents.items():
|
||||
flag_modified(parent, key)
|
||||
|
||||
@classmethod
|
||||
def associate_with_attribute(cls, attribute):
|
||||
"""Establish this type as a mutation listener for the given
|
||||
mapped descriptor.
|
||||
|
||||
"""
|
||||
cls._listen_on_attribute(attribute, True, attribute.class_)
|
||||
|
||||
@classmethod
|
||||
def associate_with(cls, sqltype):
|
||||
"""Associate this wrapper with all future mapped columns
|
||||
of the given type.
|
||||
|
||||
This is a convenience method that calls
|
||||
``associate_with_attribute`` automatically.
|
||||
|
||||
.. warning::
|
||||
|
||||
The listeners established by this method are *global*
|
||||
to all mappers, and are *not* garbage collected. Only use
|
||||
:meth:`.associate_with` for types that are permanent to an
|
||||
application, not with ad-hoc types else this will cause unbounded
|
||||
growth in memory usage.
|
||||
|
||||
"""
|
||||
|
||||
def listen_for_type(mapper, class_):
|
||||
for prop in mapper.column_attrs:
|
||||
if isinstance(prop.columns[0].type, sqltype):
|
||||
cls.associate_with_attribute(getattr(class_, prop.key))
|
||||
|
||||
event.listen(mapper, 'mapper_configured', listen_for_type)
|
||||
|
||||
@classmethod
|
||||
def as_mutable(cls, sqltype):
|
||||
"""Associate a SQL type with this mutable Python type.
|
||||
|
||||
This establishes listeners that will detect ORM mappings against
|
||||
the given type, adding mutation event trackers to those mappings.
|
||||
|
||||
The type is returned, unconditionally as an instance, so that
|
||||
:meth:`.as_mutable` can be used inline::
|
||||
|
||||
Table('mytable', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', MyMutableType.as_mutable(PickleType))
|
||||
)
|
||||
|
||||
Note that the returned type is always an instance, even if a class
|
||||
is given, and that only columns which are declared specifically with
|
||||
that type instance receive additional instrumentation.
|
||||
|
||||
To associate a particular mutable type with all occurrences of a
|
||||
particular type, use the :meth:`.Mutable.associate_with` classmethod
|
||||
of the particular :class:`.Mutable` subclass to establish a global
|
||||
association.
|
||||
|
||||
.. warning::
|
||||
|
||||
The listeners established by this method are *global*
|
||||
to all mappers, and are *not* garbage collected. Only use
|
||||
:meth:`.as_mutable` for types that are permanent to an application,
|
||||
not with ad-hoc types else this will cause unbounded growth
|
||||
in memory usage.
|
||||
|
||||
"""
|
||||
sqltype = types.to_instance(sqltype)
|
||||
|
||||
# a SchemaType will be copied when the Column is copied,
|
||||
# and we'll lose our ability to link that type back to the original.
|
||||
# so track our original type w/ columns
|
||||
if isinstance(sqltype, SchemaEventTarget):
|
||||
@event.listens_for(sqltype, "before_parent_attach")
|
||||
def _add_column_memo(sqltyp, parent):
|
||||
parent.info['_ext_mutable_orig_type'] = sqltyp
|
||||
schema_event_check = True
|
||||
else:
|
||||
schema_event_check = False
|
||||
|
||||
def listen_for_type(mapper, class_):
|
||||
for prop in mapper.column_attrs:
|
||||
if (
|
||||
schema_event_check and
|
||||
hasattr(prop.expression, 'info') and
|
||||
prop.expression.info.get('_ext_mutable_orig_type')
|
||||
is sqltype
|
||||
) or (
|
||||
prop.columns[0].type is sqltype
|
||||
):
|
||||
cls.associate_with_attribute(getattr(class_, prop.key))
|
||||
|
||||
event.listen(mapper, 'mapper_configured', listen_for_type)
|
||||
|
||||
return sqltype
|
||||
|
||||
|
||||
class MutableComposite(MutableBase):
|
||||
"""Mixin that defines transparent propagation of change
|
||||
events on a SQLAlchemy "composite" object to its
|
||||
owning parent or parents.
|
||||
|
||||
See the example in :ref:`mutable_composites` for usage information.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _get_listen_keys(cls, attribute):
|
||||
return set([attribute.key]).union(attribute.property._attribute_keys)
|
||||
|
||||
def changed(self):
|
||||
"""Subclasses should call this method whenever change events occur."""
|
||||
|
||||
for parent, key in self._parents.items():
|
||||
|
||||
prop = object_mapper(parent).get_property(key)
|
||||
for value, attr_name in zip(
|
||||
self.__composite_values__(),
|
||||
prop._attribute_keys):
|
||||
setattr(parent, attr_name, value)
|
||||
|
||||
|
||||
def _setup_composite_listener():
|
||||
def _listen_for_type(mapper, class_):
|
||||
for prop in mapper.iterate_properties:
|
||||
if (hasattr(prop, 'composite_class') and
|
||||
isinstance(prop.composite_class, type) and
|
||||
issubclass(prop.composite_class, MutableComposite)):
|
||||
prop.composite_class._listen_on_attribute(
|
||||
getattr(class_, prop.key), False, class_)
|
||||
if not event.contains(Mapper, "mapper_configured", _listen_for_type):
|
||||
event.listen(Mapper, 'mapper_configured', _listen_for_type)
|
||||
_setup_composite_listener()
|
||||
|
||||
|
||||
class MutableDict(Mutable, dict):
|
||||
"""A dictionary type that implements :class:`.Mutable`.
|
||||
|
||||
The :class:`.MutableDict` object implements a dictionary that will
|
||||
emit change events to the underlying mapping when the contents of
|
||||
the dictionary are altered, including when values are added or removed.
|
||||
|
||||
Note that :class:`.MutableDict` does **not** apply mutable tracking to the
|
||||
*values themselves* inside the dictionary. Therefore it is not a sufficient
|
||||
solution for the use case of tracking deep changes to a *recursive*
|
||||
dictionary structure, such as a JSON structure. To support this use case,
|
||||
build a subclass of :class:`.MutableDict` that provides appropriate
|
||||
coersion to the values placed in the dictionary so that they too are
|
||||
"mutable", and emit events up to their parent structure.
|
||||
|
||||
.. versionadded:: 0.8
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.MutableList`
|
||||
|
||||
:class:`.MutableSet`
|
||||
|
||||
"""
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Detect dictionary set events and emit change events."""
|
||||
dict.__setitem__(self, key, value)
|
||||
self.changed()
|
||||
|
||||
def setdefault(self, key, value):
|
||||
result = dict.setdefault(self, key, value)
|
||||
self.changed()
|
||||
return result
|
||||
|
||||
def __delitem__(self, key):
|
||||
"""Detect dictionary del events and emit change events."""
|
||||
dict.__delitem__(self, key)
|
||||
self.changed()
|
||||
|
||||
def update(self, *a, **kw):
|
||||
dict.update(self, *a, **kw)
|
||||
self.changed()
|
||||
|
||||
def pop(self, *arg):
|
||||
result = dict.pop(self, *arg)
|
||||
self.changed()
|
||||
return result
|
||||
|
||||
def popitem(self):
|
||||
result = dict.popitem(self)
|
||||
self.changed()
|
||||
return result
|
||||
|
||||
def clear(self):
|
||||
dict.clear(self)
|
||||
self.changed()
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, key, value):
|
||||
"""Convert plain dictionary to instance of this class."""
|
||||
if not isinstance(value, cls):
|
||||
if isinstance(value, dict):
|
||||
return cls(value)
|
||||
return Mutable.coerce(key, value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def __getstate__(self):
|
||||
return dict(self)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.update(state)
|
||||
|
||||
|
||||
class MutableList(Mutable, list):
|
||||
"""A list type that implements :class:`.Mutable`.
|
||||
|
||||
The :class:`.MutableList` object implements a list that will
|
||||
emit change events to the underlying mapping when the contents of
|
||||
the list are altered, including when values are added or removed.
|
||||
|
||||
Note that :class:`.MutableList` does **not** apply mutable tracking to the
|
||||
*values themselves* inside the list. Therefore it is not a sufficient
|
||||
solution for the use case of tracking deep changes to a *recursive*
|
||||
mutable structure, such as a JSON structure. To support this use case,
|
||||
build a subclass of :class:`.MutableList` that provides appropriate
|
||||
coersion to the values placed in the dictionary so that they too are
|
||||
"mutable", and emit events up to their parent structure.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.MutableDict`
|
||||
|
||||
:class:`.MutableSet`
|
||||
|
||||
"""
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
"""Detect list set events and emit change events."""
|
||||
list.__setitem__(self, index, value)
|
||||
self.changed()
|
||||
|
||||
def __setslice__(self, start, end, value):
|
||||
"""Detect list set events and emit change events."""
|
||||
list.__setslice__(self, start, end, value)
|
||||
self.changed()
|
||||
|
||||
def __delitem__(self, index):
|
||||
"""Detect list del events and emit change events."""
|
||||
list.__delitem__(self, index)
|
||||
self.changed()
|
||||
|
||||
def __delslice__(self, start, end):
|
||||
"""Detect list del events and emit change events."""
|
||||
list.__delslice__(self, start, end)
|
||||
self.changed()
|
||||
|
||||
def pop(self, *arg):
|
||||
result = list.pop(self, *arg)
|
||||
self.changed()
|
||||
return result
|
||||
|
||||
def append(self, x):
|
||||
list.append(self, x)
|
||||
self.changed()
|
||||
|
||||
def extend(self, x):
|
||||
list.extend(self, x)
|
||||
self.changed()
|
||||
|
||||
def insert(self, i, x):
|
||||
list.insert(self, i, x)
|
||||
self.changed()
|
||||
|
||||
def remove(self, i):
|
||||
list.remove(self, i)
|
||||
self.changed()
|
||||
|
||||
def clear(self):
|
||||
list.clear(self)
|
||||
self.changed()
|
||||
|
||||
def sort(self):
|
||||
list.sort(self)
|
||||
self.changed()
|
||||
|
||||
def reverse(self):
|
||||
list.reverse(self)
|
||||
self.changed()
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, index, value):
|
||||
"""Convert plain list to instance of this class."""
|
||||
if not isinstance(value, cls):
|
||||
if isinstance(value, list):
|
||||
return cls(value)
|
||||
return Mutable.coerce(index, value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def __getstate__(self):
|
||||
return list(self)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self[:] = state
|
||||
|
||||
|
||||
class MutableSet(Mutable, set):
|
||||
"""A set type that implements :class:`.Mutable`.
|
||||
|
||||
The :class:`.MutableSet` object implements a set that will
|
||||
emit change events to the underlying mapping when the contents of
|
||||
the set are altered, including when values are added or removed.
|
||||
|
||||
Note that :class:`.MutableSet` does **not** apply mutable tracking to the
|
||||
*values themselves* inside the set. Therefore it is not a sufficient
|
||||
solution for the use case of tracking deep changes to a *recursive*
|
||||
mutable structure. To support this use case,
|
||||
build a subclass of :class:`.MutableSet` that provides appropriate
|
||||
coersion to the values placed in the dictionary so that they too are
|
||||
"mutable", and emit events up to their parent structure.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.MutableDict`
|
||||
|
||||
:class:`.MutableList`
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def update(self, *arg):
|
||||
set.update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def intersection_update(self, *arg):
|
||||
set.intersection_update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def difference_update(self, *arg):
|
||||
set.difference_update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def symmetric_difference_update(self, *arg):
|
||||
set.symmetric_difference_update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def add(self, elem):
|
||||
set.add(self, elem)
|
||||
self.changed()
|
||||
|
||||
def remove(self, elem):
|
||||
set.remove(self, elem)
|
||||
self.changed()
|
||||
|
||||
def discard(self, elem):
|
||||
set.discard(self, elem)
|
||||
self.changed()
|
||||
|
||||
def pop(self, *arg):
|
||||
result = set.pop(self, *arg)
|
||||
self.changed()
|
||||
return result
|
||||
|
||||
def clear(self):
|
||||
set.clear(self)
|
||||
self.changed()
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, index, value):
|
||||
"""Convert plain set to instance of this class."""
|
||||
if not isinstance(value, cls):
|
||||
if isinstance(value, set):
|
||||
return cls(value)
|
||||
return Mutable.coerce(index, value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def __getstate__(self):
|
||||
return set(self)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.update(state)
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
return (self.__class__, (list(self), ))
|
||||
93
sqlalchemy/inspection.py
Normal file
93
sqlalchemy/inspection.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# sqlalchemy/inspect.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""The inspection module provides the :func:`.inspect` function,
|
||||
which delivers runtime information about a wide variety
|
||||
of SQLAlchemy objects, both within the Core as well as the
|
||||
ORM.
|
||||
|
||||
The :func:`.inspect` function is the entry point to SQLAlchemy's
|
||||
public API for viewing the configuration and construction
|
||||
of in-memory objects. Depending on the type of object
|
||||
passed to :func:`.inspect`, the return value will either be
|
||||
a related object which provides a known interface, or in many
|
||||
cases it will return the object itself.
|
||||
|
||||
The rationale for :func:`.inspect` is twofold. One is that
|
||||
it replaces the need to be aware of a large variety of "information
|
||||
getting" functions in SQLAlchemy, such as :meth:`.Inspector.from_engine`,
|
||||
:func:`.orm.attributes.instance_state`, :func:`.orm.class_mapper`,
|
||||
and others. The other is that the return value of :func:`.inspect`
|
||||
is guaranteed to obey a documented API, thus allowing third party
|
||||
tools which build on top of SQLAlchemy configurations to be constructed
|
||||
in a forwards-compatible way.
|
||||
|
||||
.. versionadded:: 0.8 The :func:`.inspect` system is introduced
|
||||
as of version 0.8.
|
||||
|
||||
"""
|
||||
|
||||
from . import util, exc
|
||||
_registrars = util.defaultdict(list)
|
||||
|
||||
|
||||
def inspect(subject, raiseerr=True):
|
||||
"""Produce an inspection object for the given target.
|
||||
|
||||
The returned value in some cases may be the
|
||||
same object as the one given, such as if a
|
||||
:class:`.Mapper` object is passed. In other
|
||||
cases, it will be an instance of the registered
|
||||
inspection type for the given object, such as
|
||||
if an :class:`.engine.Engine` is passed, an
|
||||
:class:`.Inspector` object is returned.
|
||||
|
||||
:param subject: the subject to be inspected.
|
||||
:param raiseerr: When ``True``, if the given subject
|
||||
does not
|
||||
correspond to a known SQLAlchemy inspected type,
|
||||
:class:`sqlalchemy.exc.NoInspectionAvailable`
|
||||
is raised. If ``False``, ``None`` is returned.
|
||||
|
||||
"""
|
||||
type_ = type(subject)
|
||||
for cls in type_.__mro__:
|
||||
if cls in _registrars:
|
||||
reg = _registrars[cls]
|
||||
if reg is True:
|
||||
return subject
|
||||
ret = reg(subject)
|
||||
if ret is not None:
|
||||
break
|
||||
else:
|
||||
reg = ret = None
|
||||
|
||||
if raiseerr and (
|
||||
reg is None or ret is None
|
||||
):
|
||||
raise exc.NoInspectionAvailable(
|
||||
"No inspection system is "
|
||||
"available for object of type %s" %
|
||||
type_)
|
||||
return ret
|
||||
|
||||
|
||||
def _inspects(*types):
|
||||
def decorate(fn_or_cls):
|
||||
for type_ in types:
|
||||
if type_ in _registrars:
|
||||
raise AssertionError(
|
||||
"Type %s is already "
|
||||
"registered" % type_)
|
||||
_registrars[type_] = fn_or_cls
|
||||
return fn_or_cls
|
||||
return decorate
|
||||
|
||||
|
||||
def _self_inspects(cls):
|
||||
_inspects(cls)(True)
|
||||
return cls
|
||||
540
sqlalchemy/orm/base.py
Normal file
540
sqlalchemy/orm/base.py
Normal file
@@ -0,0 +1,540 @@
|
||||
# orm/base.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Constants and rudimental functions used throughout the ORM.
|
||||
|
||||
"""
|
||||
|
||||
from .. import util, inspection, exc as sa_exc
|
||||
from ..sql import expression
|
||||
from . import exc
|
||||
import operator
|
||||
|
||||
PASSIVE_NO_RESULT = util.symbol(
|
||||
'PASSIVE_NO_RESULT',
|
||||
"""Symbol returned by a loader callable or other attribute/history
|
||||
retrieval operation when a value could not be determined, based
|
||||
on loader callable flags.
|
||||
"""
|
||||
)
|
||||
|
||||
ATTR_WAS_SET = util.symbol(
|
||||
'ATTR_WAS_SET',
|
||||
"""Symbol returned by a loader callable to indicate the
|
||||
retrieved value, or values, were assigned to their attributes
|
||||
on the target object.
|
||||
"""
|
||||
)
|
||||
|
||||
ATTR_EMPTY = util.symbol(
|
||||
'ATTR_EMPTY',
|
||||
"""Symbol used internally to indicate an attribute had no callable."""
|
||||
)
|
||||
|
||||
NO_VALUE = util.symbol(
|
||||
'NO_VALUE',
|
||||
"""Symbol which may be placed as the 'previous' value of an attribute,
|
||||
indicating no value was loaded for an attribute when it was modified,
|
||||
and flags indicated we were not to load it.
|
||||
"""
|
||||
)
|
||||
|
||||
NEVER_SET = util.symbol(
|
||||
'NEVER_SET',
|
||||
"""Symbol which may be placed as the 'previous' value of an attribute
|
||||
indicating that the attribute had not been assigned to previously.
|
||||
"""
|
||||
)
|
||||
|
||||
NO_CHANGE = util.symbol(
|
||||
"NO_CHANGE",
|
||||
"""No callables or SQL should be emitted on attribute access
|
||||
and no state should change
|
||||
""", canonical=0
|
||||
)
|
||||
|
||||
CALLABLES_OK = util.symbol(
|
||||
"CALLABLES_OK",
|
||||
"""Loader callables can be fired off if a value
|
||||
is not present.
|
||||
""", canonical=1
|
||||
)
|
||||
|
||||
SQL_OK = util.symbol(
|
||||
"SQL_OK",
|
||||
"""Loader callables can emit SQL at least on scalar value attributes.""",
|
||||
canonical=2
|
||||
)
|
||||
|
||||
RELATED_OBJECT_OK = util.symbol(
|
||||
"RELATED_OBJECT_OK",
|
||||
"""Callables can use SQL to load related objects as well
|
||||
as scalar value attributes.
|
||||
""", canonical=4
|
||||
)
|
||||
|
||||
INIT_OK = util.symbol(
|
||||
"INIT_OK",
|
||||
"""Attributes should be initialized with a blank
|
||||
value (None or an empty collection) upon get, if no other
|
||||
value can be obtained.
|
||||
""", canonical=8
|
||||
)
|
||||
|
||||
NON_PERSISTENT_OK = util.symbol(
|
||||
"NON_PERSISTENT_OK",
|
||||
"""Callables can be emitted if the parent is not persistent.""",
|
||||
canonical=16
|
||||
)
|
||||
|
||||
LOAD_AGAINST_COMMITTED = util.symbol(
|
||||
"LOAD_AGAINST_COMMITTED",
|
||||
"""Callables should use committed values as primary/foreign keys during a
|
||||
load.
|
||||
""", canonical=32
|
||||
)
|
||||
|
||||
NO_AUTOFLUSH = util.symbol(
|
||||
"NO_AUTOFLUSH",
|
||||
"""Loader callables should disable autoflush.""",
|
||||
canonical=64
|
||||
)
|
||||
|
||||
# pre-packaged sets of flags used as inputs
|
||||
PASSIVE_OFF = util.symbol(
|
||||
"PASSIVE_OFF",
|
||||
"Callables can be emitted in all cases.",
|
||||
canonical=(RELATED_OBJECT_OK | NON_PERSISTENT_OK |
|
||||
INIT_OK | CALLABLES_OK | SQL_OK)
|
||||
)
|
||||
PASSIVE_RETURN_NEVER_SET = util.symbol(
|
||||
"PASSIVE_RETURN_NEVER_SET",
|
||||
"""PASSIVE_OFF ^ INIT_OK""",
|
||||
canonical=PASSIVE_OFF ^ INIT_OK
|
||||
)
|
||||
PASSIVE_NO_INITIALIZE = util.symbol(
|
||||
"PASSIVE_NO_INITIALIZE",
|
||||
"PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK",
|
||||
canonical=PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK
|
||||
)
|
||||
PASSIVE_NO_FETCH = util.symbol(
|
||||
"PASSIVE_NO_FETCH",
|
||||
"PASSIVE_OFF ^ SQL_OK",
|
||||
canonical=PASSIVE_OFF ^ SQL_OK
|
||||
)
|
||||
PASSIVE_NO_FETCH_RELATED = util.symbol(
|
||||
"PASSIVE_NO_FETCH_RELATED",
|
||||
"PASSIVE_OFF ^ RELATED_OBJECT_OK",
|
||||
canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK
|
||||
)
|
||||
PASSIVE_ONLY_PERSISTENT = util.symbol(
|
||||
"PASSIVE_ONLY_PERSISTENT",
|
||||
"PASSIVE_OFF ^ NON_PERSISTENT_OK",
|
||||
canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK
|
||||
)
|
||||
|
||||
DEFAULT_MANAGER_ATTR = '_sa_class_manager'
|
||||
DEFAULT_STATE_ATTR = '_sa_instance_state'
|
||||
_INSTRUMENTOR = ('mapper', 'instrumentor')
|
||||
|
||||
EXT_CONTINUE = util.symbol('EXT_CONTINUE')
|
||||
EXT_STOP = util.symbol('EXT_STOP')
|
||||
|
||||
ONETOMANY = util.symbol(
|
||||
'ONETOMANY',
|
||||
"""Indicates the one-to-many direction for a :func:`.relationship`.
|
||||
|
||||
This symbol is typically used by the internals but may be exposed within
|
||||
certain API features.
|
||||
|
||||
""")
|
||||
|
||||
MANYTOONE = util.symbol(
|
||||
'MANYTOONE',
|
||||
"""Indicates the many-to-one direction for a :func:`.relationship`.
|
||||
|
||||
This symbol is typically used by the internals but may be exposed within
|
||||
certain API features.
|
||||
|
||||
""")
|
||||
|
||||
MANYTOMANY = util.symbol(
|
||||
'MANYTOMANY',
|
||||
"""Indicates the many-to-many direction for a :func:`.relationship`.
|
||||
|
||||
This symbol is typically used by the internals but may be exposed within
|
||||
certain API features.
|
||||
|
||||
""")
|
||||
|
||||
NOT_EXTENSION = util.symbol(
|
||||
'NOT_EXTENSION',
|
||||
"""Symbol indicating an :class:`InspectionAttr` that's
|
||||
not part of sqlalchemy.ext.
|
||||
|
||||
Is assigned to the :attr:`.InspectionAttr.extension_type`
|
||||
attibute.
|
||||
|
||||
""")
|
||||
|
||||
_never_set = frozenset([NEVER_SET])
|
||||
|
||||
_none_set = frozenset([None, NEVER_SET, PASSIVE_NO_RESULT])
|
||||
|
||||
_SET_DEFERRED_EXPIRED = util.symbol("SET_DEFERRED_EXPIRED")
|
||||
|
||||
_DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE")
|
||||
|
||||
|
||||
def _generative(*assertions):
|
||||
"""Mark a method as generative, e.g. method-chained."""
|
||||
|
||||
@util.decorator
|
||||
def generate(fn, *args, **kw):
|
||||
self = args[0]._clone()
|
||||
for assertion in assertions:
|
||||
assertion(self, fn.__name__)
|
||||
fn(self, *args[1:], **kw)
|
||||
return self
|
||||
return generate
|
||||
|
||||
|
||||
# these can be replaced by sqlalchemy.ext.instrumentation
|
||||
# if augmented class instrumentation is enabled.
|
||||
def manager_of_class(cls):
|
||||
return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None)
|
||||
|
||||
instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
|
||||
|
||||
instance_dict = operator.attrgetter('__dict__')
|
||||
|
||||
|
||||
def instance_str(instance):
|
||||
"""Return a string describing an instance."""
|
||||
|
||||
return state_str(instance_state(instance))
|
||||
|
||||
|
||||
def state_str(state):
|
||||
"""Return a string describing an instance via its InstanceState."""
|
||||
|
||||
if state is None:
|
||||
return "None"
|
||||
else:
|
||||
return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj()))
|
||||
|
||||
|
||||
def state_class_str(state):
|
||||
"""Return a string describing an instance's class via its
|
||||
InstanceState.
|
||||
"""
|
||||
|
||||
if state is None:
|
||||
return "None"
|
||||
else:
|
||||
return '<%s>' % (state.class_.__name__, )
|
||||
|
||||
|
||||
def attribute_str(instance, attribute):
|
||||
return instance_str(instance) + "." + attribute
|
||||
|
||||
|
||||
def state_attribute_str(state, attribute):
|
||||
return state_str(state) + "." + attribute
|
||||
|
||||
|
||||
def object_mapper(instance):
|
||||
"""Given an object, return the primary Mapper associated with the object
|
||||
instance.
|
||||
|
||||
Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError`
|
||||
if no mapping is configured.
|
||||
|
||||
This function is available via the inspection system as::
|
||||
|
||||
inspect(instance).mapper
|
||||
|
||||
Using the inspection system will raise
|
||||
:class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is
|
||||
not part of a mapping.
|
||||
|
||||
"""
|
||||
return object_state(instance).mapper
|
||||
|
||||
|
||||
def object_state(instance):
|
||||
"""Given an object, return the :class:`.InstanceState`
|
||||
associated with the object.
|
||||
|
||||
Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError`
|
||||
if no mapping is configured.
|
||||
|
||||
Equivalent functionality is available via the :func:`.inspect`
|
||||
function as::
|
||||
|
||||
inspect(instance)
|
||||
|
||||
Using the inspection system will raise
|
||||
:class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is
|
||||
not part of a mapping.
|
||||
|
||||
"""
|
||||
state = _inspect_mapped_object(instance)
|
||||
if state is None:
|
||||
raise exc.UnmappedInstanceError(instance)
|
||||
else:
|
||||
return state
|
||||
|
||||
|
||||
@inspection._inspects(object)
|
||||
def _inspect_mapped_object(instance):
|
||||
try:
|
||||
return instance_state(instance)
|
||||
# TODO: whats the py-2/3 syntax to catch two
|
||||
# different kinds of exceptions at once ?
|
||||
except exc.UnmappedClassError:
|
||||
return None
|
||||
except exc.NO_STATE:
|
||||
return None
|
||||
|
||||
|
||||
def _class_to_mapper(class_or_mapper):
|
||||
insp = inspection.inspect(class_or_mapper, False)
|
||||
if insp is not None:
|
||||
return insp.mapper
|
||||
else:
|
||||
raise exc.UnmappedClassError(class_or_mapper)
|
||||
|
||||
|
||||
def _mapper_or_none(entity):
|
||||
"""Return the :class:`.Mapper` for the given class or None if the
|
||||
class is not mapped.
|
||||
"""
|
||||
|
||||
insp = inspection.inspect(entity, False)
|
||||
if insp is not None:
|
||||
return insp.mapper
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _is_mapped_class(entity):
|
||||
"""Return True if the given object is a mapped class,
|
||||
:class:`.Mapper`, or :class:`.AliasedClass`.
|
||||
"""
|
||||
|
||||
insp = inspection.inspect(entity, False)
|
||||
return insp is not None and \
|
||||
not insp.is_clause_element and \
|
||||
(
|
||||
insp.is_mapper or insp.is_aliased_class
|
||||
)
|
||||
|
||||
|
||||
def _attr_as_key(attr):
|
||||
if hasattr(attr, 'key'):
|
||||
return attr.key
|
||||
else:
|
||||
return expression._column_as_key(attr)
|
||||
|
||||
|
||||
def _orm_columns(entity):
|
||||
insp = inspection.inspect(entity, False)
|
||||
if hasattr(insp, 'selectable') and hasattr(insp.selectable, 'c'):
|
||||
return [c for c in insp.selectable.c]
|
||||
else:
|
||||
return [entity]
|
||||
|
||||
|
||||
def _is_aliased_class(entity):
|
||||
insp = inspection.inspect(entity, False)
|
||||
return insp is not None and \
|
||||
getattr(insp, "is_aliased_class", False)
|
||||
|
||||
|
||||
def _entity_descriptor(entity, key):
|
||||
"""Return a class attribute given an entity and string name.
|
||||
|
||||
May return :class:`.InstrumentedAttribute` or user-defined
|
||||
attribute.
|
||||
|
||||
"""
|
||||
insp = inspection.inspect(entity)
|
||||
if insp.is_selectable:
|
||||
description = entity
|
||||
entity = insp.c
|
||||
elif insp.is_aliased_class:
|
||||
entity = insp.entity
|
||||
description = entity
|
||||
elif hasattr(insp, "mapper"):
|
||||
description = entity = insp.mapper.class_
|
||||
else:
|
||||
description = entity
|
||||
|
||||
try:
|
||||
return getattr(entity, key)
|
||||
except AttributeError:
|
||||
raise sa_exc.InvalidRequestError(
|
||||
"Entity '%s' has no property '%s'" %
|
||||
(description, key)
|
||||
)
|
||||
|
||||
_state_mapper = util.dottedgetter('manager.mapper')
|
||||
|
||||
|
||||
@inspection._inspects(type)
|
||||
def _inspect_mapped_class(class_, configure=False):
|
||||
try:
|
||||
class_manager = manager_of_class(class_)
|
||||
if not class_manager.is_mapped:
|
||||
return None
|
||||
mapper = class_manager.mapper
|
||||
except exc.NO_STATE:
|
||||
return None
|
||||
else:
|
||||
if configure and mapper._new_mappers:
|
||||
mapper._configure_all()
|
||||
return mapper
|
||||
|
||||
|
||||
def class_mapper(class_, configure=True):
|
||||
"""Given a class, return the primary :class:`.Mapper` associated
|
||||
with the key.
|
||||
|
||||
Raises :exc:`.UnmappedClassError` if no mapping is configured
|
||||
on the given class, or :exc:`.ArgumentError` if a non-class
|
||||
object is passed.
|
||||
|
||||
Equivalent functionality is available via the :func:`.inspect`
|
||||
function as::
|
||||
|
||||
inspect(some_mapped_class)
|
||||
|
||||
Using the inspection system will raise
|
||||
:class:`sqlalchemy.exc.NoInspectionAvailable` if the class is not mapped.
|
||||
|
||||
"""
|
||||
mapper = _inspect_mapped_class(class_, configure=configure)
|
||||
if mapper is None:
|
||||
if not isinstance(class_, type):
|
||||
raise sa_exc.ArgumentError(
|
||||
"Class object expected, got '%r'." % (class_, ))
|
||||
raise exc.UnmappedClassError(class_)
|
||||
else:
|
||||
return mapper
|
||||
|
||||
|
||||
class InspectionAttr(object):
|
||||
"""A base class applied to all ORM objects that can be returned
|
||||
by the :func:`.inspect` function.
|
||||
|
||||
The attributes defined here allow the usage of simple boolean
|
||||
checks to test basic facts about the object returned.
|
||||
|
||||
While the boolean checks here are basically the same as using
|
||||
the Python isinstance() function, the flags here can be used without
|
||||
the need to import all of these classes, and also such that
|
||||
the SQLAlchemy class system can change while leaving the flags
|
||||
here intact for forwards-compatibility.
|
||||
|
||||
"""
|
||||
__slots__ = ()
|
||||
|
||||
is_selectable = False
|
||||
"""Return True if this object is an instance of :class:`.Selectable`."""
|
||||
|
||||
is_aliased_class = False
|
||||
"""True if this object is an instance of :class:`.AliasedClass`."""
|
||||
|
||||
is_instance = False
|
||||
"""True if this object is an instance of :class:`.InstanceState`."""
|
||||
|
||||
is_mapper = False
|
||||
"""True if this object is an instance of :class:`.Mapper`."""
|
||||
|
||||
is_property = False
|
||||
"""True if this object is an instance of :class:`.MapperProperty`."""
|
||||
|
||||
is_attribute = False
|
||||
"""True if this object is a Python :term:`descriptor`.
|
||||
|
||||
This can refer to one of many types. Usually a
|
||||
:class:`.QueryableAttribute` which handles attributes events on behalf
|
||||
of a :class:`.MapperProperty`. But can also be an extension type
|
||||
such as :class:`.AssociationProxy` or :class:`.hybrid_property`.
|
||||
The :attr:`.InspectionAttr.extension_type` will refer to a constant
|
||||
identifying the specific subtype.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:attr:`.Mapper.all_orm_descriptors`
|
||||
|
||||
"""
|
||||
|
||||
is_clause_element = False
|
||||
"""True if this object is an instance of :class:`.ClauseElement`."""
|
||||
|
||||
extension_type = NOT_EXTENSION
|
||||
"""The extension type, if any.
|
||||
Defaults to :data:`.interfaces.NOT_EXTENSION`
|
||||
|
||||
.. versionadded:: 0.8.0
|
||||
|
||||
.. seealso::
|
||||
|
||||
:data:`.HYBRID_METHOD`
|
||||
|
||||
:data:`.HYBRID_PROPERTY`
|
||||
|
||||
:data:`.ASSOCIATION_PROXY`
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InspectionAttrInfo(InspectionAttr):
|
||||
"""Adds the ``.info`` attribute to :class:`.InspectionAttr`.
|
||||
|
||||
The rationale for :class:`.InspectionAttr` vs. :class:`.InspectionAttrInfo`
|
||||
is that the former is compatible as a mixin for classes that specify
|
||||
``__slots__``; this is essentially an implementation artifact.
|
||||
|
||||
"""
|
||||
|
||||
@util.memoized_property
|
||||
def info(self):
|
||||
"""Info dictionary associated with the object, allowing user-defined
|
||||
data to be associated with this :class:`.InspectionAttr`.
|
||||
|
||||
The dictionary is generated when first accessed. Alternatively,
|
||||
it can be specified as a constructor argument to the
|
||||
:func:`.column_property`, :func:`.relationship`, or :func:`.composite`
|
||||
functions.
|
||||
|
||||
.. versionadded:: 0.8 Added support for .info to all
|
||||
:class:`.MapperProperty` subclasses.
|
||||
|
||||
.. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also
|
||||
available on extension types via the
|
||||
:attr:`.InspectionAttrInfo.info` attribute, so that it can apply
|
||||
to a wider variety of ORM and extension constructs.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:attr:`.QueryableAttribute.info`
|
||||
|
||||
:attr:`.SchemaItem.info`
|
||||
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
class _MappedAttribute(object):
|
||||
"""Mixin for attributes which should be replaced by mapper-assigned
|
||||
attributes.
|
||||
|
||||
"""
|
||||
__slots__ = ()
|
||||
487
sqlalchemy/orm/deprecated_interfaces.py
Normal file
487
sqlalchemy/orm/deprecated_interfaces.py
Normal file
@@ -0,0 +1,487 @@
|
||||
# orm/deprecated_interfaces.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .. import event, util
|
||||
from .interfaces import EXT_CONTINUE
|
||||
|
||||
|
||||
@util.langhelpers.dependency_for("sqlalchemy.orm.interfaces")
|
||||
class MapperExtension(object):
|
||||
"""Base implementation for :class:`.Mapper` event hooks.
|
||||
|
||||
.. note::
|
||||
|
||||
:class:`.MapperExtension` is deprecated. Please
|
||||
refer to :func:`.event.listen` as well as
|
||||
:class:`.MapperEvents`.
|
||||
|
||||
New extension classes subclass :class:`.MapperExtension` and are specified
|
||||
using the ``extension`` mapper() argument, which is a single
|
||||
:class:`.MapperExtension` or a list of such::
|
||||
|
||||
from sqlalchemy.orm.interfaces import MapperExtension
|
||||
|
||||
class MyExtension(MapperExtension):
|
||||
def before_insert(self, mapper, connection, instance):
|
||||
print "instance %s before insert !" % instance
|
||||
|
||||
m = mapper(User, users_table, extension=MyExtension())
|
||||
|
||||
A single mapper can maintain a chain of ``MapperExtension``
|
||||
objects. When a particular mapping event occurs, the
|
||||
corresponding method on each ``MapperExtension`` is invoked
|
||||
serially, and each method has the ability to halt the chain
|
||||
from proceeding further::
|
||||
|
||||
m = mapper(User, users_table, extension=[ext1, ext2, ext3])
|
||||
|
||||
Each ``MapperExtension`` method returns the symbol
|
||||
EXT_CONTINUE by default. This symbol generally means "move
|
||||
to the next ``MapperExtension`` for processing". For methods
|
||||
that return objects like translated rows or new object
|
||||
instances, EXT_CONTINUE means the result of the method
|
||||
should be ignored. In some cases it's required for a
|
||||
default mapper activity to be performed, such as adding a
|
||||
new instance to a result list.
|
||||
|
||||
The symbol EXT_STOP has significance within a chain
|
||||
of ``MapperExtension`` objects that the chain will be stopped
|
||||
when this symbol is returned. Like EXT_CONTINUE, it also
|
||||
has additional significance in some cases that a default
|
||||
mapper activity will not be performed.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _adapt_instrument_class(cls, self, listener):
|
||||
cls._adapt_listener_methods(self, listener, ('instrument_class',))
|
||||
|
||||
@classmethod
|
||||
def _adapt_listener(cls, self, listener):
|
||||
cls._adapt_listener_methods(
|
||||
self, listener,
|
||||
(
|
||||
'init_instance',
|
||||
'init_failed',
|
||||
'reconstruct_instance',
|
||||
'before_insert',
|
||||
'after_insert',
|
||||
'before_update',
|
||||
'after_update',
|
||||
'before_delete',
|
||||
'after_delete'
|
||||
))
|
||||
|
||||
@classmethod
|
||||
def _adapt_listener_methods(cls, self, listener, methods):
|
||||
|
||||
for meth in methods:
|
||||
me_meth = getattr(MapperExtension, meth)
|
||||
ls_meth = getattr(listener, meth)
|
||||
|
||||
if not util.methods_equivalent(me_meth, ls_meth):
|
||||
if meth == 'reconstruct_instance':
|
||||
def go(ls_meth):
|
||||
def reconstruct(instance, ctx):
|
||||
ls_meth(self, instance)
|
||||
return reconstruct
|
||||
event.listen(self.class_manager, 'load',
|
||||
go(ls_meth), raw=False, propagate=True)
|
||||
elif meth == 'init_instance':
|
||||
def go(ls_meth):
|
||||
def init_instance(instance, args, kwargs):
|
||||
ls_meth(self, self.class_,
|
||||
self.class_manager.original_init,
|
||||
instance, args, kwargs)
|
||||
return init_instance
|
||||
event.listen(self.class_manager, 'init',
|
||||
go(ls_meth), raw=False, propagate=True)
|
||||
elif meth == 'init_failed':
|
||||
def go(ls_meth):
|
||||
def init_failed(instance, args, kwargs):
|
||||
util.warn_exception(
|
||||
ls_meth, self, self.class_,
|
||||
self.class_manager.original_init,
|
||||
instance, args, kwargs)
|
||||
|
||||
return init_failed
|
||||
event.listen(self.class_manager, 'init_failure',
|
||||
go(ls_meth), raw=False, propagate=True)
|
||||
else:
|
||||
event.listen(self, "%s" % meth, ls_meth,
|
||||
raw=False, retval=True, propagate=True)
|
||||
|
||||
def instrument_class(self, mapper, class_):
|
||||
"""Receive a class when the mapper is first constructed, and has
|
||||
applied instrumentation to the mapped class.
|
||||
|
||||
The return value is only significant within the ``MapperExtension``
|
||||
chain; the parent mapper's behavior isn't modified by this method.
|
||||
|
||||
"""
|
||||
return EXT_CONTINUE
|
||||
|
||||
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
|
||||
"""Receive an instance when its constructor is called.
|
||||
|
||||
This method is only called during a userland construction of
|
||||
an object. It is not called when an object is loaded from the
|
||||
database.
|
||||
|
||||
The return value is only significant within the ``MapperExtension``
|
||||
chain; the parent mapper's behavior isn't modified by this method.
|
||||
|
||||
"""
|
||||
return EXT_CONTINUE
|
||||
|
||||
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
|
||||
"""Receive an instance when its constructor has been called,
|
||||
and raised an exception.
|
||||
|
||||
This method is only called during a userland construction of
|
||||
an object. It is not called when an object is loaded from the
|
||||
database.
|
||||
|
||||
The return value is only significant within the ``MapperExtension``
|
||||
chain; the parent mapper's behavior isn't modified by this method.
|
||||
|
||||
"""
|
||||
return EXT_CONTINUE
|
||||
|
||||
def reconstruct_instance(self, mapper, instance):
|
||||
"""Receive an object instance after it has been created via
|
||||
``__new__``, and after initial attribute population has
|
||||
occurred.
|
||||
|
||||
This typically occurs when the instance is created based on
|
||||
incoming result rows, and is only called once for that
|
||||
instance's lifetime.
|
||||
|
||||
Note that during a result-row load, this method is called upon
|
||||
the first row received for this instance. Note that some
|
||||
attributes and collections may or may not be loaded or even
|
||||
initialized, depending on what's present in the result rows.
|
||||
|
||||
The return value is only significant within the ``MapperExtension``
|
||||
chain; the parent mapper's behavior isn't modified by this method.
|
||||
|
||||
"""
|
||||
return EXT_CONTINUE
|
||||
|
||||
def before_insert(self, mapper, connection, instance):
|
||||
"""Receive an object instance before that instance is inserted
|
||||
into its table.
|
||||
|
||||
This is a good place to set up primary key values and such
|
||||
that aren't handled otherwise.
|
||||
|
||||
Column-based attributes can be modified within this method
|
||||
which will result in the new value being inserted. However
|
||||
*no* changes to the overall flush plan can be made, and
|
||||
manipulation of the ``Session`` will not have the desired effect.
|
||||
To manipulate the ``Session`` within an extension, use
|
||||
``SessionExtension``.
|
||||
|
||||
The return value is only significant within the ``MapperExtension``
|
||||
chain; the parent mapper's behavior isn't modified by this method.
|
||||
|
||||
"""
|
||||
|
||||
return EXT_CONTINUE
|
||||
|
||||
def after_insert(self, mapper, connection, instance):
|
||||
"""Receive an object instance after that instance is inserted.
|
||||
|
||||
The return value is only significant within the ``MapperExtension``
|
||||
chain; the parent mapper's behavior isn't modified by this method.
|
||||
|
||||
"""
|
||||
|
||||
return EXT_CONTINUE
|
||||
|
||||
def before_update(self, mapper, connection, instance):
|
||||
"""Receive an object instance before that instance is updated.
|
||||
|
||||
Note that this method is called for all instances that are marked as
|
||||
"dirty", even those which have no net changes to their column-based
|
||||
attributes. An object is marked as dirty when any of its column-based
|
||||
attributes have a "set attribute" operation called or when any of its
|
||||
collections are modified. If, at update time, no column-based
|
||||
attributes have any net changes, no UPDATE statement will be issued.
|
||||
This means that an instance being sent to before_update is *not* a
|
||||
guarantee that an UPDATE statement will be issued (although you can
|
||||
affect the outcome here).
|
||||
|
||||
To detect if the column-based attributes on the object have net
|
||||
changes, and will therefore generate an UPDATE statement, use
|
||||
``object_session(instance).is_modified(instance,
|
||||
include_collections=False)``.
|
||||
|
||||
Column-based attributes can be modified within this method
|
||||
which will result in the new value being updated. However
|
||||
*no* changes to the overall flush plan can be made, and
|
||||
manipulation of the ``Session`` will not have the desired effect.
|
||||
To manipulate the ``Session`` within an extension, use
|
||||
``SessionExtension``.
|
||||
|
||||
The return value is only significant within the ``MapperExtension``
|
||||
chain; the parent mapper's behavior isn't modified by this method.
|
||||
|
||||
"""
|
||||
|
||||
return EXT_CONTINUE
|
||||
|
||||
def after_update(self, mapper, connection, instance):
|
||||
"""Receive an object instance after that instance is updated.
|
||||
|
||||
The return value is only significant within the ``MapperExtension``
|
||||
chain; the parent mapper's behavior isn't modified by this method.
|
||||
|
||||
"""
|
||||
|
||||
return EXT_CONTINUE
|
||||
|
||||
def before_delete(self, mapper, connection, instance):
|
||||
"""Receive an object instance before that instance is deleted.
|
||||
|
||||
Note that *no* changes to the overall flush plan can be made
|
||||
here; and manipulation of the ``Session`` will not have the
|
||||
desired effect. To manipulate the ``Session`` within an
|
||||
extension, use ``SessionExtension``.
|
||||
|
||||
The return value is only significant within the ``MapperExtension``
|
||||
chain; the parent mapper's behavior isn't modified by this method.
|
||||
|
||||
"""
|
||||
|
||||
return EXT_CONTINUE
|
||||
|
||||
def after_delete(self, mapper, connection, instance):
|
||||
"""Receive an object instance after that instance is deleted.
|
||||
|
||||
The return value is only significant within the ``MapperExtension``
|
||||
chain; the parent mapper's behavior isn't modified by this method.
|
||||
|
||||
"""
|
||||
|
||||
return EXT_CONTINUE
|
||||
|
||||
|
||||
@util.langhelpers.dependency_for("sqlalchemy.orm.interfaces")
|
||||
class SessionExtension(object):
|
||||
|
||||
"""Base implementation for :class:`.Session` event hooks.
|
||||
|
||||
.. note::
|
||||
|
||||
:class:`.SessionExtension` is deprecated. Please
|
||||
refer to :func:`.event.listen` as well as
|
||||
:class:`.SessionEvents`.
|
||||
|
||||
Subclasses may be installed into a :class:`.Session` (or
|
||||
:class:`.sessionmaker`) using the ``extension`` keyword
|
||||
argument::
|
||||
|
||||
from sqlalchemy.orm.interfaces import SessionExtension
|
||||
|
||||
class MySessionExtension(SessionExtension):
|
||||
def before_commit(self, session):
|
||||
print "before commit!"
|
||||
|
||||
Session = sessionmaker(extension=MySessionExtension())
|
||||
|
||||
The same :class:`.SessionExtension` instance can be used
|
||||
with any number of sessions.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _adapt_listener(cls, self, listener):
|
||||
for meth in [
|
||||
'before_commit',
|
||||
'after_commit',
|
||||
'after_rollback',
|
||||
'before_flush',
|
||||
'after_flush',
|
||||
'after_flush_postexec',
|
||||
'after_begin',
|
||||
'after_attach',
|
||||
'after_bulk_update',
|
||||
'after_bulk_delete',
|
||||
]:
|
||||
me_meth = getattr(SessionExtension, meth)
|
||||
ls_meth = getattr(listener, meth)
|
||||
|
||||
if not util.methods_equivalent(me_meth, ls_meth):
|
||||
event.listen(self, meth, getattr(listener, meth))
|
||||
|
||||
def before_commit(self, session):
|
||||
"""Execute right before commit is called.
|
||||
|
||||
Note that this may not be per-flush if a longer running
|
||||
transaction is ongoing."""
|
||||
|
||||
def after_commit(self, session):
|
||||
"""Execute after a commit has occurred.
|
||||
|
||||
Note that this may not be per-flush if a longer running
|
||||
transaction is ongoing."""
|
||||
|
||||
def after_rollback(self, session):
|
||||
"""Execute after a rollback has occurred.
|
||||
|
||||
Note that this may not be per-flush if a longer running
|
||||
transaction is ongoing."""
|
||||
|
||||
def before_flush(self, session, flush_context, instances):
|
||||
"""Execute before flush process has started.
|
||||
|
||||
`instances` is an optional list of objects which were passed to
|
||||
the ``flush()`` method. """
|
||||
|
||||
def after_flush(self, session, flush_context):
|
||||
"""Execute after flush has completed, but before commit has been
|
||||
called.
|
||||
|
||||
Note that the session's state is still in pre-flush, i.e. 'new',
|
||||
'dirty', and 'deleted' lists still show pre-flush state as well
|
||||
as the history settings on instance attributes."""
|
||||
|
||||
def after_flush_postexec(self, session, flush_context):
|
||||
"""Execute after flush has completed, and after the post-exec
|
||||
state occurs.
|
||||
|
||||
This will be when the 'new', 'dirty', and 'deleted' lists are in
|
||||
their final state. An actual commit() may or may not have
|
||||
occurred, depending on whether or not the flush started its own
|
||||
transaction or participated in a larger transaction. """
|
||||
|
||||
def after_begin(self, session, transaction, connection):
|
||||
"""Execute after a transaction is begun on a connection
|
||||
|
||||
`transaction` is the SessionTransaction. This method is called
|
||||
after an engine level transaction is begun on a connection. """
|
||||
|
||||
def after_attach(self, session, instance):
|
||||
"""Execute after an instance is attached to a session.
|
||||
|
||||
This is called after an add, delete or merge. """
|
||||
|
||||
def after_bulk_update(self, session, query, query_context, result):
|
||||
"""Execute after a bulk update operation to the session.
|
||||
|
||||
This is called after a session.query(...).update()
|
||||
|
||||
`query` is the query object that this update operation was
|
||||
called on. `query_context` was the query context object.
|
||||
`result` is the result object returned from the bulk operation.
|
||||
"""
|
||||
|
||||
def after_bulk_delete(self, session, query, query_context, result):
|
||||
"""Execute after a bulk delete operation to the session.
|
||||
|
||||
This is called after a session.query(...).delete()
|
||||
|
||||
`query` is the query object that this delete operation was
|
||||
called on. `query_context` was the query context object.
|
||||
`result` is the result object returned from the bulk operation.
|
||||
"""
|
||||
|
||||
|
||||
@util.langhelpers.dependency_for("sqlalchemy.orm.interfaces")
|
||||
class AttributeExtension(object):
|
||||
"""Base implementation for :class:`.AttributeImpl` event hooks, events
|
||||
that fire upon attribute mutations in user code.
|
||||
|
||||
.. note::
|
||||
|
||||
:class:`.AttributeExtension` is deprecated. Please
|
||||
refer to :func:`.event.listen` as well as
|
||||
:class:`.AttributeEvents`.
|
||||
|
||||
:class:`.AttributeExtension` is used to listen for set,
|
||||
remove, and append events on individual mapped attributes.
|
||||
It is established on an individual mapped attribute using
|
||||
the `extension` argument, available on
|
||||
:func:`.column_property`, :func:`.relationship`, and
|
||||
others::
|
||||
|
||||
from sqlalchemy.orm.interfaces import AttributeExtension
|
||||
from sqlalchemy.orm import mapper, relationship, column_property
|
||||
|
||||
class MyAttrExt(AttributeExtension):
|
||||
def append(self, state, value, initiator):
|
||||
print "append event !"
|
||||
return value
|
||||
|
||||
def set(self, state, value, oldvalue, initiator):
|
||||
print "set event !"
|
||||
return value
|
||||
|
||||
mapper(SomeClass, sometable, properties={
|
||||
'foo':column_property(sometable.c.foo, extension=MyAttrExt()),
|
||||
'bar':relationship(Bar, extension=MyAttrExt())
|
||||
})
|
||||
|
||||
Note that the :class:`.AttributeExtension` methods
|
||||
:meth:`~.AttributeExtension.append` and
|
||||
:meth:`~.AttributeExtension.set` need to return the
|
||||
``value`` parameter. The returned value is used as the
|
||||
effective value, and allows the extension to change what is
|
||||
ultimately persisted.
|
||||
|
||||
AttributeExtension is assembled within the descriptors associated
|
||||
with a mapped class.
|
||||
|
||||
"""
|
||||
|
||||
active_history = True
|
||||
"""indicates that the set() method would like to receive the 'old' value,
|
||||
even if it means firing lazy callables.
|
||||
|
||||
Note that ``active_history`` can also be set directly via
|
||||
:func:`.column_property` and :func:`.relationship`.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _adapt_listener(cls, self, listener):
|
||||
event.listen(self, 'append', listener.append,
|
||||
active_history=listener.active_history,
|
||||
raw=True, retval=True)
|
||||
event.listen(self, 'remove', listener.remove,
|
||||
active_history=listener.active_history,
|
||||
raw=True, retval=True)
|
||||
event.listen(self, 'set', listener.set,
|
||||
active_history=listener.active_history,
|
||||
raw=True, retval=True)
|
||||
|
||||
def append(self, state, value, initiator):
|
||||
"""Receive a collection append event.
|
||||
|
||||
The returned value will be used as the actual value to be
|
||||
appended.
|
||||
|
||||
"""
|
||||
return value
|
||||
|
||||
def remove(self, state, value, initiator):
|
||||
"""Receive a remove event.
|
||||
|
||||
No return value is defined.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def set(self, state, value, oldvalue, initiator):
|
||||
"""Receive a set event.
|
||||
|
||||
The returned value will be used as the actual value to be
|
||||
set.
|
||||
|
||||
"""
|
||||
return value
|
||||
699
sqlalchemy/orm/descriptor_props.py
Normal file
699
sqlalchemy/orm/descriptor_props.py
Normal file
@@ -0,0 +1,699 @@
|
||||
# orm/descriptor_props.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Descriptor properties are more "auxiliary" properties
|
||||
that exist as configurational elements, but don't participate
|
||||
as actively in the load/persist ORM loop.
|
||||
|
||||
"""
|
||||
|
||||
from .interfaces import MapperProperty, PropComparator
|
||||
from .util import _none_set
|
||||
from . import attributes
|
||||
from .. import util, sql, exc as sa_exc, event, schema
|
||||
from ..sql import expression
|
||||
from . import properties
|
||||
from . import query
|
||||
|
||||
|
||||
class DescriptorProperty(MapperProperty):
|
||||
""":class:`.MapperProperty` which proxies access to a
|
||||
user-defined descriptor."""
|
||||
|
||||
doc = None
|
||||
|
||||
def instrument_class(self, mapper):
|
||||
prop = self
|
||||
|
||||
class _ProxyImpl(object):
|
||||
accepts_scalar_loader = False
|
||||
expire_missing = True
|
||||
collection = False
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
|
||||
if hasattr(prop, 'get_history'):
|
||||
def get_history(self, state, dict_,
|
||||
passive=attributes.PASSIVE_OFF):
|
||||
return prop.get_history(state, dict_, passive)
|
||||
|
||||
if self.descriptor is None:
|
||||
desc = getattr(mapper.class_, self.key, None)
|
||||
if mapper._is_userland_descriptor(desc):
|
||||
self.descriptor = desc
|
||||
|
||||
if self.descriptor is None:
|
||||
def fset(obj, value):
|
||||
setattr(obj, self.name, value)
|
||||
|
||||
def fdel(obj):
|
||||
delattr(obj, self.name)
|
||||
|
||||
def fget(obj):
|
||||
return getattr(obj, self.name)
|
||||
|
||||
self.descriptor = property(
|
||||
fget=fget,
|
||||
fset=fset,
|
||||
fdel=fdel,
|
||||
)
|
||||
|
||||
proxy_attr = attributes.create_proxied_attribute(
|
||||
self.descriptor)(
|
||||
self.parent.class_,
|
||||
self.key,
|
||||
self.descriptor,
|
||||
lambda: self._comparator_factory(mapper),
|
||||
doc=self.doc,
|
||||
original_property=self
|
||||
)
|
||||
proxy_attr.impl = _ProxyImpl(self.key)
|
||||
mapper.class_manager.instrument_attribute(self.key, proxy_attr)
|
||||
|
||||
|
||||
@util.langhelpers.dependency_for("sqlalchemy.orm.properties")
|
||||
class CompositeProperty(DescriptorProperty):
|
||||
"""Defines a "composite" mapped attribute, representing a collection
|
||||
of columns as one attribute.
|
||||
|
||||
:class:`.CompositeProperty` is constructed using the :func:`.composite`
|
||||
function.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`mapper_composite`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, class_, *attrs, **kwargs):
|
||||
r"""Return a composite column-based property for use with a Mapper.
|
||||
|
||||
See the mapping documentation section :ref:`mapper_composite` for a
|
||||
full usage example.
|
||||
|
||||
The :class:`.MapperProperty` returned by :func:`.composite`
|
||||
is the :class:`.CompositeProperty`.
|
||||
|
||||
:param class\_:
|
||||
The "composite type" class.
|
||||
|
||||
:param \*cols:
|
||||
List of Column objects to be mapped.
|
||||
|
||||
:param active_history=False:
|
||||
When ``True``, indicates that the "previous" value for a
|
||||
scalar attribute should be loaded when replaced, if not
|
||||
already loaded. See the same flag on :func:`.column_property`.
|
||||
|
||||
.. versionchanged:: 0.7
|
||||
This flag specifically becomes meaningful
|
||||
- previously it was a placeholder.
|
||||
|
||||
:param group:
|
||||
A group name for this property when marked as deferred.
|
||||
|
||||
:param deferred:
|
||||
When True, the column property is "deferred", meaning that it does
|
||||
not load immediately, and is instead loaded when the attribute is
|
||||
first accessed on an instance. See also
|
||||
:func:`~sqlalchemy.orm.deferred`.
|
||||
|
||||
:param comparator_factory: a class which extends
|
||||
:class:`.CompositeProperty.Comparator` which provides custom SQL
|
||||
clause generation for comparison operations.
|
||||
|
||||
:param doc:
|
||||
optional string that will be applied as the doc on the
|
||||
class-bound descriptor.
|
||||
|
||||
:param info: Optional data dictionary which will be populated into the
|
||||
:attr:`.MapperProperty.info` attribute of this object.
|
||||
|
||||
.. versionadded:: 0.8
|
||||
|
||||
:param extension:
|
||||
an :class:`.AttributeExtension` instance,
|
||||
or list of extensions, which will be prepended to the list of
|
||||
attribute listeners for the resulting descriptor placed on the
|
||||
class. **Deprecated.** Please see :class:`.AttributeEvents`.
|
||||
|
||||
"""
|
||||
super(CompositeProperty, self).__init__()
|
||||
|
||||
self.attrs = attrs
|
||||
self.composite_class = class_
|
||||
self.active_history = kwargs.get('active_history', False)
|
||||
self.deferred = kwargs.get('deferred', False)
|
||||
self.group = kwargs.get('group', None)
|
||||
self.comparator_factory = kwargs.pop('comparator_factory',
|
||||
self.__class__.Comparator)
|
||||
if 'info' in kwargs:
|
||||
self.info = kwargs.pop('info')
|
||||
|
||||
util.set_creation_order(self)
|
||||
self._create_descriptor()
|
||||
|
||||
def instrument_class(self, mapper):
|
||||
super(CompositeProperty, self).instrument_class(mapper)
|
||||
self._setup_event_handlers()
|
||||
|
||||
def do_init(self):
|
||||
"""Initialization which occurs after the :class:`.CompositeProperty`
|
||||
has been associated with its parent mapper.
|
||||
|
||||
"""
|
||||
self._setup_arguments_on_columns()
|
||||
|
||||
def _create_descriptor(self):
|
||||
"""Create the Python descriptor that will serve as
|
||||
the access point on instances of the mapped class.
|
||||
|
||||
"""
|
||||
|
||||
def fget(instance):
|
||||
dict_ = attributes.instance_dict(instance)
|
||||
state = attributes.instance_state(instance)
|
||||
|
||||
if self.key not in dict_:
|
||||
# key not present. Iterate through related
|
||||
# attributes, retrieve their values. This
|
||||
# ensures they all load.
|
||||
values = [
|
||||
getattr(instance, key)
|
||||
for key in self._attribute_keys
|
||||
]
|
||||
|
||||
# current expected behavior here is that the composite is
|
||||
# created on access if the object is persistent or if
|
||||
# col attributes have non-None. This would be better
|
||||
# if the composite were created unconditionally,
|
||||
# but that would be a behavioral change.
|
||||
if self.key not in dict_ and (
|
||||
state.key is not None or
|
||||
not _none_set.issuperset(values)
|
||||
):
|
||||
dict_[self.key] = self.composite_class(*values)
|
||||
state.manager.dispatch.refresh(state, None, [self.key])
|
||||
|
||||
return dict_.get(self.key, None)
|
||||
|
||||
def fset(instance, value):
|
||||
dict_ = attributes.instance_dict(instance)
|
||||
state = attributes.instance_state(instance)
|
||||
attr = state.manager[self.key]
|
||||
previous = dict_.get(self.key, attributes.NO_VALUE)
|
||||
for fn in attr.dispatch.set:
|
||||
value = fn(state, value, previous, attr.impl)
|
||||
dict_[self.key] = value
|
||||
if value is None:
|
||||
for key in self._attribute_keys:
|
||||
setattr(instance, key, None)
|
||||
else:
|
||||
for key, value in zip(
|
||||
self._attribute_keys,
|
||||
value.__composite_values__()):
|
||||
setattr(instance, key, value)
|
||||
|
||||
def fdel(instance):
|
||||
state = attributes.instance_state(instance)
|
||||
dict_ = attributes.instance_dict(instance)
|
||||
previous = dict_.pop(self.key, attributes.NO_VALUE)
|
||||
attr = state.manager[self.key]
|
||||
attr.dispatch.remove(state, previous, attr.impl)
|
||||
for key in self._attribute_keys:
|
||||
setattr(instance, key, None)
|
||||
|
||||
self.descriptor = property(fget, fset, fdel)
|
||||
|
||||
@util.memoized_property
|
||||
def _comparable_elements(self):
|
||||
return [
|
||||
getattr(self.parent.class_, prop.key)
|
||||
for prop in self.props
|
||||
]
|
||||
|
||||
@util.memoized_property
|
||||
def props(self):
|
||||
props = []
|
||||
for attr in self.attrs:
|
||||
if isinstance(attr, str):
|
||||
prop = self.parent.get_property(
|
||||
attr, _configure_mappers=False)
|
||||
elif isinstance(attr, schema.Column):
|
||||
prop = self.parent._columntoproperty[attr]
|
||||
elif isinstance(attr, attributes.InstrumentedAttribute):
|
||||
prop = attr.property
|
||||
else:
|
||||
raise sa_exc.ArgumentError(
|
||||
"Composite expects Column objects or mapped "
|
||||
"attributes/attribute names as arguments, got: %r"
|
||||
% (attr,))
|
||||
props.append(prop)
|
||||
return props
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return [a for a in self.attrs if isinstance(a, schema.Column)]
|
||||
|
||||
def _setup_arguments_on_columns(self):
|
||||
"""Propagate configuration arguments made on this composite
|
||||
to the target columns, for those that apply.
|
||||
|
||||
"""
|
||||
for prop in self.props:
|
||||
prop.active_history = self.active_history
|
||||
if self.deferred:
|
||||
prop.deferred = self.deferred
|
||||
prop.strategy_key = (
|
||||
("deferred", True),
|
||||
("instrument", True))
|
||||
prop.group = self.group
|
||||
|
||||
def _setup_event_handlers(self):
|
||||
"""Establish events that populate/expire the composite attribute."""
|
||||
|
||||
def load_handler(state, *args):
|
||||
dict_ = state.dict
|
||||
|
||||
if self.key in dict_:
|
||||
return
|
||||
|
||||
# if column elements aren't loaded, skip.
|
||||
# __get__() will initiate a load for those
|
||||
# columns
|
||||
for k in self._attribute_keys:
|
||||
if k not in dict_:
|
||||
return
|
||||
|
||||
# assert self.key not in dict_
|
||||
dict_[self.key] = self.composite_class(
|
||||
*[state.dict[key] for key in
|
||||
self._attribute_keys]
|
||||
)
|
||||
|
||||
def expire_handler(state, keys):
|
||||
if keys is None or set(self._attribute_keys).intersection(keys):
|
||||
state.dict.pop(self.key, None)
|
||||
|
||||
def insert_update_handler(mapper, connection, state):
|
||||
"""After an insert or update, some columns may be expired due
|
||||
to server side defaults, or re-populated due to client side
|
||||
defaults. Pop out the composite value here so that it
|
||||
recreates.
|
||||
|
||||
"""
|
||||
|
||||
state.dict.pop(self.key, None)
|
||||
|
||||
event.listen(self.parent, 'after_insert',
|
||||
insert_update_handler, raw=True)
|
||||
event.listen(self.parent, 'after_update',
|
||||
insert_update_handler, raw=True)
|
||||
event.listen(self.parent, 'load',
|
||||
load_handler, raw=True, propagate=True)
|
||||
event.listen(self.parent, 'refresh',
|
||||
load_handler, raw=True, propagate=True)
|
||||
event.listen(self.parent, 'expire',
|
||||
expire_handler, raw=True, propagate=True)
|
||||
|
||||
# TODO: need a deserialize hook here
|
||||
|
||||
@util.memoized_property
|
||||
def _attribute_keys(self):
|
||||
return [
|
||||
prop.key for prop in self.props
|
||||
]
|
||||
|
||||
def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
|
||||
"""Provided for userland code that uses attributes.get_history()."""
|
||||
|
||||
added = []
|
||||
deleted = []
|
||||
|
||||
has_history = False
|
||||
for prop in self.props:
|
||||
key = prop.key
|
||||
hist = state.manager[key].impl.get_history(state, dict_)
|
||||
if hist.has_changes():
|
||||
has_history = True
|
||||
|
||||
non_deleted = hist.non_deleted()
|
||||
if non_deleted:
|
||||
added.extend(non_deleted)
|
||||
else:
|
||||
added.append(None)
|
||||
if hist.deleted:
|
||||
deleted.extend(hist.deleted)
|
||||
else:
|
||||
deleted.append(None)
|
||||
|
||||
if has_history:
|
||||
return attributes.History(
|
||||
[self.composite_class(*added)],
|
||||
(),
|
||||
[self.composite_class(*deleted)]
|
||||
)
|
||||
else:
|
||||
return attributes.History(
|
||||
(), [self.composite_class(*added)], ()
|
||||
)
|
||||
|
||||
def _comparator_factory(self, mapper):
|
||||
return self.comparator_factory(self, mapper)
|
||||
|
||||
class CompositeBundle(query.Bundle):
|
||||
def __init__(self, property, expr):
|
||||
self.property = property
|
||||
super(CompositeProperty.CompositeBundle, self).__init__(
|
||||
property.key, *expr)
|
||||
|
||||
def create_row_processor(self, query, procs, labels):
|
||||
def proc(row):
|
||||
return self.property.composite_class(
|
||||
*[proc(row) for proc in procs])
|
||||
return proc
|
||||
|
||||
class Comparator(PropComparator):
|
||||
"""Produce boolean, comparison, and other operators for
|
||||
:class:`.CompositeProperty` attributes.
|
||||
|
||||
See the example in :ref:`composite_operations` for an overview
|
||||
of usage , as well as the documentation for :class:`.PropComparator`.
|
||||
|
||||
See also:
|
||||
|
||||
:class:`.PropComparator`
|
||||
|
||||
:class:`.ColumnOperators`
|
||||
|
||||
:ref:`types_operators`
|
||||
|
||||
:attr:`.TypeEngine.comparator_factory`
|
||||
|
||||
"""
|
||||
|
||||
__hash__ = None
|
||||
|
||||
@property
|
||||
def clauses(self):
|
||||
return self.__clause_element__()
|
||||
|
||||
def __clause_element__(self):
|
||||
return expression.ClauseList(
|
||||
group=False, *self._comparable_elements)
|
||||
|
||||
def _query_clause_element(self):
|
||||
return CompositeProperty.CompositeBundle(
|
||||
self.prop, self.__clause_element__())
|
||||
|
||||
@util.memoized_property
|
||||
def _comparable_elements(self):
|
||||
if self._adapt_to_entity:
|
||||
return [
|
||||
getattr(
|
||||
self._adapt_to_entity.entity,
|
||||
prop.key
|
||||
) for prop in self.prop._comparable_elements
|
||||
]
|
||||
else:
|
||||
return self.prop._comparable_elements
|
||||
|
||||
def __eq__(self, other):
|
||||
if other is None:
|
||||
values = [None] * len(self.prop._comparable_elements)
|
||||
else:
|
||||
values = other.__composite_values__()
|
||||
comparisons = [
|
||||
a == b
|
||||
for a, b in zip(self.prop._comparable_elements, values)
|
||||
]
|
||||
if self._adapt_to_entity:
|
||||
comparisons = [self.adapter(x) for x in comparisons]
|
||||
return sql.and_(*comparisons)
|
||||
|
||||
def __ne__(self, other):
|
||||
return sql.not_(self.__eq__(other))
|
||||
|
||||
def __str__(self):
|
||||
return str(self.parent.class_.__name__) + "." + self.key
|
||||
|
||||
|
||||
@util.langhelpers.dependency_for("sqlalchemy.orm.properties")
|
||||
class ConcreteInheritedProperty(DescriptorProperty):
|
||||
"""A 'do nothing' :class:`.MapperProperty` that disables
|
||||
an attribute on a concrete subclass that is only present
|
||||
on the inherited mapper, not the concrete classes' mapper.
|
||||
|
||||
Cases where this occurs include:
|
||||
|
||||
* When the superclass mapper is mapped against a
|
||||
"polymorphic union", which includes all attributes from
|
||||
all subclasses.
|
||||
* When a relationship() is configured on an inherited mapper,
|
||||
but not on the subclass mapper. Concrete mappers require
|
||||
that relationship() is configured explicitly on each
|
||||
subclass.
|
||||
|
||||
"""
|
||||
|
||||
def _comparator_factory(self, mapper):
|
||||
comparator_callable = None
|
||||
|
||||
for m in self.parent.iterate_to_root():
|
||||
p = m._props[self.key]
|
||||
if not isinstance(p, ConcreteInheritedProperty):
|
||||
comparator_callable = p.comparator_factory
|
||||
break
|
||||
return comparator_callable
|
||||
|
||||
def __init__(self):
|
||||
super(ConcreteInheritedProperty, self).__init__()
|
||||
def warn():
|
||||
raise AttributeError("Concrete %s does not implement "
|
||||
"attribute %r at the instance level. Add "
|
||||
"this property explicitly to %s." %
|
||||
(self.parent, self.key, self.parent))
|
||||
|
||||
class NoninheritedConcreteProp(object):
|
||||
def __set__(s, obj, value):
|
||||
warn()
|
||||
|
||||
def __delete__(s, obj):
|
||||
warn()
|
||||
|
||||
def __get__(s, obj, owner):
|
||||
if obj is None:
|
||||
return self.descriptor
|
||||
warn()
|
||||
self.descriptor = NoninheritedConcreteProp()
|
||||
|
||||
|
||||
@util.langhelpers.dependency_for("sqlalchemy.orm.properties")
|
||||
class SynonymProperty(DescriptorProperty):
|
||||
|
||||
def __init__(self, name, map_column=None,
|
||||
descriptor=None, comparator_factory=None,
|
||||
doc=None, info=None):
|
||||
"""Denote an attribute name as a synonym to a mapped property,
|
||||
in that the attribute will mirror the value and expression behavior
|
||||
of another attribute.
|
||||
|
||||
:param name: the name of the existing mapped property. This
|
||||
can refer to the string name of any :class:`.MapperProperty`
|
||||
configured on the class, including column-bound attributes
|
||||
and relationships.
|
||||
|
||||
:param descriptor: a Python :term:`descriptor` that will be used
|
||||
as a getter (and potentially a setter) when this attribute is
|
||||
accessed at the instance level.
|
||||
|
||||
:param map_column: if ``True``, the :func:`.synonym` construct will
|
||||
locate the existing named :class:`.MapperProperty` based on the
|
||||
attribute name of this :func:`.synonym`, and assign it to a new
|
||||
attribute linked to the name of this :func:`.synonym`.
|
||||
That is, given a mapping like::
|
||||
|
||||
class MyClass(Base):
|
||||
__tablename__ = 'my_table'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
job_status = Column(String(50))
|
||||
|
||||
job_status = synonym("_job_status", map_column=True)
|
||||
|
||||
The above class ``MyClass`` will now have the ``job_status``
|
||||
:class:`.Column` object mapped to the attribute named
|
||||
``_job_status``, and the attribute named ``job_status`` will refer
|
||||
to the synonym itself. This feature is typically used in
|
||||
conjunction with the ``descriptor`` argument in order to link a
|
||||
user-defined descriptor as a "wrapper" for an existing column.
|
||||
|
||||
:param info: Optional data dictionary which will be populated into the
|
||||
:attr:`.InspectionAttr.info` attribute of this object.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
|
||||
:param comparator_factory: A subclass of :class:`.PropComparator`
|
||||
that will provide custom comparison behavior at the SQL expression
|
||||
level.
|
||||
|
||||
.. note::
|
||||
|
||||
For the use case of providing an attribute which redefines both
|
||||
Python-level and SQL-expression level behavior of an attribute,
|
||||
please refer to the Hybrid attribute introduced at
|
||||
:ref:`mapper_hybrids` for a more effective technique.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`synonyms` - examples of functionality.
|
||||
|
||||
:ref:`mapper_hybrids` - Hybrids provide a better approach for
|
||||
more complicated attribute-wrapping schemes than synonyms.
|
||||
|
||||
"""
|
||||
super(SynonymProperty, self).__init__()
|
||||
|
||||
self.name = name
|
||||
self.map_column = map_column
|
||||
self.descriptor = descriptor
|
||||
self.comparator_factory = comparator_factory
|
||||
self.doc = doc or (descriptor and descriptor.__doc__) or None
|
||||
if info:
|
||||
self.info = info
|
||||
|
||||
util.set_creation_order(self)
|
||||
|
||||
# TODO: when initialized, check _proxied_property,
|
||||
# emit a warning if its not a column-based property
|
||||
|
||||
@util.memoized_property
|
||||
def _proxied_property(self):
|
||||
return getattr(self.parent.class_, self.name).property
|
||||
|
||||
def _comparator_factory(self, mapper):
|
||||
prop = self._proxied_property
|
||||
|
||||
if self.comparator_factory:
|
||||
comp = self.comparator_factory(prop, mapper)
|
||||
else:
|
||||
comp = prop.comparator_factory(prop, mapper)
|
||||
return comp
|
||||
|
||||
def set_parent(self, parent, init):
|
||||
if self.map_column:
|
||||
# implement the 'map_column' option.
|
||||
if self.key not in parent.mapped_table.c:
|
||||
raise sa_exc.ArgumentError(
|
||||
"Can't compile synonym '%s': no column on table "
|
||||
"'%s' named '%s'"
|
||||
% (self.name, parent.mapped_table.description, self.key))
|
||||
elif parent.mapped_table.c[self.key] in \
|
||||
parent._columntoproperty and \
|
||||
parent._columntoproperty[
|
||||
parent.mapped_table.c[self.key]
|
||||
].key == self.name:
|
||||
raise sa_exc.ArgumentError(
|
||||
"Can't call map_column=True for synonym %r=%r, "
|
||||
"a ColumnProperty already exists keyed to the name "
|
||||
"%r for column %r" %
|
||||
(self.key, self.name, self.name, self.key)
|
||||
)
|
||||
p = properties.ColumnProperty(parent.mapped_table.c[self.key])
|
||||
parent._configure_property(
|
||||
self.name, p,
|
||||
init=init,
|
||||
setparent=True)
|
||||
p._mapped_by_synonym = self.key
|
||||
|
||||
self.parent = parent
|
||||
|
||||
|
||||
@util.langhelpers.dependency_for("sqlalchemy.orm.properties")
|
||||
class ComparableProperty(DescriptorProperty):
|
||||
"""Instruments a Python property for use in query expressions."""
|
||||
|
||||
def __init__(
|
||||
self, comparator_factory, descriptor=None, doc=None, info=None):
|
||||
"""Provides a method of applying a :class:`.PropComparator`
|
||||
to any Python descriptor attribute.
|
||||
|
||||
.. versionchanged:: 0.7
|
||||
:func:`.comparable_property` is superseded by
|
||||
the :mod:`~sqlalchemy.ext.hybrid` extension. See the example
|
||||
at :ref:`hybrid_custom_comparators`.
|
||||
|
||||
Allows any Python descriptor to behave like a SQL-enabled
|
||||
attribute when used at the class level in queries, allowing
|
||||
redefinition of expression operator behavior.
|
||||
|
||||
In the example below we redefine :meth:`.PropComparator.operate`
|
||||
to wrap both sides of an expression in ``func.lower()`` to produce
|
||||
case-insensitive comparison::
|
||||
|
||||
from sqlalchemy.orm import comparable_property
|
||||
from sqlalchemy.orm.interfaces import PropComparator
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy import Integer, String, Column
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
class CaseInsensitiveComparator(PropComparator):
|
||||
def __clause_element__(self):
|
||||
return self.prop
|
||||
|
||||
def operate(self, op, other):
|
||||
return op(
|
||||
func.lower(self.__clause_element__()),
|
||||
func.lower(other)
|
||||
)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class SearchWord(Base):
|
||||
__tablename__ = 'search_word'
|
||||
id = Column(Integer, primary_key=True)
|
||||
word = Column(String)
|
||||
word_insensitive = comparable_property(lambda prop, mapper:
|
||||
CaseInsensitiveComparator(
|
||||
mapper.c.word, mapper)
|
||||
)
|
||||
|
||||
|
||||
A mapping like the above allows the ``word_insensitive`` attribute
|
||||
to render an expression like::
|
||||
|
||||
>>> print SearchWord.word_insensitive == "Trucks"
|
||||
lower(search_word.word) = lower(:lower_1)
|
||||
|
||||
:param comparator_factory:
|
||||
A PropComparator subclass or factory that defines operator behavior
|
||||
for this property.
|
||||
|
||||
:param descriptor:
|
||||
Optional when used in a ``properties={}`` declaration. The Python
|
||||
descriptor or property to layer comparison behavior on top of.
|
||||
|
||||
The like-named descriptor will be automatically retrieved from the
|
||||
mapped class if left blank in a ``properties`` declaration.
|
||||
|
||||
:param info: Optional data dictionary which will be populated into the
|
||||
:attr:`.InspectionAttr.info` attribute of this object.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
|
||||
"""
|
||||
super(ComparableProperty, self).__init__()
|
||||
self.descriptor = descriptor
|
||||
self.comparator_factory = comparator_factory
|
||||
self.doc = doc or (descriptor and descriptor.__doc__) or None
|
||||
if info:
|
||||
self.info = info
|
||||
util.set_creation_order(self)
|
||||
|
||||
def _comparator_factory(self, mapper):
|
||||
return self.comparator_factory(self, mapper)
|
||||
2187
sqlalchemy/orm/events.py
Normal file
2187
sqlalchemy/orm/events.py
Normal file
File diff suppressed because it is too large
Load Diff
528
sqlalchemy/orm/instrumentation.py
Normal file
528
sqlalchemy/orm/instrumentation.py
Normal file
@@ -0,0 +1,528 @@
|
||||
# orm/instrumentation.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Defines SQLAlchemy's system of class instrumentation.
|
||||
|
||||
This module is usually not directly visible to user applications, but
|
||||
defines a large part of the ORM's interactivity.
|
||||
|
||||
instrumentation.py deals with registration of end-user classes
|
||||
for state tracking. It interacts closely with state.py
|
||||
and attributes.py which establish per-instance and per-class-attribute
|
||||
instrumentation, respectively.
|
||||
|
||||
The class instrumentation system can be customized on a per-class
|
||||
or global basis using the :mod:`sqlalchemy.ext.instrumentation`
|
||||
module, which provides the means to build and specify
|
||||
alternate instrumentation forms.
|
||||
|
||||
.. versionchanged: 0.8
|
||||
The instrumentation extension system was moved out of the
|
||||
ORM and into the external :mod:`sqlalchemy.ext.instrumentation`
|
||||
package. When that package is imported, it installs
|
||||
itself within sqlalchemy.orm so that its more comprehensive
|
||||
resolution mechanics take effect.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
from . import exc, collections, interfaces, state
|
||||
from .. import util
|
||||
from . import base
|
||||
|
||||
|
||||
_memoized_key_collection = util.group_expirable_memoized_property()
|
||||
|
||||
|
||||
class ClassManager(dict):
|
||||
"""tracks state information at the class level."""
|
||||
|
||||
MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR
|
||||
STATE_ATTR = base.DEFAULT_STATE_ATTR
|
||||
|
||||
_state_setter = staticmethod(util.attrsetter(STATE_ATTR))
|
||||
|
||||
deferred_scalar_loader = None
|
||||
|
||||
original_init = object.__init__
|
||||
|
||||
factory = None
|
||||
|
||||
def __init__(self, class_):
|
||||
self.class_ = class_
|
||||
self.info = {}
|
||||
self.new_init = None
|
||||
self.local_attrs = {}
|
||||
self.originals = {}
|
||||
|
||||
self._bases = [mgr for mgr in [
|
||||
manager_of_class(base)
|
||||
for base in self.class_.__bases__
|
||||
if isinstance(base, type)
|
||||
] if mgr is not None]
|
||||
|
||||
for base in self._bases:
|
||||
self.update(base)
|
||||
|
||||
self.dispatch._events._new_classmanager_instance(class_, self)
|
||||
# events._InstanceEventsHold.populate(class_, self)
|
||||
|
||||
for basecls in class_.__mro__:
|
||||
mgr = manager_of_class(basecls)
|
||||
if mgr is not None:
|
||||
self.dispatch._update(mgr.dispatch)
|
||||
self.manage()
|
||||
self._instrument_init()
|
||||
|
||||
if '__del__' in class_.__dict__:
|
||||
util.warn("__del__() method on class %s will "
|
||||
"cause unreachable cycles and memory leaks, "
|
||||
"as SQLAlchemy instrumentation often creates "
|
||||
"reference cycles. Please remove this method." %
|
||||
class_)
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return other is self
|
||||
|
||||
@property
|
||||
def is_mapped(self):
|
||||
return 'mapper' in self.__dict__
|
||||
|
||||
@_memoized_key_collection
|
||||
def _all_key_set(self):
|
||||
return frozenset(self)
|
||||
|
||||
@_memoized_key_collection
|
||||
def _collection_impl_keys(self):
|
||||
return frozenset([
|
||||
attr.key for attr in self.values() if attr.impl.collection])
|
||||
|
||||
@_memoized_key_collection
|
||||
def _scalar_loader_impls(self):
|
||||
return frozenset([
|
||||
attr.impl for attr in
|
||||
self.values() if attr.impl.accepts_scalar_loader])
|
||||
|
||||
@util.memoized_property
|
||||
def mapper(self):
|
||||
# raises unless self.mapper has been assigned
|
||||
raise exc.UnmappedClassError(self.class_)
|
||||
|
||||
def _all_sqla_attributes(self, exclude=None):
|
||||
"""return an iterator of all classbound attributes that are
|
||||
implement :class:`.InspectionAttr`.
|
||||
|
||||
This includes :class:`.QueryableAttribute` as well as extension
|
||||
types such as :class:`.hybrid_property` and
|
||||
:class:`.AssociationProxy`.
|
||||
|
||||
"""
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
for supercls in self.class_.__mro__:
|
||||
for key in set(supercls.__dict__).difference(exclude):
|
||||
exclude.add(key)
|
||||
val = supercls.__dict__[key]
|
||||
if isinstance(val, interfaces.InspectionAttr):
|
||||
yield key, val
|
||||
|
||||
def _attr_has_impl(self, key):
|
||||
"""Return True if the given attribute is fully initialized.
|
||||
|
||||
i.e. has an impl.
|
||||
"""
|
||||
|
||||
return key in self and self[key].impl is not None
|
||||
|
||||
def _subclass_manager(self, cls):
|
||||
"""Create a new ClassManager for a subclass of this ClassManager's
|
||||
class.
|
||||
|
||||
This is called automatically when attributes are instrumented so that
|
||||
the attributes can be propagated to subclasses against their own
|
||||
class-local manager, without the need for mappers etc. to have already
|
||||
pre-configured managers for the full class hierarchy. Mappers
|
||||
can post-configure the auto-generated ClassManager when needed.
|
||||
|
||||
"""
|
||||
manager = manager_of_class(cls)
|
||||
if manager is None:
|
||||
manager = _instrumentation_factory.create_manager_for_cls(cls)
|
||||
return manager
|
||||
|
||||
def _instrument_init(self):
|
||||
# TODO: self.class_.__init__ is often the already-instrumented
|
||||
# __init__ from an instrumented superclass. We still need to make
|
||||
# our own wrapper, but it would
|
||||
# be nice to wrap the original __init__ and not our existing wrapper
|
||||
# of such, since this adds method overhead.
|
||||
self.original_init = self.class_.__init__
|
||||
self.new_init = _generate_init(self.class_, self)
|
||||
self.install_member('__init__', self.new_init)
|
||||
|
||||
def _uninstrument_init(self):
|
||||
if self.new_init:
|
||||
self.uninstall_member('__init__')
|
||||
self.new_init = None
|
||||
|
||||
@util.memoized_property
|
||||
def _state_constructor(self):
|
||||
self.dispatch.first_init(self, self.class_)
|
||||
return state.InstanceState
|
||||
|
||||
def manage(self):
|
||||
"""Mark this instance as the manager for its class."""
|
||||
|
||||
setattr(self.class_, self.MANAGER_ATTR, self)
|
||||
|
||||
def dispose(self):
|
||||
"""Dissasociate this manager from its class."""
|
||||
|
||||
delattr(self.class_, self.MANAGER_ATTR)
|
||||
|
||||
@util.hybridmethod
|
||||
def manager_getter(self):
|
||||
return _default_manager_getter
|
||||
|
||||
@util.hybridmethod
|
||||
def state_getter(self):
|
||||
"""Return a (instance) -> InstanceState callable.
|
||||
|
||||
"state getter" callables should raise either KeyError or
|
||||
AttributeError if no InstanceState could be found for the
|
||||
instance.
|
||||
"""
|
||||
|
||||
return _default_state_getter
|
||||
|
||||
@util.hybridmethod
|
||||
def dict_getter(self):
|
||||
return _default_dict_getter
|
||||
|
||||
def instrument_attribute(self, key, inst, propagated=False):
|
||||
if propagated:
|
||||
if key in self.local_attrs:
|
||||
return # don't override local attr with inherited attr
|
||||
else:
|
||||
self.local_attrs[key] = inst
|
||||
self.install_descriptor(key, inst)
|
||||
_memoized_key_collection.expire_instance(self)
|
||||
self[key] = inst
|
||||
|
||||
for cls in self.class_.__subclasses__():
|
||||
manager = self._subclass_manager(cls)
|
||||
manager.instrument_attribute(key, inst, True)
|
||||
|
||||
def subclass_managers(self, recursive):
|
||||
for cls in self.class_.__subclasses__():
|
||||
mgr = manager_of_class(cls)
|
||||
if mgr is not None and mgr is not self:
|
||||
yield mgr
|
||||
if recursive:
|
||||
for m in mgr.subclass_managers(True):
|
||||
yield m
|
||||
|
||||
def post_configure_attribute(self, key):
|
||||
_instrumentation_factory.dispatch.\
|
||||
attribute_instrument(self.class_, key, self[key])
|
||||
|
||||
def uninstrument_attribute(self, key, propagated=False):
|
||||
if key not in self:
|
||||
return
|
||||
if propagated:
|
||||
if key in self.local_attrs:
|
||||
return # don't get rid of local attr
|
||||
else:
|
||||
del self.local_attrs[key]
|
||||
self.uninstall_descriptor(key)
|
||||
_memoized_key_collection.expire_instance(self)
|
||||
del self[key]
|
||||
for cls in self.class_.__subclasses__():
|
||||
manager = manager_of_class(cls)
|
||||
if manager:
|
||||
manager.uninstrument_attribute(key, True)
|
||||
|
||||
def unregister(self):
|
||||
"""remove all instrumentation established by this ClassManager."""
|
||||
|
||||
self._uninstrument_init()
|
||||
|
||||
self.mapper = self.dispatch = None
|
||||
self.info.clear()
|
||||
|
||||
for key in list(self):
|
||||
if key in self.local_attrs:
|
||||
self.uninstrument_attribute(key)
|
||||
|
||||
def install_descriptor(self, key, inst):
|
||||
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
|
||||
raise KeyError("%r: requested attribute name conflicts with "
|
||||
"instrumentation attribute of the same name." %
|
||||
key)
|
||||
setattr(self.class_, key, inst)
|
||||
|
||||
def uninstall_descriptor(self, key):
|
||||
delattr(self.class_, key)
|
||||
|
||||
def install_member(self, key, implementation):
|
||||
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
|
||||
raise KeyError("%r: requested attribute name conflicts with "
|
||||
"instrumentation attribute of the same name." %
|
||||
key)
|
||||
self.originals.setdefault(key, getattr(self.class_, key, None))
|
||||
setattr(self.class_, key, implementation)
|
||||
|
||||
def uninstall_member(self, key):
|
||||
original = self.originals.pop(key, None)
|
||||
if original is not None:
|
||||
setattr(self.class_, key, original)
|
||||
|
||||
def instrument_collection_class(self, key, collection_class):
|
||||
return collections.prepare_instrumentation(collection_class)
|
||||
|
||||
def initialize_collection(self, key, state, factory):
|
||||
user_data = factory()
|
||||
adapter = collections.CollectionAdapter(
|
||||
self.get_impl(key), state, user_data)
|
||||
return adapter, user_data
|
||||
|
||||
def is_instrumented(self, key, search=False):
|
||||
if search:
|
||||
return key in self
|
||||
else:
|
||||
return key in self.local_attrs
|
||||
|
||||
def get_impl(self, key):
|
||||
return self[key].impl
|
||||
|
||||
@property
|
||||
def attributes(self):
|
||||
return iter(self.values())
|
||||
|
||||
# InstanceState management
|
||||
|
||||
def new_instance(self, state=None):
|
||||
instance = self.class_.__new__(self.class_)
|
||||
if state is None:
|
||||
state = self._state_constructor(instance, self)
|
||||
self._state_setter(instance, state)
|
||||
return instance
|
||||
|
||||
def setup_instance(self, instance, state=None):
|
||||
if state is None:
|
||||
state = self._state_constructor(instance, self)
|
||||
self._state_setter(instance, state)
|
||||
|
||||
def teardown_instance(self, instance):
|
||||
delattr(instance, self.STATE_ATTR)
|
||||
|
||||
def _serialize(self, state, state_dict):
|
||||
return _SerializeManager(state, state_dict)
|
||||
|
||||
def _new_state_if_none(self, instance):
|
||||
"""Install a default InstanceState if none is present.
|
||||
|
||||
A private convenience method used by the __init__ decorator.
|
||||
|
||||
"""
|
||||
if hasattr(instance, self.STATE_ATTR):
|
||||
return False
|
||||
elif self.class_ is not instance.__class__ and \
|
||||
self.is_mapped:
|
||||
# this will create a new ClassManager for the
|
||||
# subclass, without a mapper. This is likely a
|
||||
# user error situation but allow the object
|
||||
# to be constructed, so that it is usable
|
||||
# in a non-ORM context at least.
|
||||
return self._subclass_manager(instance.__class__).\
|
||||
_new_state_if_none(instance)
|
||||
else:
|
||||
state = self._state_constructor(instance, self)
|
||||
self._state_setter(instance, state)
|
||||
return state
|
||||
|
||||
def has_state(self, instance):
|
||||
return hasattr(instance, self.STATE_ATTR)
|
||||
|
||||
def has_parent(self, state, key, optimistic=False):
|
||||
"""TODO"""
|
||||
return self.get_impl(key).hasparent(state, optimistic=optimistic)
|
||||
|
||||
def __bool__(self):
|
||||
"""All ClassManagers are non-zero regardless of attribute state."""
|
||||
return True
|
||||
|
||||
__nonzero__ = __bool__
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s of %r at %x>' % (
|
||||
self.__class__.__name__, self.class_, id(self))
|
||||
|
||||
|
||||
class _SerializeManager(object):
|
||||
"""Provide serialization of a :class:`.ClassManager`.
|
||||
|
||||
The :class:`.InstanceState` uses ``__init__()`` on serialize
|
||||
and ``__call__()`` on deserialize.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, state, d):
|
||||
self.class_ = state.class_
|
||||
manager = state.manager
|
||||
manager.dispatch.pickle(state, d)
|
||||
|
||||
def __call__(self, state, inst, state_dict):
|
||||
state.manager = manager = manager_of_class(self.class_)
|
||||
if manager is None:
|
||||
raise exc.UnmappedInstanceError(
|
||||
inst,
|
||||
"Cannot deserialize object of type %r - "
|
||||
"no mapper() has "
|
||||
"been configured for this class within the current "
|
||||
"Python process!" %
|
||||
self.class_)
|
||||
elif manager.is_mapped and not manager.mapper.configured:
|
||||
manager.mapper._configure_all()
|
||||
|
||||
# setup _sa_instance_state ahead of time so that
|
||||
# unpickle events can access the object normally.
|
||||
# see [ticket:2362]
|
||||
if inst is not None:
|
||||
manager.setup_instance(inst, state)
|
||||
manager.dispatch.unpickle(state, state_dict)
|
||||
|
||||
|
||||
class InstrumentationFactory(object):
|
||||
"""Factory for new ClassManager instances."""
|
||||
|
||||
def create_manager_for_cls(self, class_):
|
||||
assert class_ is not None
|
||||
assert manager_of_class(class_) is None
|
||||
|
||||
# give a more complicated subclass
|
||||
# a chance to do what it wants here
|
||||
manager, factory = self._locate_extended_factory(class_)
|
||||
|
||||
if factory is None:
|
||||
factory = ClassManager
|
||||
manager = factory(class_)
|
||||
|
||||
self._check_conflicts(class_, factory)
|
||||
|
||||
manager.factory = factory
|
||||
|
||||
self.dispatch.class_instrument(class_)
|
||||
return manager
|
||||
|
||||
def _locate_extended_factory(self, class_):
|
||||
"""Overridden by a subclass to do an extended lookup."""
|
||||
return None, None
|
||||
|
||||
def _check_conflicts(self, class_, factory):
|
||||
"""Overridden by a subclass to test for conflicting factories."""
|
||||
return
|
||||
|
||||
def unregister(self, class_):
|
||||
manager = manager_of_class(class_)
|
||||
manager.unregister()
|
||||
manager.dispose()
|
||||
self.dispatch.class_uninstrument(class_)
|
||||
if ClassManager.MANAGER_ATTR in class_.__dict__:
|
||||
delattr(class_, ClassManager.MANAGER_ATTR)
|
||||
|
||||
# this attribute is replaced by sqlalchemy.ext.instrumentation
|
||||
# when importred.
|
||||
_instrumentation_factory = InstrumentationFactory()
|
||||
|
||||
# these attributes are replaced by sqlalchemy.ext.instrumentation
|
||||
# when a non-standard InstrumentationManager class is first
|
||||
# used to instrument a class.
|
||||
instance_state = _default_state_getter = base.instance_state
|
||||
|
||||
instance_dict = _default_dict_getter = base.instance_dict
|
||||
|
||||
manager_of_class = _default_manager_getter = base.manager_of_class
|
||||
|
||||
|
||||
def register_class(class_):
|
||||
"""Register class instrumentation.
|
||||
|
||||
Returns the existing or newly created class manager.
|
||||
|
||||
"""
|
||||
|
||||
manager = manager_of_class(class_)
|
||||
if manager is None:
|
||||
manager = _instrumentation_factory.create_manager_for_cls(class_)
|
||||
return manager
|
||||
|
||||
|
||||
def unregister_class(class_):
|
||||
"""Unregister class instrumentation."""
|
||||
|
||||
_instrumentation_factory.unregister(class_)
|
||||
|
||||
|
||||
def is_instrumented(instance, key):
|
||||
"""Return True if the given attribute on the given instance is
|
||||
instrumented by the attributes package.
|
||||
|
||||
This function may be used regardless of instrumentation
|
||||
applied directly to the class, i.e. no descriptors are required.
|
||||
|
||||
"""
|
||||
return manager_of_class(instance.__class__).\
|
||||
is_instrumented(key, search=True)
|
||||
|
||||
|
||||
def _generate_init(class_, class_manager):
|
||||
"""Build an __init__ decorator that triggers ClassManager events."""
|
||||
|
||||
# TODO: we should use the ClassManager's notion of the
|
||||
# original '__init__' method, once ClassManager is fixed
|
||||
# to always reference that.
|
||||
original__init__ = class_.__init__
|
||||
assert original__init__
|
||||
|
||||
# Go through some effort here and don't change the user's __init__
|
||||
# calling signature, including the unlikely case that it has
|
||||
# a return value.
|
||||
# FIXME: need to juggle local names to avoid constructor argument
|
||||
# clashes.
|
||||
func_body = """\
|
||||
def __init__(%(apply_pos)s):
|
||||
new_state = class_manager._new_state_if_none(%(self_arg)s)
|
||||
if new_state:
|
||||
return new_state._initialize_instance(%(apply_kw)s)
|
||||
else:
|
||||
return original__init__(%(apply_kw)s)
|
||||
"""
|
||||
func_vars = util.format_argspec_init(original__init__, grouped=False)
|
||||
func_text = func_body % func_vars
|
||||
|
||||
if util.py2k:
|
||||
func = getattr(original__init__, 'im_func', original__init__)
|
||||
func_defaults = getattr(func, 'func_defaults', None)
|
||||
else:
|
||||
func_defaults = getattr(original__init__, '__defaults__', None)
|
||||
func_kw_defaults = getattr(original__init__, '__kwdefaults__', None)
|
||||
|
||||
env = locals().copy()
|
||||
exec(func_text, env)
|
||||
__init__ = env['__init__']
|
||||
__init__.__doc__ = original__init__.__doc__
|
||||
|
||||
if func_defaults:
|
||||
__init__.__defaults__ = func_defaults
|
||||
if not util.py2k and func_kw_defaults:
|
||||
__init__.__kwdefaults__ = func_kw_defaults
|
||||
|
||||
return __init__
|
||||
703
sqlalchemy/orm/loading.py
Normal file
703
sqlalchemy/orm/loading.py
Normal file
@@ -0,0 +1,703 @@
|
||||
# orm/loading.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""private module containing functions used to convert database
|
||||
rows into object instances and associated state.
|
||||
|
||||
the functions here are called primarily by Query, Mapper,
|
||||
as well as some of the attribute loading strategies.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .. import util
|
||||
from . import attributes, exc as orm_exc
|
||||
from ..sql import util as sql_util
|
||||
from . import strategy_options
|
||||
|
||||
from .util import _none_set, state_str
|
||||
from .base import _SET_DEFERRED_EXPIRED, _DEFER_FOR_STATE
|
||||
from .. import exc as sa_exc
|
||||
import collections
|
||||
|
||||
_new_runid = util.counter()
|
||||
|
||||
|
||||
def instances(query, cursor, context):
|
||||
"""Return an ORM result as an iterator."""
|
||||
|
||||
context.runid = _new_runid()
|
||||
|
||||
filtered = query._has_mapper_entities
|
||||
|
||||
single_entity = len(query._entities) == 1 and \
|
||||
query._entities[0].supports_single_entity
|
||||
|
||||
if filtered:
|
||||
if single_entity:
|
||||
filter_fn = id
|
||||
else:
|
||||
def filter_fn(row):
|
||||
return tuple(
|
||||
id(item)
|
||||
if ent.use_id_for_hash
|
||||
else item
|
||||
for ent, item in zip(query._entities, row)
|
||||
)
|
||||
|
||||
try:
|
||||
(process, labels) = \
|
||||
list(zip(*[
|
||||
query_entity.row_processor(query,
|
||||
context, cursor)
|
||||
for query_entity in query._entities
|
||||
]))
|
||||
|
||||
if not single_entity:
|
||||
keyed_tuple = util.lightweight_named_tuple('result', labels)
|
||||
|
||||
while True:
|
||||
context.partials = {}
|
||||
|
||||
if query._yield_per:
|
||||
fetch = cursor.fetchmany(query._yield_per)
|
||||
if not fetch:
|
||||
break
|
||||
else:
|
||||
fetch = cursor.fetchall()
|
||||
|
||||
if single_entity:
|
||||
proc = process[0]
|
||||
rows = [proc(row) for row in fetch]
|
||||
else:
|
||||
rows = [keyed_tuple([proc(row) for proc in process])
|
||||
for row in fetch]
|
||||
|
||||
if filtered:
|
||||
rows = util.unique_list(rows, filter_fn)
|
||||
|
||||
for row in rows:
|
||||
yield row
|
||||
|
||||
if not query._yield_per:
|
||||
break
|
||||
except Exception as err:
|
||||
cursor.close()
|
||||
util.raise_from_cause(err)
|
||||
|
||||
|
||||
@util.dependencies("sqlalchemy.orm.query")
|
||||
def merge_result(querylib, query, iterator, load=True):
|
||||
"""Merge a result into this :class:`.Query` object's Session."""
|
||||
|
||||
session = query.session
|
||||
if load:
|
||||
# flush current contents if we expect to load data
|
||||
session._autoflush()
|
||||
|
||||
autoflush = session.autoflush
|
||||
try:
|
||||
session.autoflush = False
|
||||
single_entity = len(query._entities) == 1
|
||||
if single_entity:
|
||||
if isinstance(query._entities[0], querylib._MapperEntity):
|
||||
result = [session._merge(
|
||||
attributes.instance_state(instance),
|
||||
attributes.instance_dict(instance),
|
||||
load=load, _recursive={}, _resolve_conflict_map={})
|
||||
for instance in iterator]
|
||||
else:
|
||||
result = list(iterator)
|
||||
else:
|
||||
mapped_entities = [i for i, e in enumerate(query._entities)
|
||||
if isinstance(e, querylib._MapperEntity)]
|
||||
result = []
|
||||
keys = [ent._label_name for ent in query._entities]
|
||||
keyed_tuple = util.lightweight_named_tuple('result', keys)
|
||||
for row in iterator:
|
||||
newrow = list(row)
|
||||
for i in mapped_entities:
|
||||
if newrow[i] is not None:
|
||||
newrow[i] = session._merge(
|
||||
attributes.instance_state(newrow[i]),
|
||||
attributes.instance_dict(newrow[i]),
|
||||
load=load, _recursive={}, _resolve_conflict_map={})
|
||||
result.append(keyed_tuple(newrow))
|
||||
|
||||
return iter(result)
|
||||
finally:
|
||||
session.autoflush = autoflush
|
||||
|
||||
|
||||
def get_from_identity(session, key, passive):
|
||||
"""Look up the given key in the given session's identity map,
|
||||
check the object for expired state if found.
|
||||
|
||||
"""
|
||||
instance = session.identity_map.get(key)
|
||||
if instance is not None:
|
||||
|
||||
state = attributes.instance_state(instance)
|
||||
|
||||
# expired - ensure it still exists
|
||||
if state.expired:
|
||||
if not passive & attributes.SQL_OK:
|
||||
# TODO: no coverage here
|
||||
return attributes.PASSIVE_NO_RESULT
|
||||
elif not passive & attributes.RELATED_OBJECT_OK:
|
||||
# this mode is used within a flush and the instance's
|
||||
# expired state will be checked soon enough, if necessary
|
||||
return instance
|
||||
try:
|
||||
state._load_expired(state, passive)
|
||||
except orm_exc.ObjectDeletedError:
|
||||
session._remove_newly_deleted([state])
|
||||
return None
|
||||
return instance
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def load_on_ident(query, key,
|
||||
refresh_state=None, lockmode=None,
|
||||
only_load_props=None):
|
||||
"""Load the given identity key from the database."""
|
||||
|
||||
if key is not None:
|
||||
ident = key[1]
|
||||
else:
|
||||
ident = None
|
||||
|
||||
if refresh_state is None:
|
||||
q = query._clone()
|
||||
q._get_condition()
|
||||
else:
|
||||
q = query._clone()
|
||||
|
||||
if ident is not None:
|
||||
mapper = query._mapper_zero()
|
||||
|
||||
(_get_clause, _get_params) = mapper._get_clause
|
||||
|
||||
# None present in ident - turn those comparisons
|
||||
# into "IS NULL"
|
||||
if None in ident:
|
||||
nones = set([
|
||||
_get_params[col].key for col, value in
|
||||
zip(mapper.primary_key, ident) if value is None
|
||||
])
|
||||
_get_clause = sql_util.adapt_criterion_to_null(
|
||||
_get_clause, nones)
|
||||
|
||||
_get_clause = q._adapt_clause(_get_clause, True, False)
|
||||
q._criterion = _get_clause
|
||||
|
||||
params = dict([
|
||||
(_get_params[primary_key].key, id_val)
|
||||
for id_val, primary_key in zip(ident, mapper.primary_key)
|
||||
])
|
||||
|
||||
q._params = params
|
||||
|
||||
if lockmode is not None:
|
||||
version_check = True
|
||||
q = q.with_lockmode(lockmode)
|
||||
elif query._for_update_arg is not None:
|
||||
version_check = True
|
||||
q._for_update_arg = query._for_update_arg
|
||||
else:
|
||||
version_check = False
|
||||
|
||||
q._get_options(
|
||||
populate_existing=bool(refresh_state),
|
||||
version_check=version_check,
|
||||
only_load_props=only_load_props,
|
||||
refresh_state=refresh_state)
|
||||
q._order_by = None
|
||||
|
||||
try:
|
||||
return q.one()
|
||||
except orm_exc.NoResultFound:
|
||||
return None
|
||||
|
||||
|
||||
def _setup_entity_query(
|
||||
context, mapper, query_entity,
|
||||
path, adapter, column_collection,
|
||||
with_polymorphic=None, only_load_props=None,
|
||||
polymorphic_discriminator=None, **kw):
|
||||
|
||||
if with_polymorphic:
|
||||
poly_properties = mapper._iterate_polymorphic_properties(
|
||||
with_polymorphic)
|
||||
else:
|
||||
poly_properties = mapper._polymorphic_properties
|
||||
|
||||
quick_populators = {}
|
||||
|
||||
path.set(
|
||||
context.attributes,
|
||||
"memoized_setups",
|
||||
quick_populators)
|
||||
|
||||
for value in poly_properties:
|
||||
if only_load_props and \
|
||||
value.key not in only_load_props:
|
||||
continue
|
||||
value.setup(
|
||||
context,
|
||||
query_entity,
|
||||
path,
|
||||
adapter,
|
||||
only_load_props=only_load_props,
|
||||
column_collection=column_collection,
|
||||
memoized_populators=quick_populators,
|
||||
**kw
|
||||
)
|
||||
|
||||
if polymorphic_discriminator is not None and \
|
||||
polymorphic_discriminator \
|
||||
is not mapper.polymorphic_on:
|
||||
|
||||
if adapter:
|
||||
pd = adapter.columns[polymorphic_discriminator]
|
||||
else:
|
||||
pd = polymorphic_discriminator
|
||||
column_collection.append(pd)
|
||||
|
||||
|
||||
def _instance_processor(
|
||||
mapper, context, result, path, adapter,
|
||||
only_load_props=None, refresh_state=None,
|
||||
polymorphic_discriminator=None,
|
||||
_polymorphic_from=None):
|
||||
"""Produce a mapper level row processor callable
|
||||
which processes rows into mapped instances."""
|
||||
|
||||
# note that this method, most of which exists in a closure
|
||||
# called _instance(), resists being broken out, as
|
||||
# attempts to do so tend to add significant function
|
||||
# call overhead. _instance() is the most
|
||||
# performance-critical section in the whole ORM.
|
||||
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if adapter:
|
||||
pk_cols = [adapter.columns[c] for c in pk_cols]
|
||||
|
||||
identity_class = mapper._identity_class
|
||||
|
||||
populators = collections.defaultdict(list)
|
||||
|
||||
props = mapper._prop_set
|
||||
if only_load_props is not None:
|
||||
props = props.intersection(
|
||||
mapper._props[k] for k in only_load_props)
|
||||
|
||||
quick_populators = path.get(
|
||||
context.attributes, "memoized_setups", _none_set)
|
||||
|
||||
for prop in props:
|
||||
if prop in quick_populators:
|
||||
# this is an inlined path just for column-based attributes.
|
||||
col = quick_populators[prop]
|
||||
if col is _DEFER_FOR_STATE:
|
||||
populators["new"].append(
|
||||
(prop.key, prop._deferred_column_loader))
|
||||
elif col is _SET_DEFERRED_EXPIRED:
|
||||
# note that in this path, we are no longer
|
||||
# searching in the result to see if the column might
|
||||
# be present in some unexpected way.
|
||||
populators["expire"].append((prop.key, False))
|
||||
else:
|
||||
if adapter:
|
||||
col = adapter.columns[col]
|
||||
getter = result._getter(col, False)
|
||||
if getter:
|
||||
populators["quick"].append((prop.key, getter))
|
||||
else:
|
||||
# fall back to the ColumnProperty itself, which
|
||||
# will iterate through all of its columns
|
||||
# to see if one fits
|
||||
prop.create_row_processor(
|
||||
context, path, mapper, result, adapter, populators)
|
||||
else:
|
||||
prop.create_row_processor(
|
||||
context, path, mapper, result, adapter, populators)
|
||||
|
||||
propagate_options = context.propagate_options
|
||||
load_path = context.query._current_path + path \
|
||||
if context.query._current_path.path else path
|
||||
|
||||
session_identity_map = context.session.identity_map
|
||||
|
||||
populate_existing = context.populate_existing or mapper.always_refresh
|
||||
load_evt = bool(mapper.class_manager.dispatch.load)
|
||||
refresh_evt = bool(mapper.class_manager.dispatch.refresh)
|
||||
persistent_evt = bool(context.session.dispatch.loaded_as_persistent)
|
||||
if persistent_evt:
|
||||
loaded_as_persistent = context.session.dispatch.loaded_as_persistent
|
||||
instance_state = attributes.instance_state
|
||||
instance_dict = attributes.instance_dict
|
||||
session_id = context.session.hash_key
|
||||
version_check = context.version_check
|
||||
runid = context.runid
|
||||
|
||||
if refresh_state:
|
||||
refresh_identity_key = refresh_state.key
|
||||
if refresh_identity_key is None:
|
||||
# super-rare condition; a refresh is being called
|
||||
# on a non-instance-key instance; this is meant to only
|
||||
# occur within a flush()
|
||||
refresh_identity_key = \
|
||||
mapper._identity_key_from_state(refresh_state)
|
||||
else:
|
||||
refresh_identity_key = None
|
||||
|
||||
if mapper.allow_partial_pks:
|
||||
is_not_primary_key = _none_set.issuperset
|
||||
else:
|
||||
is_not_primary_key = _none_set.intersection
|
||||
|
||||
def _instance(row):
|
||||
|
||||
# determine the state that we'll be populating
|
||||
if refresh_identity_key:
|
||||
# fixed state that we're refreshing
|
||||
state = refresh_state
|
||||
instance = state.obj()
|
||||
dict_ = instance_dict(instance)
|
||||
isnew = state.runid != runid
|
||||
currentload = True
|
||||
loaded_instance = False
|
||||
else:
|
||||
# look at the row, see if that identity is in the
|
||||
# session, or we have to create a new one
|
||||
identitykey = (
|
||||
identity_class,
|
||||
tuple([row[column] for column in pk_cols])
|
||||
)
|
||||
|
||||
instance = session_identity_map.get(identitykey)
|
||||
|
||||
if instance is not None:
|
||||
# existing instance
|
||||
state = instance_state(instance)
|
||||
dict_ = instance_dict(instance)
|
||||
|
||||
isnew = state.runid != runid
|
||||
currentload = not isnew
|
||||
loaded_instance = False
|
||||
|
||||
if version_check and not currentload:
|
||||
_validate_version_id(mapper, state, dict_, row, adapter)
|
||||
|
||||
else:
|
||||
# create a new instance
|
||||
|
||||
# check for non-NULL values in the primary key columns,
|
||||
# else no entity is returned for the row
|
||||
if is_not_primary_key(identitykey[1]):
|
||||
return None
|
||||
|
||||
isnew = True
|
||||
currentload = True
|
||||
loaded_instance = True
|
||||
|
||||
instance = mapper.class_manager.new_instance()
|
||||
|
||||
dict_ = instance_dict(instance)
|
||||
state = instance_state(instance)
|
||||
state.key = identitykey
|
||||
|
||||
# attach instance to session.
|
||||
state.session_id = session_id
|
||||
session_identity_map._add_unpresent(state, identitykey)
|
||||
|
||||
# populate. this looks at whether this state is new
|
||||
# for this load or was existing, and whether or not this
|
||||
# row is the first row with this identity.
|
||||
if currentload or populate_existing:
|
||||
# full population routines. Objects here are either
|
||||
# just created, or we are doing a populate_existing
|
||||
|
||||
# be conservative about setting load_path when populate_existing
|
||||
# is in effect; want to maintain options from the original
|
||||
# load. see test_expire->test_refresh_maintains_deferred_options
|
||||
if isnew and (propagate_options or not populate_existing):
|
||||
state.load_options = propagate_options
|
||||
state.load_path = load_path
|
||||
|
||||
_populate_full(
|
||||
context, row, state, dict_, isnew, load_path,
|
||||
loaded_instance, populate_existing, populators)
|
||||
|
||||
if isnew:
|
||||
if loaded_instance:
|
||||
if load_evt:
|
||||
state.manager.dispatch.load(state, context)
|
||||
if persistent_evt:
|
||||
loaded_as_persistent(context.session, state.obj())
|
||||
elif refresh_evt:
|
||||
state.manager.dispatch.refresh(
|
||||
state, context, only_load_props)
|
||||
|
||||
if populate_existing or state.modified:
|
||||
if refresh_state and only_load_props:
|
||||
state._commit(dict_, only_load_props)
|
||||
else:
|
||||
state._commit_all(dict_, session_identity_map)
|
||||
|
||||
else:
|
||||
# partial population routines, for objects that were already
|
||||
# in the Session, but a row matches them; apply eager loaders
|
||||
# on existing objects, etc.
|
||||
unloaded = state.unloaded
|
||||
isnew = state not in context.partials
|
||||
|
||||
if not isnew or unloaded or populators["eager"]:
|
||||
# state is having a partial set of its attributes
|
||||
# refreshed. Populate those attributes,
|
||||
# and add to the "context.partials" collection.
|
||||
|
||||
to_load = _populate_partial(
|
||||
context, row, state, dict_, isnew, load_path,
|
||||
unloaded, populators)
|
||||
|
||||
if isnew:
|
||||
if refresh_evt:
|
||||
state.manager.dispatch.refresh(
|
||||
state, context, to_load)
|
||||
|
||||
state._commit(dict_, to_load)
|
||||
|
||||
return instance
|
||||
|
||||
if mapper.polymorphic_map and not _polymorphic_from and not refresh_state:
|
||||
# if we are doing polymorphic, dispatch to a different _instance()
|
||||
# method specific to the subclass mapper
|
||||
_instance = _decorate_polymorphic_switch(
|
||||
_instance, context, mapper, result, path,
|
||||
polymorphic_discriminator, adapter)
|
||||
|
||||
return _instance
|
||||
|
||||
|
||||
def _populate_full(
|
||||
context, row, state, dict_, isnew, load_path,
|
||||
loaded_instance, populate_existing, populators):
|
||||
if isnew:
|
||||
# first time we are seeing a row with this identity.
|
||||
state.runid = context.runid
|
||||
|
||||
for key, getter in populators["quick"]:
|
||||
dict_[key] = getter(row)
|
||||
if populate_existing:
|
||||
for key, set_callable in populators["expire"]:
|
||||
dict_.pop(key, None)
|
||||
if set_callable:
|
||||
state.expired_attributes.add(key)
|
||||
else:
|
||||
for key, set_callable in populators["expire"]:
|
||||
if set_callable:
|
||||
state.expired_attributes.add(key)
|
||||
for key, populator in populators["new"]:
|
||||
populator(state, dict_, row)
|
||||
for key, populator in populators["delayed"]:
|
||||
populator(state, dict_, row)
|
||||
elif load_path != state.load_path:
|
||||
# new load path, e.g. object is present in more than one
|
||||
# column position in a series of rows
|
||||
state.load_path = load_path
|
||||
|
||||
# if we have data, and the data isn't in the dict, OK, let's put
|
||||
# it in.
|
||||
for key, getter in populators["quick"]:
|
||||
if key not in dict_:
|
||||
dict_[key] = getter(row)
|
||||
|
||||
# otherwise treat like an "already seen" row
|
||||
for key, populator in populators["existing"]:
|
||||
populator(state, dict_, row)
|
||||
# TODO: allow "existing" populator to know this is
|
||||
# a new path for the state:
|
||||
# populator(state, dict_, row, new_path=True)
|
||||
|
||||
else:
|
||||
# have already seen rows with this identity in this same path.
|
||||
for key, populator in populators["existing"]:
|
||||
populator(state, dict_, row)
|
||||
|
||||
# TODO: same path
|
||||
# populator(state, dict_, row, new_path=False)
|
||||
|
||||
|
||||
def _populate_partial(
|
||||
context, row, state, dict_, isnew, load_path,
|
||||
unloaded, populators):
|
||||
|
||||
if not isnew:
|
||||
to_load = context.partials[state]
|
||||
for key, populator in populators["existing"]:
|
||||
if key in to_load:
|
||||
populator(state, dict_, row)
|
||||
else:
|
||||
to_load = unloaded
|
||||
context.partials[state] = to_load
|
||||
|
||||
for key, getter in populators["quick"]:
|
||||
if key in to_load:
|
||||
dict_[key] = getter(row)
|
||||
for key, set_callable in populators["expire"]:
|
||||
if key in to_load:
|
||||
dict_.pop(key, None)
|
||||
if set_callable:
|
||||
state.expired_attributes.add(key)
|
||||
for key, populator in populators["new"]:
|
||||
if key in to_load:
|
||||
populator(state, dict_, row)
|
||||
for key, populator in populators["delayed"]:
|
||||
if key in to_load:
|
||||
populator(state, dict_, row)
|
||||
for key, populator in populators["eager"]:
|
||||
if key not in unloaded:
|
||||
populator(state, dict_, row)
|
||||
|
||||
return to_load
|
||||
|
||||
|
||||
def _validate_version_id(mapper, state, dict_, row, adapter):
|
||||
|
||||
version_id_col = mapper.version_id_col
|
||||
|
||||
if version_id_col is None:
|
||||
return
|
||||
|
||||
if adapter:
|
||||
version_id_col = adapter.columns[version_id_col]
|
||||
|
||||
if mapper._get_state_attr_by_column(
|
||||
state, dict_, mapper.version_id_col) != row[version_id_col]:
|
||||
raise orm_exc.StaleDataError(
|
||||
"Instance '%s' has version id '%s' which "
|
||||
"does not match database-loaded version id '%s'."
|
||||
% (state_str(state), mapper._get_state_attr_by_column(
|
||||
state, dict_, mapper.version_id_col),
|
||||
row[version_id_col]))
|
||||
|
||||
|
||||
def _decorate_polymorphic_switch(
|
||||
instance_fn, context, mapper, result, path,
|
||||
polymorphic_discriminator, adapter):
|
||||
if polymorphic_discriminator is not None:
|
||||
polymorphic_on = polymorphic_discriminator
|
||||
else:
|
||||
polymorphic_on = mapper.polymorphic_on
|
||||
if polymorphic_on is None:
|
||||
return instance_fn
|
||||
|
||||
if adapter:
|
||||
polymorphic_on = adapter.columns[polymorphic_on]
|
||||
|
||||
def configure_subclass_mapper(discriminator):
|
||||
try:
|
||||
sub_mapper = mapper.polymorphic_map[discriminator]
|
||||
except KeyError:
|
||||
raise AssertionError(
|
||||
"No such polymorphic_identity %r is defined" %
|
||||
discriminator)
|
||||
else:
|
||||
if sub_mapper is mapper:
|
||||
return None
|
||||
|
||||
return _instance_processor(
|
||||
sub_mapper, context, result,
|
||||
path, adapter, _polymorphic_from=mapper)
|
||||
|
||||
polymorphic_instances = util.PopulateDict(
|
||||
configure_subclass_mapper
|
||||
)
|
||||
|
||||
def polymorphic_instance(row):
|
||||
discriminator = row[polymorphic_on]
|
||||
if discriminator is not None:
|
||||
_instance = polymorphic_instances[discriminator]
|
||||
if _instance:
|
||||
return _instance(row)
|
||||
return instance_fn(row)
|
||||
return polymorphic_instance
|
||||
|
||||
|
||||
def load_scalar_attributes(mapper, state, attribute_names):
|
||||
"""initiate a column-based attribute refresh operation."""
|
||||
|
||||
# assert mapper is _state_mapper(state)
|
||||
session = state.session
|
||||
if not session:
|
||||
raise orm_exc.DetachedInstanceError(
|
||||
"Instance %s is not bound to a Session; "
|
||||
"attribute refresh operation cannot proceed" %
|
||||
(state_str(state)))
|
||||
|
||||
has_key = bool(state.key)
|
||||
|
||||
result = False
|
||||
|
||||
if mapper.inherits and not mapper.concrete:
|
||||
# because we are using Core to produce a select() that we
|
||||
# pass to the Query, we aren't calling setup() for mapped
|
||||
# attributes; in 1.0 this means deferred attrs won't get loaded
|
||||
# by default
|
||||
statement = mapper._optimized_get_statement(state, attribute_names)
|
||||
if statement is not None:
|
||||
result = load_on_ident(
|
||||
session.query(mapper).
|
||||
options(
|
||||
strategy_options.Load(mapper).undefer("*")
|
||||
).from_statement(statement),
|
||||
None,
|
||||
only_load_props=attribute_names,
|
||||
refresh_state=state
|
||||
)
|
||||
|
||||
if result is False:
|
||||
if has_key:
|
||||
identity_key = state.key
|
||||
else:
|
||||
# this codepath is rare - only valid when inside a flush, and the
|
||||
# object is becoming persistent but hasn't yet been assigned
|
||||
# an identity_key.
|
||||
# check here to ensure we have the attrs we need.
|
||||
pk_attrs = [mapper._columntoproperty[col].key
|
||||
for col in mapper.primary_key]
|
||||
if state.expired_attributes.intersection(pk_attrs):
|
||||
raise sa_exc.InvalidRequestError(
|
||||
"Instance %s cannot be refreshed - it's not "
|
||||
" persistent and does not "
|
||||
"contain a full primary key." % state_str(state))
|
||||
identity_key = mapper._identity_key_from_state(state)
|
||||
|
||||
if (_none_set.issubset(identity_key) and
|
||||
not mapper.allow_partial_pks) or \
|
||||
_none_set.issuperset(identity_key):
|
||||
util.warn_limited(
|
||||
"Instance %s to be refreshed doesn't "
|
||||
"contain a full primary key - can't be refreshed "
|
||||
"(and shouldn't be expired, either).",
|
||||
state_str(state))
|
||||
return
|
||||
|
||||
result = load_on_ident(
|
||||
session.query(mapper),
|
||||
identity_key,
|
||||
refresh_state=state,
|
||||
only_load_props=attribute_names)
|
||||
|
||||
# if instance is pending, a refresh operation
|
||||
# may not complete (even if PK attributes are assigned)
|
||||
if has_key and result is None:
|
||||
raise orm_exc.ObjectDeletedError(state)
|
||||
271
sqlalchemy/orm/path_registry.py
Normal file
271
sqlalchemy/orm/path_registry.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# orm/path_registry.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""Path tracking utilities, representing mapper graph traversals.
|
||||
|
||||
"""
|
||||
|
||||
from .. import inspection
|
||||
from .. import util
|
||||
from .. import exc
|
||||
from itertools import chain
|
||||
from .base import class_mapper
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _unreduce_path(path):
|
||||
return PathRegistry.deserialize(path)
|
||||
|
||||
|
||||
_WILDCARD_TOKEN = "*"
|
||||
_DEFAULT_TOKEN = "_sa_default"
|
||||
|
||||
|
||||
class PathRegistry(object):
|
||||
"""Represent query load paths and registry functions.
|
||||
|
||||
Basically represents structures like:
|
||||
|
||||
(<User mapper>, "orders", <Order mapper>, "items", <Item mapper>)
|
||||
|
||||
These structures are generated by things like
|
||||
query options (joinedload(), subqueryload(), etc.) and are
|
||||
used to compose keys stored in the query._attributes dictionary
|
||||
for various options.
|
||||
|
||||
They are then re-composed at query compile/result row time as
|
||||
the query is formed and as rows are fetched, where they again
|
||||
serve to compose keys to look up options in the context.attributes
|
||||
dictionary, which is copied from query._attributes.
|
||||
|
||||
The path structure has a limited amount of caching, where each
|
||||
"root" ultimately pulls from a fixed registry associated with
|
||||
the first mapper, that also contains elements for each of its
|
||||
property keys. However paths longer than two elements, which
|
||||
are the exception rather than the rule, are generated on an
|
||||
as-needed basis.
|
||||
|
||||
"""
|
||||
|
||||
is_token = False
|
||||
is_root = False
|
||||
|
||||
def __eq__(self, other):
|
||||
return other is not None and \
|
||||
self.path == other.path
|
||||
|
||||
def set(self, attributes, key, value):
|
||||
log.debug("set '%s' on path '%s' to '%s'", key, self, value)
|
||||
attributes[(key, self.path)] = value
|
||||
|
||||
def setdefault(self, attributes, key, value):
|
||||
log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value)
|
||||
attributes.setdefault((key, self.path), value)
|
||||
|
||||
def get(self, attributes, key, value=None):
|
||||
key = (key, self.path)
|
||||
if key in attributes:
|
||||
return attributes[key]
|
||||
else:
|
||||
return value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.path)
|
||||
|
||||
@property
|
||||
def length(self):
|
||||
return len(self.path)
|
||||
|
||||
def pairs(self):
|
||||
path = self.path
|
||||
for i in range(0, len(path), 2):
|
||||
yield path[i], path[i + 1]
|
||||
|
||||
def contains_mapper(self, mapper):
|
||||
for path_mapper in [
|
||||
self.path[i] for i in range(0, len(self.path), 2)
|
||||
]:
|
||||
if path_mapper.is_mapper and \
|
||||
path_mapper.isa(mapper):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def contains(self, attributes, key):
|
||||
return (key, self.path) in attributes
|
||||
|
||||
def __reduce__(self):
|
||||
return _unreduce_path, (self.serialize(), )
|
||||
|
||||
def serialize(self):
|
||||
path = self.path
|
||||
return list(zip(
|
||||
[m.class_ for m in [path[i] for i in range(0, len(path), 2)]],
|
||||
[path[i].key for i in range(1, len(path), 2)] + [None]
|
||||
))
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, path):
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
p = tuple(chain(*[(class_mapper(mcls),
|
||||
class_mapper(mcls).attrs[key]
|
||||
if key is not None else None)
|
||||
for mcls, key in path]))
|
||||
if p and p[-1] is None:
|
||||
p = p[0:-1]
|
||||
return cls.coerce(p)
|
||||
|
||||
@classmethod
|
||||
def per_mapper(cls, mapper):
|
||||
return EntityRegistry(
|
||||
cls.root, mapper
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, raw):
|
||||
return util.reduce(lambda prev, next: prev[next], raw, cls.root)
|
||||
|
||||
def token(self, token):
|
||||
if token.endswith(':' + _WILDCARD_TOKEN):
|
||||
return TokenRegistry(self, token)
|
||||
elif token.endswith(":" + _DEFAULT_TOKEN):
|
||||
return TokenRegistry(self.root, token)
|
||||
else:
|
||||
raise exc.ArgumentError("invalid token: %s" % token)
|
||||
|
||||
def __add__(self, other):
|
||||
return util.reduce(
|
||||
lambda prev, next: prev[next],
|
||||
other.path, self)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%r)" % (self.__class__.__name__, self.path, )
|
||||
|
||||
|
||||
class RootRegistry(PathRegistry):
|
||||
"""Root registry, defers to mappers so that
|
||||
paths are maintained per-root-mapper.
|
||||
|
||||
"""
|
||||
path = ()
|
||||
has_entity = False
|
||||
is_aliased_class = False
|
||||
is_root = True
|
||||
|
||||
def __getitem__(self, entity):
|
||||
return entity._path_registry
|
||||
|
||||
PathRegistry.root = RootRegistry()
|
||||
|
||||
|
||||
class TokenRegistry(PathRegistry):
|
||||
def __init__(self, parent, token):
|
||||
self.token = token
|
||||
self.parent = parent
|
||||
self.path = parent.path + (token,)
|
||||
|
||||
has_entity = False
|
||||
|
||||
is_token = True
|
||||
|
||||
def generate_for_superclasses(self):
|
||||
if not self.parent.is_aliased_class and not self.parent.is_root:
|
||||
for ent in self.parent.mapper.iterate_to_root():
|
||||
yield TokenRegistry(self.parent.parent[ent], self.token)
|
||||
else:
|
||||
yield self
|
||||
|
||||
def __getitem__(self, entity):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PropRegistry(PathRegistry):
|
||||
def __init__(self, parent, prop):
|
||||
# restate this path in terms of the
|
||||
# given MapperProperty's parent.
|
||||
insp = inspection.inspect(parent[-1])
|
||||
if not insp.is_aliased_class or insp._use_mapper_path:
|
||||
parent = parent.parent[prop.parent]
|
||||
elif insp.is_aliased_class and insp.with_polymorphic_mappers:
|
||||
if prop.parent is not insp.mapper and \
|
||||
prop.parent in insp.with_polymorphic_mappers:
|
||||
subclass_entity = parent[-1]._entity_for_mapper(prop.parent)
|
||||
parent = parent.parent[subclass_entity]
|
||||
|
||||
self.prop = prop
|
||||
self.parent = parent
|
||||
self.path = parent.path + (prop,)
|
||||
|
||||
self._wildcard_path_loader_key = (
|
||||
"loader",
|
||||
self.parent.path + self.prop._wildcard_token
|
||||
)
|
||||
self._default_path_loader_key = self.prop._default_path_loader_key
|
||||
self._loader_key = ("loader", self.path)
|
||||
|
||||
def __str__(self):
|
||||
return " -> ".join(
|
||||
str(elem) for elem in self.path
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def has_entity(self):
|
||||
return hasattr(self.prop, "mapper")
|
||||
|
||||
@util.memoized_property
|
||||
def entity(self):
|
||||
return self.prop.mapper
|
||||
|
||||
@property
|
||||
def mapper(self):
|
||||
return self.entity
|
||||
|
||||
@property
|
||||
def entity_path(self):
|
||||
return self[self.entity]
|
||||
|
||||
def __getitem__(self, entity):
|
||||
if isinstance(entity, (int, slice)):
|
||||
return self.path[entity]
|
||||
else:
|
||||
return EntityRegistry(
|
||||
self, entity
|
||||
)
|
||||
|
||||
|
||||
class EntityRegistry(PathRegistry, dict):
|
||||
is_aliased_class = False
|
||||
has_entity = True
|
||||
|
||||
def __init__(self, parent, entity):
|
||||
self.key = entity
|
||||
self.parent = parent
|
||||
self.is_aliased_class = entity.is_aliased_class
|
||||
self.entity = entity
|
||||
self.path = parent.path + (entity,)
|
||||
self.entity_path = self
|
||||
|
||||
@property
|
||||
def mapper(self):
|
||||
return inspection.inspect(self.entity).mapper
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
__nonzero__ = __bool__
|
||||
|
||||
def __getitem__(self, entity):
|
||||
if isinstance(entity, (int, slice)):
|
||||
return self.path[entity]
|
||||
else:
|
||||
return dict.__getitem__(self, entity)
|
||||
|
||||
def __missing__(self, key):
|
||||
self[key] = item = PropRegistry(self, key)
|
||||
return item
|
||||
1460
sqlalchemy/orm/persistence.py
Normal file
1460
sqlalchemy/orm/persistence.py
Normal file
File diff suppressed because it is too large
Load Diff
2875
sqlalchemy/orm/relationships.py
Normal file
2875
sqlalchemy/orm/relationships.py
Normal file
File diff suppressed because it is too large
Load Diff
1106
sqlalchemy/orm/strategy_options.py
Normal file
1106
sqlalchemy/orm/strategy_options.py
Normal file
File diff suppressed because it is too large
Load Diff
203
sqlalchemy/sql/annotation.py
Normal file
203
sqlalchemy/sql/annotation.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# sql/annotation.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""The :class:`.Annotated` class and related routines; creates hash-equivalent
|
||||
copies of SQL constructs which contain context-specific markers and
|
||||
associations.
|
||||
|
||||
"""
|
||||
|
||||
from .. import util
|
||||
from . import operators
|
||||
|
||||
|
||||
class Annotated(object):
|
||||
"""clones a ClauseElement and applies an 'annotations' dictionary.
|
||||
|
||||
Unlike regular clones, this clone also mimics __hash__() and
|
||||
__cmp__() of the original element so that it takes its place
|
||||
in hashed collections.
|
||||
|
||||
A reference to the original element is maintained, for the important
|
||||
reason of keeping its hash value current. When GC'ed, the
|
||||
hash value may be reused, causing conflicts.
|
||||
|
||||
.. note:: The rationale for Annotated producing a brand new class,
|
||||
rather than placing the functionality directly within ClauseElement,
|
||||
is **performance**. The __hash__() method is absent on plain
|
||||
ClauseElement which leads to significantly reduced function call
|
||||
overhead, as the use of sets and dictionaries against ClauseElement
|
||||
objects is prevalent, but most are not "annotated".
|
||||
|
||||
"""
|
||||
|
||||
def __new__(cls, *args):
|
||||
if not args:
|
||||
# clone constructor
|
||||
return object.__new__(cls)
|
||||
else:
|
||||
element, values = args
|
||||
# pull appropriate subclass from registry of annotated
|
||||
# classes
|
||||
try:
|
||||
cls = annotated_classes[element.__class__]
|
||||
except KeyError:
|
||||
cls = _new_annotation_type(element.__class__, cls)
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(self, element, values):
|
||||
self.__dict__ = element.__dict__.copy()
|
||||
self.__element = element
|
||||
self._annotations = values
|
||||
self._hash = hash(element)
|
||||
|
||||
def _annotate(self, values):
|
||||
_values = self._annotations.copy()
|
||||
_values.update(values)
|
||||
return self._with_annotations(_values)
|
||||
|
||||
def _with_annotations(self, values):
|
||||
clone = self.__class__.__new__(self.__class__)
|
||||
clone.__dict__ = self.__dict__.copy()
|
||||
clone._annotations = values
|
||||
return clone
|
||||
|
||||
def _deannotate(self, values=None, clone=True):
|
||||
if values is None:
|
||||
return self.__element
|
||||
else:
|
||||
_values = self._annotations.copy()
|
||||
for v in values:
|
||||
_values.pop(v, None)
|
||||
return self._with_annotations(_values)
|
||||
|
||||
def _compiler_dispatch(self, visitor, **kw):
|
||||
return self.__element.__class__._compiler_dispatch(
|
||||
self, visitor, **kw)
|
||||
|
||||
@property
|
||||
def _constructor(self):
|
||||
return self.__element._constructor
|
||||
|
||||
def _clone(self):
|
||||
clone = self.__element._clone()
|
||||
if clone is self.__element:
|
||||
# detect immutable, don't change anything
|
||||
return self
|
||||
else:
|
||||
# update the clone with any changes that have occurred
|
||||
# to this object's __dict__.
|
||||
clone.__dict__.update(self.__dict__)
|
||||
return self.__class__(clone, self._annotations)
|
||||
|
||||
def __hash__(self):
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(self.__element, operators.ColumnOperators):
|
||||
return self.__element.__class__.__eq__(self, other)
|
||||
else:
|
||||
return hash(other) == hash(self)
|
||||
|
||||
|
||||
# hard-generate Annotated subclasses. this technique
|
||||
# is used instead of on-the-fly types (i.e. type.__new__())
|
||||
# so that the resulting objects are pickleable.
|
||||
annotated_classes = {}
|
||||
|
||||
|
||||
def _deep_annotate(element, annotations, exclude=None):
|
||||
"""Deep copy the given ClauseElement, annotating each element
|
||||
with the given annotations dictionary.
|
||||
|
||||
Elements within the exclude collection will be cloned but not annotated.
|
||||
|
||||
"""
|
||||
def clone(elem):
|
||||
if exclude and \
|
||||
hasattr(elem, 'proxy_set') and \
|
||||
elem.proxy_set.intersection(exclude):
|
||||
newelem = elem._clone()
|
||||
elif annotations != elem._annotations:
|
||||
newelem = elem._annotate(annotations)
|
||||
else:
|
||||
newelem = elem
|
||||
newelem._copy_internals(clone=clone)
|
||||
return newelem
|
||||
|
||||
if element is not None:
|
||||
element = clone(element)
|
||||
return element
|
||||
|
||||
|
||||
def _deep_deannotate(element, values=None):
|
||||
"""Deep copy the given element, removing annotations."""
|
||||
|
||||
cloned = util.column_dict()
|
||||
|
||||
def clone(elem):
|
||||
# if a values dict is given,
|
||||
# the elem must be cloned each time it appears,
|
||||
# as there may be different annotations in source
|
||||
# elements that are remaining. if totally
|
||||
# removing all annotations, can assume the same
|
||||
# slate...
|
||||
if values or elem not in cloned:
|
||||
newelem = elem._deannotate(values=values, clone=True)
|
||||
newelem._copy_internals(clone=clone)
|
||||
if not values:
|
||||
cloned[elem] = newelem
|
||||
return newelem
|
||||
else:
|
||||
return cloned[elem]
|
||||
|
||||
if element is not None:
|
||||
element = clone(element)
|
||||
return element
|
||||
|
||||
|
||||
def _shallow_annotate(element, annotations):
|
||||
"""Annotate the given ClauseElement and copy its internals so that
|
||||
internal objects refer to the new annotated object.
|
||||
|
||||
Basically used to apply a "dont traverse" annotation to a
|
||||
selectable, without digging throughout the whole
|
||||
structure wasting time.
|
||||
"""
|
||||
element = element._annotate(annotations)
|
||||
element._copy_internals()
|
||||
return element
|
||||
|
||||
|
||||
def _new_annotation_type(cls, base_cls):
|
||||
if issubclass(cls, Annotated):
|
||||
return cls
|
||||
elif cls in annotated_classes:
|
||||
return annotated_classes[cls]
|
||||
|
||||
for super_ in cls.__mro__:
|
||||
# check if an Annotated subclass more specific than
|
||||
# the given base_cls is already registered, such
|
||||
# as AnnotatedColumnElement.
|
||||
if super_ in annotated_classes:
|
||||
base_cls = annotated_classes[super_]
|
||||
break
|
||||
|
||||
annotated_classes[cls] = anno_cls = type(
|
||||
"Annotated%s" % cls.__name__,
|
||||
(base_cls, cls), {})
|
||||
globals()["Annotated%s" % cls.__name__] = anno_cls
|
||||
return anno_cls
|
||||
|
||||
|
||||
def _prepare_annotations(target_hierarchy, base_cls):
|
||||
stack = [target_hierarchy]
|
||||
while stack:
|
||||
cls = stack.pop()
|
||||
stack.extend(cls.__subclasses__())
|
||||
|
||||
_new_annotation_type(cls, base_cls)
|
||||
633
sqlalchemy/sql/base.py
Normal file
633
sqlalchemy/sql/base.py
Normal file
@@ -0,0 +1,633 @@
|
||||
# sql/base.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Foundational utilities common to many sql modules.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
from .. import util, exc
|
||||
import itertools
|
||||
from .visitors import ClauseVisitor
|
||||
import re
|
||||
import collections
|
||||
|
||||
PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT')
|
||||
NO_ARG = util.symbol('NO_ARG')
|
||||
|
||||
|
||||
class Immutable(object):
|
||||
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
|
||||
|
||||
def unique_params(self, *optionaldict, **kwargs):
|
||||
raise NotImplementedError("Immutable objects do not support copying")
|
||||
|
||||
def params(self, *optionaldict, **kwargs):
|
||||
raise NotImplementedError("Immutable objects do not support copying")
|
||||
|
||||
def _clone(self):
|
||||
return self
|
||||
|
||||
|
||||
def _from_objects(*elements):
|
||||
return itertools.chain(*[element._from_objects for element in elements])
|
||||
|
||||
|
||||
@util.decorator
|
||||
def _generative(fn, *args, **kw):
|
||||
"""Mark a method as generative."""
|
||||
|
||||
self = args[0]._generate()
|
||||
fn(self, *args[1:], **kw)
|
||||
return self
|
||||
|
||||
|
||||
class _DialectArgView(collections.MutableMapping):
|
||||
"""A dictionary view of dialect-level arguments in the form
|
||||
<dialectname>_<argument_name>.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
|
||||
def _key(self, key):
|
||||
try:
|
||||
dialect, value_key = key.split("_", 1)
|
||||
except ValueError:
|
||||
raise KeyError(key)
|
||||
else:
|
||||
return dialect, value_key
|
||||
|
||||
def __getitem__(self, key):
|
||||
dialect, value_key = self._key(key)
|
||||
|
||||
try:
|
||||
opt = self.obj.dialect_options[dialect]
|
||||
except exc.NoSuchModuleError:
|
||||
raise KeyError(key)
|
||||
else:
|
||||
return opt[value_key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
try:
|
||||
dialect, value_key = self._key(key)
|
||||
except KeyError:
|
||||
raise exc.ArgumentError(
|
||||
"Keys must be of the form <dialectname>_<argname>")
|
||||
else:
|
||||
self.obj.dialect_options[dialect][value_key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
dialect, value_key = self._key(key)
|
||||
del self.obj.dialect_options[dialect][value_key]
|
||||
|
||||
def __len__(self):
|
||||
return sum(len(args._non_defaults) for args in
|
||||
self.obj.dialect_options.values())
|
||||
|
||||
def __iter__(self):
|
||||
return (
|
||||
util.safe_kwarg("%s_%s" % (dialect_name, value_name))
|
||||
for dialect_name in self.obj.dialect_options
|
||||
for value_name in
|
||||
self.obj.dialect_options[dialect_name]._non_defaults
|
||||
)
|
||||
|
||||
|
||||
class _DialectArgDict(collections.MutableMapping):
|
||||
"""A dictionary view of dialect-level arguments for a specific
|
||||
dialect.
|
||||
|
||||
Maintains a separate collection of user-specified arguments
|
||||
and dialect-specified default arguments.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._non_defaults = {}
|
||||
self._defaults = {}
|
||||
|
||||
def __len__(self):
|
||||
return len(set(self._non_defaults).union(self._defaults))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(set(self._non_defaults).union(self._defaults))
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self._non_defaults:
|
||||
return self._non_defaults[key]
|
||||
else:
|
||||
return self._defaults[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._non_defaults[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self._non_defaults[key]
|
||||
|
||||
|
||||
class DialectKWArgs(object):
|
||||
"""Establish the ability for a class to have dialect-specific arguments
|
||||
with defaults and constructor validation.
|
||||
|
||||
The :class:`.DialectKWArgs` interacts with the
|
||||
:attr:`.DefaultDialect.construct_arguments` present on a dialect.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:attr:`.DefaultDialect.construct_arguments`
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def argument_for(cls, dialect_name, argument_name, default):
|
||||
"""Add a new kind of dialect-specific keyword argument for this class.
|
||||
|
||||
E.g.::
|
||||
|
||||
Index.argument_for("mydialect", "length", None)
|
||||
|
||||
some_index = Index('a', 'b', mydialect_length=5)
|
||||
|
||||
The :meth:`.DialectKWArgs.argument_for` method is a per-argument
|
||||
way adding extra arguments to the
|
||||
:attr:`.DefaultDialect.construct_arguments` dictionary. This
|
||||
dictionary provides a list of argument names accepted by various
|
||||
schema-level constructs on behalf of a dialect.
|
||||
|
||||
New dialects should typically specify this dictionary all at once as a
|
||||
data member of the dialect class. The use case for ad-hoc addition of
|
||||
argument names is typically for end-user code that is also using
|
||||
a custom compilation scheme which consumes the additional arguments.
|
||||
|
||||
:param dialect_name: name of a dialect. The dialect must be
|
||||
locatable, else a :class:`.NoSuchModuleError` is raised. The
|
||||
dialect must also include an existing
|
||||
:attr:`.DefaultDialect.construct_arguments` collection, indicating
|
||||
that it participates in the keyword-argument validation and default
|
||||
system, else :class:`.ArgumentError` is raised. If the dialect does
|
||||
not include this collection, then any keyword argument can be
|
||||
specified on behalf of this dialect already. All dialects packaged
|
||||
within SQLAlchemy include this collection, however for third party
|
||||
dialects, support may vary.
|
||||
|
||||
:param argument_name: name of the parameter.
|
||||
|
||||
:param default: default value of the parameter.
|
||||
|
||||
.. versionadded:: 0.9.4
|
||||
|
||||
"""
|
||||
|
||||
construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
|
||||
if construct_arg_dictionary is None:
|
||||
raise exc.ArgumentError(
|
||||
"Dialect '%s' does have keyword-argument "
|
||||
"validation and defaults enabled configured" %
|
||||
dialect_name)
|
||||
if cls not in construct_arg_dictionary:
|
||||
construct_arg_dictionary[cls] = {}
|
||||
construct_arg_dictionary[cls][argument_name] = default
|
||||
|
||||
@util.memoized_property
|
||||
def dialect_kwargs(self):
|
||||
"""A collection of keyword arguments specified as dialect-specific
|
||||
options to this construct.
|
||||
|
||||
The arguments are present here in their original ``<dialect>_<kwarg>``
|
||||
format. Only arguments that were actually passed are included;
|
||||
unlike the :attr:`.DialectKWArgs.dialect_options` collection, which
|
||||
contains all options known by this dialect including defaults.
|
||||
|
||||
The collection is also writable; keys are accepted of the
|
||||
form ``<dialect>_<kwarg>`` where the value will be assembled
|
||||
into the list of options.
|
||||
|
||||
.. versionadded:: 0.9.2
|
||||
|
||||
.. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs`
|
||||
collection is now writable.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:attr:`.DialectKWArgs.dialect_options` - nested dictionary form
|
||||
|
||||
"""
|
||||
return _DialectArgView(self)
|
||||
|
||||
@property
|
||||
def kwargs(self):
|
||||
"""A synonym for :attr:`.DialectKWArgs.dialect_kwargs`."""
|
||||
return self.dialect_kwargs
|
||||
|
||||
@util.dependencies("sqlalchemy.dialects")
|
||||
def _kw_reg_for_dialect(dialects, dialect_name):
|
||||
dialect_cls = dialects.registry.load(dialect_name)
|
||||
if dialect_cls.construct_arguments is None:
|
||||
return None
|
||||
return dict(dialect_cls.construct_arguments)
|
||||
_kw_registry = util.PopulateDict(_kw_reg_for_dialect)
|
||||
|
||||
def _kw_reg_for_dialect_cls(self, dialect_name):
|
||||
construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
|
||||
d = _DialectArgDict()
|
||||
|
||||
if construct_arg_dictionary is None:
|
||||
d._defaults.update({"*": None})
|
||||
else:
|
||||
for cls in reversed(self.__class__.__mro__):
|
||||
if cls in construct_arg_dictionary:
|
||||
d._defaults.update(construct_arg_dictionary[cls])
|
||||
return d
|
||||
|
||||
@util.memoized_property
|
||||
def dialect_options(self):
|
||||
"""A collection of keyword arguments specified as dialect-specific
|
||||
options to this construct.
|
||||
|
||||
This is a two-level nested registry, keyed to ``<dialect_name>``
|
||||
and ``<argument_name>``. For example, the ``postgresql_where``
|
||||
argument would be locatable as::
|
||||
|
||||
arg = my_object.dialect_options['postgresql']['where']
|
||||
|
||||
.. versionadded:: 0.9.2
|
||||
|
||||
.. seealso::
|
||||
|
||||
:attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form
|
||||
|
||||
"""
|
||||
|
||||
return util.PopulateDict(
|
||||
util.portable_instancemethod(self._kw_reg_for_dialect_cls)
|
||||
)
|
||||
|
||||
def _validate_dialect_kwargs(self, kwargs):
|
||||
# validate remaining kwargs that they all specify DB prefixes
|
||||
|
||||
if not kwargs:
|
||||
return
|
||||
|
||||
for k in kwargs:
|
||||
m = re.match('^(.+?)_(.+)$', k)
|
||||
if not m:
|
||||
raise TypeError(
|
||||
"Additional arguments should be "
|
||||
"named <dialectname>_<argument>, got '%s'" % k)
|
||||
dialect_name, arg_name = m.group(1, 2)
|
||||
|
||||
try:
|
||||
construct_arg_dictionary = self.dialect_options[dialect_name]
|
||||
except exc.NoSuchModuleError:
|
||||
util.warn(
|
||||
"Can't validate argument %r; can't "
|
||||
"locate any SQLAlchemy dialect named %r" %
|
||||
(k, dialect_name))
|
||||
self.dialect_options[dialect_name] = d = _DialectArgDict()
|
||||
d._defaults.update({"*": None})
|
||||
d._non_defaults[arg_name] = kwargs[k]
|
||||
else:
|
||||
if "*" not in construct_arg_dictionary and \
|
||||
arg_name not in construct_arg_dictionary:
|
||||
raise exc.ArgumentError(
|
||||
"Argument %r is not accepted by "
|
||||
"dialect %r on behalf of %r" % (
|
||||
k,
|
||||
dialect_name, self.__class__
|
||||
))
|
||||
else:
|
||||
construct_arg_dictionary[arg_name] = kwargs[k]
|
||||
|
||||
|
||||
class Generative(object):
|
||||
"""Allow a ClauseElement to generate itself via the
|
||||
@_generative decorator.
|
||||
|
||||
"""
|
||||
|
||||
def _generate(self):
|
||||
s = self.__class__.__new__(self.__class__)
|
||||
s.__dict__ = self.__dict__.copy()
|
||||
return s
|
||||
|
||||
|
||||
class Executable(Generative):
|
||||
"""Mark a ClauseElement as supporting execution.
|
||||
|
||||
:class:`.Executable` is a superclass for all "statement" types
|
||||
of objects, including :func:`select`, :func:`delete`, :func:`update`,
|
||||
:func:`insert`, :func:`text`.
|
||||
|
||||
"""
|
||||
|
||||
supports_execution = True
|
||||
_execution_options = util.immutabledict()
|
||||
_bind = None
|
||||
|
||||
@_generative
|
||||
def execution_options(self, **kw):
|
||||
""" Set non-SQL options for the statement which take effect during
|
||||
execution.
|
||||
|
||||
Execution options can be set on a per-statement or
|
||||
per :class:`.Connection` basis. Additionally, the
|
||||
:class:`.Engine` and ORM :class:`~.orm.query.Query` objects provide
|
||||
access to execution options which they in turn configure upon
|
||||
connections.
|
||||
|
||||
The :meth:`execution_options` method is generative. A new
|
||||
instance of this statement is returned that contains the options::
|
||||
|
||||
statement = select([table.c.x, table.c.y])
|
||||
statement = statement.execution_options(autocommit=True)
|
||||
|
||||
Note that only a subset of possible execution options can be applied
|
||||
to a statement - these include "autocommit" and "stream_results",
|
||||
but not "isolation_level" or "compiled_cache".
|
||||
See :meth:`.Connection.execution_options` for a full list of
|
||||
possible options.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`.Connection.execution_options()`
|
||||
|
||||
:meth:`.Query.execution_options()`
|
||||
|
||||
"""
|
||||
if 'isolation_level' in kw:
|
||||
raise exc.ArgumentError(
|
||||
"'isolation_level' execution option may only be specified "
|
||||
"on Connection.execution_options(), or "
|
||||
"per-engine using the isolation_level "
|
||||
"argument to create_engine()."
|
||||
)
|
||||
if 'compiled_cache' in kw:
|
||||
raise exc.ArgumentError(
|
||||
"'compiled_cache' execution option may only be specified "
|
||||
"on Connection.execution_options(), not per statement."
|
||||
)
|
||||
self._execution_options = self._execution_options.union(kw)
|
||||
|
||||
def execute(self, *multiparams, **params):
|
||||
"""Compile and execute this :class:`.Executable`."""
|
||||
e = self.bind
|
||||
if e is None:
|
||||
label = getattr(self, 'description', self.__class__.__name__)
|
||||
msg = ('This %s is not directly bound to a Connection or Engine.'
|
||||
'Use the .execute() method of a Connection or Engine '
|
||||
'to execute this construct.' % label)
|
||||
raise exc.UnboundExecutionError(msg)
|
||||
return e._execute_clauseelement(self, multiparams, params)
|
||||
|
||||
def scalar(self, *multiparams, **params):
|
||||
"""Compile and execute this :class:`.Executable`, returning the
|
||||
result's scalar representation.
|
||||
|
||||
"""
|
||||
return self.execute(*multiparams, **params).scalar()
|
||||
|
||||
@property
|
||||
def bind(self):
|
||||
"""Returns the :class:`.Engine` or :class:`.Connection` to
|
||||
which this :class:`.Executable` is bound, or None if none found.
|
||||
|
||||
This is a traversal which checks locally, then
|
||||
checks among the "from" clauses of associated objects
|
||||
until a bound engine or connection is found.
|
||||
|
||||
"""
|
||||
if self._bind is not None:
|
||||
return self._bind
|
||||
|
||||
for f in _from_objects(self):
|
||||
if f is self:
|
||||
continue
|
||||
engine = f.bind
|
||||
if engine is not None:
|
||||
return engine
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class SchemaEventTarget(object):
|
||||
"""Base class for elements that are the targets of :class:`.DDLEvents`
|
||||
events.
|
||||
|
||||
This includes :class:`.SchemaItem` as well as :class:`.SchemaType`.
|
||||
|
||||
"""
|
||||
|
||||
def _set_parent(self, parent):
|
||||
"""Associate with this SchemaEvent's parent object."""
|
||||
|
||||
def _set_parent_with_dispatch(self, parent):
|
||||
self.dispatch.before_parent_attach(self, parent)
|
||||
self._set_parent(parent)
|
||||
self.dispatch.after_parent_attach(self, parent)
|
||||
|
||||
|
||||
class SchemaVisitor(ClauseVisitor):
|
||||
"""Define the visiting for ``SchemaItem`` objects."""
|
||||
|
||||
__traverse_options__ = {'schema_visitor': True}
|
||||
|
||||
|
||||
class ColumnCollection(util.OrderedProperties):
|
||||
"""An ordered dictionary that stores a list of ColumnElement
|
||||
instances.
|
||||
|
||||
Overrides the ``__eq__()`` method to produce SQL clauses between
|
||||
sets of correlated columns.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = '_all_columns'
|
||||
|
||||
def __init__(self, *columns):
|
||||
super(ColumnCollection, self).__init__()
|
||||
object.__setattr__(self, '_all_columns', [])
|
||||
for c in columns:
|
||||
self.add(c)
|
||||
|
||||
def __str__(self):
|
||||
return repr([str(c) for c in self])
|
||||
|
||||
def replace(self, column):
|
||||
"""add the given column to this collection, removing unaliased
|
||||
versions of this column as well as existing columns with the
|
||||
same key.
|
||||
|
||||
e.g.::
|
||||
|
||||
t = Table('sometable', metadata, Column('col1', Integer))
|
||||
t.columns.replace(Column('col1', Integer, key='columnone'))
|
||||
|
||||
will remove the original 'col1' from the collection, and add
|
||||
the new column under the name 'columnname'.
|
||||
|
||||
Used by schema.Column to override columns during table reflection.
|
||||
|
||||
"""
|
||||
remove_col = None
|
||||
if column.name in self and column.key != column.name:
|
||||
other = self[column.name]
|
||||
if other.name == other.key:
|
||||
remove_col = other
|
||||
del self._data[other.key]
|
||||
|
||||
if column.key in self._data:
|
||||
remove_col = self._data[column.key]
|
||||
|
||||
self._data[column.key] = column
|
||||
if remove_col is not None:
|
||||
self._all_columns[:] = [column if c is remove_col
|
||||
else c for c in self._all_columns]
|
||||
else:
|
||||
self._all_columns.append(column)
|
||||
|
||||
def add(self, column):
|
||||
"""Add a column to this collection.
|
||||
|
||||
The key attribute of the column will be used as the hash key
|
||||
for this dictionary.
|
||||
|
||||
"""
|
||||
if not column.key:
|
||||
raise exc.ArgumentError(
|
||||
"Can't add unnamed column to column collection")
|
||||
self[column.key] = column
|
||||
|
||||
def __delitem__(self, key):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __setattr__(self, key, object):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key in self:
|
||||
|
||||
# this warning is primarily to catch select() statements
|
||||
# which have conflicting column names in their exported
|
||||
# columns collection
|
||||
|
||||
existing = self[key]
|
||||
if not existing.shares_lineage(value):
|
||||
util.warn('Column %r on table %r being replaced by '
|
||||
'%r, which has the same key. Consider '
|
||||
'use_labels for select() statements.' %
|
||||
(key, getattr(existing, 'table', None), value))
|
||||
|
||||
# pop out memoized proxy_set as this
|
||||
# operation may very well be occurring
|
||||
# in a _make_proxy operation
|
||||
util.memoized_property.reset(value, "proxy_set")
|
||||
|
||||
self._all_columns.append(value)
|
||||
self._data[key] = value
|
||||
|
||||
def clear(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def remove(self, column):
|
||||
del self._data[column.key]
|
||||
self._all_columns[:] = [
|
||||
c for c in self._all_columns if c is not column]
|
||||
|
||||
def update(self, iter):
|
||||
cols = list(iter)
|
||||
all_col_set = set(self._all_columns)
|
||||
self._all_columns.extend(
|
||||
c for label, c in cols if c not in all_col_set)
|
||||
self._data.update((label, c) for label, c in cols)
|
||||
|
||||
def extend(self, iter):
|
||||
cols = list(iter)
|
||||
all_col_set = set(self._all_columns)
|
||||
self._all_columns.extend(c for c in cols if c not in all_col_set)
|
||||
self._data.update((c.key, c) for c in cols)
|
||||
|
||||
__hash__ = None
|
||||
|
||||
@util.dependencies("sqlalchemy.sql.elements")
|
||||
def __eq__(self, elements, other):
|
||||
l = []
|
||||
for c in getattr(other, "_all_columns", other):
|
||||
for local in self._all_columns:
|
||||
if c.shares_lineage(local):
|
||||
l.append(c == local)
|
||||
return elements.and_(*l)
|
||||
|
||||
def __contains__(self, other):
|
||||
if not isinstance(other, util.string_types):
|
||||
raise exc.ArgumentError("__contains__ requires a string argument")
|
||||
return util.OrderedProperties.__contains__(self, other)
|
||||
|
||||
def __getstate__(self):
|
||||
return {'_data': self._data,
|
||||
'_all_columns': self._all_columns}
|
||||
|
||||
def __setstate__(self, state):
|
||||
object.__setattr__(self, '_data', state['_data'])
|
||||
object.__setattr__(self, '_all_columns', state['_all_columns'])
|
||||
|
||||
def contains_column(self, col):
|
||||
return col in set(self._all_columns)
|
||||
|
||||
def as_immutable(self):
|
||||
return ImmutableColumnCollection(self._data, self._all_columns)
|
||||
|
||||
|
||||
class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
|
||||
def __init__(self, data, all_columns):
|
||||
util.ImmutableProperties.__init__(self, data)
|
||||
object.__setattr__(self, '_all_columns', all_columns)
|
||||
|
||||
extend = remove = util.ImmutableProperties._immutable
|
||||
|
||||
|
||||
class ColumnSet(util.ordered_column_set):
|
||||
def contains_column(self, col):
|
||||
return col in self
|
||||
|
||||
def extend(self, cols):
|
||||
for col in cols:
|
||||
self.add(col)
|
||||
|
||||
def __add__(self, other):
|
||||
return list(self) + list(other)
|
||||
|
||||
@util.dependencies("sqlalchemy.sql.elements")
|
||||
def __eq__(self, elements, other):
|
||||
l = []
|
||||
for c in other:
|
||||
for local in self:
|
||||
if c.shares_lineage(local):
|
||||
l.append(c == local)
|
||||
return elements.and_(*l)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(x for x in self))
|
||||
|
||||
|
||||
def _bind_or_error(schemaitem, msg=None):
|
||||
bind = schemaitem.bind
|
||||
if not bind:
|
||||
name = schemaitem.__class__.__name__
|
||||
label = getattr(schemaitem, 'fullname',
|
||||
getattr(schemaitem, 'name', None))
|
||||
if label:
|
||||
item = '%s object %r' % (name, label)
|
||||
else:
|
||||
item = '%s object' % name
|
||||
if msg is None:
|
||||
msg = "%s is not bound to an Engine or Connection. "\
|
||||
"Execution can not proceed without a database to execute "\
|
||||
"against." % item
|
||||
raise exc.UnboundExecutionError(msg)
|
||||
return bind
|
||||
692
sqlalchemy/sql/crud.py
Normal file
692
sqlalchemy/sql/crud.py
Normal file
@@ -0,0 +1,692 @@
|
||||
# sql/crud.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Functions used by compiler.py to determine the parameters rendered
|
||||
within INSERT and UPDATE statements.
|
||||
|
||||
"""
|
||||
from .. import util
|
||||
from .. import exc
|
||||
from . import dml
|
||||
from . import elements
|
||||
import operator
|
||||
|
||||
REQUIRED = util.symbol('REQUIRED', """
|
||||
Placeholder for the value within a :class:`.BindParameter`
|
||||
which is required to be present when the statement is passed
|
||||
to :meth:`.Connection.execute`.
|
||||
|
||||
This symbol is typically used when a :func:`.expression.insert`
|
||||
or :func:`.expression.update` statement is compiled without parameter
|
||||
values present.
|
||||
|
||||
""")
|
||||
|
||||
ISINSERT = util.symbol('ISINSERT')
|
||||
ISUPDATE = util.symbol('ISUPDATE')
|
||||
ISDELETE = util.symbol('ISDELETE')
|
||||
|
||||
|
||||
def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
|
||||
restore_isinsert = compiler.isinsert
|
||||
restore_isupdate = compiler.isupdate
|
||||
restore_isdelete = compiler.isdelete
|
||||
|
||||
should_restore = (
|
||||
restore_isinsert or restore_isupdate or restore_isdelete
|
||||
) or len(compiler.stack) > 1
|
||||
|
||||
if local_stmt_type is ISINSERT:
|
||||
compiler.isupdate = False
|
||||
compiler.isinsert = True
|
||||
elif local_stmt_type is ISUPDATE:
|
||||
compiler.isupdate = True
|
||||
compiler.isinsert = False
|
||||
elif local_stmt_type is ISDELETE:
|
||||
if not should_restore:
|
||||
compiler.isdelete = True
|
||||
else:
|
||||
assert False, "ISINSERT, ISUPDATE, or ISDELETE expected"
|
||||
|
||||
try:
|
||||
if local_stmt_type in (ISINSERT, ISUPDATE):
|
||||
return _get_crud_params(compiler, stmt, **kw)
|
||||
finally:
|
||||
if should_restore:
|
||||
compiler.isinsert = restore_isinsert
|
||||
compiler.isupdate = restore_isupdate
|
||||
compiler.isdelete = restore_isdelete
|
||||
|
||||
|
||||
def _get_crud_params(compiler, stmt, **kw):
|
||||
"""create a set of tuples representing column/string pairs for use
|
||||
in an INSERT or UPDATE statement.
|
||||
|
||||
Also generates the Compiled object's postfetch, prefetch, and
|
||||
returning column collections, used for default handling and ultimately
|
||||
populating the ResultProxy's prefetch_cols() and postfetch_cols()
|
||||
collections.
|
||||
|
||||
"""
|
||||
|
||||
compiler.postfetch = []
|
||||
compiler.insert_prefetch = []
|
||||
compiler.update_prefetch = []
|
||||
compiler.returning = []
|
||||
|
||||
# no parameters in the statement, no parameters in the
|
||||
# compiled params - return binds for all columns
|
||||
if compiler.column_keys is None and stmt.parameters is None:
|
||||
return [
|
||||
(c, _create_bind_param(
|
||||
compiler, c, None, required=True))
|
||||
for c in stmt.table.columns
|
||||
]
|
||||
|
||||
if stmt._has_multi_parameters:
|
||||
stmt_parameters = stmt.parameters[0]
|
||||
else:
|
||||
stmt_parameters = stmt.parameters
|
||||
|
||||
# getters - these are normally just column.key,
|
||||
# but in the case of mysql multi-table update, the rules for
|
||||
# .key must conditionally take tablename into account
|
||||
_column_as_key, _getattr_col_key, _col_bind_name = \
|
||||
_key_getters_for_crud_column(compiler, stmt)
|
||||
|
||||
# if we have statement parameters - set defaults in the
|
||||
# compiled params
|
||||
if compiler.column_keys is None:
|
||||
parameters = {}
|
||||
else:
|
||||
parameters = dict((_column_as_key(key), REQUIRED)
|
||||
for key in compiler.column_keys
|
||||
if not stmt_parameters or
|
||||
key not in stmt_parameters)
|
||||
|
||||
# create a list of column assignment clauses as tuples
|
||||
values = []
|
||||
|
||||
if stmt_parameters is not None:
|
||||
_get_stmt_parameters_params(
|
||||
compiler,
|
||||
parameters, stmt_parameters, _column_as_key, values, kw)
|
||||
|
||||
check_columns = {}
|
||||
|
||||
# special logic that only occurs for multi-table UPDATE
|
||||
# statements
|
||||
if compiler.isupdate and stmt._extra_froms and stmt_parameters:
|
||||
_get_multitable_params(
|
||||
compiler, stmt, stmt_parameters, check_columns,
|
||||
_col_bind_name, _getattr_col_key, values, kw)
|
||||
|
||||
if compiler.isinsert and stmt.select_names:
|
||||
_scan_insert_from_select_cols(
|
||||
compiler, stmt, parameters,
|
||||
_getattr_col_key, _column_as_key,
|
||||
_col_bind_name, check_columns, values, kw)
|
||||
else:
|
||||
_scan_cols(
|
||||
compiler, stmt, parameters,
|
||||
_getattr_col_key, _column_as_key,
|
||||
_col_bind_name, check_columns, values, kw)
|
||||
|
||||
if parameters and stmt_parameters:
|
||||
check = set(parameters).intersection(
|
||||
_column_as_key(k) for k in stmt_parameters
|
||||
).difference(check_columns)
|
||||
if check:
|
||||
raise exc.CompileError(
|
||||
"Unconsumed column names: %s" %
|
||||
(", ".join("%s" % c for c in check))
|
||||
)
|
||||
|
||||
if stmt._has_multi_parameters:
|
||||
values = _extend_values_for_multiparams(compiler, stmt, values, kw)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def _create_bind_param(
|
||||
compiler, col, value, process=True,
|
||||
required=False, name=None, **kw):
|
||||
if name is None:
|
||||
name = col.key
|
||||
bindparam = elements.BindParameter(
|
||||
name, value, type_=col.type, required=required)
|
||||
bindparam._is_crud = True
|
||||
if process:
|
||||
bindparam = bindparam._compiler_dispatch(compiler, **kw)
|
||||
return bindparam
|
||||
|
||||
|
||||
def _key_getters_for_crud_column(compiler, stmt):
|
||||
if compiler.isupdate and stmt._extra_froms:
|
||||
# when extra tables are present, refer to the columns
|
||||
# in those extra tables as table-qualified, including in
|
||||
# dictionaries and when rendering bind param names.
|
||||
# the "main" table of the statement remains unqualified,
|
||||
# allowing the most compatibility with a non-multi-table
|
||||
# statement.
|
||||
_et = set(stmt._extra_froms)
|
||||
|
||||
def _column_as_key(key):
|
||||
str_key = elements._column_as_key(key)
|
||||
if hasattr(key, 'table') and key.table in _et:
|
||||
return (key.table.name, str_key)
|
||||
else:
|
||||
return str_key
|
||||
|
||||
def _getattr_col_key(col):
|
||||
if col.table in _et:
|
||||
return (col.table.name, col.key)
|
||||
else:
|
||||
return col.key
|
||||
|
||||
def _col_bind_name(col):
|
||||
if col.table in _et:
|
||||
return "%s_%s" % (col.table.name, col.key)
|
||||
else:
|
||||
return col.key
|
||||
|
||||
else:
|
||||
_column_as_key = elements._column_as_key
|
||||
_getattr_col_key = _col_bind_name = operator.attrgetter("key")
|
||||
|
||||
return _column_as_key, _getattr_col_key, _col_bind_name
|
||||
|
||||
|
||||
def _scan_insert_from_select_cols(
|
||||
compiler, stmt, parameters, _getattr_col_key,
|
||||
_column_as_key, _col_bind_name, check_columns, values, kw):
|
||||
|
||||
need_pks, implicit_returning, \
|
||||
implicit_return_defaults, postfetch_lastrowid = \
|
||||
_get_returning_modifiers(compiler, stmt)
|
||||
|
||||
cols = [stmt.table.c[_column_as_key(name)]
|
||||
for name in stmt.select_names]
|
||||
|
||||
compiler._insert_from_select = stmt.select
|
||||
|
||||
add_select_cols = []
|
||||
if stmt.include_insert_from_select_defaults:
|
||||
col_set = set(cols)
|
||||
for col in stmt.table.columns:
|
||||
if col not in col_set and col.default:
|
||||
cols.append(col)
|
||||
|
||||
for c in cols:
|
||||
col_key = _getattr_col_key(c)
|
||||
if col_key in parameters and col_key not in check_columns:
|
||||
parameters.pop(col_key)
|
||||
values.append((c, None))
|
||||
else:
|
||||
_append_param_insert_select_hasdefault(
|
||||
compiler, stmt, c, add_select_cols, kw)
|
||||
|
||||
if add_select_cols:
|
||||
values.extend(add_select_cols)
|
||||
compiler._insert_from_select = compiler._insert_from_select._generate()
|
||||
compiler._insert_from_select._raw_columns = \
|
||||
tuple(compiler._insert_from_select._raw_columns) + tuple(
|
||||
expr for col, expr in add_select_cols)
|
||||
|
||||
|
||||
def _scan_cols(
|
||||
compiler, stmt, parameters, _getattr_col_key,
|
||||
_column_as_key, _col_bind_name, check_columns, values, kw):
|
||||
|
||||
need_pks, implicit_returning, \
|
||||
implicit_return_defaults, postfetch_lastrowid = \
|
||||
_get_returning_modifiers(compiler, stmt)
|
||||
|
||||
if stmt._parameter_ordering:
|
||||
parameter_ordering = [
|
||||
_column_as_key(key) for key in stmt._parameter_ordering
|
||||
]
|
||||
ordered_keys = set(parameter_ordering)
|
||||
cols = [
|
||||
stmt.table.c[key] for key in parameter_ordering
|
||||
] + [
|
||||
c for c in stmt.table.c if c.key not in ordered_keys
|
||||
]
|
||||
else:
|
||||
cols = stmt.table.columns
|
||||
|
||||
for c in cols:
|
||||
col_key = _getattr_col_key(c)
|
||||
|
||||
if col_key in parameters and col_key not in check_columns:
|
||||
|
||||
_append_param_parameter(
|
||||
compiler, stmt, c, col_key, parameters, _col_bind_name,
|
||||
implicit_returning, implicit_return_defaults, values, kw)
|
||||
|
||||
elif compiler.isinsert:
|
||||
if c.primary_key and \
|
||||
need_pks and \
|
||||
(
|
||||
implicit_returning or
|
||||
not postfetch_lastrowid or
|
||||
c is not stmt.table._autoincrement_column
|
||||
):
|
||||
|
||||
if implicit_returning:
|
||||
_append_param_insert_pk_returning(
|
||||
compiler, stmt, c, values, kw)
|
||||
else:
|
||||
_append_param_insert_pk(compiler, stmt, c, values, kw)
|
||||
|
||||
elif c.default is not None:
|
||||
|
||||
_append_param_insert_hasdefault(
|
||||
compiler, stmt, c, implicit_return_defaults,
|
||||
values, kw)
|
||||
|
||||
elif c.server_default is not None:
|
||||
if implicit_return_defaults and \
|
||||
c in implicit_return_defaults:
|
||||
compiler.returning.append(c)
|
||||
elif not c.primary_key:
|
||||
compiler.postfetch.append(c)
|
||||
elif implicit_return_defaults and \
|
||||
c in implicit_return_defaults:
|
||||
compiler.returning.append(c)
|
||||
elif c.primary_key and \
|
||||
c is not stmt.table._autoincrement_column and \
|
||||
not c.nullable:
|
||||
_warn_pk_with_no_anticipated_value(c)
|
||||
|
||||
elif compiler.isupdate:
|
||||
_append_param_update(
|
||||
compiler, stmt, c, implicit_return_defaults, values, kw)
|
||||
|
||||
|
||||
def _append_param_parameter(
|
||||
compiler, stmt, c, col_key, parameters, _col_bind_name,
|
||||
implicit_returning, implicit_return_defaults, values, kw):
|
||||
value = parameters.pop(col_key)
|
||||
if elements._is_literal(value):
|
||||
value = _create_bind_param(
|
||||
compiler, c, value, required=value is REQUIRED,
|
||||
name=_col_bind_name(c)
|
||||
if not stmt._has_multi_parameters
|
||||
else "%s_m0" % _col_bind_name(c),
|
||||
**kw
|
||||
)
|
||||
else:
|
||||
if isinstance(value, elements.BindParameter) and \
|
||||
value.type._isnull:
|
||||
value = value._clone()
|
||||
value.type = c.type
|
||||
|
||||
if c.primary_key and implicit_returning:
|
||||
compiler.returning.append(c)
|
||||
value = compiler.process(value.self_group(), **kw)
|
||||
elif implicit_return_defaults and \
|
||||
c in implicit_return_defaults:
|
||||
compiler.returning.append(c)
|
||||
value = compiler.process(value.self_group(), **kw)
|
||||
else:
|
||||
compiler.postfetch.append(c)
|
||||
value = compiler.process(value.self_group(), **kw)
|
||||
values.append((c, value))
|
||||
|
||||
|
||||
def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
|
||||
"""Create a primary key expression in the INSERT statement and
|
||||
possibly a RETURNING clause for it.
|
||||
|
||||
If the column has a Python-side default, we will create a bound
|
||||
parameter for it and "pre-execute" the Python function. If
|
||||
the column has a SQL expression default, or is a sequence,
|
||||
we will add it directly into the INSERT statement and add a
|
||||
RETURNING element to get the new value. If the column has a
|
||||
server side default or is marked as the "autoincrement" column,
|
||||
we will add a RETRUNING element to get at the value.
|
||||
|
||||
If all the above tests fail, that indicates a primary key column with no
|
||||
noted default generation capabilities that has no parameter passed;
|
||||
raise an exception.
|
||||
|
||||
"""
|
||||
if c.default is not None:
|
||||
if c.default.is_sequence:
|
||||
if compiler.dialect.supports_sequences and \
|
||||
(not c.default.optional or
|
||||
not compiler.dialect.sequences_optional):
|
||||
proc = compiler.process(c.default, **kw)
|
||||
values.append((c, proc))
|
||||
compiler.returning.append(c)
|
||||
elif c.default.is_clause_element:
|
||||
values.append(
|
||||
(c, compiler.process(
|
||||
c.default.arg.self_group(), **kw))
|
||||
)
|
||||
compiler.returning.append(c)
|
||||
else:
|
||||
values.append(
|
||||
(c, _create_insert_prefetch_bind_param(compiler, c))
|
||||
)
|
||||
elif c is stmt.table._autoincrement_column or c.server_default is not None:
|
||||
compiler.returning.append(c)
|
||||
elif not c.nullable:
|
||||
# no .default, no .server_default, not autoincrement, we have
|
||||
# no indication this primary key column will have any value
|
||||
_warn_pk_with_no_anticipated_value(c)
|
||||
|
||||
|
||||
def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None):
|
||||
param = _create_bind_param(compiler, c, None, process=process, name=name)
|
||||
compiler.insert_prefetch.append(c)
|
||||
return param
|
||||
|
||||
|
||||
def _create_update_prefetch_bind_param(compiler, c, process=True, name=None):
|
||||
param = _create_bind_param(compiler, c, None, process=process, name=name)
|
||||
compiler.update_prefetch.append(c)
|
||||
return param
|
||||
|
||||
|
||||
class _multiparam_column(elements.ColumnElement):
|
||||
def __init__(self, original, index):
|
||||
self.key = "%s_m%d" % (original.key, index + 1)
|
||||
self.original = original
|
||||
self.default = original.default
|
||||
self.type = original.type
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, _multiparam_column) and \
|
||||
other.key == self.key and \
|
||||
other.original == self.original
|
||||
|
||||
|
||||
def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
|
||||
|
||||
if not c.default:
|
||||
raise exc.CompileError(
|
||||
"INSERT value for column %s is explicitly rendered as a bound"
|
||||
"parameter in the VALUES clause; "
|
||||
"a Python-side value or SQL expression is required" % c)
|
||||
elif c.default.is_clause_element:
|
||||
return compiler.process(c.default.arg.self_group(), **kw)
|
||||
else:
|
||||
col = _multiparam_column(c, index)
|
||||
if isinstance(stmt, dml.Insert):
|
||||
return _create_insert_prefetch_bind_param(compiler, col)
|
||||
else:
|
||||
return _create_update_prefetch_bind_param(compiler, col)
|
||||
|
||||
|
||||
def _append_param_insert_pk(compiler, stmt, c, values, kw):
|
||||
"""Create a bound parameter in the INSERT statement to receive a
|
||||
'prefetched' default value.
|
||||
|
||||
The 'prefetched' value indicates that we are to invoke a Python-side
|
||||
default function or expliclt SQL expression before the INSERT statement
|
||||
proceeds, so that we have a primary key value available.
|
||||
|
||||
if the column has no noted default generation capabilities, it has
|
||||
no value passed in either; raise an exception.
|
||||
|
||||
"""
|
||||
if (
|
||||
(
|
||||
# column has a Python-side default
|
||||
c.default is not None and
|
||||
(
|
||||
# and it won't be a Sequence
|
||||
not c.default.is_sequence or
|
||||
compiler.dialect.supports_sequences
|
||||
)
|
||||
)
|
||||
or
|
||||
(
|
||||
# column is the "autoincrement column"
|
||||
c is stmt.table._autoincrement_column and
|
||||
(
|
||||
# and it's either a "sequence" or a
|
||||
# pre-executable "autoincrement" sequence
|
||||
compiler.dialect.supports_sequences or
|
||||
compiler.dialect.preexecute_autoincrement_sequences
|
||||
)
|
||||
)
|
||||
):
|
||||
values.append(
|
||||
(c, _create_insert_prefetch_bind_param(compiler, c))
|
||||
)
|
||||
elif c.default is None and c.server_default is None and not c.nullable:
|
||||
# no .default, no .server_default, not autoincrement, we have
|
||||
# no indication this primary key column will have any value
|
||||
_warn_pk_with_no_anticipated_value(c)
|
||||
|
||||
|
||||
def _append_param_insert_hasdefault(
|
||||
compiler, stmt, c, implicit_return_defaults, values, kw):
|
||||
|
||||
if c.default.is_sequence:
|
||||
if compiler.dialect.supports_sequences and \
|
||||
(not c.default.optional or
|
||||
not compiler.dialect.sequences_optional):
|
||||
proc = compiler.process(c.default, **kw)
|
||||
values.append((c, proc))
|
||||
if implicit_return_defaults and \
|
||||
c in implicit_return_defaults:
|
||||
compiler.returning.append(c)
|
||||
elif not c.primary_key:
|
||||
compiler.postfetch.append(c)
|
||||
elif c.default.is_clause_element:
|
||||
proc = compiler.process(c.default.arg.self_group(), **kw)
|
||||
values.append((c, proc))
|
||||
|
||||
if implicit_return_defaults and \
|
||||
c in implicit_return_defaults:
|
||||
compiler.returning.append(c)
|
||||
elif not c.primary_key:
|
||||
# don't add primary key column to postfetch
|
||||
compiler.postfetch.append(c)
|
||||
else:
|
||||
values.append(
|
||||
(c, _create_insert_prefetch_bind_param(compiler, c))
|
||||
)
|
||||
|
||||
|
||||
def _append_param_insert_select_hasdefault(
|
||||
compiler, stmt, c, values, kw):
|
||||
|
||||
if c.default.is_sequence:
|
||||
if compiler.dialect.supports_sequences and \
|
||||
(not c.default.optional or
|
||||
not compiler.dialect.sequences_optional):
|
||||
proc = c.default
|
||||
values.append((c, proc.next_value()))
|
||||
elif c.default.is_clause_element:
|
||||
proc = c.default.arg.self_group()
|
||||
values.append((c, proc))
|
||||
else:
|
||||
values.append(
|
||||
(c, _create_insert_prefetch_bind_param(compiler, c, process=False))
|
||||
)
|
||||
|
||||
|
||||
def _append_param_update(
|
||||
compiler, stmt, c, implicit_return_defaults, values, kw):
|
||||
|
||||
if c.onupdate is not None and not c.onupdate.is_sequence:
|
||||
if c.onupdate.is_clause_element:
|
||||
values.append(
|
||||
(c, compiler.process(
|
||||
c.onupdate.arg.self_group(), **kw))
|
||||
)
|
||||
if implicit_return_defaults and \
|
||||
c in implicit_return_defaults:
|
||||
compiler.returning.append(c)
|
||||
else:
|
||||
compiler.postfetch.append(c)
|
||||
else:
|
||||
values.append(
|
||||
(c, _create_update_prefetch_bind_param(compiler, c))
|
||||
)
|
||||
elif c.server_onupdate is not None:
|
||||
if implicit_return_defaults and \
|
||||
c in implicit_return_defaults:
|
||||
compiler.returning.append(c)
|
||||
else:
|
||||
compiler.postfetch.append(c)
|
||||
elif implicit_return_defaults and \
|
||||
stmt._return_defaults is not True and \
|
||||
c in implicit_return_defaults:
|
||||
compiler.returning.append(c)
|
||||
|
||||
|
||||
def _get_multitable_params(
|
||||
compiler, stmt, stmt_parameters, check_columns,
|
||||
_col_bind_name, _getattr_col_key, values, kw):
|
||||
|
||||
normalized_params = dict(
|
||||
(elements._clause_element_as_expr(c), param)
|
||||
for c, param in stmt_parameters.items()
|
||||
)
|
||||
affected_tables = set()
|
||||
for t in stmt._extra_froms:
|
||||
for c in t.c:
|
||||
if c in normalized_params:
|
||||
affected_tables.add(t)
|
||||
check_columns[_getattr_col_key(c)] = c
|
||||
value = normalized_params[c]
|
||||
if elements._is_literal(value):
|
||||
value = _create_bind_param(
|
||||
compiler, c, value, required=value is REQUIRED,
|
||||
name=_col_bind_name(c))
|
||||
else:
|
||||
compiler.postfetch.append(c)
|
||||
value = compiler.process(value.self_group(), **kw)
|
||||
values.append((c, value))
|
||||
# determine tables which are actually to be updated - process onupdate
|
||||
# and server_onupdate for these
|
||||
for t in affected_tables:
|
||||
for c in t.c:
|
||||
if c in normalized_params:
|
||||
continue
|
||||
elif (c.onupdate is not None and not
|
||||
c.onupdate.is_sequence):
|
||||
if c.onupdate.is_clause_element:
|
||||
values.append(
|
||||
(c, compiler.process(
|
||||
c.onupdate.arg.self_group(),
|
||||
**kw)
|
||||
)
|
||||
)
|
||||
compiler.postfetch.append(c)
|
||||
else:
|
||||
values.append(
|
||||
(c, _create_update_prefetch_bind_param(
|
||||
compiler, c, name=_col_bind_name(c)))
|
||||
)
|
||||
elif c.server_onupdate is not None:
|
||||
compiler.postfetch.append(c)
|
||||
|
||||
|
||||
def _extend_values_for_multiparams(compiler, stmt, values, kw):
|
||||
values_0 = values
|
||||
values = [values]
|
||||
|
||||
values.extend(
|
||||
[
|
||||
(
|
||||
c,
|
||||
(_create_bind_param(
|
||||
compiler, c, row[c.key],
|
||||
name="%s_m%d" % (c.key, i + 1), **kw
|
||||
) if elements._is_literal(row[c.key])
|
||||
else compiler.process(
|
||||
row[c.key].self_group(), **kw))
|
||||
if c.key in row else
|
||||
_process_multiparam_default_bind(compiler, stmt, c, i, kw)
|
||||
)
|
||||
for (c, param) in values_0
|
||||
]
|
||||
for i, row in enumerate(stmt.parameters[1:])
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
def _get_stmt_parameters_params(
|
||||
compiler, parameters, stmt_parameters, _column_as_key, values, kw):
|
||||
for k, v in stmt_parameters.items():
|
||||
colkey = _column_as_key(k)
|
||||
if colkey is not None:
|
||||
parameters.setdefault(colkey, v)
|
||||
else:
|
||||
# a non-Column expression on the left side;
|
||||
# add it to values() in an "as-is" state,
|
||||
# coercing right side to bound param
|
||||
if elements._is_literal(v):
|
||||
v = compiler.process(
|
||||
elements.BindParameter(None, v, type_=k.type),
|
||||
**kw)
|
||||
else:
|
||||
v = compiler.process(v.self_group(), **kw)
|
||||
|
||||
values.append((k, v))
|
||||
|
||||
|
||||
def _get_returning_modifiers(compiler, stmt):
|
||||
need_pks = compiler.isinsert and \
|
||||
not compiler.inline and \
|
||||
not stmt._returning and \
|
||||
not stmt._has_multi_parameters
|
||||
|
||||
implicit_returning = need_pks and \
|
||||
compiler.dialect.implicit_returning and \
|
||||
stmt.table.implicit_returning
|
||||
|
||||
if compiler.isinsert:
|
||||
implicit_return_defaults = (implicit_returning and
|
||||
stmt._return_defaults)
|
||||
elif compiler.isupdate:
|
||||
implicit_return_defaults = (compiler.dialect.implicit_returning and
|
||||
stmt.table.implicit_returning and
|
||||
stmt._return_defaults)
|
||||
else:
|
||||
# this line is unused, currently we are always
|
||||
# isinsert or isupdate
|
||||
implicit_return_defaults = False # pragma: no cover
|
||||
|
||||
if implicit_return_defaults:
|
||||
if stmt._return_defaults is True:
|
||||
implicit_return_defaults = set(stmt.table.c)
|
||||
else:
|
||||
implicit_return_defaults = set(stmt._return_defaults)
|
||||
|
||||
postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
|
||||
|
||||
return need_pks, implicit_returning, \
|
||||
implicit_return_defaults, postfetch_lastrowid
|
||||
|
||||
|
||||
def _warn_pk_with_no_anticipated_value(c):
|
||||
msg = (
|
||||
"Column '%s.%s' is marked as a member of the "
|
||||
"primary key for table '%s', "
|
||||
"but has no Python-side or server-side default generator indicated, "
|
||||
"nor does it indicate 'autoincrement=True' or 'nullable=True', "
|
||||
"and no explicit value is passed. "
|
||||
"Primary key columns typically may not store NULL."
|
||||
%
|
||||
(c.table.fullname, c.name, c.table.fullname))
|
||||
if len(c.table.primary_key) > 1:
|
||||
msg += (
|
||||
" Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be "
|
||||
"indicated explicitly for composite (e.g. multicolumn) primary "
|
||||
"keys if AUTO_INCREMENT/SERIAL/IDENTITY "
|
||||
"behavior is expected for one of the columns in the primary key. "
|
||||
"CREATE TABLE statements are impacted by this change as well on "
|
||||
"most backends.")
|
||||
util.warn(msg)
|
||||
1100
sqlalchemy/sql/ddl.py
Normal file
1100
sqlalchemy/sql/ddl.py
Normal file
File diff suppressed because it is too large
Load Diff
308
sqlalchemy/sql/default_comparator.py
Normal file
308
sqlalchemy/sql/default_comparator.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# sql/default_comparator.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Default implementation of SQL comparison operations.
|
||||
"""
|
||||
|
||||
from .. import exc, util
|
||||
from . import type_api
|
||||
from . import operators
|
||||
from .elements import BindParameter, True_, False_, BinaryExpression, \
|
||||
Null, _const_expr, _clause_element_as_expr, \
|
||||
ClauseList, ColumnElement, TextClause, UnaryExpression, \
|
||||
collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \
|
||||
Slice, Visitable, _literal_as_binds
|
||||
from .selectable import SelectBase, Alias, Selectable, ScalarSelect
|
||||
|
||||
|
||||
def _boolean_compare(expr, op, obj, negate=None, reverse=False,
|
||||
_python_is_types=(util.NoneType, bool),
|
||||
result_type = None,
|
||||
**kwargs):
|
||||
|
||||
if result_type is None:
|
||||
result_type = type_api.BOOLEANTYPE
|
||||
|
||||
if isinstance(obj, _python_is_types + (Null, True_, False_)):
|
||||
|
||||
# allow x ==/!= True/False to be treated as a literal.
|
||||
# this comes out to "== / != true/false" or "1/0" if those
|
||||
# constants aren't supported and works on all platforms
|
||||
if op in (operators.eq, operators.ne) and \
|
||||
isinstance(obj, (bool, True_, False_)):
|
||||
return BinaryExpression(expr,
|
||||
_literal_as_text(obj),
|
||||
op,
|
||||
type_=result_type,
|
||||
negate=negate, modifiers=kwargs)
|
||||
elif op in (operators.is_distinct_from, operators.isnot_distinct_from):
|
||||
return BinaryExpression(expr,
|
||||
_literal_as_text(obj),
|
||||
op,
|
||||
type_=result_type,
|
||||
negate=negate, modifiers=kwargs)
|
||||
else:
|
||||
# all other None/True/False uses IS, IS NOT
|
||||
if op in (operators.eq, operators.is_):
|
||||
return BinaryExpression(expr, _const_expr(obj),
|
||||
operators.is_,
|
||||
negate=operators.isnot)
|
||||
elif op in (operators.ne, operators.isnot):
|
||||
return BinaryExpression(expr, _const_expr(obj),
|
||||
operators.isnot,
|
||||
negate=operators.is_)
|
||||
else:
|
||||
raise exc.ArgumentError(
|
||||
"Only '=', '!=', 'is_()', 'isnot()', "
|
||||
"'is_distinct_from()', 'isnot_distinct_from()' "
|
||||
"operators can be used with None/True/False")
|
||||
else:
|
||||
obj = _check_literal(expr, op, obj)
|
||||
|
||||
if reverse:
|
||||
return BinaryExpression(obj,
|
||||
expr,
|
||||
op,
|
||||
type_=result_type,
|
||||
negate=negate, modifiers=kwargs)
|
||||
else:
|
||||
return BinaryExpression(expr,
|
||||
obj,
|
||||
op,
|
||||
type_=result_type,
|
||||
negate=negate, modifiers=kwargs)
|
||||
|
||||
|
||||
def _binary_operate(expr, op, obj, reverse=False, result_type=None,
|
||||
**kw):
|
||||
obj = _check_literal(expr, op, obj)
|
||||
|
||||
if reverse:
|
||||
left, right = obj, expr
|
||||
else:
|
||||
left, right = expr, obj
|
||||
|
||||
if result_type is None:
|
||||
op, result_type = left.comparator._adapt_expression(
|
||||
op, right.comparator)
|
||||
|
||||
return BinaryExpression(
|
||||
left, right, op, type_=result_type, modifiers=kw)
|
||||
|
||||
|
||||
def _conjunction_operate(expr, op, other, **kw):
|
||||
if op is operators.and_:
|
||||
return and_(expr, other)
|
||||
elif op is operators.or_:
|
||||
return or_(expr, other)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _scalar(expr, op, fn, **kw):
|
||||
return fn(expr)
|
||||
|
||||
|
||||
def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
|
||||
seq_or_selectable = _clause_element_as_expr(seq_or_selectable)
|
||||
|
||||
if isinstance(seq_or_selectable, ScalarSelect):
|
||||
return _boolean_compare(expr, op, seq_or_selectable,
|
||||
negate=negate_op)
|
||||
elif isinstance(seq_or_selectable, SelectBase):
|
||||
|
||||
# TODO: if we ever want to support (x, y, z) IN (select x,
|
||||
# y, z from table), we would need a multi-column version of
|
||||
# as_scalar() to produce a multi- column selectable that
|
||||
# does not export itself as a FROM clause
|
||||
|
||||
return _boolean_compare(
|
||||
expr, op, seq_or_selectable.as_scalar(),
|
||||
negate=negate_op, **kw)
|
||||
elif isinstance(seq_or_selectable, (Selectable, TextClause)):
|
||||
return _boolean_compare(expr, op, seq_or_selectable,
|
||||
negate=negate_op, **kw)
|
||||
elif isinstance(seq_or_selectable, ClauseElement):
|
||||
raise exc.InvalidRequestError(
|
||||
'in_() accepts'
|
||||
' either a list of expressions '
|
||||
'or a selectable: %r' % seq_or_selectable)
|
||||
|
||||
# Handle non selectable arguments as sequences
|
||||
args = []
|
||||
for o in seq_or_selectable:
|
||||
if not _is_literal(o):
|
||||
if not isinstance(o, operators.ColumnOperators):
|
||||
raise exc.InvalidRequestError(
|
||||
'in_() accepts'
|
||||
' either a list of expressions '
|
||||
'or a selectable: %r' % o)
|
||||
elif o is None:
|
||||
o = Null()
|
||||
else:
|
||||
o = expr._bind_param(op, o)
|
||||
args.append(o)
|
||||
if len(args) == 0:
|
||||
|
||||
# Special case handling for empty IN's, behave like
|
||||
# comparison against zero row selectable. We use != to
|
||||
# build the contradiction as it handles NULL values
|
||||
# appropriately, i.e. "not (x IN ())" should not return NULL
|
||||
# values for x.
|
||||
|
||||
util.warn('The IN-predicate on "%s" was invoked with an '
|
||||
'empty sequence. This results in a '
|
||||
'contradiction, which nonetheless can be '
|
||||
'expensive to evaluate. Consider alternative '
|
||||
'strategies for improved performance.' % expr)
|
||||
if op is operators.in_op:
|
||||
return expr != expr
|
||||
else:
|
||||
return expr == expr
|
||||
|
||||
return _boolean_compare(expr, op,
|
||||
ClauseList(*args).self_group(against=op),
|
||||
negate=negate_op)
|
||||
|
||||
|
||||
def _getitem_impl(expr, op, other, **kw):
|
||||
if isinstance(expr.type, type_api.INDEXABLE):
|
||||
other = _check_literal(expr, op, other)
|
||||
return _binary_operate(expr, op, other, **kw)
|
||||
else:
|
||||
_unsupported_impl(expr, op, other, **kw)
|
||||
|
||||
|
||||
def _unsupported_impl(expr, op, *arg, **kw):
|
||||
raise NotImplementedError("Operator '%s' is not supported on "
|
||||
"this expression" % op.__name__)
|
||||
|
||||
|
||||
def _inv_impl(expr, op, **kw):
|
||||
"""See :meth:`.ColumnOperators.__inv__`."""
|
||||
if hasattr(expr, 'negation_clause'):
|
||||
return expr.negation_clause
|
||||
else:
|
||||
return expr._negate()
|
||||
|
||||
|
||||
def _neg_impl(expr, op, **kw):
|
||||
"""See :meth:`.ColumnOperators.__neg__`."""
|
||||
return UnaryExpression(expr, operator=operators.neg, type_=expr.type)
|
||||
|
||||
|
||||
def _match_impl(expr, op, other, **kw):
|
||||
"""See :meth:`.ColumnOperators.match`."""
|
||||
|
||||
return _boolean_compare(
|
||||
expr, operators.match_op,
|
||||
_check_literal(
|
||||
expr, operators.match_op, other),
|
||||
result_type=type_api.MATCHTYPE,
|
||||
negate=operators.notmatch_op
|
||||
if op is operators.match_op else operators.match_op,
|
||||
**kw
|
||||
)
|
||||
|
||||
|
||||
def _distinct_impl(expr, op, **kw):
|
||||
"""See :meth:`.ColumnOperators.distinct`."""
|
||||
return UnaryExpression(expr, operator=operators.distinct_op,
|
||||
type_=expr.type)
|
||||
|
||||
|
||||
def _between_impl(expr, op, cleft, cright, **kw):
|
||||
"""See :meth:`.ColumnOperators.between`."""
|
||||
return BinaryExpression(
|
||||
expr,
|
||||
ClauseList(
|
||||
_check_literal(expr, operators.and_, cleft),
|
||||
_check_literal(expr, operators.and_, cright),
|
||||
operator=operators.and_,
|
||||
group=False, group_contents=False),
|
||||
op,
|
||||
negate=operators.notbetween_op
|
||||
if op is operators.between_op
|
||||
else operators.between_op,
|
||||
modifiers=kw)
|
||||
|
||||
|
||||
def _collate_impl(expr, op, other, **kw):
|
||||
return collate(expr, other)
|
||||
|
||||
# a mapping of operators with the method they use, along with
|
||||
# their negated operator for comparison operators
|
||||
operator_lookup = {
|
||||
"and_": (_conjunction_operate,),
|
||||
"or_": (_conjunction_operate,),
|
||||
"inv": (_inv_impl,),
|
||||
"add": (_binary_operate,),
|
||||
"mul": (_binary_operate,),
|
||||
"sub": (_binary_operate,),
|
||||
"div": (_binary_operate,),
|
||||
"mod": (_binary_operate,),
|
||||
"truediv": (_binary_operate,),
|
||||
"custom_op": (_binary_operate,),
|
||||
"json_path_getitem_op": (_binary_operate, ),
|
||||
"json_getitem_op": (_binary_operate, ),
|
||||
"concat_op": (_binary_operate,),
|
||||
"lt": (_boolean_compare, operators.ge),
|
||||
"le": (_boolean_compare, operators.gt),
|
||||
"ne": (_boolean_compare, operators.eq),
|
||||
"gt": (_boolean_compare, operators.le),
|
||||
"ge": (_boolean_compare, operators.lt),
|
||||
"eq": (_boolean_compare, operators.ne),
|
||||
"is_distinct_from": (_boolean_compare, operators.isnot_distinct_from),
|
||||
"isnot_distinct_from": (_boolean_compare, operators.is_distinct_from),
|
||||
"like_op": (_boolean_compare, operators.notlike_op),
|
||||
"ilike_op": (_boolean_compare, operators.notilike_op),
|
||||
"notlike_op": (_boolean_compare, operators.like_op),
|
||||
"notilike_op": (_boolean_compare, operators.ilike_op),
|
||||
"contains_op": (_boolean_compare, operators.notcontains_op),
|
||||
"startswith_op": (_boolean_compare, operators.notstartswith_op),
|
||||
"endswith_op": (_boolean_compare, operators.notendswith_op),
|
||||
"desc_op": (_scalar, UnaryExpression._create_desc),
|
||||
"asc_op": (_scalar, UnaryExpression._create_asc),
|
||||
"nullsfirst_op": (_scalar, UnaryExpression._create_nullsfirst),
|
||||
"nullslast_op": (_scalar, UnaryExpression._create_nullslast),
|
||||
"in_op": (_in_impl, operators.notin_op),
|
||||
"notin_op": (_in_impl, operators.in_op),
|
||||
"is_": (_boolean_compare, operators.is_),
|
||||
"isnot": (_boolean_compare, operators.isnot),
|
||||
"collate": (_collate_impl,),
|
||||
"match_op": (_match_impl,),
|
||||
"notmatch_op": (_match_impl,),
|
||||
"distinct_op": (_distinct_impl,),
|
||||
"between_op": (_between_impl, ),
|
||||
"notbetween_op": (_between_impl, ),
|
||||
"neg": (_neg_impl,),
|
||||
"getitem": (_getitem_impl,),
|
||||
"lshift": (_unsupported_impl,),
|
||||
"rshift": (_unsupported_impl,),
|
||||
"contains": (_unsupported_impl,),
|
||||
}
|
||||
|
||||
|
||||
def _check_literal(expr, operator, other, bindparam_type=None):
|
||||
if isinstance(other, (ColumnElement, TextClause)):
|
||||
if isinstance(other, BindParameter) and \
|
||||
other.type._isnull:
|
||||
other = other._clone()
|
||||
other.type = expr.type
|
||||
return other
|
||||
elif hasattr(other, '__clause_element__'):
|
||||
other = other.__clause_element__()
|
||||
elif isinstance(other, type_api.TypeEngine.Comparator):
|
||||
other = other.expr
|
||||
|
||||
if isinstance(other, (SelectBase, Alias)):
|
||||
return other.as_scalar()
|
||||
elif not isinstance(other, Visitable):
|
||||
return expr._bind_param(operator, other, type_=bindparam_type)
|
||||
else:
|
||||
return other
|
||||
|
||||
851
sqlalchemy/sql/dml.py
Normal file
851
sqlalchemy/sql/dml.py
Normal file
@@ -0,0 +1,851 @@
|
||||
# sql/dml.py
|
||||
# Copyright (C) 2009-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
"""
|
||||
Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`.
|
||||
|
||||
"""
|
||||
|
||||
from .base import Executable, _generative, _from_objects, DialectKWArgs, \
|
||||
ColumnCollection
|
||||
from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \
|
||||
_column_as_key
|
||||
from .selectable import _interpret_as_from, _interpret_as_select, \
|
||||
HasPrefixes, HasCTE
|
||||
from .. import util
|
||||
from .. import exc
|
||||
|
||||
|
||||
class UpdateBase(
|
||||
HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement):
|
||||
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'update_base'
|
||||
|
||||
_execution_options = \
|
||||
Executable._execution_options.union({'autocommit': True})
|
||||
_hints = util.immutabledict()
|
||||
_parameter_ordering = None
|
||||
_prefixes = ()
|
||||
named_with_column = False
|
||||
|
||||
def _process_colparams(self, parameters):
|
||||
def process_single(p):
|
||||
if isinstance(p, (list, tuple)):
|
||||
return dict(
|
||||
(c.key, pval)
|
||||
for c, pval in zip(self.table.c, p)
|
||||
)
|
||||
else:
|
||||
return p
|
||||
|
||||
if self._preserve_parameter_order and parameters is not None:
|
||||
if not isinstance(parameters, list) or \
|
||||
(parameters and not isinstance(parameters[0], tuple)):
|
||||
raise ValueError(
|
||||
"When preserve_parameter_order is True, "
|
||||
"values() only accepts a list of 2-tuples")
|
||||
self._parameter_ordering = [key for key, value in parameters]
|
||||
|
||||
return dict(parameters), False
|
||||
|
||||
if (isinstance(parameters, (list, tuple)) and parameters and
|
||||
isinstance(parameters[0], (list, tuple, dict))):
|
||||
|
||||
if not self._supports_multi_parameters:
|
||||
raise exc.InvalidRequestError(
|
||||
"This construct does not support "
|
||||
"multiple parameter sets.")
|
||||
|
||||
return [process_single(p) for p in parameters], True
|
||||
else:
|
||||
return process_single(parameters), False
|
||||
|
||||
def params(self, *arg, **kw):
|
||||
"""Set the parameters for the statement.
|
||||
|
||||
This method raises ``NotImplementedError`` on the base class,
|
||||
and is overridden by :class:`.ValuesBase` to provide the
|
||||
SET/VALUES clause of UPDATE and INSERT.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"params() is not supported for INSERT/UPDATE/DELETE statements."
|
||||
" To set the values for an INSERT or UPDATE statement, use"
|
||||
" stmt.values(**parameters).")
|
||||
|
||||
def bind(self):
|
||||
"""Return a 'bind' linked to this :class:`.UpdateBase`
|
||||
or a :class:`.Table` associated with it.
|
||||
|
||||
"""
|
||||
return self._bind or self.table.bind
|
||||
|
||||
def _set_bind(self, bind):
|
||||
self._bind = bind
|
||||
bind = property(bind, _set_bind)
|
||||
|
||||
@_generative
|
||||
def returning(self, *cols):
|
||||
r"""Add a :term:`RETURNING` or equivalent clause to this statement.
|
||||
|
||||
e.g.::
|
||||
|
||||
stmt = table.update().\
|
||||
where(table.c.data == 'value').\
|
||||
values(status='X').\
|
||||
returning(table.c.server_flag,
|
||||
table.c.updated_timestamp)
|
||||
|
||||
for server_flag, updated_timestamp in connection.execute(stmt):
|
||||
print(server_flag, updated_timestamp)
|
||||
|
||||
The given collection of column expressions should be derived from
|
||||
the table that is
|
||||
the target of the INSERT, UPDATE, or DELETE. While :class:`.Column`
|
||||
objects are typical, the elements can also be expressions::
|
||||
|
||||
stmt = table.insert().returning(
|
||||
(table.c.first_name + " " + table.c.last_name).
|
||||
label('fullname'))
|
||||
|
||||
Upon compilation, a RETURNING clause, or database equivalent,
|
||||
will be rendered within the statement. For INSERT and UPDATE,
|
||||
the values are the newly inserted/updated values. For DELETE,
|
||||
the values are those of the rows which were deleted.
|
||||
|
||||
Upon execution, the values of the columns to be returned are made
|
||||
available via the result set and can be iterated using
|
||||
:meth:`.ResultProxy.fetchone` and similar. For DBAPIs which do not
|
||||
natively support returning values (i.e. cx_oracle), SQLAlchemy will
|
||||
approximate this behavior at the result level so that a reasonable
|
||||
amount of behavioral neutrality is provided.
|
||||
|
||||
Note that not all databases/DBAPIs
|
||||
support RETURNING. For those backends with no support,
|
||||
an exception is raised upon compilation and/or execution.
|
||||
For those who do support it, the functionality across backends
|
||||
varies greatly, including restrictions on executemany()
|
||||
and other statements which return multiple rows. Please
|
||||
read the documentation notes for the database in use in
|
||||
order to determine the availability of RETURNING.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`.ValuesBase.return_defaults` - an alternative method tailored
|
||||
towards efficient fetching of server-side defaults and triggers
|
||||
for single-row INSERTs or UPDATEs.
|
||||
|
||||
|
||||
"""
|
||||
self._returning = cols
|
||||
|
||||
@_generative
|
||||
def with_hint(self, text, selectable=None, dialect_name="*"):
|
||||
"""Add a table hint for a single table to this
|
||||
INSERT/UPDATE/DELETE statement.
|
||||
|
||||
.. note::
|
||||
|
||||
:meth:`.UpdateBase.with_hint` currently applies only to
|
||||
Microsoft SQL Server. For MySQL INSERT/UPDATE/DELETE hints, use
|
||||
:meth:`.UpdateBase.prefix_with`.
|
||||
|
||||
The text of the hint is rendered in the appropriate
|
||||
location for the database backend in use, relative
|
||||
to the :class:`.Table` that is the subject of this
|
||||
statement, or optionally to that of the given
|
||||
:class:`.Table` passed as the ``selectable`` argument.
|
||||
|
||||
The ``dialect_name`` option will limit the rendering of a particular
|
||||
hint to a particular backend. Such as, to add a hint
|
||||
that only takes effect for SQL Server::
|
||||
|
||||
mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql")
|
||||
|
||||
.. versionadded:: 0.7.6
|
||||
|
||||
:param text: Text of the hint.
|
||||
:param selectable: optional :class:`.Table` that specifies
|
||||
an element of the FROM clause within an UPDATE or DELETE
|
||||
to be the subject of the hint - applies only to certain backends.
|
||||
:param dialect_name: defaults to ``*``, if specified as the name
|
||||
of a particular dialect, will apply these hints only when
|
||||
that dialect is in use.
|
||||
"""
|
||||
if selectable is None:
|
||||
selectable = self.table
|
||||
|
||||
self._hints = self._hints.union(
|
||||
{(selectable, dialect_name): text})
|
||||
|
||||
|
||||
class ValuesBase(UpdateBase):
|
||||
"""Supplies support for :meth:`.ValuesBase.values` to
|
||||
INSERT and UPDATE constructs."""
|
||||
|
||||
__visit_name__ = 'values_base'
|
||||
|
||||
_supports_multi_parameters = False
|
||||
_has_multi_parameters = False
|
||||
_preserve_parameter_order = False
|
||||
select = None
|
||||
_post_values_clause = None
|
||||
|
||||
def __init__(self, table, values, prefixes):
|
||||
self.table = _interpret_as_from(table)
|
||||
self.parameters, self._has_multi_parameters = \
|
||||
self._process_colparams(values)
|
||||
if prefixes:
|
||||
self._setup_prefixes(prefixes)
|
||||
|
||||
@_generative
|
||||
def values(self, *args, **kwargs):
|
||||
r"""specify a fixed VALUES clause for an INSERT statement, or the SET
|
||||
clause for an UPDATE.
|
||||
|
||||
Note that the :class:`.Insert` and :class:`.Update` constructs support
|
||||
per-execution time formatting of the VALUES and/or SET clauses,
|
||||
based on the arguments passed to :meth:`.Connection.execute`.
|
||||
However, the :meth:`.ValuesBase.values` method can be used to "fix" a
|
||||
particular set of parameters into the statement.
|
||||
|
||||
Multiple calls to :meth:`.ValuesBase.values` will produce a new
|
||||
construct, each one with the parameter list modified to include
|
||||
the new parameters sent. In the typical case of a single
|
||||
dictionary of parameters, the newly passed keys will replace
|
||||
the same keys in the previous construct. In the case of a list-based
|
||||
"multiple values" construct, each new list of values is extended
|
||||
onto the existing list of values.
|
||||
|
||||
:param \**kwargs: key value pairs representing the string key
|
||||
of a :class:`.Column` mapped to the value to be rendered into the
|
||||
VALUES or SET clause::
|
||||
|
||||
users.insert().values(name="some name")
|
||||
|
||||
users.update().where(users.c.id==5).values(name="some name")
|
||||
|
||||
:param \*args: As an alternative to passing key/value parameters,
|
||||
a dictionary, tuple, or list of dictionaries or tuples can be passed
|
||||
as a single positional argument in order to form the VALUES or
|
||||
SET clause of the statement. The forms that are accepted vary
|
||||
based on whether this is an :class:`.Insert` or an :class:`.Update`
|
||||
construct.
|
||||
|
||||
For either an :class:`.Insert` or :class:`.Update` construct, a
|
||||
single dictionary can be passed, which works the same as that of
|
||||
the kwargs form::
|
||||
|
||||
users.insert().values({"name": "some name"})
|
||||
|
||||
users.update().values({"name": "some new name"})
|
||||
|
||||
Also for either form but more typically for the :class:`.Insert`
|
||||
construct, a tuple that contains an entry for every column in the
|
||||
table is also accepted::
|
||||
|
||||
users.insert().values((5, "some name"))
|
||||
|
||||
The :class:`.Insert` construct also supports being passed a list
|
||||
of dictionaries or full-table-tuples, which on the server will
|
||||
render the less common SQL syntax of "multiple values" - this
|
||||
syntax is supported on backends such as SQLite, PostgreSQL, MySQL,
|
||||
but not necessarily others::
|
||||
|
||||
users.insert().values([
|
||||
{"name": "some name"},
|
||||
{"name": "some other name"},
|
||||
{"name": "yet another name"},
|
||||
])
|
||||
|
||||
The above form would render a multiple VALUES statement similar to::
|
||||
|
||||
INSERT INTO users (name) VALUES
|
||||
(:name_1),
|
||||
(:name_2),
|
||||
(:name_3)
|
||||
|
||||
It is essential to note that **passing multiple values is
|
||||
NOT the same as using traditional executemany() form**. The above
|
||||
syntax is a **special** syntax not typically used. To emit an
|
||||
INSERT statement against multiple rows, the normal method is
|
||||
to pass a multiple values list to the :meth:`.Connection.execute`
|
||||
method, which is supported by all database backends and is generally
|
||||
more efficient for a very large number of parameters.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`execute_multiple` - an introduction to
|
||||
the traditional Core method of multiple parameter set
|
||||
invocation for INSERTs and other statements.
|
||||
|
||||
.. versionchanged:: 1.0.0 an INSERT that uses a multiple-VALUES
|
||||
clause, even a list of length one,
|
||||
implies that the :paramref:`.Insert.inline` flag is set to
|
||||
True, indicating that the statement will not attempt to fetch
|
||||
the "last inserted primary key" or other defaults. The
|
||||
statement deals with an arbitrary number of rows, so the
|
||||
:attr:`.ResultProxy.inserted_primary_key` accessor does not
|
||||
apply.
|
||||
|
||||
.. versionchanged:: 1.0.0 A multiple-VALUES INSERT now supports
|
||||
columns with Python side default values and callables in the
|
||||
same way as that of an "executemany" style of invocation; the
|
||||
callable is invoked for each row. See :ref:`bug_3288`
|
||||
for other details.
|
||||
|
||||
The :class:`.Update` construct supports a special form which is a
|
||||
list of 2-tuples, which when provided must be passed in conjunction
|
||||
with the
|
||||
:paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`
|
||||
parameter.
|
||||
This form causes the UPDATE statement to render the SET clauses
|
||||
using the order of parameters given to :meth:`.Update.values`, rather
|
||||
than the ordering of columns given in the :class:`.Table`.
|
||||
|
||||
.. versionadded:: 1.0.10 - added support for parameter-ordered
|
||||
UPDATE statements via the
|
||||
:paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`
|
||||
flag.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`updates_order_parameters` - full example of the
|
||||
:paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`
|
||||
flag
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`inserts_and_updates` - SQL Expression
|
||||
Language Tutorial
|
||||
|
||||
:func:`~.expression.insert` - produce an ``INSERT`` statement
|
||||
|
||||
:func:`~.expression.update` - produce an ``UPDATE`` statement
|
||||
|
||||
"""
|
||||
if self.select is not None:
|
||||
raise exc.InvalidRequestError(
|
||||
"This construct already inserts from a SELECT")
|
||||
if self._has_multi_parameters and kwargs:
|
||||
raise exc.InvalidRequestError(
|
||||
"This construct already has multiple parameter sets.")
|
||||
|
||||
if args:
|
||||
if len(args) > 1:
|
||||
raise exc.ArgumentError(
|
||||
"Only a single dictionary/tuple or list of "
|
||||
"dictionaries/tuples is accepted positionally.")
|
||||
v = args[0]
|
||||
else:
|
||||
v = {}
|
||||
|
||||
if self.parameters is None:
|
||||
self.parameters, self._has_multi_parameters = \
|
||||
self._process_colparams(v)
|
||||
else:
|
||||
if self._has_multi_parameters:
|
||||
self.parameters = list(self.parameters)
|
||||
p, self._has_multi_parameters = self._process_colparams(v)
|
||||
if not self._has_multi_parameters:
|
||||
raise exc.ArgumentError(
|
||||
"Can't mix single-values and multiple values "
|
||||
"formats in one statement")
|
||||
|
||||
self.parameters.extend(p)
|
||||
else:
|
||||
self.parameters = self.parameters.copy()
|
||||
p, self._has_multi_parameters = self._process_colparams(v)
|
||||
if self._has_multi_parameters:
|
||||
raise exc.ArgumentError(
|
||||
"Can't mix single-values and multiple values "
|
||||
"formats in one statement")
|
||||
self.parameters.update(p)
|
||||
|
||||
if kwargs:
|
||||
if self._has_multi_parameters:
|
||||
raise exc.ArgumentError(
|
||||
"Can't pass kwargs and multiple parameter sets "
|
||||
"simultaneously")
|
||||
else:
|
||||
self.parameters.update(kwargs)
|
||||
|
||||
@_generative
|
||||
def return_defaults(self, *cols):
|
||||
"""Make use of a :term:`RETURNING` clause for the purpose
|
||||
of fetching server-side expressions and defaults.
|
||||
|
||||
E.g.::
|
||||
|
||||
stmt = table.insert().values(data='newdata').return_defaults()
|
||||
|
||||
result = connection.execute(stmt)
|
||||
|
||||
server_created_at = result.returned_defaults['created_at']
|
||||
|
||||
When used against a backend that supports RETURNING, all column
|
||||
values generated by SQL expression or server-side-default will be
|
||||
added to any existing RETURNING clause, provided that
|
||||
:meth:`.UpdateBase.returning` is not used simultaneously. The column
|
||||
values will then be available on the result using the
|
||||
:attr:`.ResultProxy.returned_defaults` accessor as a dictionary,
|
||||
referring to values keyed to the :class:`.Column` object as well as
|
||||
its ``.key``.
|
||||
|
||||
This method differs from :meth:`.UpdateBase.returning` in these ways:
|
||||
|
||||
1. :meth:`.ValuesBase.return_defaults` is only intended for use with
|
||||
an INSERT or an UPDATE statement that matches exactly one row.
|
||||
While the RETURNING construct in the general sense supports
|
||||
multiple rows for a multi-row UPDATE or DELETE statement, or for
|
||||
special cases of INSERT that return multiple rows (e.g. INSERT from
|
||||
SELECT, multi-valued VALUES clause),
|
||||
:meth:`.ValuesBase.return_defaults` is intended only for an
|
||||
"ORM-style" single-row INSERT/UPDATE statement. The row returned
|
||||
by the statement is also consumed implicitly when
|
||||
:meth:`.ValuesBase.return_defaults` is used. By contrast,
|
||||
:meth:`.UpdateBase.returning` leaves the RETURNING result-set
|
||||
intact with a collection of any number of rows.
|
||||
|
||||
2. It is compatible with the existing logic to fetch auto-generated
|
||||
primary key values, also known as "implicit returning". Backends
|
||||
that support RETURNING will automatically make use of RETURNING in
|
||||
order to fetch the value of newly generated primary keys; while the
|
||||
:meth:`.UpdateBase.returning` method circumvents this behavior,
|
||||
:meth:`.ValuesBase.return_defaults` leaves it intact.
|
||||
|
||||
3. It can be called against any backend. Backends that don't support
|
||||
RETURNING will skip the usage of the feature, rather than raising
|
||||
an exception. The return value of
|
||||
:attr:`.ResultProxy.returned_defaults` will be ``None``
|
||||
|
||||
:meth:`.ValuesBase.return_defaults` is used by the ORM to provide
|
||||
an efficient implementation for the ``eager_defaults`` feature of
|
||||
:func:`.mapper`.
|
||||
|
||||
:param cols: optional list of column key names or :class:`.Column`
|
||||
objects. If omitted, all column expressions evaluated on the server
|
||||
are added to the returning list.
|
||||
|
||||
.. versionadded:: 0.9.0
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`.UpdateBase.returning`
|
||||
|
||||
:attr:`.ResultProxy.returned_defaults`
|
||||
|
||||
"""
|
||||
self._return_defaults = cols or True
|
||||
|
||||
|
||||
class Insert(ValuesBase):
|
||||
"""Represent an INSERT construct.
|
||||
|
||||
The :class:`.Insert` object is created using the
|
||||
:func:`~.expression.insert()` function.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`coretutorial_insert_expressions`
|
||||
|
||||
"""
|
||||
__visit_name__ = 'insert'
|
||||
|
||||
_supports_multi_parameters = True
|
||||
|
||||
def __init__(self,
|
||||
table,
|
||||
values=None,
|
||||
inline=False,
|
||||
bind=None,
|
||||
prefixes=None,
|
||||
returning=None,
|
||||
return_defaults=False,
|
||||
**dialect_kw):
|
||||
"""Construct an :class:`.Insert` object.
|
||||
|
||||
Similar functionality is available via the
|
||||
:meth:`~.TableClause.insert` method on
|
||||
:class:`~.schema.Table`.
|
||||
|
||||
:param table: :class:`.TableClause` which is the subject of the
|
||||
insert.
|
||||
|
||||
:param values: collection of values to be inserted; see
|
||||
:meth:`.Insert.values` for a description of allowed formats here.
|
||||
Can be omitted entirely; a :class:`.Insert` construct will also
|
||||
dynamically render the VALUES clause at execution time based on
|
||||
the parameters passed to :meth:`.Connection.execute`.
|
||||
|
||||
:param inline: if True, no attempt will be made to retrieve the
|
||||
SQL-generated default values to be provided within the statement;
|
||||
in particular,
|
||||
this allows SQL expressions to be rendered 'inline' within the
|
||||
statement without the need to pre-execute them beforehand; for
|
||||
backends that support "returning", this turns off the "implicit
|
||||
returning" feature for the statement.
|
||||
|
||||
If both `values` and compile-time bind parameters are present, the
|
||||
compile-time bind parameters override the information specified
|
||||
within `values` on a per-key basis.
|
||||
|
||||
The keys within `values` can be either
|
||||
:class:`~sqlalchemy.schema.Column` objects or their string
|
||||
identifiers. Each key may reference one of:
|
||||
|
||||
* a literal data value (i.e. string, number, etc.);
|
||||
* a Column object;
|
||||
* a SELECT statement.
|
||||
|
||||
If a ``SELECT`` statement is specified which references this
|
||||
``INSERT`` statement's table, the statement will be correlated
|
||||
against the ``INSERT`` statement.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`coretutorial_insert_expressions` - SQL Expression Tutorial
|
||||
|
||||
:ref:`inserts_and_updates` - SQL Expression Tutorial
|
||||
|
||||
"""
|
||||
ValuesBase.__init__(self, table, values, prefixes)
|
||||
self._bind = bind
|
||||
self.select = self.select_names = None
|
||||
self.include_insert_from_select_defaults = False
|
||||
self.inline = inline
|
||||
self._returning = returning
|
||||
self._validate_dialect_kwargs(dialect_kw)
|
||||
self._return_defaults = return_defaults
|
||||
|
||||
def get_children(self, **kwargs):
|
||||
if self.select is not None:
|
||||
return self.select,
|
||||
else:
|
||||
return ()
|
||||
|
||||
@_generative
|
||||
def from_select(self, names, select, include_defaults=True):
|
||||
"""Return a new :class:`.Insert` construct which represents
|
||||
an ``INSERT...FROM SELECT`` statement.
|
||||
|
||||
e.g.::
|
||||
|
||||
sel = select([table1.c.a, table1.c.b]).where(table1.c.c > 5)
|
||||
ins = table2.insert().from_select(['a', 'b'], sel)
|
||||
|
||||
:param names: a sequence of string column names or :class:`.Column`
|
||||
objects representing the target columns.
|
||||
:param select: a :func:`.select` construct, :class:`.FromClause`
|
||||
or other construct which resolves into a :class:`.FromClause`,
|
||||
such as an ORM :class:`.Query` object, etc. The order of
|
||||
columns returned from this FROM clause should correspond to the
|
||||
order of columns sent as the ``names`` parameter; while this
|
||||
is not checked before passing along to the database, the database
|
||||
would normally raise an exception if these column lists don't
|
||||
correspond.
|
||||
:param include_defaults: if True, non-server default values and
|
||||
SQL expressions as specified on :class:`.Column` objects
|
||||
(as documented in :ref:`metadata_defaults_toplevel`) not
|
||||
otherwise specified in the list of names will be rendered
|
||||
into the INSERT and SELECT statements, so that these values are also
|
||||
included in the data to be inserted.
|
||||
|
||||
.. note:: A Python-side default that uses a Python callable function
|
||||
will only be invoked **once** for the whole statement, and **not
|
||||
per row**.
|
||||
|
||||
.. versionadded:: 1.0.0 - :meth:`.Insert.from_select` now renders
|
||||
Python-side and SQL expression column defaults into the
|
||||
SELECT statement for columns otherwise not included in the
|
||||
list of column names.
|
||||
|
||||
.. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT
|
||||
implies that the :paramref:`.insert.inline` flag is set to
|
||||
True, indicating that the statement will not attempt to fetch
|
||||
the "last inserted primary key" or other defaults. The statement
|
||||
deals with an arbitrary number of rows, so the
|
||||
:attr:`.ResultProxy.inserted_primary_key` accessor does not apply.
|
||||
|
||||
.. versionadded:: 0.8.3
|
||||
|
||||
"""
|
||||
if self.parameters:
|
||||
raise exc.InvalidRequestError(
|
||||
"This construct already inserts value expressions")
|
||||
|
||||
self.parameters, self._has_multi_parameters = \
|
||||
self._process_colparams(
|
||||
dict((_column_as_key(n), Null()) for n in names))
|
||||
|
||||
self.select_names = names
|
||||
self.inline = True
|
||||
self.include_insert_from_select_defaults = include_defaults
|
||||
self.select = _interpret_as_select(select)
|
||||
|
||||
def _copy_internals(self, clone=_clone, **kw):
|
||||
# TODO: coverage
|
||||
self.parameters = self.parameters.copy()
|
||||
if self.select is not None:
|
||||
self.select = _clone(self.select)
|
||||
|
||||
|
||||
class Update(ValuesBase):
|
||||
"""Represent an Update construct.
|
||||
|
||||
The :class:`.Update` object is created using the :func:`update()`
|
||||
function.
|
||||
|
||||
"""
|
||||
__visit_name__ = 'update'
|
||||
|
||||
def __init__(self,
|
||||
table,
|
||||
whereclause=None,
|
||||
values=None,
|
||||
inline=False,
|
||||
bind=None,
|
||||
prefixes=None,
|
||||
returning=None,
|
||||
return_defaults=False,
|
||||
preserve_parameter_order=False,
|
||||
**dialect_kw):
|
||||
r"""Construct an :class:`.Update` object.
|
||||
|
||||
E.g.::
|
||||
|
||||
from sqlalchemy import update
|
||||
|
||||
stmt = update(users).where(users.c.id==5).\
|
||||
values(name='user #5')
|
||||
|
||||
Similar functionality is available via the
|
||||
:meth:`~.TableClause.update` method on
|
||||
:class:`.Table`::
|
||||
|
||||
stmt = users.update().\
|
||||
where(users.c.id==5).\
|
||||
values(name='user #5')
|
||||
|
||||
:param table: A :class:`.Table` object representing the database
|
||||
table to be updated.
|
||||
|
||||
:param whereclause: Optional SQL expression describing the ``WHERE``
|
||||
condition of the ``UPDATE`` statement. Modern applications
|
||||
may prefer to use the generative :meth:`~Update.where()`
|
||||
method to specify the ``WHERE`` clause.
|
||||
|
||||
The WHERE clause can refer to multiple tables.
|
||||
For databases which support this, an ``UPDATE FROM`` clause will
|
||||
be generated, or on MySQL, a multi-table update. The statement
|
||||
will fail on databases that don't have support for multi-table
|
||||
update statements. A SQL-standard method of referring to
|
||||
additional tables in the WHERE clause is to use a correlated
|
||||
subquery::
|
||||
|
||||
users.update().values(name='ed').where(
|
||||
users.c.name==select([addresses.c.email_address]).\
|
||||
where(addresses.c.user_id==users.c.id).\
|
||||
as_scalar()
|
||||
)
|
||||
|
||||
.. versionchanged:: 0.7.4
|
||||
The WHERE clause can refer to multiple tables.
|
||||
|
||||
:param values:
|
||||
Optional dictionary which specifies the ``SET`` conditions of the
|
||||
``UPDATE``. If left as ``None``, the ``SET``
|
||||
conditions are determined from those parameters passed to the
|
||||
statement during the execution and/or compilation of the
|
||||
statement. When compiled standalone without any parameters,
|
||||
the ``SET`` clause generates for all columns.
|
||||
|
||||
Modern applications may prefer to use the generative
|
||||
:meth:`.Update.values` method to set the values of the
|
||||
UPDATE statement.
|
||||
|
||||
:param inline:
|
||||
if True, SQL defaults present on :class:`.Column` objects via
|
||||
the ``default`` keyword will be compiled 'inline' into the statement
|
||||
and not pre-executed. This means that their values will not
|
||||
be available in the dictionary returned from
|
||||
:meth:`.ResultProxy.last_updated_params`.
|
||||
|
||||
:param preserve_parameter_order: if True, the update statement is
|
||||
expected to receive parameters **only** via the :meth:`.Update.values`
|
||||
method, and they must be passed as a Python ``list`` of 2-tuples.
|
||||
The rendered UPDATE statement will emit the SET clause for each
|
||||
referenced column maintaining this order.
|
||||
|
||||
.. versionadded:: 1.0.10
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`updates_order_parameters` - full example of the
|
||||
:paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order` flag
|
||||
|
||||
If both ``values`` and compile-time bind parameters are present, the
|
||||
compile-time bind parameters override the information specified
|
||||
within ``values`` on a per-key basis.
|
||||
|
||||
The keys within ``values`` can be either :class:`.Column`
|
||||
objects or their string identifiers (specifically the "key" of the
|
||||
:class:`.Column`, normally but not necessarily equivalent to
|
||||
its "name"). Normally, the
|
||||
:class:`.Column` objects used here are expected to be
|
||||
part of the target :class:`.Table` that is the table
|
||||
to be updated. However when using MySQL, a multiple-table
|
||||
UPDATE statement can refer to columns from any of
|
||||
the tables referred to in the WHERE clause.
|
||||
|
||||
The values referred to in ``values`` are typically:
|
||||
|
||||
* a literal data value (i.e. string, number, etc.)
|
||||
* a SQL expression, such as a related :class:`.Column`,
|
||||
a scalar-returning :func:`.select` construct,
|
||||
etc.
|
||||
|
||||
When combining :func:`.select` constructs within the values
|
||||
clause of an :func:`.update` construct,
|
||||
the subquery represented by the :func:`.select` should be
|
||||
*correlated* to the parent table, that is, providing criterion
|
||||
which links the table inside the subquery to the outer table
|
||||
being updated::
|
||||
|
||||
users.update().values(
|
||||
name=select([addresses.c.email_address]).\
|
||||
where(addresses.c.user_id==users.c.id).\
|
||||
as_scalar()
|
||||
)
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`inserts_and_updates` - SQL Expression
|
||||
Language Tutorial
|
||||
|
||||
|
||||
"""
|
||||
self._preserve_parameter_order = preserve_parameter_order
|
||||
ValuesBase.__init__(self, table, values, prefixes)
|
||||
self._bind = bind
|
||||
self._returning = returning
|
||||
if whereclause is not None:
|
||||
self._whereclause = _literal_as_text(whereclause)
|
||||
else:
|
||||
self._whereclause = None
|
||||
self.inline = inline
|
||||
self._validate_dialect_kwargs(dialect_kw)
|
||||
self._return_defaults = return_defaults
|
||||
|
||||
def get_children(self, **kwargs):
|
||||
if self._whereclause is not None:
|
||||
return self._whereclause,
|
||||
else:
|
||||
return ()
|
||||
|
||||
def _copy_internals(self, clone=_clone, **kw):
|
||||
# TODO: coverage
|
||||
self._whereclause = clone(self._whereclause, **kw)
|
||||
self.parameters = self.parameters.copy()
|
||||
|
||||
@_generative
|
||||
def where(self, whereclause):
|
||||
"""return a new update() construct with the given expression added to
|
||||
its WHERE clause, joined to the existing clause via AND, if any.
|
||||
|
||||
"""
|
||||
if self._whereclause is not None:
|
||||
self._whereclause = and_(self._whereclause,
|
||||
_literal_as_text(whereclause))
|
||||
else:
|
||||
self._whereclause = _literal_as_text(whereclause)
|
||||
|
||||
@property
|
||||
def _extra_froms(self):
|
||||
# TODO: this could be made memoized
|
||||
# if the memoization is reset on each generative call.
|
||||
froms = []
|
||||
seen = set([self.table])
|
||||
|
||||
if self._whereclause is not None:
|
||||
for item in _from_objects(self._whereclause):
|
||||
if not seen.intersection(item._cloned_set):
|
||||
froms.append(item)
|
||||
seen.update(item._cloned_set)
|
||||
|
||||
return froms
|
||||
|
||||
|
||||
class Delete(UpdateBase):
|
||||
"""Represent a DELETE construct.
|
||||
|
||||
The :class:`.Delete` object is created using the :func:`delete()`
|
||||
function.
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = 'delete'
|
||||
|
||||
def __init__(self,
|
||||
table,
|
||||
whereclause=None,
|
||||
bind=None,
|
||||
returning=None,
|
||||
prefixes=None,
|
||||
**dialect_kw):
|
||||
"""Construct :class:`.Delete` object.
|
||||
|
||||
Similar functionality is available via the
|
||||
:meth:`~.TableClause.delete` method on
|
||||
:class:`~.schema.Table`.
|
||||
|
||||
:param table: The table to delete rows from.
|
||||
|
||||
:param whereclause: A :class:`.ClauseElement` describing the ``WHERE``
|
||||
condition of the ``DELETE`` statement. Note that the
|
||||
:meth:`~Delete.where()` generative method may be used instead.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`deletes` - SQL Expression Tutorial
|
||||
|
||||
"""
|
||||
self._bind = bind
|
||||
self.table = _interpret_as_from(table)
|
||||
self._returning = returning
|
||||
|
||||
if prefixes:
|
||||
self._setup_prefixes(prefixes)
|
||||
|
||||
if whereclause is not None:
|
||||
self._whereclause = _literal_as_text(whereclause)
|
||||
else:
|
||||
self._whereclause = None
|
||||
|
||||
self._validate_dialect_kwargs(dialect_kw)
|
||||
|
||||
def get_children(self, **kwargs):
|
||||
if self._whereclause is not None:
|
||||
return self._whereclause,
|
||||
else:
|
||||
return ()
|
||||
|
||||
@_generative
|
||||
def where(self, whereclause):
|
||||
"""Add the given WHERE clause to a newly returned delete construct."""
|
||||
|
||||
if self._whereclause is not None:
|
||||
self._whereclause = and_(self._whereclause,
|
||||
_literal_as_text(whereclause))
|
||||
else:
|
||||
self._whereclause = _literal_as_text(whereclause)
|
||||
|
||||
def _copy_internals(self, clone=_clone, **kw):
|
||||
# TODO: coverage
|
||||
self._whereclause = clone(self._whereclause, **kw)
|
||||
4403
sqlalchemy/sql/elements.py
Normal file
4403
sqlalchemy/sql/elements.py
Normal file
File diff suppressed because it is too large
Load Diff
146
sqlalchemy/sql/naming.py
Normal file
146
sqlalchemy/sql/naming.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# sqlalchemy/naming.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Establish constraint and index naming conventions.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from .schema import Constraint, ForeignKeyConstraint, PrimaryKeyConstraint, \
|
||||
UniqueConstraint, CheckConstraint, Index, Table, Column
|
||||
from .. import event, events
|
||||
from .. import exc
|
||||
from .elements import _truncated_label, _defer_name, _defer_none_name, conv
|
||||
import re
|
||||
|
||||
|
||||
class ConventionDict(object):
|
||||
|
||||
def __init__(self, const, table, convention):
|
||||
self.const = const
|
||||
self._is_fk = isinstance(const, ForeignKeyConstraint)
|
||||
self.table = table
|
||||
self.convention = convention
|
||||
self._const_name = const.name
|
||||
|
||||
def _key_table_name(self):
|
||||
return self.table.name
|
||||
|
||||
def _column_X(self, idx):
|
||||
if self._is_fk:
|
||||
fk = self.const.elements[idx]
|
||||
return fk.parent
|
||||
else:
|
||||
return list(self.const.columns)[idx]
|
||||
|
||||
def _key_constraint_name(self):
|
||||
if isinstance(self._const_name, (type(None), _defer_none_name)):
|
||||
raise exc.InvalidRequestError(
|
||||
"Naming convention including "
|
||||
"%(constraint_name)s token requires that "
|
||||
"constraint is explicitly named."
|
||||
)
|
||||
if not isinstance(self._const_name, conv):
|
||||
self.const.name = None
|
||||
return self._const_name
|
||||
|
||||
def _key_column_X_name(self, idx):
|
||||
return self._column_X(idx).name
|
||||
|
||||
def _key_column_X_label(self, idx):
|
||||
return self._column_X(idx)._label
|
||||
|
||||
def _key_referred_table_name(self):
|
||||
fk = self.const.elements[0]
|
||||
refs = fk.target_fullname.split(".")
|
||||
if len(refs) == 3:
|
||||
refschema, reftable, refcol = refs
|
||||
else:
|
||||
reftable, refcol = refs
|
||||
return reftable
|
||||
|
||||
def _key_referred_column_X_name(self, idx):
|
||||
fk = self.const.elements[idx]
|
||||
refs = fk.target_fullname.split(".")
|
||||
if len(refs) == 3:
|
||||
refschema, reftable, refcol = refs
|
||||
else:
|
||||
reftable, refcol = refs
|
||||
return refcol
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.convention:
|
||||
return self.convention[key](self.const, self.table)
|
||||
elif hasattr(self, '_key_%s' % key):
|
||||
return getattr(self, '_key_%s' % key)()
|
||||
else:
|
||||
col_template = re.match(r".*_?column_(\d+)_.+", key)
|
||||
if col_template:
|
||||
idx = col_template.group(1)
|
||||
attr = "_key_" + key.replace(idx, "X")
|
||||
idx = int(idx)
|
||||
if hasattr(self, attr):
|
||||
return getattr(self, attr)(idx)
|
||||
raise KeyError(key)
|
||||
|
||||
_prefix_dict = {
|
||||
Index: "ix",
|
||||
PrimaryKeyConstraint: "pk",
|
||||
CheckConstraint: "ck",
|
||||
UniqueConstraint: "uq",
|
||||
ForeignKeyConstraint: "fk"
|
||||
}
|
||||
|
||||
|
||||
def _get_convention(dict_, key):
|
||||
|
||||
for super_ in key.__mro__:
|
||||
if super_ in _prefix_dict and _prefix_dict[super_] in dict_:
|
||||
return dict_[_prefix_dict[super_]]
|
||||
elif super_ in dict_:
|
||||
return dict_[super_]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _constraint_name_for_table(const, table):
|
||||
metadata = table.metadata
|
||||
convention = _get_convention(metadata.naming_convention, type(const))
|
||||
|
||||
if isinstance(const.name, conv):
|
||||
return const.name
|
||||
elif convention is not None and \
|
||||
not isinstance(const.name, conv) and \
|
||||
(
|
||||
const.name is None or
|
||||
"constraint_name" in convention or
|
||||
isinstance(const.name, _defer_name)):
|
||||
return conv(
|
||||
convention % ConventionDict(const, table,
|
||||
metadata.naming_convention)
|
||||
)
|
||||
elif isinstance(convention, _defer_none_name):
|
||||
return None
|
||||
|
||||
|
||||
@event.listens_for(Constraint, "after_parent_attach")
|
||||
@event.listens_for(Index, "after_parent_attach")
|
||||
def _constraint_name(const, table):
|
||||
if isinstance(table, Column):
|
||||
# for column-attached constraint, set another event
|
||||
# to link the column attached to the table as this constraint
|
||||
# associated with the table.
|
||||
event.listen(table, "after_parent_attach",
|
||||
lambda col, table: _constraint_name(const, table)
|
||||
)
|
||||
elif isinstance(table, Table):
|
||||
if isinstance(const.name, (conv, _defer_name)):
|
||||
return
|
||||
|
||||
newname = _constraint_name_for_table(const, table)
|
||||
if newname is not None:
|
||||
const.name = newname
|
||||
4027
sqlalchemy/sql/schema.py
Normal file
4027
sqlalchemy/sql/schema.py
Normal file
File diff suppressed because it is too large
Load Diff
3716
sqlalchemy/sql/selectable.py
Normal file
3716
sqlalchemy/sql/selectable.py
Normal file
File diff suppressed because it is too large
Load Diff
2619
sqlalchemy/sql/sqltypes.py
Normal file
2619
sqlalchemy/sql/sqltypes.py
Normal file
File diff suppressed because it is too large
Load Diff
1307
sqlalchemy/sql/type_api.py
Normal file
1307
sqlalchemy/sql/type_api.py
Normal file
File diff suppressed because it is too large
Load Diff
36
sqlalchemy/testing/__init__.py
Normal file
36
sqlalchemy/testing/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# testing/__init__.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
from .warnings import assert_warnings
|
||||
|
||||
from . import config
|
||||
|
||||
from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
|
||||
fails_on, fails_on_everything_except, skip, only_on, exclude, \
|
||||
against as _against, _server_version, only_if, fails
|
||||
|
||||
|
||||
def against(*queries):
|
||||
return _against(config._current, *queries)
|
||||
|
||||
from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
|
||||
eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \
|
||||
assert_raises_message, AssertsCompiledSQL, ComparesTables, \
|
||||
AssertsExecutionResults, expect_deprecated, expect_warnings, \
|
||||
in_, not_in_, eq_ignore_whitespace, eq_regex, is_true, is_false
|
||||
|
||||
from .util import run_as_contextmanager, rowset, fail, \
|
||||
provide_metadata, adict, force_drop_names, \
|
||||
teardown_events
|
||||
|
||||
crashes = skip
|
||||
|
||||
from .config import db
|
||||
from .config import requirements as requires
|
||||
|
||||
from . import mock
|
||||
520
sqlalchemy/testing/assertions.py
Normal file
520
sqlalchemy/testing/assertions.py
Normal file
@@ -0,0 +1,520 @@
|
||||
# testing/assertions.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from . import util as testutil
|
||||
from sqlalchemy import pool, orm, util
|
||||
from sqlalchemy.engine import default, url
|
||||
from sqlalchemy.util import decorator, compat
|
||||
from sqlalchemy import types as sqltypes, schema, exc as sa_exc
|
||||
import warnings
|
||||
import re
|
||||
from .exclusions import db_spec
|
||||
from . import assertsql
|
||||
from . import config
|
||||
from .util import fail
|
||||
import contextlib
|
||||
from . import mock
|
||||
|
||||
|
||||
def expect_warnings(*messages, **kw):
|
||||
"""Context manager which expects one or more warnings.
|
||||
|
||||
With no arguments, squelches all SAWarnings emitted via
|
||||
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
|
||||
pass string expressions that will match selected warnings via regex;
|
||||
all non-matching warnings are sent through.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
Note that the test suite sets SAWarning warnings to raise exceptions.
|
||||
|
||||
"""
|
||||
return _expect_warnings(sa_exc.SAWarning, messages, **kw)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def expect_warnings_on(db, *messages, **kw):
|
||||
"""Context manager which expects one or more warnings on specific
|
||||
dialects.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
"""
|
||||
spec = db_spec(db)
|
||||
|
||||
if isinstance(db, util.string_types) and not spec(config._current):
|
||||
yield
|
||||
else:
|
||||
with expect_warnings(*messages, **kw):
|
||||
yield
|
||||
|
||||
|
||||
def emits_warning(*messages):
|
||||
"""Decorator form of expect_warnings().
|
||||
|
||||
Note that emits_warning does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_warnings(assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def expect_deprecated(*messages, **kw):
|
||||
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
|
||||
|
||||
|
||||
def emits_warning_on(db, *messages):
|
||||
"""Mark a test as emitting a warning on a specific dialect.
|
||||
|
||||
With no arguments, squelches all SAWarning failures. Or pass one or more
|
||||
strings; these will be matched to the root of the warning description by
|
||||
warnings.filterwarnings().
|
||||
|
||||
Note that emits_warning_on does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_warnings_on(db, assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def uses_deprecated(*messages):
|
||||
"""Mark a test as immune from fatal deprecation warnings.
|
||||
|
||||
With no arguments, squelches all SADeprecationWarning failures.
|
||||
Or pass one or more strings; these will be matched to the root
|
||||
of the warning description by warnings.filterwarnings().
|
||||
|
||||
As a special case, you may pass a function name prefixed with //
|
||||
and it will be re-written as needed to match the standard warning
|
||||
verbiage emitted by the sqlalchemy.util.deprecated decorator.
|
||||
|
||||
Note that uses_deprecated does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_deprecated(*messages, assert_=False):
|
||||
return fn(*args, **kw)
|
||||
return decorate
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
|
||||
py2konly=False):
|
||||
|
||||
if regex:
|
||||
filters = [re.compile(msg, re.I | re.S) for msg in messages]
|
||||
else:
|
||||
filters = messages
|
||||
|
||||
seen = set(filters)
|
||||
|
||||
real_warn = warnings.warn
|
||||
|
||||
def our_warn(msg, exception, *arg, **kw):
|
||||
if not issubclass(exception, exc_cls):
|
||||
return real_warn(msg, exception, *arg, **kw)
|
||||
|
||||
if not filters:
|
||||
return
|
||||
|
||||
for filter_ in filters:
|
||||
if (regex and filter_.match(msg)) or \
|
||||
(not regex and filter_ == msg):
|
||||
seen.discard(filter_)
|
||||
break
|
||||
else:
|
||||
real_warn(msg, exception, *arg, **kw)
|
||||
|
||||
with mock.patch("warnings.warn", our_warn):
|
||||
yield
|
||||
|
||||
if assert_ and (not py2konly or not compat.py3k):
|
||||
assert not seen, "Warnings were not seen: %s" % \
|
||||
", ".join("%r" % (s.pattern if regex else s) for s in seen)
|
||||
|
||||
|
||||
def global_cleanup_assertions():
|
||||
"""Check things that have to be finalized at the end of a test suite.
|
||||
|
||||
Hardcoded at the moment, a modular system can be built here
|
||||
to support things like PG prepared transactions, tables all
|
||||
dropped, etc.
|
||||
|
||||
"""
|
||||
_assert_no_stray_pool_connections()
|
||||
|
||||
_STRAY_CONNECTION_FAILURES = 0
|
||||
|
||||
|
||||
def _assert_no_stray_pool_connections():
|
||||
global _STRAY_CONNECTION_FAILURES
|
||||
|
||||
# lazy gc on cPython means "do nothing." pool connections
|
||||
# shouldn't be in cycles, should go away.
|
||||
testutil.lazy_gc()
|
||||
|
||||
# however, once in awhile, on an EC2 machine usually,
|
||||
# there's a ref in there. usually just one.
|
||||
if pool._refs:
|
||||
|
||||
# OK, let's be somewhat forgiving.
|
||||
_STRAY_CONNECTION_FAILURES += 1
|
||||
|
||||
print("Encountered a stray connection in test cleanup: %s"
|
||||
% str(pool._refs))
|
||||
# then do a real GC sweep. We shouldn't even be here
|
||||
# so a single sweep should really be doing it, otherwise
|
||||
# there's probably a real unreachable cycle somewhere.
|
||||
testutil.gc_collect()
|
||||
|
||||
# if we've already had two of these occurrences, or
|
||||
# after a hard gc sweep we still have pool._refs?!
|
||||
# now we have to raise.
|
||||
if pool._refs:
|
||||
err = str(pool._refs)
|
||||
|
||||
# but clean out the pool refs collection directly,
|
||||
# reset the counter,
|
||||
# so the error doesn't at least keep happening.
|
||||
pool._refs.clear()
|
||||
_STRAY_CONNECTION_FAILURES = 0
|
||||
warnings.warn(
|
||||
"Stray connection refused to leave "
|
||||
"after gc.collect(): %s" % err)
|
||||
elif _STRAY_CONNECTION_FAILURES > 10:
|
||||
assert False, "Encountered more than 10 stray connections"
|
||||
_STRAY_CONNECTION_FAILURES = 0
|
||||
|
||||
|
||||
def eq_regex(a, b, msg=None):
|
||||
assert re.match(b, a), msg or "%r !~ %r" % (a, b)
|
||||
|
||||
|
||||
def eq_(a, b, msg=None):
|
||||
"""Assert a == b, with repr messaging on failure."""
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def ne_(a, b, msg=None):
|
||||
"""Assert a != b, with repr messaging on failure."""
|
||||
assert a != b, msg or "%r == %r" % (a, b)
|
||||
|
||||
|
||||
def le_(a, b, msg=None):
|
||||
"""Assert a <= b, with repr messaging on failure."""
|
||||
assert a <= b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def is_true(a, msg=None):
|
||||
is_(a, True, msg=msg)
|
||||
|
||||
|
||||
def is_false(a, msg=None):
|
||||
is_(a, False, msg=msg)
|
||||
|
||||
|
||||
def is_(a, b, msg=None):
|
||||
"""Assert a is b, with repr messaging on failure."""
|
||||
assert a is b, msg or "%r is not %r" % (a, b)
|
||||
|
||||
|
||||
def is_not_(a, b, msg=None):
|
||||
"""Assert a is not b, with repr messaging on failure."""
|
||||
assert a is not b, msg or "%r is %r" % (a, b)
|
||||
|
||||
|
||||
def in_(a, b, msg=None):
|
||||
"""Assert a in b, with repr messaging on failure."""
|
||||
assert a in b, msg or "%r not in %r" % (a, b)
|
||||
|
||||
|
||||
def not_in_(a, b, msg=None):
|
||||
"""Assert a in not b, with repr messaging on failure."""
|
||||
assert a not in b, msg or "%r is in %r" % (a, b)
|
||||
|
||||
|
||||
def startswith_(a, fragment, msg=None):
|
||||
"""Assert a.startswith(fragment), with repr messaging on failure."""
|
||||
assert a.startswith(fragment), msg or "%r does not start with %r" % (
|
||||
a, fragment)
|
||||
|
||||
|
||||
def eq_ignore_whitespace(a, b, msg=None):
|
||||
a = re.sub(r'^\s+?|\n', "", a)
|
||||
a = re.sub(r' {2,}', " ", a)
|
||||
b = re.sub(r'^\s+?|\n', "", b)
|
||||
b = re.sub(r' {2,}', " ", b)
|
||||
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def assert_raises(except_cls, callable_, *args, **kw):
|
||||
try:
|
||||
callable_(*args, **kw)
|
||||
success = False
|
||||
except except_cls:
|
||||
success = True
|
||||
|
||||
# assert outside the block so it works for AssertionError too !
|
||||
assert success, "Callable did not raise an exception"
|
||||
|
||||
|
||||
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
|
||||
try:
|
||||
callable_(*args, **kwargs)
|
||||
assert False, "Callable did not raise an exception"
|
||||
except except_cls as e:
|
||||
assert re.search(
|
||||
msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
|
||||
print(util.text_type(e).encode('utf-8'))
|
||||
|
||||
|
||||
class AssertsCompiledSQL(object):
|
||||
def assert_compile(self, clause, result, params=None,
|
||||
checkparams=None, dialect=None,
|
||||
checkpositional=None,
|
||||
check_prefetch=None,
|
||||
use_default_dialect=False,
|
||||
allow_dialect_select=False,
|
||||
literal_binds=False,
|
||||
schema_translate_map=None):
|
||||
if use_default_dialect:
|
||||
dialect = default.DefaultDialect()
|
||||
elif allow_dialect_select:
|
||||
dialect = None
|
||||
else:
|
||||
if dialect is None:
|
||||
dialect = getattr(self, '__dialect__', None)
|
||||
|
||||
if dialect is None:
|
||||
dialect = config.db.dialect
|
||||
elif dialect == 'default':
|
||||
dialect = default.DefaultDialect()
|
||||
elif dialect == 'default_enhanced':
|
||||
dialect = default.StrCompileDialect()
|
||||
elif isinstance(dialect, util.string_types):
|
||||
dialect = url.URL(dialect).get_dialect()()
|
||||
|
||||
kw = {}
|
||||
compile_kwargs = {}
|
||||
|
||||
if schema_translate_map:
|
||||
kw['schema_translate_map'] = schema_translate_map
|
||||
|
||||
if params is not None:
|
||||
kw['column_keys'] = list(params)
|
||||
|
||||
if literal_binds:
|
||||
compile_kwargs['literal_binds'] = True
|
||||
|
||||
if isinstance(clause, orm.Query):
|
||||
context = clause._compile_context()
|
||||
context.statement.use_labels = True
|
||||
clause = context.statement
|
||||
|
||||
if compile_kwargs:
|
||||
kw['compile_kwargs'] = compile_kwargs
|
||||
|
||||
c = clause.compile(dialect=dialect, **kw)
|
||||
|
||||
param_str = repr(getattr(c, 'params', {}))
|
||||
|
||||
if util.py3k:
|
||||
param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
|
||||
print(
|
||||
("\nSQL String:\n" +
|
||||
util.text_type(c) +
|
||||
param_str).encode('utf-8'))
|
||||
else:
|
||||
print(
|
||||
"\nSQL String:\n" +
|
||||
util.text_type(c).encode('utf-8') +
|
||||
param_str)
|
||||
|
||||
cc = re.sub(r'[\n\t]', '', util.text_type(c))
|
||||
|
||||
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
|
||||
|
||||
if checkparams is not None:
|
||||
eq_(c.construct_params(params), checkparams)
|
||||
if checkpositional is not None:
|
||||
p = c.construct_params(params)
|
||||
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
|
||||
if check_prefetch is not None:
|
||||
eq_(c.prefetch, check_prefetch)
|
||||
|
||||
|
||||
class ComparesTables(object):
|
||||
|
||||
def assert_tables_equal(self, table, reflected_table, strict_types=False):
|
||||
assert len(table.c) == len(reflected_table.c)
|
||||
for c, reflected_c in zip(table.c, reflected_table.c):
|
||||
eq_(c.name, reflected_c.name)
|
||||
assert reflected_c is reflected_table.c[c.name]
|
||||
eq_(c.primary_key, reflected_c.primary_key)
|
||||
eq_(c.nullable, reflected_c.nullable)
|
||||
|
||||
if strict_types:
|
||||
msg = "Type '%s' doesn't correspond to type '%s'"
|
||||
assert isinstance(reflected_c.type, type(c.type)), \
|
||||
msg % (reflected_c.type, c.type)
|
||||
else:
|
||||
self.assert_types_base(reflected_c, c)
|
||||
|
||||
if isinstance(c.type, sqltypes.String):
|
||||
eq_(c.type.length, reflected_c.type.length)
|
||||
|
||||
eq_(
|
||||
set([f.column.name for f in c.foreign_keys]),
|
||||
set([f.column.name for f in reflected_c.foreign_keys])
|
||||
)
|
||||
if c.server_default:
|
||||
assert isinstance(reflected_c.server_default,
|
||||
schema.FetchedValue)
|
||||
|
||||
assert len(table.primary_key) == len(reflected_table.primary_key)
|
||||
for c in table.primary_key:
|
||||
assert reflected_table.primary_key.columns[c.name] is not None
|
||||
|
||||
def assert_types_base(self, c1, c2):
|
||||
assert c1.type._compare_type_affinity(c2.type),\
|
||||
"On column %r, type '%s' doesn't correspond to type '%s'" % \
|
||||
(c1.name, c1.type, c2.type)
|
||||
|
||||
|
||||
class AssertsExecutionResults(object):
|
||||
def assert_result(self, result, class_, *objects):
|
||||
result = list(result)
|
||||
print(repr(result))
|
||||
self.assert_list(result, class_, objects)
|
||||
|
||||
def assert_list(self, result, class_, list):
|
||||
self.assert_(len(result) == len(list),
|
||||
"result list is not the same size as test list, " +
|
||||
"for class " + class_.__name__)
|
||||
for i in range(0, len(list)):
|
||||
self.assert_row(class_, result[i], list[i])
|
||||
|
||||
def assert_row(self, class_, rowobj, desc):
|
||||
self.assert_(rowobj.__class__ is class_,
|
||||
"item class is not " + repr(class_))
|
||||
for key, value in desc.items():
|
||||
if isinstance(value, tuple):
|
||||
if isinstance(value[1], list):
|
||||
self.assert_list(getattr(rowobj, key), value[0], value[1])
|
||||
else:
|
||||
self.assert_row(value[0], getattr(rowobj, key), value[1])
|
||||
else:
|
||||
self.assert_(getattr(rowobj, key) == value,
|
||||
"attribute %s value %s does not match %s" % (
|
||||
key, getattr(rowobj, key), value))
|
||||
|
||||
def assert_unordered_result(self, result, cls, *expected):
|
||||
"""As assert_result, but the order of objects is not considered.
|
||||
|
||||
The algorithm is very expensive but not a big deal for the small
|
||||
numbers of rows that the test suite manipulates.
|
||||
"""
|
||||
|
||||
class immutabledict(dict):
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
found = util.IdentitySet(result)
|
||||
expected = set([immutabledict(e) for e in expected])
|
||||
|
||||
for wrong in util.itertools_filterfalse(lambda o:
|
||||
isinstance(o, cls), found):
|
||||
fail('Unexpected type "%s", expected "%s"' % (
|
||||
type(wrong).__name__, cls.__name__))
|
||||
|
||||
if len(found) != len(expected):
|
||||
fail('Unexpected object count "%s", expected "%s"' % (
|
||||
len(found), len(expected)))
|
||||
|
||||
NOVALUE = object()
|
||||
|
||||
def _compare_item(obj, spec):
|
||||
for key, value in spec.items():
|
||||
if isinstance(value, tuple):
|
||||
try:
|
||||
self.assert_unordered_result(
|
||||
getattr(obj, key), value[0], *value[1])
|
||||
except AssertionError:
|
||||
return False
|
||||
else:
|
||||
if getattr(obj, key, NOVALUE) != value:
|
||||
return False
|
||||
return True
|
||||
|
||||
for expected_item in expected:
|
||||
for found_item in found:
|
||||
if _compare_item(found_item, expected_item):
|
||||
found.remove(found_item)
|
||||
break
|
||||
else:
|
||||
fail(
|
||||
"Expected %s instance with attributes %s not found." % (
|
||||
cls.__name__, repr(expected_item)))
|
||||
return True
|
||||
|
||||
def sql_execution_asserter(self, db=None):
|
||||
if db is None:
|
||||
from . import db as db
|
||||
|
||||
return assertsql.assert_engine(db)
|
||||
|
||||
def assert_sql_execution(self, db, callable_, *rules):
|
||||
with self.sql_execution_asserter(db) as asserter:
|
||||
callable_()
|
||||
asserter.assert_(*rules)
|
||||
|
||||
def assert_sql(self, db, callable_, rules):
|
||||
|
||||
newrules = []
|
||||
for rule in rules:
|
||||
if isinstance(rule, dict):
|
||||
newrule = assertsql.AllOf(*[
|
||||
assertsql.CompiledSQL(k, v) for k, v in rule.items()
|
||||
])
|
||||
else:
|
||||
newrule = assertsql.CompiledSQL(*rule)
|
||||
newrules.append(newrule)
|
||||
|
||||
self.assert_sql_execution(db, callable_, *newrules)
|
||||
|
||||
def assert_sql_count(self, db, callable_, count):
|
||||
self.assert_sql_execution(
|
||||
db, callable_, assertsql.CountStatements(count))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_execution(self, *rules):
|
||||
assertsql.asserter.add_rules(rules)
|
||||
try:
|
||||
yield
|
||||
assertsql.asserter.statement_complete()
|
||||
finally:
|
||||
assertsql.asserter.clear_rules()
|
||||
|
||||
def assert_statement_count(self, count):
|
||||
return self.assert_execution(assertsql.CountStatements(count))
|
||||
377
sqlalchemy/testing/assertsql.py
Normal file
377
sqlalchemy/testing/assertsql.py
Normal file
@@ -0,0 +1,377 @@
|
||||
# testing/assertsql.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from ..engine.default import DefaultDialect
|
||||
from .. import util
|
||||
import re
|
||||
import collections
|
||||
import contextlib
|
||||
from .. import event
|
||||
from sqlalchemy.schema import _DDLCompiles
|
||||
from sqlalchemy.engine.util import _distill_params
|
||||
from sqlalchemy.engine import url
|
||||
|
||||
|
||||
class AssertRule(object):
|
||||
|
||||
is_consumed = False
|
||||
errormessage = None
|
||||
consume_statement = True
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
pass
|
||||
|
||||
def no_more_statements(self):
|
||||
assert False, 'All statements are complete, but pending '\
|
||||
'assertion rules remain'
|
||||
|
||||
|
||||
class SQLMatchRule(AssertRule):
|
||||
pass
|
||||
|
||||
|
||||
class CursorSQL(SQLMatchRule):
|
||||
consume_statement = False
|
||||
|
||||
def __init__(self, statement, params=None):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
stmt = execute_observed.statements[0]
|
||||
if self.statement != stmt.statement or (
|
||||
self.params is not None and self.params != stmt.parameters):
|
||||
self.errormessage = \
|
||||
"Testing for exact SQL %s parameters %s received %s %s" % (
|
||||
self.statement, self.params,
|
||||
stmt.statement, stmt.parameters
|
||||
)
|
||||
else:
|
||||
execute_observed.statements.pop(0)
|
||||
self.is_consumed = True
|
||||
if not execute_observed.statements:
|
||||
self.consume_statement = True
|
||||
|
||||
|
||||
class CompiledSQL(SQLMatchRule):
|
||||
|
||||
def __init__(self, statement, params=None, dialect='default'):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
self.dialect = dialect
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = re.sub(r'[\n\t]', '', self.statement)
|
||||
return received_statement == stmt
|
||||
|
||||
def _compile_dialect(self, execute_observed):
|
||||
if self.dialect == 'default':
|
||||
return DefaultDialect()
|
||||
else:
|
||||
# ugh
|
||||
if self.dialect == 'postgresql':
|
||||
params = {'implicit_returning': True}
|
||||
else:
|
||||
params = {}
|
||||
return url.URL(self.dialect).get_dialect()(**params)
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
"""reconstruct the statement and params in terms
|
||||
of a target dialect, which for CompiledSQL is just DefaultDialect."""
|
||||
|
||||
context = execute_observed.context
|
||||
compare_dialect = self._compile_dialect(execute_observed)
|
||||
if isinstance(context.compiled.statement, _DDLCompiles):
|
||||
compiled = \
|
||||
context.compiled.statement.compile(
|
||||
dialect=compare_dialect,
|
||||
schema_translate_map=context.
|
||||
execution_options.get('schema_translate_map'))
|
||||
else:
|
||||
compiled = (
|
||||
context.compiled.statement.compile(
|
||||
dialect=compare_dialect,
|
||||
column_keys=context.compiled.column_keys,
|
||||
inline=context.compiled.inline,
|
||||
schema_translate_map=context.
|
||||
execution_options.get('schema_translate_map'))
|
||||
)
|
||||
_received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
|
||||
parameters = execute_observed.parameters
|
||||
|
||||
if not parameters:
|
||||
_received_parameters = [compiled.construct_params()]
|
||||
else:
|
||||
_received_parameters = [
|
||||
compiled.construct_params(m) for m in parameters]
|
||||
|
||||
return _received_statement, _received_parameters
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
context = execute_observed.context
|
||||
|
||||
_received_statement, _received_parameters = \
|
||||
self._received_statement(execute_observed)
|
||||
params = self._all_params(context)
|
||||
|
||||
equivalent = self._compare_sql(execute_observed, _received_statement)
|
||||
|
||||
if equivalent:
|
||||
if params is not None:
|
||||
all_params = list(params)
|
||||
all_received = list(_received_parameters)
|
||||
while all_params and all_received:
|
||||
param = dict(all_params.pop(0))
|
||||
|
||||
for idx, received in enumerate(list(all_received)):
|
||||
# do a positive compare only
|
||||
for param_key in param:
|
||||
# a key in param did not match current
|
||||
# 'received'
|
||||
if param_key not in received or \
|
||||
received[param_key] != param[param_key]:
|
||||
break
|
||||
else:
|
||||
# all keys in param matched 'received';
|
||||
# onto next param
|
||||
del all_received[idx]
|
||||
break
|
||||
else:
|
||||
# param did not match any entry
|
||||
# in all_received
|
||||
equivalent = False
|
||||
break
|
||||
if all_params or all_received:
|
||||
equivalent = False
|
||||
|
||||
if equivalent:
|
||||
self.is_consumed = True
|
||||
self.errormessage = None
|
||||
else:
|
||||
self.errormessage = self._failure_message(params) % {
|
||||
'received_statement': _received_statement,
|
||||
'received_parameters': _received_parameters
|
||||
}
|
||||
|
||||
def _all_params(self, context):
|
||||
if self.params:
|
||||
if util.callable(self.params):
|
||||
params = self.params(context)
|
||||
else:
|
||||
params = self.params
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
return params
|
||||
else:
|
||||
return None
|
||||
|
||||
def _failure_message(self, expected_params):
|
||||
return (
|
||||
'Testing for compiled statement %r partial params %r, '
|
||||
'received %%(received_statement)r with params '
|
||||
'%%(received_parameters)r' % (
|
||||
self.statement.replace('%', '%%'), expected_params
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RegexSQL(CompiledSQL):
|
||||
def __init__(self, regex, params=None):
|
||||
SQLMatchRule.__init__(self)
|
||||
self.regex = re.compile(regex)
|
||||
self.orig_regex = regex
|
||||
self.params = params
|
||||
self.dialect = 'default'
|
||||
|
||||
def _failure_message(self, expected_params):
|
||||
return (
|
||||
'Testing for compiled statement ~%r partial params %r, '
|
||||
'received %%(received_statement)r with params '
|
||||
'%%(received_parameters)r' % (
|
||||
self.orig_regex, expected_params
|
||||
)
|
||||
)
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
return bool(self.regex.match(received_statement))
|
||||
|
||||
|
||||
class DialectSQL(CompiledSQL):
|
||||
def _compile_dialect(self, execute_observed):
|
||||
return execute_observed.context.dialect
|
||||
|
||||
def _compare_no_space(self, real_stmt, received_stmt):
|
||||
stmt = re.sub(r'[\n\t]', '', real_stmt)
|
||||
return received_stmt == stmt
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
received_stmt, received_params = super(DialectSQL, self).\
|
||||
_received_statement(execute_observed)
|
||||
|
||||
# TODO: why do we need this part?
|
||||
for real_stmt in execute_observed.statements:
|
||||
if self._compare_no_space(real_stmt.statement, received_stmt):
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Can't locate compiled statement %r in list of "
|
||||
"statements actually invoked" % received_stmt)
|
||||
|
||||
return received_stmt, execute_observed.context.compiled_parameters
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = re.sub(r'[\n\t]', '', self.statement)
|
||||
# convert our comparison statement to have the
|
||||
# paramstyle of the received
|
||||
paramstyle = execute_observed.context.dialect.paramstyle
|
||||
if paramstyle == 'pyformat':
|
||||
stmt = re.sub(
|
||||
r':([\w_]+)', r"%(\1)s", stmt)
|
||||
else:
|
||||
# positional params
|
||||
repl = None
|
||||
if paramstyle == 'qmark':
|
||||
repl = "?"
|
||||
elif paramstyle == 'format':
|
||||
repl = r"%s"
|
||||
elif paramstyle == 'numeric':
|
||||
repl = None
|
||||
stmt = re.sub(r':([\w_]+)', repl, stmt)
|
||||
|
||||
return received_statement == stmt
|
||||
|
||||
|
||||
class CountStatements(AssertRule):
|
||||
|
||||
def __init__(self, count):
|
||||
self.count = count
|
||||
self._statement_count = 0
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
self._statement_count += 1
|
||||
|
||||
def no_more_statements(self):
|
||||
if self.count != self._statement_count:
|
||||
assert False, 'desired statement count %d does not match %d' \
|
||||
% (self.count, self._statement_count)
|
||||
|
||||
|
||||
class AllOf(AssertRule):
|
||||
|
||||
def __init__(self, *rules):
|
||||
self.rules = set(rules)
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in list(self.rules):
|
||||
rule.errormessage = None
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.rules.discard(rule)
|
||||
if not self.rules:
|
||||
self.is_consumed = True
|
||||
break
|
||||
elif not rule.errormessage:
|
||||
# rule is not done yet
|
||||
self.errormessage = None
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
|
||||
class Or(AllOf):
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in self.rules:
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.is_consumed = True
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
|
||||
class SQLExecuteObserved(object):
|
||||
def __init__(self, context, clauseelement, multiparams, params):
|
||||
self.context = context
|
||||
self.clauseelement = clauseelement
|
||||
self.parameters = _distill_params(multiparams, params)
|
||||
self.statements = []
|
||||
|
||||
|
||||
class SQLCursorExecuteObserved(
|
||||
collections.namedtuple(
|
||||
"SQLCursorExecuteObserved",
|
||||
["statement", "parameters", "context", "executemany"])
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAsserter(object):
|
||||
def __init__(self):
|
||||
self.accumulated = []
|
||||
|
||||
def _close(self):
|
||||
self._final = self.accumulated
|
||||
del self.accumulated
|
||||
|
||||
def assert_(self, *rules):
|
||||
rules = list(rules)
|
||||
observed = list(self._final)
|
||||
|
||||
while observed and rules:
|
||||
rule = rules[0]
|
||||
rule.process_statement(observed[0])
|
||||
if rule.is_consumed:
|
||||
rules.pop(0)
|
||||
elif rule.errormessage:
|
||||
assert False, rule.errormessage
|
||||
|
||||
if rule.consume_statement:
|
||||
observed.pop(0)
|
||||
|
||||
if not observed and rules:
|
||||
rules[0].no_more_statements()
|
||||
elif not rules and observed:
|
||||
assert False, "Additional SQL statements remain"
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_engine(engine):
|
||||
asserter = SQLAsserter()
|
||||
|
||||
orig = []
|
||||
|
||||
@event.listens_for(engine, "before_execute")
|
||||
def connection_execute(conn, clauseelement, multiparams, params):
|
||||
# grab the original statement + params before any cursor
|
||||
# execution
|
||||
orig[:] = clauseelement, multiparams, params
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def cursor_execute(conn, cursor, statement, parameters,
|
||||
context, executemany):
|
||||
if not context:
|
||||
return
|
||||
# then grab real cursor statements and associate them all
|
||||
# around a single context
|
||||
if asserter.accumulated and \
|
||||
asserter.accumulated[-1].context is context:
|
||||
obs = asserter.accumulated[-1]
|
||||
else:
|
||||
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
|
||||
asserter.accumulated.append(obs)
|
||||
obs.statements.append(
|
||||
SQLCursorExecuteObserved(
|
||||
statement, parameters, context, executemany)
|
||||
)
|
||||
|
||||
try:
|
||||
yield asserter
|
||||
finally:
|
||||
event.remove(engine, "after_cursor_execute", cursor_execute)
|
||||
event.remove(engine, "before_execute", connection_execute)
|
||||
asserter._close()
|
||||
97
sqlalchemy/testing/config.py
Normal file
97
sqlalchemy/testing/config.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# testing/config.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import collections
|
||||
|
||||
requirements = None
|
||||
db = None
|
||||
db_url = None
|
||||
db_opts = None
|
||||
file_config = None
|
||||
test_schema = None
|
||||
test_schema_2 = None
|
||||
_current = None
|
||||
|
||||
try:
|
||||
from unittest import SkipTest as _skip_test_exception
|
||||
except ImportError:
|
||||
_skip_test_exception = None
|
||||
|
||||
|
||||
class Config(object):
|
||||
def __init__(self, db, db_opts, options, file_config):
|
||||
self.db = db
|
||||
self.db_opts = db_opts
|
||||
self.options = options
|
||||
self.file_config = file_config
|
||||
self.test_schema = "test_schema"
|
||||
self.test_schema_2 = "test_schema_2"
|
||||
|
||||
_stack = collections.deque()
|
||||
_configs = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, db, db_opts, options, file_config):
|
||||
"""add a config as one of the global configs.
|
||||
|
||||
If there are no configs set up yet, this config also
|
||||
gets set as the "_current".
|
||||
"""
|
||||
cfg = Config(db, db_opts, options, file_config)
|
||||
|
||||
cls._configs[cfg.db.name] = cfg
|
||||
cls._configs[(cfg.db.name, cfg.db.dialect)] = cfg
|
||||
cls._configs[cfg.db] = cfg
|
||||
return cfg
|
||||
|
||||
@classmethod
|
||||
def set_as_current(cls, config, namespace):
|
||||
global db, _current, db_url, test_schema, test_schema_2, db_opts
|
||||
_current = config
|
||||
db_url = config.db.url
|
||||
db_opts = config.db_opts
|
||||
test_schema = config.test_schema
|
||||
test_schema_2 = config.test_schema_2
|
||||
namespace.db = db = config.db
|
||||
|
||||
@classmethod
|
||||
def push_engine(cls, db, namespace):
|
||||
assert _current, "Can't push without a default Config set up"
|
||||
cls.push(
|
||||
Config(
|
||||
db, _current.db_opts, _current.options, _current.file_config),
|
||||
namespace
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def push(cls, config, namespace):
|
||||
cls._stack.append(_current)
|
||||
cls.set_as_current(config, namespace)
|
||||
|
||||
@classmethod
|
||||
def reset(cls, namespace):
|
||||
if cls._stack:
|
||||
cls.set_as_current(cls._stack[0], namespace)
|
||||
cls._stack.clear()
|
||||
|
||||
@classmethod
|
||||
def all_configs(cls):
|
||||
for cfg in set(cls._configs.values()):
|
||||
yield cfg
|
||||
|
||||
@classmethod
|
||||
def all_dbs(cls):
|
||||
for cfg in cls.all_configs():
|
||||
yield cfg.db
|
||||
|
||||
def skip_test(self, msg):
|
||||
skip_test(msg)
|
||||
|
||||
|
||||
def skip_test(msg):
|
||||
raise _skip_test_exception(msg)
|
||||
|
||||
349
sqlalchemy/testing/engines.py
Normal file
349
sqlalchemy/testing/engines.py
Normal file
@@ -0,0 +1,349 @@
|
||||
# testing/engines.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import weakref
|
||||
from . import config
|
||||
from .util import decorator
|
||||
from .. import event, pool
|
||||
import re
|
||||
import warnings
|
||||
|
||||
|
||||
class ConnectionKiller(object):
|
||||
|
||||
def __init__(self):
|
||||
self.proxy_refs = weakref.WeakKeyDictionary()
|
||||
self.testing_engines = weakref.WeakKeyDictionary()
|
||||
self.conns = set()
|
||||
|
||||
def add_engine(self, engine):
|
||||
self.testing_engines[engine] = True
|
||||
|
||||
def connect(self, dbapi_conn, con_record):
|
||||
self.conns.add((dbapi_conn, con_record))
|
||||
|
||||
def checkout(self, dbapi_con, con_record, con_proxy):
|
||||
self.proxy_refs[con_proxy] = True
|
||||
|
||||
def invalidate(self, dbapi_con, con_record, exception):
|
||||
self.conns.discard((dbapi_con, con_record))
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"testing_reaper couldn't "
|
||||
"rollback/close connection: %s" % e)
|
||||
|
||||
def rollback_all(self):
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self._safe(rec.rollback)
|
||||
|
||||
def close_all(self):
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self._safe(rec._close)
|
||||
|
||||
def _after_test_ctx(self):
|
||||
# this can cause a deadlock with pg8000 - pg8000 acquires
|
||||
# prepared statement lock inside of rollback() - if async gc
|
||||
# is collecting in finalize_fairy, deadlock.
|
||||
# not sure if this should be if pypy/jython only.
|
||||
# note that firebird/fdb definitely needs this though
|
||||
for conn, rec in list(self.conns):
|
||||
self._safe(conn.rollback)
|
||||
|
||||
def _stop_test_ctx(self):
|
||||
if config.options.low_connections:
|
||||
self._stop_test_ctx_minimal()
|
||||
else:
|
||||
self._stop_test_ctx_aggressive()
|
||||
|
||||
def _stop_test_ctx_minimal(self):
|
||||
self.close_all()
|
||||
|
||||
self.conns = set()
|
||||
|
||||
for rec in list(self.testing_engines):
|
||||
if rec is not config.db:
|
||||
rec.dispose()
|
||||
|
||||
def _stop_test_ctx_aggressive(self):
|
||||
self.close_all()
|
||||
for conn, rec in list(self.conns):
|
||||
self._safe(conn.close)
|
||||
rec.connection = None
|
||||
|
||||
self.conns = set()
|
||||
for rec in list(self.testing_engines):
|
||||
rec.dispose()
|
||||
|
||||
def assert_all_closed(self):
|
||||
for rec in self.proxy_refs:
|
||||
if rec.is_valid:
|
||||
assert False
|
||||
|
||||
testing_reaper = ConnectionKiller()
|
||||
|
||||
|
||||
def drop_all_tables(metadata, bind):
|
||||
testing_reaper.close_all()
|
||||
if hasattr(bind, 'close'):
|
||||
bind.close()
|
||||
|
||||
if not config.db.dialect.supports_alter:
|
||||
from . import assertions
|
||||
with assertions.expect_warnings(
|
||||
"Can't sort tables", assert_=False):
|
||||
metadata.drop_all(bind)
|
||||
else:
|
||||
metadata.drop_all(bind)
|
||||
|
||||
|
||||
@decorator
|
||||
def assert_conns_closed(fn, *args, **kw):
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.assert_all_closed()
|
||||
|
||||
|
||||
@decorator
|
||||
def rollback_open_connections(fn, *args, **kw):
|
||||
"""Decorator that rolls back all open connections after fn execution."""
|
||||
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.rollback_all()
|
||||
|
||||
|
||||
@decorator
|
||||
def close_first(fn, *args, **kw):
|
||||
"""Decorator that closes all connections before fn execution."""
|
||||
|
||||
testing_reaper.close_all()
|
||||
fn(*args, **kw)
|
||||
|
||||
|
||||
@decorator
|
||||
def close_open_connections(fn, *args, **kw):
|
||||
"""Decorator that closes all connections after fn execution."""
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.close_all()
|
||||
|
||||
|
||||
def all_dialects(exclude=None):
|
||||
import sqlalchemy.databases as d
|
||||
for name in d.__all__:
|
||||
# TEMPORARY
|
||||
if exclude and name in exclude:
|
||||
continue
|
||||
mod = getattr(d, name, None)
|
||||
if not mod:
|
||||
mod = getattr(__import__(
|
||||
'sqlalchemy.databases.%s' % name).databases, name)
|
||||
yield mod.dialect()
|
||||
|
||||
|
||||
class ReconnectFixture(object):
|
||||
|
||||
def __init__(self, dbapi):
|
||||
self.dbapi = dbapi
|
||||
self.connections = []
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.dbapi, key)
|
||||
|
||||
def connect(self, *args, **kwargs):
|
||||
conn = self.dbapi.connect(*args, **kwargs)
|
||||
self.connections.append(conn)
|
||||
return conn
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"ReconnectFixture couldn't "
|
||||
"close connection: %s" % e)
|
||||
|
||||
def shutdown(self):
|
||||
# TODO: this doesn't cover all cases
|
||||
# as nicely as we'd like, namely MySQLdb.
|
||||
# would need to implement R. Brewer's
|
||||
# proxy server idea to get better
|
||||
# coverage.
|
||||
for c in list(self.connections):
|
||||
self._safe(c.close)
|
||||
self.connections = []
|
||||
|
||||
|
||||
def reconnecting_engine(url=None, options=None):
|
||||
url = url or config.db.url
|
||||
dbapi = config.db.dialect.dbapi
|
||||
if not options:
|
||||
options = {}
|
||||
options['module'] = ReconnectFixture(dbapi)
|
||||
engine = testing_engine(url, options)
|
||||
_dispose = engine.dispose
|
||||
|
||||
def dispose():
|
||||
engine.dialect.dbapi.shutdown()
|
||||
_dispose()
|
||||
|
||||
engine.test_shutdown = engine.dialect.dbapi.shutdown
|
||||
engine.dispose = dispose
|
||||
return engine
|
||||
|
||||
|
||||
def testing_engine(url=None, options=None):
|
||||
"""Produce an engine configured by --options with optional overrides."""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
if not options:
|
||||
use_reaper = True
|
||||
else:
|
||||
use_reaper = options.pop('use_reaper', True)
|
||||
|
||||
url = url or config.db.url
|
||||
|
||||
url = make_url(url)
|
||||
if options is None:
|
||||
if config.db is None or url.drivername == config.db.url.drivername:
|
||||
options = config.db_opts
|
||||
else:
|
||||
options = {}
|
||||
elif config.db is not None and url.drivername == config.db.url.drivername:
|
||||
default_opt = config.db_opts.copy()
|
||||
default_opt.update(options)
|
||||
|
||||
engine = create_engine(url, **options)
|
||||
engine._has_events = True # enable event blocks, helps with profiling
|
||||
|
||||
if isinstance(engine.pool, pool.QueuePool):
|
||||
engine.pool._timeout = 0
|
||||
engine.pool._max_overflow = 0
|
||||
if use_reaper:
|
||||
event.listen(engine.pool, 'connect', testing_reaper.connect)
|
||||
event.listen(engine.pool, 'checkout', testing_reaper.checkout)
|
||||
event.listen(engine.pool, 'invalidate', testing_reaper.invalidate)
|
||||
testing_reaper.add_engine(engine)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def mock_engine(dialect_name=None):
|
||||
"""Provides a mocking engine based on the current testing.db.
|
||||
|
||||
This is normally used to test DDL generation flow as emitted
|
||||
by an Engine.
|
||||
|
||||
It should not be used in other cases, as assert_compile() and
|
||||
assert_sql_execution() are much better choices with fewer
|
||||
moving parts.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
if not dialect_name:
|
||||
dialect_name = config.db.name
|
||||
|
||||
buffer = []
|
||||
|
||||
def executor(sql, *a, **kw):
|
||||
buffer.append(sql)
|
||||
|
||||
def assert_sql(stmts):
|
||||
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
|
||||
assert recv == stmts, recv
|
||||
|
||||
def print_sql():
|
||||
d = engine.dialect
|
||||
return "\n".join(
|
||||
str(s.compile(dialect=d))
|
||||
for s in engine.mock
|
||||
)
|
||||
|
||||
engine = create_engine(dialect_name + '://',
|
||||
strategy='mock', executor=executor)
|
||||
assert not hasattr(engine, 'mock')
|
||||
engine.mock = buffer
|
||||
engine.assert_sql = assert_sql
|
||||
engine.print_sql = print_sql
|
||||
return engine
|
||||
|
||||
|
||||
class DBAPIProxyCursor(object):
|
||||
"""Proxy a DBAPI cursor.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level cursor operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, conn, *args, **kwargs):
|
||||
self.engine = engine
|
||||
self.connection = conn
|
||||
self.cursor = conn.cursor(*args, **kwargs)
|
||||
|
||||
def execute(self, stmt, parameters=None, **kw):
|
||||
if parameters:
|
||||
return self.cursor.execute(stmt, parameters, **kw)
|
||||
else:
|
||||
return self.cursor.execute(stmt, **kw)
|
||||
|
||||
def executemany(self, stmt, params, **kw):
|
||||
return self.cursor.executemany(stmt, params, **kw)
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.cursor, key)
|
||||
|
||||
|
||||
class DBAPIProxyConnection(object):
|
||||
"""Proxy a DBAPI connection.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level connection operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, cursor_cls):
|
||||
self.conn = self._sqla_unwrap = engine.pool._creator()
|
||||
self.engine = engine
|
||||
self.cursor_cls = cursor_cls
|
||||
|
||||
def cursor(self, *args, **kwargs):
|
||||
return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.conn, key)
|
||||
|
||||
|
||||
def proxying_engine(conn_cls=DBAPIProxyConnection,
|
||||
cursor_cls=DBAPIProxyCursor):
|
||||
"""Produce an engine that provides proxy hooks for
|
||||
common methods.
|
||||
|
||||
"""
|
||||
def mock_conn():
|
||||
return conn_cls(config.db, cursor_cls)
|
||||
return testing_engine(options={'creator': mock_conn})
|
||||
|
||||
|
||||
101
sqlalchemy/testing/entities.py
Normal file
101
sqlalchemy/testing/entities.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# testing/entities.py
|
||||
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exc as sa_exc
|
||||
|
||||
_repr_stack = set()
|
||||
|
||||
|
||||
class BasicEntity(object):
|
||||
|
||||
def __init__(self, **kw):
|
||||
for key, value in kw.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def __repr__(self):
|
||||
if id(self) in _repr_stack:
|
||||
return object.__repr__(self)
|
||||
_repr_stack.add(id(self))
|
||||
try:
|
||||
return "%s(%s)" % (
|
||||
(self.__class__.__name__),
|
||||
', '.join(["%s=%r" % (key, getattr(self, key))
|
||||
for key in sorted(self.__dict__.keys())
|
||||
if not key.startswith('_')]))
|
||||
finally:
|
||||
_repr_stack.remove(id(self))
|
||||
|
||||
_recursion_stack = set()
|
||||
|
||||
|
||||
class ComparableEntity(BasicEntity):
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.__class__)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""'Deep, sparse compare.
|
||||
|
||||
Deeply compare two entities, following the non-None attributes of the
|
||||
non-persisted object, if possible.
|
||||
|
||||
"""
|
||||
if other is self:
|
||||
return True
|
||||
elif not self.__class__ == other.__class__:
|
||||
return False
|
||||
|
||||
if id(self) in _recursion_stack:
|
||||
return True
|
||||
_recursion_stack.add(id(self))
|
||||
|
||||
try:
|
||||
# pick the entity that's not SA persisted as the source
|
||||
try:
|
||||
self_key = sa.orm.attributes.instance_state(self).key
|
||||
except sa.orm.exc.NO_STATE:
|
||||
self_key = None
|
||||
|
||||
if other is None:
|
||||
a = self
|
||||
b = other
|
||||
elif self_key is not None:
|
||||
a = other
|
||||
b = self
|
||||
else:
|
||||
a = self
|
||||
b = other
|
||||
|
||||
for attr in list(a.__dict__):
|
||||
if attr.startswith('_'):
|
||||
continue
|
||||
value = getattr(a, attr)
|
||||
|
||||
try:
|
||||
# handle lazy loader errors
|
||||
battr = getattr(b, attr)
|
||||
except (AttributeError, sa_exc.UnboundExecutionError):
|
||||
return False
|
||||
|
||||
if hasattr(value, '__iter__'):
|
||||
if hasattr(value, '__getitem__') and not hasattr(
|
||||
value, 'keys'):
|
||||
if list(value) != list(battr):
|
||||
return False
|
||||
else:
|
||||
if set(value) != set(battr):
|
||||
return False
|
||||
else:
|
||||
if value is not None and value != battr:
|
||||
return False
|
||||
return True
|
||||
finally:
|
||||
_recursion_stack.remove(id(self))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user