diff --git a/dibbler/models/Transaction.py b/dibbler/models/Transaction.py index 50d2d5b..c59929c 100644 --- a/dibbler/models/Transaction.py +++ b/dibbler/models/Transaction.py @@ -27,6 +27,13 @@ if TYPE_CHECKING: from .Product import Product from .User import User + +# NOTE: these only matter when there are no adjustments made in the database. + +DEFAULT_INTEREST_RATE_PERCENTAGE = 100 +DEFAULT_PENALTY_THRESHOLD = -100 +DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE = 200 + # TODO: allow for joint transactions? # dibbler allows joint transactions (e.g. buying more than one product at once, several people buying the same product, etc.) # instead of having the software split the transactions up, making them hard to reconnect, @@ -51,7 +58,7 @@ _EXPECTED_FIELDS: dict[TransactionType, set[str]] = { TransactionType.ADJUST_STOCK: {"product_count", "product_id"}, # TODO: remove amount from BUY_PRODUCT # this requires modifications to user credit calculations - TransactionType.BUY_PRODUCT: {"amount", "product_count", "product_id"}, + TransactionType.BUY_PRODUCT: {"product_count", "product_id"}, TransactionType.TRANSFER: {"amount", "transfer_user_id"}, } @@ -379,7 +386,6 @@ class Transaction(Base): @classmethod def buy_product( cls: type[Self], - amount: int, user_id: int, product_id: int, product_count: int, @@ -389,7 +395,6 @@ class Transaction(Base): return cls( time=time, type_=TransactionType.BUY_PRODUCT, - amount=amount, user_id=user_id, product_id=product_id, product_count=product_count, diff --git a/dibbler/queries/buy_product.py b/dibbler/queries/buy_product.py index 99605ea..ec3bb1d 100644 --- a/dibbler/queries/buy_product.py +++ b/dibbler/queries/buy_product.py @@ -1,13 +1,16 @@ +import math from datetime import datetime from sqlalchemy.orm import Session from dibbler.models import ( - Transaction, - TransactionType, - User, Product, + Transaction, + User, ) +from dibbler.queries.current_interest import current_interest +from dibbler.queries.current_penalty import current_penalty +from dibbler.queries.user_balance import user_balance from .product_price import product_price @@ -24,12 +27,26 @@ def buy_product( Creates a BUY_PRODUCT transaction with the amount automatically calculated based on the product's current price. """ - price = product_price(sql_session, product) + # balance = user_balance(sql_session, user) - return Transaction( + # price = product_price(sql_session, product) + + # interest_rate = current_interest(sql_session) + + # penalty_threshold, penalty_multiplier_percent = current_penalty(sql_session) + + # price *= product_count + + # price *= 1 + interest_rate / 100 + + # if balance < penalty_threshold: + # price *= 1 + penalty_multiplier_percent / 100 + + # price = math.ceil(price) + + return Transaction.buy_product( time=time, - type_=TransactionType.BUY_PRODUCT, - amount=price * product_count, + # amount=price, user_id=user.id, product_id=product.id, product_count=product_count, diff --git a/dibbler/queries/current_interest.py b/dibbler/queries/current_interest.py new file mode 100644 index 0000000..51272bf --- /dev/null +++ b/dibbler/queries/current_interest.py @@ -0,0 +1,21 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, TransactionType +from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE + + +def current_interest(sql_session: Session) -> int: + result = sql_session.scalars( + select(Transaction) + .where(Transaction.type_ == TransactionType.ADJUST_INTEREST) + .order_by(Transaction.time.desc()) + .limit(1) + ).one_or_none() + + if result is None: + return DEFAULT_INTEREST_RATE_PERCENTAGE + + assert result.interest_rate_percent is not None, "Interest rate percent must be set" + + return result.interest_rate_percent diff --git a/dibbler/queries/current_penalty.py b/dibbler/queries/current_penalty.py new file mode 100644 index 0000000..5658b06 --- /dev/null +++ b/dibbler/queries/current_penalty.py @@ -0,0 +1,25 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, TransactionType +from dibbler.models.Transaction import ( + DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE, + DEFAULT_PENALTY_THRESHOLD, +) + + +def current_penalty(sql_session: Session) -> tuple[int, int]: + result = sql_session.scalars( + select(Transaction) + .where(Transaction.type_ == TransactionType.ADJUST_PENALTY) + .order_by(Transaction.time.desc()) + .limit(1) + ).one_or_none() + + if result is None: + return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE + + assert result.penalty_threshold is not None, "Penalty threshold must be set" + assert result.penalty_multiplier_percent is not None, "Penalty multiplier percent must be set" + + return result.penalty_threshold, result.penalty_multiplier_percent diff --git a/dibbler/queries/product_price.py b/dibbler/queries/product_price.py index 621efc4..e0c4d3d 100644 --- a/dibbler/queries/product_price.py +++ b/dibbler/queries/product_price.py @@ -9,7 +9,6 @@ from sqlalchemy import ( literal, select, ) - from sqlalchemy.orm import Session from dibbler.models import ( @@ -18,14 +17,20 @@ from dibbler.models import ( TransactionType, ) + def _product_price_query( - product: Product, - # use_cache: bool = True, - # until: datetime | None = None, + product_id: int, + use_cache: bool = True, + until: datetime | None = None, + cte_name: str = "rec_cte", ): """ The inner query for calculating the product price. """ + + if use_cache: + print("WARNING: Using cache for product price query is not implemented yet.") + initial_element = select( literal(0).label("i"), literal(0).label("time"), @@ -33,7 +38,7 @@ def _product_price_query( literal(0).label("product_count"), ) - recursive_cte = initial_element.cte(name="rec_cte", recursive=True) + recursive_cte = initial_element.cte(name=cte_name, recursive=True) # Subset of transactions that we'll want to iterate over. trx_subset = ( @@ -52,11 +57,8 @@ def _product_price_query( TransactionType.ADJUST_STOCK, ] ), - Transaction.product_id == product.id, - # TODO: - # If we have a transaction to limit the price calculation to, use it. - # If not, use all transactions for the product. - # (Transaction.time <= until.time) if until else True, + Transaction.product_id == product_id, + Transaction.time <= until if until is not None else 1 == 1, ) .order_by(Transaction.time.asc()) .alias("trx_subset") @@ -122,15 +124,18 @@ def _product_price_query( def product_price_log( sql_session: Session, product: Product, - # use_cache: bool = True, - # Optional: calculate the price until a certain transaction. - # until: Transaction | None = None, + use_cache: bool = True, + until: Transaction | None = None, ) -> list[tuple[int, datetime, int, int]]: """ Calculates the price of a product and returns a log of the price changes. """ - recursive_cte = _product_price_query(product) + recursive_cte = _product_price_query( + product.id, + use_cache=use_cache, + until=until.time if until else None, + ) result = sql_session.execute( select( @@ -154,15 +159,18 @@ def product_price_log( def product_price( sql_session: Session, product: Product, - # use_cache: bool = True, - # Optional: calculate the price until a certain transaction. - # until: Transaction | None = None, + use_cache: bool = True, + until: Transaction | None = None, ) -> int: """ Calculates the price of a product. """ - recursive_cte = _product_price_query(product) # , until=until) + recursive_cte = _product_price_query( + product.id, + use_cache=use_cache, + until=until.time if until else None, + ) # TODO: optionally verify subresults: # - product_count should never be negative (but this happens sometimes, so just a warning) diff --git a/dibbler/queries/user_balance.py b/dibbler/queries/user_balance.py index dcedd5e..e35c068 100644 --- a/dibbler/queries/user_balance.py +++ b/dibbler/queries/user_balance.py @@ -1,4 +1,17 @@ -from sqlalchemy import func, select +from datetime import datetime + +from sqlalchemy import ( + Integer, + and_, + asc, + case, + cast, + column, + func, + literal, + or_, + select, +) from sqlalchemy.orm import Session from dibbler.models import ( @@ -6,92 +19,236 @@ from dibbler.models import ( TransactionType, User, ) +from dibbler.models.Transaction import ( + DEFAULT_INTEREST_RATE_PERCENTAGE, + DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE, + DEFAULT_PENALTY_THRESHOLD, +) +from dibbler.queries.product_price import _product_price_query -# TODO: rename to 'balance' everywhere def _user_balance_query( user: User, - # use_cache: bool = True, - # until: datetime | None = None, + use_cache: bool = True, + until: datetime | None = None, + cte_name: str = "rec_cte", ): """ The inner query for calculating the user's balance. """ - balance_adjustments = ( - select(func.coalesce(func.sum(Transaction.amount).label("balance_adjustments"), 0)) - .where( - Transaction.user_id == user.id, - Transaction.type_ == TransactionType.ADJUST_BALANCE, + if use_cache: + print("WARNING: Using cache for user balance query is not implemented yet.") + + initial_element = select( + literal(0).label("i"), + literal(0).label("time"), + literal(0).label("balance"), + literal(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"), + literal(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"), + literal(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"), + ) + + recursive_cte = initial_element.cte(name=cte_name, recursive=True) + + # Subset of transactions that we'll want to iterate over. + trx_subset = ( + select( + func.row_number().over(order_by=asc(Transaction.time)).label("i"), + Transaction.time, + Transaction.type_, + Transaction.amount, + Transaction.product_count, + Transaction.product_id, + Transaction.transfer_user_id, + Transaction.interest_rate_percent, + Transaction.penalty_multiplier_percent, + Transaction.penalty_threshold, ) - .scalar_subquery() - ) - - transfers_to_other_users = ( - select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_other_users"), 0)) .where( - Transaction.user_id == user.id, - Transaction.type_ == TransactionType.TRANSFER, + or_( + and_( + Transaction.user_id == user.id, + Transaction.type_.in_( + [ + TransactionType.ADD_PRODUCT, + TransactionType.ADJUST_BALANCE, + TransactionType.BUY_PRODUCT, + TransactionType.TRANSFER, + ] + ), + ), + and_( + Transaction.type_ == TransactionType.TRANSFER, + Transaction.transfer_user_id == user.id, + ), + Transaction.type_.in_( + [ + TransactionType.ADJUST_INTEREST, + TransactionType.ADJUST_PENALTY, + ] + ), + ), + Transaction.time <= until if until is not None else 1 == 1, ) - .scalar_subquery() + .order_by(Transaction.time.asc()) + .alias("trx_subset") ) - transfers_to_self = ( - select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_self"), 0)) - .where( - Transaction.transfer_user_id == user.id, - Transaction.type_ == TransactionType.TRANSFER, + recursive_elements = ( + select( + trx_subset.c.i, + trx_subset.c.time, + case( + # Adjusts balance -> balance gets adjusted + ( + trx_subset.c.type_ == TransactionType.ADJUST_BALANCE, + recursive_cte.c.balance + trx_subset.c.amount, + ), + # Adds a product -> balance increases + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + recursive_cte.c.balance + trx_subset.c.amount, + ), + # Buys a product -> balance decreases + ( + trx_subset.c.type_ == TransactionType.BUY_PRODUCT, + recursive_cte.c.balance + - ( + trx_subset.c.product_count + # Price of a single product, accounted for penalties and interest. + * cast( + func.ceil( + # TODO: This can get quite expensive real quick, so we should do some caching of the + # product prices somehow. + # Base price + ( + select(column("price")) + .select_from( + _product_price_query( + trx_subset.c.product_id, + use_cache=use_cache, + until=trx_subset.c.time, + cte_name="product_price_cte", + ) + ) + .order_by(column("i").desc()) + .limit(1) + ).scalar_subquery() + # TODO: should interest be applied before or after the penalty multiplier? + # at the moment of writing, after sound right, but maybe ask someone? + # Interest + * (recursive_cte.c.interest_rate_percent / 100) + # Penalty + * case( + ( + recursive_cte.c.balance < recursive_cte.c.penalty_threshold, + (recursive_cte.c.penalty_multiplier_percent / 100), + ), + else_=1.0, + ) + ), + Integer, + ) + ), + ), + # Transfers money to self -> balance increases + ( + trx_subset.c.type_ == TransactionType.TRANSFER + and trx_subset.c.transfer_user_id == user.id, + recursive_cte.c.balance + trx_subset.c.amount, + ), + # Transfers money from self -> balance decreases + ( + trx_subset.c.type_ == TransactionType.TRANSFER + and trx_subset.c.transfer_user_id != user.id, + recursive_cte.c.balance - trx_subset.c.amount, + ), + # Interest adjustment -> balance stays the same + # Penalty adjustment -> balance stays the same + else_=recursive_cte.c.balance, + ).label("balance"), + case( + ( + trx_subset.c.type_ == TransactionType.ADJUST_INTEREST, + trx_subset.c.interest_rate_percent, + ), + else_=recursive_cte.c.interest_rate_percent, + ).label("interest_rate_percent"), + case( + ( + trx_subset.c.type_ == TransactionType.ADJUST_PENALTY, + trx_subset.c.penalty_threshold, + ), + else_=recursive_cte.c.penalty_threshold, + ).label("penalty_threshold"), + case( + ( + trx_subset.c.type_ == TransactionType.ADJUST_PENALTY, + trx_subset.c.penalty_multiplier_percent, + ), + else_=recursive_cte.c.penalty_multiplier_percent, + ).label("penalty_multiplier_percent"), ) - .scalar_subquery() + .select_from(trx_subset) + .where(trx_subset.c.i == recursive_cte.c.i + 1) ) - add_products = ( - select(func.coalesce(func.sum(Transaction.amount).label("add_products"), 0)) - .where( - Transaction.user_id == user.id, - Transaction.type_ == TransactionType.ADD_PRODUCT, - ) - .scalar_subquery() - ) - - buy_products = ( - select(func.coalesce(func.sum(Transaction.amount).label("buy_products"), 0)) - .where( - Transaction.user_id == user.id, - Transaction.type_ == TransactionType.BUY_PRODUCT, - ) - .scalar_subquery() - ) - - query = select( - # TODO: clearly define and fix the sign of the amount - ( - 0 - + balance_adjustments - - transfers_to_other_users - + transfers_to_self - + add_products - - buy_products - ).label("balance") - ) - - return query + return recursive_cte.union_all(recursive_elements) -def user_balance( +def user_balance_log( sql_session: Session, user: User, - # use_cache: bool = True, - # Optional: calculate the balance until a certain transaction. - # until: Transaction | None = None, -) -> int: - """ - Calculates the balance of a user. - """ + use_cache: bool = True, + until: Transaction | None = None, +) -> list[tuple[int, datetime, int, int, int, int]]: + recursive_cte = _user_balance_query( + user, + use_cache=use_cache, + until=until.time if until else None, + ) - query = _user_balance_query(user) # , until=until) - - result = sql_session.scalar(query) + result = sql_session.execute( + select( + recursive_cte.c.i, + recursive_cte.c.time, + recursive_cte.c.balance, + recursive_cte.c.interest_rate_percent, + recursive_cte.c.penalty_threshold, + recursive_cte.c.penalty_multiplier_percent, + ).order_by(recursive_cte.c.i.asc()) + ).all() + + if result is None: + # If there are no transactions for this user, the query should return 0, not None. + raise RuntimeError( + f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." + ) + + return result + + +def user_balance( + sql_session: Session, + user: User, + use_cache: bool = True, + # Optional: calculate the balance until a certain transaction. + until: Transaction | None = None, +) -> int: + """ + Calculates the balance of a user. + """ + + recursive_cte = _user_balance_query( + user, + use_cache=use_cache, + until=until.time if until else None, + ) + + result = sql_session.scalar( + select(recursive_cte.c.balance).order_by(recursive_cte.c.i.desc()).limit(1) + ) if result is None: # If there are no transactions for this user, the query should return 0, not None. diff --git a/tests/models/test_transaction.py b/tests/models/test_transaction.py index cbcc303..a7c4eb2 100644 --- a/tests/models/test_transaction.py +++ b/tests/models/test_transaction.py @@ -40,34 +40,3 @@ def test_transaction_no_duplicate_timestamps(sql_session: Session): with pytest.raises(IntegrityError): sql_session.commit() - - -def test_transaction_buy_product_wrong_amount(sql_session: Session) -> None: - user, product = insert_test_data(sql_session) - - # Set price by adding a product - transaction = Transaction.add_product( - time=datetime(2023, 10, 1, 12, 0, 0), - user_id=user.id, - product_id=product.id, - amount=27, - per_product=27, - product_count=1, - ) - - sql_session.add(transaction) - sql_session.commit() - - # Attempt to buy product with wrong amount - transaction2 = Transaction.buy_product( - time=datetime(2023, 10, 1, 12, 0, 1), - user_id=user.id, - product_id=product.id, - amount=(27 * 2) + 1, # Wrong amount - product_count=2, - ) - - sql_session.add(transaction2) - - with pytest.raises(ValueError): - sql_session.commit() diff --git a/tests/models/test_user.py b/tests/models/test_user.py index dda5930..2b6aa9e 100644 --- a/tests/models/test_user.py +++ b/tests/models/test_user.py @@ -59,7 +59,6 @@ def test_user_transactions(sql_session: Session): ), Transaction.buy_product( time=datetime(2023, 10, 1, 12, 0, 1), - amount=27, product_count=1, user_id=user2.id, product_id=product.id, diff --git a/tests/queries/test_buy_product.py b/tests/queries/test_buy_product.py index ff80ddd..4791d6b 100644 --- a/tests/queries/test_buy_product.py +++ b/tests/queries/test_buy_product.py @@ -4,7 +4,6 @@ from datetime import datetime from sqlalchemy.orm import Session from dibbler.models import Product, Transaction, User -from dibbler.queries.buy_product import buy_product from dibbler.queries.product_stock import product_stock from dibbler.queries.user_balance import user_balance @@ -48,11 +47,10 @@ def insert_test_data(sql_session: Session) -> tuple[User, Product]: def test_buy_product_basic(sql_session: Session) -> None: user, product = insert_test_data(sql_session) - transaction = buy_product( - sql_session=sql_session, + transaction = Transaction.buy_product( time=datetime(2023, 10, 1, 12, 0, 0), - user=user, - product=product, + user_id=user.id, + product_id=product.id, product_count=1, ) @@ -73,11 +71,10 @@ def test_buy_product_with_penalty(sql_session: Session) -> None: sql_session.add_all(transactions) sql_session.commit() - transaction = buy_product( - sql_session=sql_session, + transaction = Transaction.buy_product( time=datetime(2023, 10, 1, 12, 0, 0), - user=user, - product=product, + user_id=user.id, + product_id=product.id, product_count=1, ) sql_session.add(transaction) @@ -99,11 +96,10 @@ def test_buy_product_with_interest(sql_session: Session) -> None: sql_session.add_all(transactions) sql_session.commit() - transaction = buy_product( - sql_session=sql_session, + transaction = Transaction.buy_product( time=datetime(2023, 10, 1, 12, 0, 0), - user=user, - product=product, + user_id=user.id, + product_id=product.id, product_count=1, ) sql_session.add(transaction) @@ -125,11 +121,10 @@ def test_buy_product_with_changing_penalty(sql_session: Session) -> None: sql_session.add_all(transactions) sql_session.commit() - transaction = buy_product( - sql_session=sql_session, + transaction = Transaction.buy_product( time=datetime(2023, 10, 1, 12, 0, 0), - user=user, - product=product, + user_id=user.id, + product_id=product.id, product_count=1, ) sql_session.add(transaction) @@ -146,11 +141,10 @@ def test_buy_product_with_changing_penalty(sql_session: Session) -> None: sql_session.add(adjust_penalty) sql_session.commit() - transaction = buy_product( - sql_session=sql_session, + transaction = Transaction.buy_product( time=datetime(2023, 10, 1, 14, 0, 0), - user=user, - product=product, + user_id=user.id, + product_id=product.id, product_count=1, ) sql_session.add(transaction) @@ -170,12 +164,11 @@ def test_buy_product_with_penalty_interest_combined(sql_session: Session) -> Non def test_buy_product_more_than_stock(sql_session: Session) -> None: user, product = insert_test_data(sql_session) - transaction = buy_product( - sql_session=sql_session, + transaction = Transaction.buy_product( time=datetime(2023, 10, 1, 13, 0, 0), product_count=10, - user=user, - product=product, + user_id=user.id, + product_id=product.id, ) sql_session.add(transaction) diff --git a/tests/queries/test_product_price.py b/tests/queries/test_product_price.py index c2a247c..a8153e0 100644 --- a/tests/queries/test_product_price.py +++ b/tests/queries/test_product_price.py @@ -49,7 +49,6 @@ def insert_test_data(sql_session: Session) -> None: ), Transaction.buy_product( time=datetime(2023, 10, 1, 12, 0, 1), - amount=27, product_count=1, user_id=user2.id, product_id=product1.id, @@ -76,7 +75,6 @@ def insert_test_data(sql_session: Session) -> None: ), Transaction.buy_product( time=datetime(2023, 10, 1, 12, 0, 5), - amount=50, product_count=1, user_id=user1.id, product_id=product3.id, @@ -139,7 +137,6 @@ def test_product_price_with_negative_stock_single_addition(sql_session: Session) transaction = Transaction.buy_product( time=datetime(2023, 10, 1, 13, 0, 0), - amount=27 * 5, product_count=10, user_id=user1.id, product_id=product1.id, diff --git a/tests/queries/test_product_stock.py b/tests/queries/test_product_stock.py index 14f2d3f..d8ba334 100644 --- a/tests/queries/test_product_stock.py +++ b/tests/queries/test_product_stock.py @@ -60,7 +60,6 @@ def test_product_stock_complex_history(sql_session: Session) -> None: ), Transaction.buy_product( time=datetime(2023, 10, 1, 13, 0, 1), - amount=27 * 3, user_id=user1.id, product_id=product.id, product_count=3, @@ -123,7 +122,6 @@ def test_negative_product_stock(sql_session: Session) -> None: ), Transaction.buy_product( time=datetime(2023, 10, 1, 14, 0, 1), - amount=50, user_id=user1.id, product_id=product.id, product_count=2, diff --git a/tests/queries/test_user_balance.py b/tests/queries/test_user_balance.py index 3d5b313..b0ce281 100644 --- a/tests/queries/test_user_balance.py +++ b/tests/queries/test_user_balance.py @@ -4,7 +4,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from dibbler.models import Product, Transaction, User -from dibbler.queries.user_balance import user_balance +from dibbler.queries.user_balance import user_balance, user_balance_log def insert_test_data(sql_session: Session) -> None: @@ -50,7 +50,6 @@ def insert_test_data(sql_session: Session) -> None: ), Transaction.buy_product( time=datetime(2023, 10, 1, 12, 0, 1), - amount=27, product_count=1, user_id=user2.id, product_id=product1.id,