This commit is contained in:
20
dibbler/lib/query_helpers.py
Normal file
20
dibbler/lib/query_helpers.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import TypeVar
|
||||||
|
from sqlalchemy import BindParameter, literal
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
def const(value: T) -> BindParameter[T]:
|
||||||
|
"""
|
||||||
|
Create a constant SQL literal bind parameter.
|
||||||
|
|
||||||
|
This is useful to avoid too many `?` bind parameters in SQL queries,
|
||||||
|
when the input value is known to be safe.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return literal(value, literal_execute=True)
|
||||||
|
|
||||||
|
CONST_ZERO: BindParameter[int] = const(0)
|
||||||
|
CONST_ONE: BindParameter[int] = const(1)
|
||||||
|
CONST_TRUE: BindParameter[bool] = const(True)
|
||||||
|
CONST_FALSE: BindParameter[bool] = const(False)
|
||||||
|
CONST_NONE: BindParameter[None] = const(None)
|
||||||
@@ -19,6 +19,17 @@ class TransactionType(StrEnum):
|
|||||||
THROW_PRODUCT = auto()
|
THROW_PRODUCT = auto()
|
||||||
TRANSFER = auto()
|
TRANSFER = auto()
|
||||||
|
|
||||||
|
def as_literal_column(self):
|
||||||
|
"""
|
||||||
|
Return the transaction type as a SQL literal column.
|
||||||
|
|
||||||
|
This is useful to avoid too many `?` bind parameters in SQL queries,
|
||||||
|
when the input value is known to be safe.
|
||||||
|
"""
|
||||||
|
from sqlalchemy import literal_column
|
||||||
|
|
||||||
|
return literal_column(f"'{self.value}'")
|
||||||
|
|
||||||
|
|
||||||
TransactionTypeSQL = SQLEnum(
|
TransactionTypeSQL = SQLEnum(
|
||||||
TransactionType,
|
TransactionType,
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from datetime import datetime
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
CTE,
|
CTE,
|
||||||
|
BindParameter,
|
||||||
and_,
|
and_,
|
||||||
asc,
|
bindparam,
|
||||||
case,
|
case,
|
||||||
func,
|
func,
|
||||||
literal,
|
literal,
|
||||||
@@ -12,6 +13,7 @@ from sqlalchemy import (
|
|||||||
)
|
)
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
|
||||||
from dibbler.models import (
|
from dibbler.models import (
|
||||||
Product,
|
Product,
|
||||||
Transaction,
|
Transaction,
|
||||||
@@ -22,9 +24,9 @@ from dibbler.queries.product_stock import _product_stock_query
|
|||||||
|
|
||||||
|
|
||||||
def _product_owners_query(
|
def _product_owners_query(
|
||||||
product_id: int,
|
product_id: BindParameter[int] | int,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
until: datetime | None = None,
|
until: BindParameter[datetime] | datetime | None = None,
|
||||||
cte_name: str = "rec_cte",
|
cte_name: str = "rec_cte",
|
||||||
) -> CTE:
|
) -> CTE:
|
||||||
"""
|
"""
|
||||||
@@ -34,6 +36,12 @@ def _product_owners_query(
|
|||||||
if use_cache:
|
if use_cache:
|
||||||
print("WARNING: Using cache for users owning product query is not implemented yet.")
|
print("WARNING: Using cache for users owning product query is not implemented yet.")
|
||||||
|
|
||||||
|
if isinstance(product_id, int):
|
||||||
|
product_id = bindparam("product_id", value=product_id)
|
||||||
|
|
||||||
|
if isinstance(until, datetime):
|
||||||
|
until = BindParameter("until", value=until)
|
||||||
|
|
||||||
product_stock = _product_stock_query(
|
product_stock = _product_stock_query(
|
||||||
product_id=product_id,
|
product_id=product_id,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -54,26 +62,26 @@ def _product_owners_query(
|
|||||||
.where(
|
.where(
|
||||||
Transaction.type_.in_(
|
Transaction.type_.in_(
|
||||||
[
|
[
|
||||||
TransactionType.ADD_PRODUCT,
|
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||||
# TransactionType.BUY_PRODUCT,
|
# TransactionType.BUY_PRODUCT,
|
||||||
TransactionType.ADJUST_STOCK,
|
TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||||
# TransactionType.JOINT,
|
# TransactionType.JOINT,
|
||||||
# TransactionType.THROW_PRODUCT,
|
# TransactionType.THROW_PRODUCT,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
Transaction.product_id == product_id,
|
Transaction.product_id == product_id,
|
||||||
literal(True) if until is None else Transaction.time <= until,
|
CONST_TRUE if until is None else Transaction.time <= until,
|
||||||
)
|
)
|
||||||
.order_by(Transaction.time.desc())
|
.order_by(Transaction.time.desc())
|
||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_element = select(
|
initial_element = select(
|
||||||
literal(0).label("i"),
|
CONST_ZERO.label("i"),
|
||||||
literal(0).label("time"),
|
CONST_ZERO.label("time"),
|
||||||
literal(None).label("transaction_id"),
|
CONST_NONE.label("transaction_id"),
|
||||||
literal(None).label("user_id"),
|
CONST_NONE.label("user_id"),
|
||||||
literal(0).label("product_count"),
|
CONST_ZERO.label("product_count"),
|
||||||
product_stock.scalar_subquery().label("products_left_to_account_for"),
|
product_stock.scalar_subquery().label("products_left_to_account_for"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,28 +96,31 @@ def _product_owners_query(
|
|||||||
case(
|
case(
|
||||||
# Someone adds the product -> they own it
|
# Someone adds the product -> they own it
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
|
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||||
trx_subset.c.user_id,
|
trx_subset.c.user_id,
|
||||||
),
|
),
|
||||||
else_=None,
|
else_=CONST_NONE,
|
||||||
).label("user_id"),
|
).label("user_id"),
|
||||||
# How many products did they add (if any)
|
# How many products did they add (if any)
|
||||||
case(
|
case(
|
||||||
# Someone adds the product -> they added a certain amount of products
|
# Someone adds the product -> they added a certain amount of products
|
||||||
(trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.product_count),
|
|
||||||
# Stock got adjusted upwards -> consider those products as added by nobody
|
|
||||||
(
|
(
|
||||||
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK)
|
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||||
& (trx_subset.c.product_count > 0),
|
|
||||||
trx_subset.c.product_count,
|
trx_subset.c.product_count,
|
||||||
),
|
),
|
||||||
else_=0,
|
# Stock got adjusted upwards -> consider those products as added by nobody
|
||||||
|
(
|
||||||
|
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
|
||||||
|
and (trx_subset.c.product_count > CONST_ZERO),
|
||||||
|
trx_subset.c.product_count,
|
||||||
|
),
|
||||||
|
else_=CONST_ZERO,
|
||||||
).label("product_count"),
|
).label("product_count"),
|
||||||
# How many products left to account for
|
# How many products left to account for
|
||||||
case(
|
case(
|
||||||
# Someone adds the product -> increase the number of products left to account for
|
# Someone adds the product -> increase the number of products left to account for
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
|
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||||
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
|
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
|
||||||
),
|
),
|
||||||
# Someone buys/joins/throws the product -> decrease the number of products left to account for
|
# Someone buys/joins/throws the product -> decrease the number of products left to account for
|
||||||
@@ -127,8 +138,8 @@ def _product_owners_query(
|
|||||||
# If adjusted upwards -> products owned by nobody, decrease products left to account for
|
# If adjusted upwards -> products owned by nobody, decrease products left to account for
|
||||||
# If adjusted downwards -> products taken away from owners, decrease products left to account for
|
# If adjusted downwards -> products taken away from owners, decrease products left to account for
|
||||||
(
|
(
|
||||||
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK)
|
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
|
||||||
and (trx_subset.c.product_count > 0),
|
and (trx_subset.c.product_count > CONST_ZERO),
|
||||||
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
|
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
|
||||||
),
|
),
|
||||||
# (
|
# (
|
||||||
@@ -142,8 +153,8 @@ def _product_owners_query(
|
|||||||
.select_from(trx_subset)
|
.select_from(trx_subset)
|
||||||
.where(
|
.where(
|
||||||
and_(
|
and_(
|
||||||
trx_subset.c.i == recursive_cte.c.i + 1,
|
trx_subset.c.i == recursive_cte.c.i + CONST_ONE,
|
||||||
recursive_cte.c.products_left_to_account_for > 0,
|
recursive_cte.c.products_left_to_account_for > CONST_ZERO,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ from dataclasses import dataclass
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
BindParameter,
|
||||||
ColumnElement,
|
ColumnElement,
|
||||||
Integer,
|
Integer,
|
||||||
SQLColumnExpression,
|
|
||||||
asc,
|
asc,
|
||||||
case,
|
case,
|
||||||
cast,
|
cast,
|
||||||
@@ -15,6 +15,7 @@ from sqlalchemy import (
|
|||||||
)
|
)
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
|
||||||
from dibbler.models import (
|
from dibbler.models import (
|
||||||
Product,
|
Product,
|
||||||
Transaction,
|
Transaction,
|
||||||
@@ -26,8 +27,8 @@ from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
|
|||||||
def _product_price_query(
|
def _product_price_query(
|
||||||
product_id: int | ColumnElement[int],
|
product_id: int | ColumnElement[int],
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
until: datetime | SQLColumnExpression[datetime] | None = None,
|
until: BindParameter[datetime] | datetime | None = None,
|
||||||
until_including: bool = True,
|
until_including: BindParameter[bool] | bool = True,
|
||||||
cte_name: str = "rec_cte",
|
cte_name: str = "rec_cte",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -37,12 +38,21 @@ def _product_price_query(
|
|||||||
if use_cache:
|
if use_cache:
|
||||||
print("WARNING: Using cache for product price query is not implemented yet.")
|
print("WARNING: Using cache for product price query is not implemented yet.")
|
||||||
|
|
||||||
|
if isinstance(product_id, int):
|
||||||
|
product_id = BindParameter("product_id", value=product_id)
|
||||||
|
|
||||||
|
if isinstance(until, datetime):
|
||||||
|
until = BindParameter("until", value=until)
|
||||||
|
|
||||||
|
if isinstance(until_including, bool):
|
||||||
|
until_including = BindParameter("until_including", value=until_including)
|
||||||
|
|
||||||
initial_element = select(
|
initial_element = select(
|
||||||
literal(0).label("i"),
|
CONST_ZERO.label("i"),
|
||||||
literal(0).label("time"),
|
CONST_ZERO.label("time"),
|
||||||
literal(None).label("transaction_id"),
|
CONST_NONE.label("transaction_id"),
|
||||||
literal(0).label("price"),
|
CONST_ZERO.label("price"),
|
||||||
literal(0).label("product_count"),
|
CONST_ZERO.label("product_count"),
|
||||||
)
|
)
|
||||||
|
|
||||||
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
|
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
|
||||||
@@ -60,19 +70,19 @@ def _product_price_query(
|
|||||||
.where(
|
.where(
|
||||||
Transaction.type_.in_(
|
Transaction.type_.in_(
|
||||||
[
|
[
|
||||||
TransactionType.BUY_PRODUCT,
|
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||||
TransactionType.ADD_PRODUCT,
|
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||||
TransactionType.ADJUST_STOCK,
|
TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||||
TransactionType.JOINT,
|
TransactionType.JOINT.as_literal_column(),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
Transaction.product_id == product_id,
|
Transaction.product_id == product_id,
|
||||||
case(
|
case(
|
||||||
(literal(until_including), Transaction.time <= until),
|
(until_including, Transaction.time <= until),
|
||||||
else_=Transaction.time < until,
|
else_=Transaction.time < until,
|
||||||
)
|
)
|
||||||
if until is not None
|
if until is not None
|
||||||
else literal(True),
|
else CONST_TRUE,
|
||||||
)
|
)
|
||||||
.order_by(Transaction.time.asc())
|
.order_by(Transaction.time.asc())
|
||||||
.alias("trx_subset")
|
.alias("trx_subset")
|
||||||
@@ -85,22 +95,26 @@ def _product_price_query(
|
|||||||
trx_subset.c.id.label("transaction_id"),
|
trx_subset.c.id.label("transaction_id"),
|
||||||
case(
|
case(
|
||||||
# Someone buys the product -> price remains the same.
|
# Someone buys the product -> price remains the same.
|
||||||
(trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price),
|
(
|
||||||
|
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||||
|
recursive_cte.c.price,
|
||||||
|
),
|
||||||
# Someone adds the product -> price is recalculated based on
|
# Someone adds the product -> price is recalculated based on
|
||||||
# product count, previous price, and new price.
|
# product count, previous price, and new price.
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
|
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||||
cast(
|
cast(
|
||||||
func.ceil(
|
func.ceil(
|
||||||
(
|
(
|
||||||
recursive_cte.c.price * func.max(recursive_cte.c.product_count, 0)
|
recursive_cte.c.price
|
||||||
|
* func.max(recursive_cte.c.product_count, CONST_ZERO)
|
||||||
+ trx_subset.c.per_product * trx_subset.c.product_count
|
+ trx_subset.c.per_product * trx_subset.c.product_count
|
||||||
)
|
)
|
||||||
/ (
|
/ (
|
||||||
# The running product count can be negative if the accounting is bad.
|
# The running product count can be negative if the accounting is bad.
|
||||||
# This ensures that we never end up with negative prices or zero divisions
|
# This ensures that we never end up with negative prices or zero divisions
|
||||||
# and other disastrous phenomena.
|
# and other disastrous phenomena.
|
||||||
func.max(recursive_cte.c.product_count, 0)
|
func.max(recursive_cte.c.product_count, CONST_ZERO)
|
||||||
+ trx_subset.c.product_count
|
+ trx_subset.c.product_count
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -108,28 +122,31 @@ def _product_price_query(
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
# Someone adjusts the stock -> price remains the same.
|
# Someone adjusts the stock -> price remains the same.
|
||||||
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK, recursive_cte.c.price),
|
(
|
||||||
|
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||||
|
recursive_cte.c.price,
|
||||||
|
),
|
||||||
# Should never happen
|
# Should never happen
|
||||||
else_=recursive_cte.c.price,
|
else_=recursive_cte.c.price,
|
||||||
).label("price"),
|
).label("price"),
|
||||||
case(
|
case(
|
||||||
# Someone buys the product -> product count is reduced.
|
# Someone buys the product -> product count is reduced.
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.BUY_PRODUCT,
|
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||||
recursive_cte.c.product_count - trx_subset.c.product_count,
|
recursive_cte.c.product_count - trx_subset.c.product_count,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.JOINT,
|
trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(),
|
||||||
recursive_cte.c.product_count - trx_subset.c.product_count,
|
recursive_cte.c.product_count - trx_subset.c.product_count,
|
||||||
),
|
),
|
||||||
# Someone adds the product -> product count is increased.
|
# Someone adds the product -> product count is increased.
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
|
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||||
recursive_cte.c.product_count + trx_subset.c.product_count,
|
recursive_cte.c.product_count + trx_subset.c.product_count,
|
||||||
),
|
),
|
||||||
# Someone adjusts the stock -> product count is adjusted.
|
# Someone adjusts the stock -> product count is adjusted.
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.ADJUST_STOCK,
|
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||||
recursive_cte.c.product_count + trx_subset.c.product_count,
|
recursive_cte.c.product_count + trx_subset.c.product_count,
|
||||||
),
|
),
|
||||||
# Should never happen
|
# Should never happen
|
||||||
@@ -137,7 +154,7 @@ def _product_price_query(
|
|||||||
).label("product_count"),
|
).label("product_count"),
|
||||||
)
|
)
|
||||||
.select_from(trx_subset)
|
.select_from(trx_subset)
|
||||||
.where(trx_subset.c.i == recursive_cte.c.i + 1)
|
.where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
|
||||||
)
|
)
|
||||||
|
|
||||||
return recursive_cte.union_all(recursive_elements)
|
return recursive_cte.union_all(recursive_elements)
|
||||||
@@ -222,7 +239,10 @@ def product_price(
|
|||||||
# - price should never be negative
|
# - price should never be negative
|
||||||
|
|
||||||
result = sql_session.scalars(
|
result = sql_session.scalars(
|
||||||
select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1)
|
select(recursive_cte.c.price)
|
||||||
|
.order_by(recursive_cte.c.i.desc())
|
||||||
|
.limit(CONST_ONE)
|
||||||
|
.offset(CONST_ZERO)
|
||||||
).one_or_none()
|
).one_or_none()
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
@@ -237,10 +257,10 @@ def product_price(
|
|||||||
select(Transaction.interest_rate_percent)
|
select(Transaction.interest_rate_percent)
|
||||||
.where(
|
.where(
|
||||||
Transaction.type_ == TransactionType.ADJUST_INTEREST,
|
Transaction.type_ == TransactionType.ADJUST_INTEREST,
|
||||||
literal(True) if until is None else Transaction.time <= until.time,
|
CONST_TRUE if until is None else Transaction.time <= until.time,
|
||||||
)
|
)
|
||||||
.order_by(Transaction.time.desc())
|
.order_by(Transaction.time.desc())
|
||||||
.limit(1)
|
.limit(CONST_ONE)
|
||||||
)
|
)
|
||||||
or DEFAULT_INTEREST_RATE_PERCENTAGE
|
or DEFAULT_INTEREST_RATE_PERCENTAGE
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
BindParameter,
|
||||||
Select,
|
Select,
|
||||||
case,
|
case,
|
||||||
func,
|
func,
|
||||||
literal,
|
|
||||||
select,
|
select,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from dibbler.lib.query_helpers import CONST_TRUE
|
||||||
from dibbler.models import (
|
from dibbler.models import (
|
||||||
Product,
|
Product,
|
||||||
Transaction,
|
Transaction,
|
||||||
@@ -17,9 +18,9 @@ from dibbler.models import (
|
|||||||
|
|
||||||
|
|
||||||
def _product_stock_query(
|
def _product_stock_query(
|
||||||
product_id: int,
|
product_id: BindParameter[int] | int,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
until: datetime | None = None,
|
until: BindParameter[datetime] | datetime | None = None,
|
||||||
) -> Select:
|
) -> Select:
|
||||||
"""
|
"""
|
||||||
The inner query for calculating the product stock.
|
The inner query for calculating the product stock.
|
||||||
@@ -28,27 +29,33 @@ def _product_stock_query(
|
|||||||
if use_cache:
|
if use_cache:
|
||||||
print("WARNING: Using cache for product stock query is not implemented yet.")
|
print("WARNING: Using cache for product stock query is not implemented yet.")
|
||||||
|
|
||||||
|
if isinstance(product_id, int):
|
||||||
|
product_id = BindParameter("product_id", value=product_id)
|
||||||
|
|
||||||
|
if isinstance(until, datetime):
|
||||||
|
until = BindParameter("until", value=until)
|
||||||
|
|
||||||
query = select(
|
query = select(
|
||||||
func.sum(
|
func.sum(
|
||||||
case(
|
case(
|
||||||
(
|
(
|
||||||
Transaction.type_ == TransactionType.ADD_PRODUCT,
|
Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||||
Transaction.product_count,
|
Transaction.product_count,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
Transaction.type_ == TransactionType.ADJUST_STOCK,
|
Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||||
Transaction.product_count,
|
Transaction.product_count,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
Transaction.type_ == TransactionType.BUY_PRODUCT,
|
Transaction.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||||
-Transaction.product_count,
|
-Transaction.product_count,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
Transaction.type_ == TransactionType.JOINT,
|
Transaction.type_ == TransactionType.JOINT.as_literal_column(),
|
||||||
-Transaction.product_count,
|
-Transaction.product_count,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
Transaction.type_ == TransactionType.THROW_PRODUCT,
|
Transaction.type_ == TransactionType.THROW_PRODUCT.as_literal_column(),
|
||||||
-Transaction.product_count,
|
-Transaction.product_count,
|
||||||
),
|
),
|
||||||
else_=0,
|
else_=0,
|
||||||
@@ -57,15 +64,15 @@ def _product_stock_query(
|
|||||||
).where(
|
).where(
|
||||||
Transaction.type_.in_(
|
Transaction.type_.in_(
|
||||||
[
|
[
|
||||||
TransactionType.ADD_PRODUCT,
|
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||||
TransactionType.ADJUST_STOCK,
|
TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||||
TransactionType.BUY_PRODUCT,
|
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||||
TransactionType.JOINT,
|
TransactionType.JOINT.as_literal_column(),
|
||||||
TransactionType.THROW_PRODUCT,
|
TransactionType.THROW_PRODUCT.as_literal_column(),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
Transaction.product_id == product_id,
|
Transaction.product_id == product_id,
|
||||||
Transaction.time <= until if until is not None else literal(True),
|
Transaction.time <= until if until is not None else CONST_TRUE,
|
||||||
)
|
)
|
||||||
|
|
||||||
return query
|
return query
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ from datetime import datetime
|
|||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
CTE,
|
CTE,
|
||||||
|
BindParameter,
|
||||||
Float,
|
Float,
|
||||||
Integer,
|
Integer,
|
||||||
and_,
|
and_,
|
||||||
asc,
|
|
||||||
case,
|
case,
|
||||||
cast,
|
cast,
|
||||||
column,
|
column,
|
||||||
@@ -17,6 +17,7 @@ from sqlalchemy import (
|
|||||||
)
|
)
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
|
||||||
from dibbler.models import (
|
from dibbler.models import (
|
||||||
Transaction,
|
Transaction,
|
||||||
TransactionType,
|
TransactionType,
|
||||||
@@ -31,10 +32,10 @@ from dibbler.queries.product_price import _product_price_query
|
|||||||
|
|
||||||
|
|
||||||
def _user_balance_query(
|
def _user_balance_query(
|
||||||
user_id: int,
|
user_id: BindParameter[int] | int,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
until: datetime | None = None,
|
until: BindParameter[datetime] | BindParameter[None] | datetime | None = None,
|
||||||
until_including: bool = True,
|
until_including: BindParameter[bool] | bool = True,
|
||||||
cte_name: str = "rec_cte",
|
cte_name: str = "rec_cte",
|
||||||
) -> CTE:
|
) -> CTE:
|
||||||
"""
|
"""
|
||||||
@@ -44,14 +45,23 @@ def _user_balance_query(
|
|||||||
if use_cache:
|
if use_cache:
|
||||||
print("WARNING: Using cache for user balance query is not implemented yet.")
|
print("WARNING: Using cache for user balance query is not implemented yet.")
|
||||||
|
|
||||||
|
if isinstance(user_id, int):
|
||||||
|
user_id = BindParameter("user_id", value=user_id)
|
||||||
|
|
||||||
|
if isinstance(until, datetime):
|
||||||
|
until = BindParameter("until", value=until, type_=datetime)
|
||||||
|
|
||||||
|
if isinstance(until_including, bool):
|
||||||
|
until_including = BindParameter("until_including", value=until_including, type_=bool)
|
||||||
|
|
||||||
initial_element = select(
|
initial_element = select(
|
||||||
literal(0).label("i"),
|
CONST_ZERO.label("i"),
|
||||||
literal(0).label("time"),
|
CONST_ZERO.label("time"),
|
||||||
literal(None).label("transaction_id"),
|
CONST_NONE.label("transaction_id"),
|
||||||
literal(0).label("balance"),
|
CONST_ZERO.label("balance"),
|
||||||
literal(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
|
const(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
|
||||||
literal(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
|
const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
|
||||||
literal(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"),
|
const(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"),
|
||||||
)
|
)
|
||||||
|
|
||||||
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
|
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
|
||||||
@@ -59,7 +69,7 @@ def _user_balance_query(
|
|||||||
# Subset of transactions that we'll want to iterate over.
|
# Subset of transactions that we'll want to iterate over.
|
||||||
trx_subset = (
|
trx_subset = (
|
||||||
select(
|
select(
|
||||||
func.row_number().over(order_by=asc(Transaction.time)).label("i"),
|
func.row_number().over(order_by=Transaction.time.asc()).label("i"),
|
||||||
Transaction.amount,
|
Transaction.amount,
|
||||||
Transaction.id,
|
Transaction.id,
|
||||||
Transaction.interest_rate_percent,
|
Transaction.interest_rate_percent,
|
||||||
@@ -77,34 +87,34 @@ def _user_balance_query(
|
|||||||
Transaction.user_id == user_id,
|
Transaction.user_id == user_id,
|
||||||
Transaction.type_.in_(
|
Transaction.type_.in_(
|
||||||
[
|
[
|
||||||
TransactionType.ADD_PRODUCT,
|
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||||
TransactionType.ADJUST_BALANCE,
|
TransactionType.ADJUST_BALANCE.as_literal_column(),
|
||||||
TransactionType.BUY_PRODUCT,
|
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||||
TransactionType.TRANSFER,
|
TransactionType.TRANSFER.as_literal_column(),
|
||||||
# TODO: join this with the JOINT transactions, and determine
|
# TODO: join this with the JOINT transactions, and determine
|
||||||
# how much the current user paid for the product.
|
# how much the current user paid for the product.
|
||||||
TransactionType.JOINT_BUY_PRODUCT,
|
TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
and_(
|
and_(
|
||||||
Transaction.type_ == TransactionType.TRANSFER,
|
Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
|
||||||
Transaction.transfer_user_id == user_id,
|
Transaction.transfer_user_id == user_id,
|
||||||
),
|
),
|
||||||
Transaction.type_.in_(
|
Transaction.type_.in_(
|
||||||
[
|
[
|
||||||
TransactionType.THROW_PRODUCT,
|
TransactionType.THROW_PRODUCT.as_literal_column(),
|
||||||
TransactionType.ADJUST_INTEREST,
|
TransactionType.ADJUST_INTEREST.as_literal_column(),
|
||||||
TransactionType.ADJUST_PENALTY,
|
TransactionType.ADJUST_PENALTY.as_literal_column(),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
case(
|
case(
|
||||||
(literal(until_including), Transaction.time <= until),
|
(until_including, Transaction.time <= until),
|
||||||
else_=Transaction.time < until,
|
else_=Transaction.time < until,
|
||||||
)
|
)
|
||||||
if until is not None
|
if until is not None
|
||||||
else literal(True),
|
else CONST_TRUE,
|
||||||
)
|
)
|
||||||
.order_by(Transaction.time.asc())
|
.order_by(Transaction.time.asc())
|
||||||
.alias("trx_subset")
|
.alias("trx_subset")
|
||||||
@@ -118,17 +128,17 @@ def _user_balance_query(
|
|||||||
case(
|
case(
|
||||||
# Adjusts balance -> balance gets adjusted
|
# Adjusts balance -> balance gets adjusted
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.ADJUST_BALANCE,
|
trx_subset.c.type_ == TransactionType.ADJUST_BALANCE.as_literal_column(),
|
||||||
recursive_cte.c.balance + trx_subset.c.amount,
|
recursive_cte.c.balance + trx_subset.c.amount,
|
||||||
),
|
),
|
||||||
# Adds a product -> balance increases
|
# Adds a product -> balance increases
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
|
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||||
recursive_cte.c.balance + trx_subset.c.amount,
|
recursive_cte.c.balance + trx_subset.c.amount,
|
||||||
),
|
),
|
||||||
# Buys a product -> balance decreases
|
# Buys a product -> balance decreases
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.BUY_PRODUCT,
|
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||||
recursive_cte.c.balance
|
recursive_cte.c.balance
|
||||||
- (
|
- (
|
||||||
trx_subset.c.product_count
|
trx_subset.c.product_count
|
||||||
@@ -151,12 +161,12 @@ def _user_balance_query(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
.order_by(column("i").desc())
|
.order_by(column("i").desc())
|
||||||
.limit(1)
|
.limit(CONST_ONE)
|
||||||
).scalar_subquery()
|
).scalar_subquery()
|
||||||
# TODO: should interest be applied before or after the penalty multiplier?
|
# TODO: should interest be applied before or after the penalty multiplier?
|
||||||
# at the moment of writing, after sound right, but maybe ask someone?
|
# at the moment of writing, after sound right, but maybe ask someone?
|
||||||
# Interest
|
# Interest
|
||||||
* (cast(recursive_cte.c.interest_rate_percent, Float) / 100)
|
* (cast(recursive_cte.c.interest_rate_percent, Float) / const(100))
|
||||||
# TODO: these should be added together, not multiplied, see specification
|
# TODO: these should be added together, not multiplied, see specification
|
||||||
# Penalty
|
# Penalty
|
||||||
* case(
|
* case(
|
||||||
@@ -164,10 +174,10 @@ def _user_balance_query(
|
|||||||
recursive_cte.c.balance < recursive_cte.c.penalty_threshold,
|
recursive_cte.c.balance < recursive_cte.c.penalty_threshold,
|
||||||
(
|
(
|
||||||
cast(recursive_cte.c.penalty_multiplier_percent, Float)
|
cast(recursive_cte.c.penalty_multiplier_percent, Float)
|
||||||
/ 100
|
/ const(100)
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
else_=1.0,
|
else_=const(1.0),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
Integer,
|
Integer,
|
||||||
@@ -177,7 +187,7 @@ def _user_balance_query(
|
|||||||
# Transfers money to self -> balance increases
|
# Transfers money to self -> balance increases
|
||||||
(
|
(
|
||||||
and_(
|
and_(
|
||||||
trx_subset.c.type_ == TransactionType.TRANSFER,
|
trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
|
||||||
trx_subset.c.transfer_user_id == user_id,
|
trx_subset.c.transfer_user_id == user_id,
|
||||||
),
|
),
|
||||||
recursive_cte.c.balance + trx_subset.c.amount,
|
recursive_cte.c.balance + trx_subset.c.amount,
|
||||||
@@ -185,7 +195,7 @@ def _user_balance_query(
|
|||||||
# Transfers money from self -> balance decreases
|
# Transfers money from self -> balance decreases
|
||||||
(
|
(
|
||||||
and_(
|
and_(
|
||||||
trx_subset.c.type_ == TransactionType.TRANSFER,
|
trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
|
||||||
trx_subset.c.transfer_user_id != user_id,
|
trx_subset.c.transfer_user_id != user_id,
|
||||||
),
|
),
|
||||||
recursive_cte.c.balance - trx_subset.c.amount,
|
recursive_cte.c.balance - trx_subset.c.amount,
|
||||||
@@ -202,28 +212,28 @@ def _user_balance_query(
|
|||||||
).label("balance"),
|
).label("balance"),
|
||||||
case(
|
case(
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.ADJUST_INTEREST,
|
trx_subset.c.type_ == TransactionType.ADJUST_INTEREST.as_literal_column(),
|
||||||
trx_subset.c.interest_rate_percent,
|
trx_subset.c.interest_rate_percent,
|
||||||
),
|
),
|
||||||
else_=recursive_cte.c.interest_rate_percent,
|
else_=recursive_cte.c.interest_rate_percent,
|
||||||
).label("interest_rate_percent"),
|
).label("interest_rate_percent"),
|
||||||
case(
|
case(
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY,
|
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
|
||||||
trx_subset.c.penalty_threshold,
|
trx_subset.c.penalty_threshold,
|
||||||
),
|
),
|
||||||
else_=recursive_cte.c.penalty_threshold,
|
else_=recursive_cte.c.penalty_threshold,
|
||||||
).label("penalty_threshold"),
|
).label("penalty_threshold"),
|
||||||
case(
|
case(
|
||||||
(
|
(
|
||||||
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY,
|
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
|
||||||
trx_subset.c.penalty_multiplier_percent,
|
trx_subset.c.penalty_multiplier_percent,
|
||||||
),
|
),
|
||||||
else_=recursive_cte.c.penalty_multiplier_percent,
|
else_=recursive_cte.c.penalty_multiplier_percent,
|
||||||
).label("penalty_multiplier_percent"),
|
).label("penalty_multiplier_percent"),
|
||||||
)
|
)
|
||||||
.select_from(trx_subset)
|
.select_from(trx_subset)
|
||||||
.where(trx_subset.c.i == recursive_cte.c.i + 1)
|
.where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
|
||||||
)
|
)
|
||||||
|
|
||||||
return recursive_cte.union_all(recursive_elements)
|
return recursive_cte.union_all(recursive_elements)
|
||||||
@@ -322,7 +332,10 @@ def user_balance(
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = sql_session.scalar(
|
result = sql_session.scalar(
|
||||||
select(recursive_cte.c.balance).order_by(recursive_cte.c.i.desc()).limit(1)
|
select(recursive_cte.c.balance)
|
||||||
|
.order_by(recursive_cte.c.i.desc())
|
||||||
|
.limit(CONST_ONE)
|
||||||
|
.offset(CONST_ZERO)
|
||||||
)
|
)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import logging
|
|||||||
import pytest
|
import pytest
|
||||||
import sqlparse
|
import sqlparse
|
||||||
from sqlalchemy import create_engine, event
|
from sqlalchemy import create_engine, event
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from dibbler.models import Base
|
from dibbler.models import Base
|
||||||
@@ -35,21 +36,27 @@ class SqlParseFormatter(logging.Formatter):
|
|||||||
return super().format(record)
|
return super().format(record)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
def pytest_configure(config):
|
||||||
def sql_session(request):
|
"""Setup pretty SQL logging if --echo is enabled."""
|
||||||
"""Create a new SQLAlchemy session for testing."""
|
|
||||||
|
|
||||||
logging.basicConfig()
|
|
||||||
logger = logging.getLogger("sqlalchemy.engine")
|
logger = logging.getLogger("sqlalchemy.engine")
|
||||||
|
|
||||||
|
# TODO: it would be nice not to duplicate these logs.
|
||||||
|
# logging.NullHandler() does not seem to work here
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
handler.setFormatter(SqlParseFormatter())
|
handler.setFormatter(SqlParseFormatter())
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
echo = request.config.getoption("--echo")
|
echo = config.getoption("--echo")
|
||||||
engine = create_engine(
|
if echo:
|
||||||
"sqlite:///:memory:",
|
logger.setLevel(logging.INFO)
|
||||||
echo=echo,
|
|
||||||
)
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def sql_session(request):
|
||||||
|
"""Create a new SQLAlchemy session for testing."""
|
||||||
|
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
|
||||||
@event.listens_for(engine, "connect")
|
@event.listens_for(engine, "connect")
|
||||||
def set_sqlite_pragma(dbapi_connection, _connection_record):
|
def set_sqlite_pragma(dbapi_connection, _connection_record):
|
||||||
@@ -60,3 +67,17 @@ def sql_session(request):
|
|||||||
Base.metadata.create_all(engine)
|
Base.metadata.create_all(engine)
|
||||||
with Session(engine) as sql_session:
|
with Session(engine) as sql_session:
|
||||||
yield sql_session
|
yield sql_session
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME: Declaring this hook seems to have a side effect where the database does not
|
||||||
|
# get reset between tests.
|
||||||
|
# @pytest.hookimpl(trylast=True)
|
||||||
|
# def pytest_runtest_call(item: pytest.Item):
|
||||||
|
# """Hook to format SQL statements in OperationalError exceptions."""
|
||||||
|
# try:
|
||||||
|
# item.runtest()
|
||||||
|
# except OperationalError as e:
|
||||||
|
# if e.statement is not None:
|
||||||
|
# formatted_sql = sqlparse.format(e.statement, reindent=True, keyword_case="upper")
|
||||||
|
# e.statement = "\n" + formatted_sql + "\n"
|
||||||
|
# raise e
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
|||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from dibbler.models import Product, Transaction, User
|
from dibbler.models import Product, Transaction, User
|
||||||
@@ -104,8 +103,8 @@ def test_user_balance_with_transfers(sql_session: Session) -> None:
|
|||||||
assert user2_balance == 50 - 30
|
assert user2_balance == 50 - 30
|
||||||
|
|
||||||
|
|
||||||
def test_user_balance_complex_history(sql_session: Session) -> None:
|
@pytest.mark.skip(reason="Not yet implemented")
|
||||||
raise NotImplementedError("This test is not implemented yet.")
|
def test_user_balance_complex_history(sql_session: Session) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
def test_user_balance_penalty(sql_session: Session) -> None:
|
def test_user_balance_penalty(sql_session: Session) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user