diff --git a/dibbler/queries/product_stock.py b/dibbler/queries/product_stock.py index 44d904d..976d691 100644 --- a/dibbler/queries/product_stock.py +++ b/dibbler/queries/product_stock.py @@ -1,6 +1,12 @@ from datetime import datetime -from sqlalchemy import case, func, literal, select +from sqlalchemy import ( + Select, + case, + func, + literal, + select, +) from sqlalchemy.orm import Session from dibbler.models import ( @@ -14,7 +20,7 @@ def _product_stock_query( product_id: int, use_cache: bool = True, until: datetime | None = None, -): +) -> Select: """ The inner query for calculating the product stock. """ @@ -73,6 +79,8 @@ def product_stock( ) -> int: """ Returns the number of products in stock. + + If 'until' is given, only transactions up to that time are considered. """ query = _product_stock_query( diff --git a/dibbler/queries/user_balance.py b/dibbler/queries/user_balance.py index d590b47..0778dec 100644 --- a/dibbler/queries/user_balance.py +++ b/dibbler/queries/user_balance.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from datetime import datetime from sqlalchemy import ( + CTE, Float, Integer, and_, @@ -35,7 +36,7 @@ def _user_balance_query( until: datetime | None = None, until_including: bool = True, cte_name: str = "rec_cte", -): +) -> CTE: """ The inner query for calculating the user's balance. """ @@ -259,6 +260,8 @@ def user_balance_log( ) -> list[UserBalanceLogEntry]: """ Returns a log of the user's balance over time, including interest and penalty adjustments. + + If 'until' is given, only transactions up to that time are considered. """ recursive_cte = _user_balance_query( @@ -309,6 +312,8 @@ def user_balance( ) -> int: """ Calculates the balance of a user. + + If 'until' is given, only transactions up to that time are considered. """ recursive_cte = _user_balance_query( diff --git a/dibbler/queries/users_owning_product.py b/dibbler/queries/users_owning_product.py index e69de29..a4947a9 100644 --- a/dibbler/queries/users_owning_product.py +++ b/dibbler/queries/users_owning_product.py @@ -0,0 +1,164 @@ +from datetime import datetime + +from sqlalchemy import ( + CTE, + and_, + case, + literal, + select, +) +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + TransactionType, + User, +) +from dibbler.queries.product_stock import _product_stock_query + + +def _users_owning_product_query( + product_id: int, + use_cache: bool = True, + until: datetime | None = None, + cte_name: str = "rec_cte", +) -> CTE: + """ + The inner query for calculating the users owning a given product. + """ + + if use_cache: + print("WARNING: Using cache for users owning product query is not implemented yet.") + + product_stock = _product_stock_query( + product_id=product_id, + use_cache=use_cache, + until=until, + ) + + # Subset of transactions that we'll want to iterate over. + trx_subset = ( + select(Transaction) + .where( + Transaction.type_.in_( + [ + TransactionType.ADD_PRODUCT, + TransactionType.BUY_PRODUCT, + TransactionType.ADJUST_STOCK, + TransactionType.JOINT, + TransactionType.THROW_PRODUCT, + ] + ), + Transaction.product_id == product_id, + literal(True) if until is None else Transaction.time <= until, + ) + .order_by(Transaction.time.desc()) + .subquery() + ) + + initial_element = select( + literal(0).label("i"), + literal(0).label("time"), + literal(None).label("transaction_id"), + literal(None).label("user_id"), + literal(0).label("product_count"), + product_stock.as_scalar().label("products_left_to_account_for"), + ) + + recursive_cte = initial_element.cte(name=cte_name, recursive=True) + + recursive_elements = ( + select( + trx_subset.c.i, + trx_subset.c.time, + trx_subset.c.id.label("transaction_id"), + # Who added the product + case( + # Someone adds the product -> they own it + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + trx_subset.c.user_id, + ), + else_=None, + ).label("user_id"), + # How many products did they add + case( + # Someone adds the product -> they added a certain amount of products + (trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.product_count), + # Stock got adjusted upwards -> consider those products as added by nobody + ( + (trx_subset.c.type_ == TransactionType.ADJUST_STOCK) + & (trx_subset.c.product_count > 0), + trx_subset.c.product_count, + ), + else_=None, + ).label("product_count"), + # How many products left to account for + case( + # Someone adds the product -> increase the number of products left to account for + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, + ), + # Someone buys/joins/throws the product -> decrease the number of products left to account for + ( + trx_subset.c.type_.in_( + [ + TransactionType.BUY_PRODUCT, + TransactionType.JOINT, + TransactionType.THROW_PRODUCT, + ] + ), + recursive_cte.c.products_left_to_account_for + trx_subset.c.product_count, + ), + # Someone adjusts the stock -> + # If adjusted upwards -> products owned by nobody, decrease products left to account for + # If adjusted downwards -> products taken away from owners, decrease products left to account for + ( + trx_subset.c.type_ == TransactionType.ADJUST_STOCK, + recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, + ), + else_=recursive_cte.c.products_left_to_account_for, + ).label("products_left_to_account_for"), + ) + .select_from(trx_subset) + .where( + and_( + trx_subset.c.i == recursive_cte.c.i + 1, + recursive_cte.c.products_left_to_account_for > 0, + ) + ) + ) + + return recursive_cte.union_all(recursive_elements) + + +def users_owning_product( + sql_session: Session, + product: Product, + use_cache: bool = True, + until: datetime | None = None, +) -> list[User | None]: + """ + Returns an ordered list of users owning the given product. + + If 'until' is given, only transactions up to that time are considered. + """ + + recursive_cte = _users_owning_product_query( + product_id=product.id, + use_cache=use_cache, + until=until, + ) + + result = sql_session.scalars( + select( + recursive_cte.c.user_id, + recursive_cte.c.product_count, + ) + .distinct() + .order_by(recursive_cte.c.i.desc()) + ).all() + + return list(result)