diff --git a/dibbler/queries/product_stock.py b/dibbler/queries/product_stock.py index 39bc984..082478f 100644 --- a/dibbler/queries/product_stock.py +++ b/dibbler/queries/product_stock.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Tuple from sqlalchemy import ( BindParameter, @@ -21,7 +22,7 @@ def _product_stock_query( product_id: BindParameter[int] | int, use_cache: bool = True, until: BindParameter[datetime] | datetime | None = None, -) -> Select: +) -> Select[Tuple[int]]: """ The inner query for calculating the product stock. """ diff --git a/dibbler/queries/user_balance.py b/dibbler/queries/user_balance.py index 747e0b3..c41c583 100644 --- a/dibbler/queries/user_balance.py +++ b/dibbler/queries/user_balance.py @@ -1,11 +1,13 @@ from dataclasses import dataclass from datetime import datetime +from typing import Tuple from sqlalchemy import ( CTE, BindParameter, Float, Integer, + Select, and_, case, cast, @@ -14,7 +16,7 @@ from sqlalchemy import ( or_, select, ) -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, aliased from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const from dibbler.models import ( @@ -30,10 +32,184 @@ from dibbler.models.Transaction import ( from dibbler.queries.product_price import _product_price_query +def _joint_transaction_query( + user_id: BindParameter[int] | int, + use_cache: bool, + until: BindParameter[datetime] | None, + until_including: BindParameter[bool], +) -> Select[Tuple[int, int, int]]: + """ + The inner query for getting joint transactions relevant to a user. + + This scans for JOINT_BUY_PRODUCT transactions made by the user, + then finds the corresponding JOINT transactions, and counts how many "shares" + of the joint transaction the user has, as well as the total number of shares. + """ + + # First, select all joint buy product transactions for the given user + # sub_joint_transaction = aliased(Transaction, name="right_trx") + sub_joint_transaction = ( + select(Transaction.joint_transaction_id.distinct().label("joint_transaction_id")) + .where( + Transaction.type_ == TransactionType.JOINT_BUY_PRODUCT.as_literal_column(), + Transaction.user_id == user_id, + ) + .subquery("sub_joint_transaction") + ) + + # Join those with their main joint transaction + # (just use Transaction) + + # Then, count how many users are involved in each joint transaction + joint_transaction_count = aliased(Transaction, name="count_trx") + + joint_transaction = ( + select( + Transaction.id, + # Shares the user has in the transaction, + func.sum( + case( + (joint_transaction_count.user_id == user_id, CONST_ONE), + else_=CONST_ZERO, + ) + ).label("user_shares"), + # The total number of shares in the transaction, + func.count(joint_transaction_count.id).label("user_count"), + ) + .select_from(sub_joint_transaction) + .join( + Transaction, + onclause=Transaction.id == sub_joint_transaction.c.joint_transaction_id, + ) + .join( + joint_transaction_count, + onclause=joint_transaction_count.joint_transaction_id == Transaction.id, + ) + .where( + case( + (until_including, sub_joint_transaction.time <= until), + else_=sub_joint_transaction.time < until, + ) + if until is not None + else CONST_TRUE, + ) + .group_by(joint_transaction_count.joint_transaction_id) + ) + + return joint_transaction + + +def _non_joint_transaction_query( + user_id: BindParameter[int] | int, + use_cache: bool, + until: BindParameter[datetime] | None, + until_including: BindParameter[bool], +) -> Select[Tuple[int, None, None]]: + """ + The inner query for getting non-joint transactions relevant to a user. + """ + + query = select( + Transaction.id, + CONST_NONE.label("user_shares"), + CONST_NONE.label("user_count"), + ).where( + or_( + and_( + Transaction.user_id == user_id, + Transaction.type_.in_( + [ + TransactionType.ADD_PRODUCT.as_literal_column(), + TransactionType.ADJUST_BALANCE.as_literal_column(), + TransactionType.BUY_PRODUCT.as_literal_column(), + TransactionType.TRANSFER.as_literal_column(), + ] + ), + ), + and_( + Transaction.type_ == TransactionType.TRANSFER.as_literal_column(), + Transaction.transfer_user_id == user_id, + ), + Transaction.type_.in_( + [ + TransactionType.THROW_PRODUCT.as_literal_column(), + TransactionType.ADJUST_INTEREST.as_literal_column(), + TransactionType.ADJUST_PENALTY.as_literal_column(), + ] + ), + case( + (until_including, Transaction.time <= until), + else_=Transaction.time < until, + ) + if until is not None + else CONST_TRUE, + ), + ) + + return query + + +def _product_cost_expression( + product_count_column, + product_id_column, + interest_rate_percent_column, + user_balance_column, + penalty_threshold_column, + penalty_multiplier_percent_column, + use_cache: bool, + until: BindParameter[datetime] | None, + cte_name: str = "product_price_cte", + trx_subset_name: str = "product_price_trx_subset", +): + expression = ( + product_count_column + # Price of a single product, accounted for penalties and interest. + * cast( + func.ceil( + # TODO: This can get quite expensive real quick, so we should do some caching of the + # product prices somehow. + # Base price + ( + # FIXME: this always returns 0 for some reason... + select(cast(column("price"), Float)) + .select_from( + _product_price_query( + product_id_column, + use_cache=use_cache, + until=until, + until_including=False, + cte_name=cte_name, + trx_subset_name=trx_subset_name, + ) + ) + .order_by(column("i").desc()) + .limit(CONST_ONE) + ).scalar_subquery() + # TODO: should interest be applied before or after the penalty multiplier? + # at the moment of writing, after sound right, but maybe ask someone? + # Interest + * (cast(interest_rate_percent_column, Float) / const(100)) + # TODO: these should be added together, not multiplied, see specification + # Penalty + * case( + ( + user_balance_column < penalty_threshold_column, + (cast(penalty_multiplier_percent_column, Float) / const(100)), + ), + else_=const(1.0), + ) + ), + Integer, + ) + ) + + return expression + + def _user_balance_query( user_id: BindParameter[int] | int, use_cache: bool = True, - until: BindParameter[datetime] | BindParameter[None] | datetime | None = None, + until: BindParameter[datetime] | datetime | None = None, until_including: BindParameter[bool] | bool = True, cte_name: str = "rec_cte", trx_subset_name: str = "trx_subset", @@ -49,10 +225,10 @@ def _user_balance_query( user_id = BindParameter("user_id", value=user_id) if isinstance(until, datetime): - until = BindParameter("until", value=until, type_=datetime) + until = BindParameter("until", value=until) if isinstance(until_including, bool): - until_including = BindParameter("until_including", value=until_including, type_=bool) + until_including = BindParameter("until_including", value=until_including) initial_element = select( CONST_ZERO.label("i"), @@ -66,12 +242,30 @@ def _user_balance_query( recursive_cte = initial_element.cte(name=cte_name, recursive=True) + trx_subset_subset = ( + _non_joint_transaction_query( + user_id=user_id, + use_cache=use_cache, + until=until, + until_including=until_including, + ) + .union_all( + _joint_transaction_query( + user_id=user_id, + use_cache=use_cache, + until=until, + until_including=until_including, + ) + ) + .subquery(f"{trx_subset_name}_subset") + ) + # Subset of transactions that we'll want to iterate over. trx_subset = ( select( func.row_number().over(order_by=Transaction.time.asc()).label("i"), - Transaction.amount, Transaction.id, + Transaction.amount, Transaction.interest_rate_percent, Transaction.penalty_multiplier_percent, Transaction.penalty_threshold, @@ -80,41 +274,13 @@ def _user_balance_query( Transaction.time, Transaction.transfer_user_id, Transaction.type_, + trx_subset_subset.c.user_shares, + trx_subset_subset.c.user_count, ) - .where( - or_( - and_( - Transaction.user_id == user_id, - Transaction.type_.in_( - [ - TransactionType.ADD_PRODUCT.as_literal_column(), - TransactionType.ADJUST_BALANCE.as_literal_column(), - TransactionType.BUY_PRODUCT.as_literal_column(), - TransactionType.TRANSFER.as_literal_column(), - # TODO: join this with the JOINT transactions, and determine - # how much the current user paid for the product. - TransactionType.JOINT_BUY_PRODUCT.as_literal_column(), - ] - ), - ), - and_( - Transaction.type_ == TransactionType.TRANSFER.as_literal_column(), - Transaction.transfer_user_id == user_id, - ), - Transaction.type_.in_( - [ - TransactionType.THROW_PRODUCT.as_literal_column(), - TransactionType.ADJUST_INTEREST.as_literal_column(), - TransactionType.ADJUST_PENALTY.as_literal_column(), - ] - ), - ), - case( - (until_including, Transaction.time <= until), - else_=Transaction.time < until, - ) - if until is not None - else CONST_TRUE, + .select_from(trx_subset_subset) + .join( + Transaction, + onclause=Transaction.id == trx_subset_subset.c.id, ) .order_by(Transaction.time.asc()) .subquery(trx_subset_name) @@ -140,49 +306,39 @@ def _user_balance_query( ( trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), recursive_cte.c.balance - - ( - trx_subset.c.product_count - # Price of a single product, accounted for penalties and interest. - * cast( - func.ceil( - # TODO: This can get quite expensive real quick, so we should do some caching of the - # product prices somehow. - # Base price - ( - # FIXME: this always returns 0 for some reason... - select(cast(column("price"), Float)) - .select_from( - _product_price_query( - trx_subset.c.product_id, - use_cache=use_cache, - until=trx_subset.c.time, - until_including=False, - cte_name="product_price_cte", - trx_subset_name="product_price_trx_subset", - ) - ) - .order_by(column("i").desc()) - .limit(CONST_ONE) - ).scalar_subquery() - # TODO: should interest be applied before or after the penalty multiplier? - # at the moment of writing, after sound right, but maybe ask someone? - # Interest - * (cast(recursive_cte.c.interest_rate_percent, Float) / const(100)) - # TODO: these should be added together, not multiplied, see specification - # Penalty - * case( - ( - recursive_cte.c.balance < recursive_cte.c.penalty_threshold, - ( - cast(recursive_cte.c.penalty_multiplier_percent, Float) - / const(100) - ), - ), - else_=const(1.0), - ) - ), - Integer, + - _product_cost_expression( + product_count_column=trx_subset.c.product_count, + product_id_column=trx_subset.c.product_id, + interest_rate_percent_column=recursive_cte.c.interest_rate_percent, + user_balance_column=recursive_cte.c.balance, + penalty_threshold_column=recursive_cte.c.penalty_threshold, + penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent, + use_cache=use_cache, + until=until, + cte_name=f"{cte_name}_price", + trx_subset_name=f"{trx_subset_name}_price", + ), + ), + # Joint transaction -> balance decreases proportionally + ( + trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(), + recursive_cte.c.balance + - func.ceil( + _product_cost_expression( + product_count_column=trx_subset.c.product_count, + product_id_column=trx_subset.c.product_id, + interest_rate_percent_column=recursive_cte.c.interest_rate_percent, + user_balance_column=recursive_cte.c.balance, + penalty_threshold_column=recursive_cte.c.penalty_threshold, + penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent, + use_cache=use_cache, + until=until, + cte_name=f"{cte_name}_joint_price", + trx_subset_name=f"{trx_subset_name}_joint_price", ) + # TODO: move this inside of the product cost expression + * trx_subset.c.user_shares + / trx_subset.c.user_count ), ), # Transfers money to self -> balance increases