This commit is contained in:
20
dibbler/lib/query_helpers.py
Normal file
20
dibbler/lib/query_helpers.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import TypeVar
|
||||
from sqlalchemy import BindParameter, literal
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
def const(value: T) -> BindParameter[T]:
|
||||
"""
|
||||
Create a constant SQL literal bind parameter.
|
||||
|
||||
This is useful to avoid too many `?` bind parameters in SQL queries,
|
||||
when the input value is known to be safe.
|
||||
"""
|
||||
|
||||
return literal(value, literal_execute=True)
|
||||
|
||||
CONST_ZERO: BindParameter[int] = const(0)
|
||||
CONST_ONE: BindParameter[int] = const(1)
|
||||
CONST_TRUE: BindParameter[bool] = const(True)
|
||||
CONST_FALSE: BindParameter[bool] = const(False)
|
||||
CONST_NONE: BindParameter[None] = const(None)
|
||||
@@ -19,6 +19,17 @@ class TransactionType(StrEnum):
|
||||
THROW_PRODUCT = auto()
|
||||
TRANSFER = auto()
|
||||
|
||||
def as_literal_column(self):
|
||||
"""
|
||||
Return the transaction type as a SQL literal column.
|
||||
|
||||
This is useful to avoid too many `?` bind parameters in SQL queries,
|
||||
when the input value is known to be safe.
|
||||
"""
|
||||
from sqlalchemy import literal_column
|
||||
|
||||
return literal_column(f"'{self.value}'")
|
||||
|
||||
|
||||
TransactionTypeSQL = SQLEnum(
|
||||
TransactionType,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
CTE,
|
||||
BindParameter,
|
||||
and_,
|
||||
asc,
|
||||
bindparam,
|
||||
case,
|
||||
func,
|
||||
literal,
|
||||
@@ -12,6 +13,7 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
|
||||
from dibbler.models import (
|
||||
Product,
|
||||
Transaction,
|
||||
@@ -22,9 +24,9 @@ from dibbler.queries.product_stock import _product_stock_query
|
||||
|
||||
|
||||
def _product_owners_query(
|
||||
product_id: int,
|
||||
product_id: BindParameter[int] | int,
|
||||
use_cache: bool = True,
|
||||
until: datetime | None = None,
|
||||
until: BindParameter[datetime] | datetime | None = None,
|
||||
cte_name: str = "rec_cte",
|
||||
) -> CTE:
|
||||
"""
|
||||
@@ -34,6 +36,12 @@ def _product_owners_query(
|
||||
if use_cache:
|
||||
print("WARNING: Using cache for users owning product query is not implemented yet.")
|
||||
|
||||
if isinstance(product_id, int):
|
||||
product_id = bindparam("product_id", value=product_id)
|
||||
|
||||
if isinstance(until, datetime):
|
||||
until = BindParameter("until", value=until)
|
||||
|
||||
product_stock = _product_stock_query(
|
||||
product_id=product_id,
|
||||
use_cache=use_cache,
|
||||
@@ -54,26 +62,26 @@ def _product_owners_query(
|
||||
.where(
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.ADD_PRODUCT,
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
# TransactionType.BUY_PRODUCT,
|
||||
TransactionType.ADJUST_STOCK,
|
||||
TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
# TransactionType.JOINT,
|
||||
# TransactionType.THROW_PRODUCT,
|
||||
]
|
||||
),
|
||||
Transaction.product_id == product_id,
|
||||
literal(True) if until is None else Transaction.time <= until,
|
||||
CONST_TRUE if until is None else Transaction.time <= until,
|
||||
)
|
||||
.order_by(Transaction.time.desc())
|
||||
.subquery()
|
||||
)
|
||||
|
||||
initial_element = select(
|
||||
literal(0).label("i"),
|
||||
literal(0).label("time"),
|
||||
literal(None).label("transaction_id"),
|
||||
literal(None).label("user_id"),
|
||||
literal(0).label("product_count"),
|
||||
CONST_ZERO.label("i"),
|
||||
CONST_ZERO.label("time"),
|
||||
CONST_NONE.label("transaction_id"),
|
||||
CONST_NONE.label("user_id"),
|
||||
CONST_ZERO.label("product_count"),
|
||||
product_stock.scalar_subquery().label("products_left_to_account_for"),
|
||||
)
|
||||
|
||||
@@ -88,28 +96,31 @@ def _product_owners_query(
|
||||
case(
|
||||
# Someone adds the product -> they own it
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
trx_subset.c.user_id,
|
||||
),
|
||||
else_=None,
|
||||
else_=CONST_NONE,
|
||||
).label("user_id"),
|
||||
# How many products did they add (if any)
|
||||
case(
|
||||
# Someone adds the product -> they added a certain amount of products
|
||||
(trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.product_count),
|
||||
# Stock got adjusted upwards -> consider those products as added by nobody
|
||||
(
|
||||
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK)
|
||||
& (trx_subset.c.product_count > 0),
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
trx_subset.c.product_count,
|
||||
),
|
||||
else_=0,
|
||||
# Stock got adjusted upwards -> consider those products as added by nobody
|
||||
(
|
||||
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
|
||||
and (trx_subset.c.product_count > CONST_ZERO),
|
||||
trx_subset.c.product_count,
|
||||
),
|
||||
else_=CONST_ZERO,
|
||||
).label("product_count"),
|
||||
# How many products left to account for
|
||||
case(
|
||||
# Someone adds the product -> increase the number of products left to account for
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
|
||||
),
|
||||
# Someone buys/joins/throws the product -> decrease the number of products left to account for
|
||||
@@ -127,8 +138,8 @@ def _product_owners_query(
|
||||
# If adjusted upwards -> products owned by nobody, decrease products left to account for
|
||||
# If adjusted downwards -> products taken away from owners, decrease products left to account for
|
||||
(
|
||||
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK)
|
||||
and (trx_subset.c.product_count > 0),
|
||||
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
|
||||
and (trx_subset.c.product_count > CONST_ZERO),
|
||||
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
|
||||
),
|
||||
# (
|
||||
@@ -142,8 +153,8 @@ def _product_owners_query(
|
||||
.select_from(trx_subset)
|
||||
.where(
|
||||
and_(
|
||||
trx_subset.c.i == recursive_cte.c.i + 1,
|
||||
recursive_cte.c.products_left_to_account_for > 0,
|
||||
trx_subset.c.i == recursive_cte.c.i + CONST_ONE,
|
||||
recursive_cte.c.products_left_to_account_for > CONST_ZERO,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -3,9 +3,9 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
BindParameter,
|
||||
ColumnElement,
|
||||
Integer,
|
||||
SQLColumnExpression,
|
||||
asc,
|
||||
case,
|
||||
cast,
|
||||
@@ -15,6 +15,7 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
|
||||
from dibbler.models import (
|
||||
Product,
|
||||
Transaction,
|
||||
@@ -26,8 +27,8 @@ from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
|
||||
def _product_price_query(
|
||||
product_id: int | ColumnElement[int],
|
||||
use_cache: bool = True,
|
||||
until: datetime | SQLColumnExpression[datetime] | None = None,
|
||||
until_including: bool = True,
|
||||
until: BindParameter[datetime] | datetime | None = None,
|
||||
until_including: BindParameter[bool] | bool = True,
|
||||
cte_name: str = "rec_cte",
|
||||
):
|
||||
"""
|
||||
@@ -37,12 +38,21 @@ def _product_price_query(
|
||||
if use_cache:
|
||||
print("WARNING: Using cache for product price query is not implemented yet.")
|
||||
|
||||
if isinstance(product_id, int):
|
||||
product_id = BindParameter("product_id", value=product_id)
|
||||
|
||||
if isinstance(until, datetime):
|
||||
until = BindParameter("until", value=until)
|
||||
|
||||
if isinstance(until_including, bool):
|
||||
until_including = BindParameter("until_including", value=until_including)
|
||||
|
||||
initial_element = select(
|
||||
literal(0).label("i"),
|
||||
literal(0).label("time"),
|
||||
literal(None).label("transaction_id"),
|
||||
literal(0).label("price"),
|
||||
literal(0).label("product_count"),
|
||||
CONST_ZERO.label("i"),
|
||||
CONST_ZERO.label("time"),
|
||||
CONST_NONE.label("transaction_id"),
|
||||
CONST_ZERO.label("price"),
|
||||
CONST_ZERO.label("product_count"),
|
||||
)
|
||||
|
||||
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
|
||||
@@ -60,19 +70,19 @@ def _product_price_query(
|
||||
.where(
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.BUY_PRODUCT,
|
||||
TransactionType.ADD_PRODUCT,
|
||||
TransactionType.ADJUST_STOCK,
|
||||
TransactionType.JOINT,
|
||||
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
TransactionType.JOINT.as_literal_column(),
|
||||
]
|
||||
),
|
||||
Transaction.product_id == product_id,
|
||||
case(
|
||||
(literal(until_including), Transaction.time <= until),
|
||||
(until_including, Transaction.time <= until),
|
||||
else_=Transaction.time < until,
|
||||
)
|
||||
if until is not None
|
||||
else literal(True),
|
||||
else CONST_TRUE,
|
||||
)
|
||||
.order_by(Transaction.time.asc())
|
||||
.alias("trx_subset")
|
||||
@@ -85,22 +95,26 @@ def _product_price_query(
|
||||
trx_subset.c.id.label("transaction_id"),
|
||||
case(
|
||||
# Someone buys the product -> price remains the same.
|
||||
(trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price),
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
recursive_cte.c.price,
|
||||
),
|
||||
# Someone adds the product -> price is recalculated based on
|
||||
# product count, previous price, and new price.
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
cast(
|
||||
func.ceil(
|
||||
(
|
||||
recursive_cte.c.price * func.max(recursive_cte.c.product_count, 0)
|
||||
recursive_cte.c.price
|
||||
* func.max(recursive_cte.c.product_count, CONST_ZERO)
|
||||
+ trx_subset.c.per_product * trx_subset.c.product_count
|
||||
)
|
||||
/ (
|
||||
# The running product count can be negative if the accounting is bad.
|
||||
# This ensures that we never end up with negative prices or zero divisions
|
||||
# and other disastrous phenomena.
|
||||
func.max(recursive_cte.c.product_count, 0)
|
||||
func.max(recursive_cte.c.product_count, CONST_ZERO)
|
||||
+ trx_subset.c.product_count
|
||||
)
|
||||
),
|
||||
@@ -108,28 +122,31 @@ def _product_price_query(
|
||||
),
|
||||
),
|
||||
# Someone adjusts the stock -> price remains the same.
|
||||
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK, recursive_cte.c.price),
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
recursive_cte.c.price,
|
||||
),
|
||||
# Should never happen
|
||||
else_=recursive_cte.c.price,
|
||||
).label("price"),
|
||||
case(
|
||||
# Someone buys the product -> product count is reduced.
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.BUY_PRODUCT,
|
||||
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
recursive_cte.c.product_count - trx_subset.c.product_count,
|
||||
),
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.JOINT,
|
||||
trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(),
|
||||
recursive_cte.c.product_count - trx_subset.c.product_count,
|
||||
),
|
||||
# Someone adds the product -> product count is increased.
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
recursive_cte.c.product_count + trx_subset.c.product_count,
|
||||
),
|
||||
# Someone adjusts the stock -> product count is adjusted.
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_STOCK,
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
recursive_cte.c.product_count + trx_subset.c.product_count,
|
||||
),
|
||||
# Should never happen
|
||||
@@ -137,7 +154,7 @@ def _product_price_query(
|
||||
).label("product_count"),
|
||||
)
|
||||
.select_from(trx_subset)
|
||||
.where(trx_subset.c.i == recursive_cte.c.i + 1)
|
||||
.where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
|
||||
)
|
||||
|
||||
return recursive_cte.union_all(recursive_elements)
|
||||
@@ -222,7 +239,10 @@ def product_price(
|
||||
# - price should never be negative
|
||||
|
||||
result = sql_session.scalars(
|
||||
select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1)
|
||||
select(recursive_cte.c.price)
|
||||
.order_by(recursive_cte.c.i.desc())
|
||||
.limit(CONST_ONE)
|
||||
.offset(CONST_ZERO)
|
||||
).one_or_none()
|
||||
|
||||
if result is None:
|
||||
@@ -237,10 +257,10 @@ def product_price(
|
||||
select(Transaction.interest_rate_percent)
|
||||
.where(
|
||||
Transaction.type_ == TransactionType.ADJUST_INTEREST,
|
||||
literal(True) if until is None else Transaction.time <= until.time,
|
||||
CONST_TRUE if until is None else Transaction.time <= until.time,
|
||||
)
|
||||
.order_by(Transaction.time.desc())
|
||||
.limit(1)
|
||||
.limit(CONST_ONE)
|
||||
)
|
||||
or DEFAULT_INTEREST_RATE_PERCENTAGE
|
||||
)
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
BindParameter,
|
||||
Select,
|
||||
case,
|
||||
func,
|
||||
literal,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.lib.query_helpers import CONST_TRUE
|
||||
from dibbler.models import (
|
||||
Product,
|
||||
Transaction,
|
||||
@@ -17,9 +18,9 @@ from dibbler.models import (
|
||||
|
||||
|
||||
def _product_stock_query(
|
||||
product_id: int,
|
||||
product_id: BindParameter[int] | int,
|
||||
use_cache: bool = True,
|
||||
until: datetime | None = None,
|
||||
until: BindParameter[datetime] | datetime | None = None,
|
||||
) -> Select:
|
||||
"""
|
||||
The inner query for calculating the product stock.
|
||||
@@ -28,27 +29,33 @@ def _product_stock_query(
|
||||
if use_cache:
|
||||
print("WARNING: Using cache for product stock query is not implemented yet.")
|
||||
|
||||
if isinstance(product_id, int):
|
||||
product_id = BindParameter("product_id", value=product_id)
|
||||
|
||||
if isinstance(until, datetime):
|
||||
until = BindParameter("until", value=until)
|
||||
|
||||
query = select(
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
Transaction.type_ == TransactionType.ADD_PRODUCT,
|
||||
Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
Transaction.product_count,
|
||||
),
|
||||
(
|
||||
Transaction.type_ == TransactionType.ADJUST_STOCK,
|
||||
Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
Transaction.product_count,
|
||||
),
|
||||
(
|
||||
Transaction.type_ == TransactionType.BUY_PRODUCT,
|
||||
Transaction.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
-Transaction.product_count,
|
||||
),
|
||||
(
|
||||
Transaction.type_ == TransactionType.JOINT,
|
||||
Transaction.type_ == TransactionType.JOINT.as_literal_column(),
|
||||
-Transaction.product_count,
|
||||
),
|
||||
(
|
||||
Transaction.type_ == TransactionType.THROW_PRODUCT,
|
||||
Transaction.type_ == TransactionType.THROW_PRODUCT.as_literal_column(),
|
||||
-Transaction.product_count,
|
||||
),
|
||||
else_=0,
|
||||
@@ -57,15 +64,15 @@ def _product_stock_query(
|
||||
).where(
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.ADD_PRODUCT,
|
||||
TransactionType.ADJUST_STOCK,
|
||||
TransactionType.BUY_PRODUCT,
|
||||
TransactionType.JOINT,
|
||||
TransactionType.THROW_PRODUCT,
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
TransactionType.JOINT.as_literal_column(),
|
||||
TransactionType.THROW_PRODUCT.as_literal_column(),
|
||||
]
|
||||
),
|
||||
Transaction.product_id == product_id,
|
||||
Transaction.time <= until if until is not None else literal(True),
|
||||
Transaction.time <= until if until is not None else CONST_TRUE,
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
@@ -3,10 +3,10 @@ from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
CTE,
|
||||
BindParameter,
|
||||
Float,
|
||||
Integer,
|
||||
and_,
|
||||
asc,
|
||||
case,
|
||||
cast,
|
||||
column,
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
|
||||
from dibbler.models import (
|
||||
Transaction,
|
||||
TransactionType,
|
||||
@@ -31,10 +32,10 @@ from dibbler.queries.product_price import _product_price_query
|
||||
|
||||
|
||||
def _user_balance_query(
|
||||
user_id: int,
|
||||
user_id: BindParameter[int] | int,
|
||||
use_cache: bool = True,
|
||||
until: datetime | None = None,
|
||||
until_including: bool = True,
|
||||
until: BindParameter[datetime] | BindParameter[None] | datetime | None = None,
|
||||
until_including: BindParameter[bool] | bool = True,
|
||||
cte_name: str = "rec_cte",
|
||||
) -> CTE:
|
||||
"""
|
||||
@@ -44,14 +45,23 @@ def _user_balance_query(
|
||||
if use_cache:
|
||||
print("WARNING: Using cache for user balance query is not implemented yet.")
|
||||
|
||||
if isinstance(user_id, int):
|
||||
user_id = BindParameter("user_id", value=user_id)
|
||||
|
||||
if isinstance(until, datetime):
|
||||
until = BindParameter("until", value=until, type_=datetime)
|
||||
|
||||
if isinstance(until_including, bool):
|
||||
until_including = BindParameter("until_including", value=until_including, type_=bool)
|
||||
|
||||
initial_element = select(
|
||||
literal(0).label("i"),
|
||||
literal(0).label("time"),
|
||||
literal(None).label("transaction_id"),
|
||||
literal(0).label("balance"),
|
||||
literal(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
|
||||
literal(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
|
||||
literal(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"),
|
||||
CONST_ZERO.label("i"),
|
||||
CONST_ZERO.label("time"),
|
||||
CONST_NONE.label("transaction_id"),
|
||||
CONST_ZERO.label("balance"),
|
||||
const(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
|
||||
const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
|
||||
const(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"),
|
||||
)
|
||||
|
||||
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
|
||||
@@ -59,7 +69,7 @@ def _user_balance_query(
|
||||
# Subset of transactions that we'll want to iterate over.
|
||||
trx_subset = (
|
||||
select(
|
||||
func.row_number().over(order_by=asc(Transaction.time)).label("i"),
|
||||
func.row_number().over(order_by=Transaction.time.asc()).label("i"),
|
||||
Transaction.amount,
|
||||
Transaction.id,
|
||||
Transaction.interest_rate_percent,
|
||||
@@ -77,34 +87,34 @@ def _user_balance_query(
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.ADD_PRODUCT,
|
||||
TransactionType.ADJUST_BALANCE,
|
||||
TransactionType.BUY_PRODUCT,
|
||||
TransactionType.TRANSFER,
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_BALANCE.as_literal_column(),
|
||||
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
TransactionType.TRANSFER.as_literal_column(),
|
||||
# TODO: join this with the JOINT transactions, and determine
|
||||
# how much the current user paid for the product.
|
||||
TransactionType.JOINT_BUY_PRODUCT,
|
||||
TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
|
||||
]
|
||||
),
|
||||
),
|
||||
and_(
|
||||
Transaction.type_ == TransactionType.TRANSFER,
|
||||
Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
|
||||
Transaction.transfer_user_id == user_id,
|
||||
),
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.THROW_PRODUCT,
|
||||
TransactionType.ADJUST_INTEREST,
|
||||
TransactionType.ADJUST_PENALTY,
|
||||
TransactionType.THROW_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_INTEREST.as_literal_column(),
|
||||
TransactionType.ADJUST_PENALTY.as_literal_column(),
|
||||
]
|
||||
),
|
||||
),
|
||||
case(
|
||||
(literal(until_including), Transaction.time <= until),
|
||||
(until_including, Transaction.time <= until),
|
||||
else_=Transaction.time < until,
|
||||
)
|
||||
if until is not None
|
||||
else literal(True),
|
||||
else CONST_TRUE,
|
||||
)
|
||||
.order_by(Transaction.time.asc())
|
||||
.alias("trx_subset")
|
||||
@@ -118,17 +128,17 @@ def _user_balance_query(
|
||||
case(
|
||||
# Adjusts balance -> balance gets adjusted
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_BALANCE,
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_BALANCE.as_literal_column(),
|
||||
recursive_cte.c.balance + trx_subset.c.amount,
|
||||
),
|
||||
# Adds a product -> balance increases
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
recursive_cte.c.balance + trx_subset.c.amount,
|
||||
),
|
||||
# Buys a product -> balance decreases
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.BUY_PRODUCT,
|
||||
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
recursive_cte.c.balance
|
||||
- (
|
||||
trx_subset.c.product_count
|
||||
@@ -151,12 +161,12 @@ def _user_balance_query(
|
||||
)
|
||||
)
|
||||
.order_by(column("i").desc())
|
||||
.limit(1)
|
||||
.limit(CONST_ONE)
|
||||
).scalar_subquery()
|
||||
# TODO: should interest be applied before or after the penalty multiplier?
|
||||
# at the moment of writing, after sound right, but maybe ask someone?
|
||||
# Interest
|
||||
* (cast(recursive_cte.c.interest_rate_percent, Float) / 100)
|
||||
* (cast(recursive_cte.c.interest_rate_percent, Float) / const(100))
|
||||
# TODO: these should be added together, not multiplied, see specification
|
||||
# Penalty
|
||||
* case(
|
||||
@@ -164,10 +174,10 @@ def _user_balance_query(
|
||||
recursive_cte.c.balance < recursive_cte.c.penalty_threshold,
|
||||
(
|
||||
cast(recursive_cte.c.penalty_multiplier_percent, Float)
|
||||
/ 100
|
||||
/ const(100)
|
||||
),
|
||||
),
|
||||
else_=1.0,
|
||||
else_=const(1.0),
|
||||
)
|
||||
),
|
||||
Integer,
|
||||
@@ -177,7 +187,7 @@ def _user_balance_query(
|
||||
# Transfers money to self -> balance increases
|
||||
(
|
||||
and_(
|
||||
trx_subset.c.type_ == TransactionType.TRANSFER,
|
||||
trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
|
||||
trx_subset.c.transfer_user_id == user_id,
|
||||
),
|
||||
recursive_cte.c.balance + trx_subset.c.amount,
|
||||
@@ -185,7 +195,7 @@ def _user_balance_query(
|
||||
# Transfers money from self -> balance decreases
|
||||
(
|
||||
and_(
|
||||
trx_subset.c.type_ == TransactionType.TRANSFER,
|
||||
trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
|
||||
trx_subset.c.transfer_user_id != user_id,
|
||||
),
|
||||
recursive_cte.c.balance - trx_subset.c.amount,
|
||||
@@ -202,28 +212,28 @@ def _user_balance_query(
|
||||
).label("balance"),
|
||||
case(
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_INTEREST,
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_INTEREST.as_literal_column(),
|
||||
trx_subset.c.interest_rate_percent,
|
||||
),
|
||||
else_=recursive_cte.c.interest_rate_percent,
|
||||
).label("interest_rate_percent"),
|
||||
case(
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY,
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
|
||||
trx_subset.c.penalty_threshold,
|
||||
),
|
||||
else_=recursive_cte.c.penalty_threshold,
|
||||
).label("penalty_threshold"),
|
||||
case(
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY,
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
|
||||
trx_subset.c.penalty_multiplier_percent,
|
||||
),
|
||||
else_=recursive_cte.c.penalty_multiplier_percent,
|
||||
).label("penalty_multiplier_percent"),
|
||||
)
|
||||
.select_from(trx_subset)
|
||||
.where(trx_subset.c.i == recursive_cte.c.i + 1)
|
||||
.where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
|
||||
)
|
||||
|
||||
return recursive_cte.union_all(recursive_elements)
|
||||
@@ -322,7 +332,10 @@ def user_balance(
|
||||
)
|
||||
|
||||
result = sql_session.scalar(
|
||||
select(recursive_cte.c.balance).order_by(recursive_cte.c.i.desc()).limit(1)
|
||||
select(recursive_cte.c.balance)
|
||||
.order_by(recursive_cte.c.i.desc())
|
||||
.limit(CONST_ONE)
|
||||
.offset(CONST_ZERO)
|
||||
)
|
||||
|
||||
if result is None:
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import pytest
|
||||
import sqlparse
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Base
|
||||
@@ -35,21 +36,27 @@ class SqlParseFormatter(logging.Formatter):
|
||||
return super().format(record)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def sql_session(request):
|
||||
"""Create a new SQLAlchemy session for testing."""
|
||||
def pytest_configure(config):
|
||||
"""Setup pretty SQL logging if --echo is enabled."""
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger("sqlalchemy.engine")
|
||||
|
||||
# TODO: it would be nice not to duplicate these logs.
|
||||
# logging.NullHandler() does not seem to work here
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(SqlParseFormatter())
|
||||
logger.addHandler(handler)
|
||||
|
||||
echo = request.config.getoption("--echo")
|
||||
engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
echo=echo,
|
||||
)
|
||||
echo = config.getoption("--echo")
|
||||
if echo:
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def sql_session(request):
|
||||
"""Create a new SQLAlchemy session for testing."""
|
||||
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, _connection_record):
|
||||
@@ -60,3 +67,17 @@ def sql_session(request):
|
||||
Base.metadata.create_all(engine)
|
||||
with Session(engine) as sql_session:
|
||||
yield sql_session
|
||||
|
||||
|
||||
# FIXME: Declaring this hook seems to have a side effect where the database does not
|
||||
# get reset between tests.
|
||||
# @pytest.hookimpl(trylast=True)
|
||||
# def pytest_runtest_call(item: pytest.Item):
|
||||
# """Hook to format SQL statements in OperationalError exceptions."""
|
||||
# try:
|
||||
# item.runtest()
|
||||
# except OperationalError as e:
|
||||
# if e.statement is not None:
|
||||
# formatted_sql = sqlparse.format(e.statement, reindent=True, keyword_case="upper")
|
||||
# e.statement = "\n" + formatted_sql + "\n"
|
||||
# raise e
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
from pprint import pprint
|
||||
|
||||
import pytest
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, Transaction, User
|
||||
@@ -104,8 +103,8 @@ def test_user_balance_with_transfers(sql_session: Session) -> None:
|
||||
assert user2_balance == 50 - 30
|
||||
|
||||
|
||||
def test_user_balance_complex_history(sql_session: Session) -> None:
|
||||
raise NotImplementedError("This test is not implemented yet.")
|
||||
@pytest.mark.skip(reason="Not yet implemented")
|
||||
def test_user_balance_complex_history(sql_session: Session) -> None: ...
|
||||
|
||||
|
||||
def test_user_balance_penalty(sql_session: Session) -> None:
|
||||
|
||||
Reference in New Issue
Block a user