diff --git a/dibbler/queries/__init__.py b/dibbler/queries/__init__.py index 67f5110..dfba1e5 100644 --- a/dibbler/queries/__init__.py +++ b/dibbler/queries/__init__.py @@ -7,6 +7,7 @@ __all__ = [ "current_penalty", "joint_buy_product", "product_owners", + "product_owners_log", "product_price", "product_price_log", "product_stock", @@ -25,7 +26,7 @@ from .adjust_penalty import adjust_penalty from .current_interest import current_interest from .current_penalty import current_penalty from .joint_buy_product import joint_buy_product -from .product_owners import product_owners +from .product_owners import product_owners, product_owners_log from .product_price import product_price, product_price_log from .product_stock import product_stock diff --git a/dibbler/queries/product_owners.py b/dibbler/queries/product_owners.py index 6760dbf..38d46f4 100644 --- a/dibbler/queries/product_owners.py +++ b/dibbler/queries/product_owners.py @@ -1,4 +1,5 @@ from datetime import datetime +from dataclasses import dataclass from sqlalchemy import ( CTE, @@ -42,12 +43,13 @@ def _product_owners_query( # Subset of transactions that we'll want to iterate over. trx_subset = ( select( - func.row_number().over(order_by=asc(Transaction.time)).label("i"), - Transaction.time, - Transaction.id, - Transaction.type_, - Transaction.user_id, - Transaction.product_count, ) + func.row_number().over(order_by=asc(Transaction.time)).label("i"), + Transaction.time, + Transaction.id, + Transaction.type_, + Transaction.user_id, + Transaction.product_count, + ) .where( Transaction.type_.in_( [ @@ -100,7 +102,7 @@ def _product_owners_query( & (trx_subset.c.product_count > 0), trx_subset.c.product_count, ), - else_=None, + else_=0, ).label("product_count"), # How many products left to account for case( @@ -124,11 +126,13 @@ def _product_owners_query( # If adjusted upwards -> products owned by nobody, decrease products left to account for # If adjusted downwards -> products taken away from owners, decrease products left to account for ( - (trx_subset.c.type_ == TransactionType.ADJUST_STOCK) and (trx_subset.c.product_count > 0), + (trx_subset.c.type_ == TransactionType.ADJUST_STOCK) + and (trx_subset.c.product_count > 0), recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, ), ( - (trx_subset.c.type_ == TransactionType.ADJUST_STOCK) and (trx_subset.c.product_count < 0), + (trx_subset.c.type_ == TransactionType.ADJUST_STOCK) + and (trx_subset.c.product_count < 0), recursive_cte.c.products_left_to_account_for + trx_subset.c.product_count, ), else_=recursive_cte.c.products_left_to_account_for, @@ -146,6 +150,63 @@ def _product_owners_query( return recursive_cte.union_all(recursive_elements) +@dataclass +class ProductOwnersLogEntry: + transaction: Transaction + user: User | None + + +def product_owners_log( + sql_session: Session, + product: Product, + use_cache: bool = True, + until: Transaction | None = None, +) -> list[ProductOwnersLogEntry]: + """ + Returns a log of the product ownership calculation for the given product. + + If 'until' is given, only transactions up to that time are considered. + """ + + recursive_cte = _product_owners_query( + product_id=product.id, + use_cache=use_cache, + until=until.time if until else None, + ) + + result = sql_session.execute( + select( + Transaction, + User, + ) + .select_from(recursive_cte) + .join( + Transaction, + onclause=Transaction.id == recursive_cte.c.transaction_id, + ) + .join( + User, + onclause=User.id == recursive_cte.c.user_id, + isouter=True, + ) + .order_by(recursive_cte.c.i.desc()) + ).all() + + if result is None: + # If there are no transactions for this product, the query should return an empty list, not None. + raise RuntimeError( + f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id})." + ) + + return [ + ProductOwnersLogEntry( + transaction=row[0], + user=row[1], + ) + for row in result + ] + + def product_owners( sql_session: Session, product: Product, @@ -166,16 +227,46 @@ def product_owners( db_result = sql_session.execute( select( + recursive_cte.c.products_left_to_account_for, recursive_cte.c.product_count, User, ) - .join(User, User.id == recursive_cte.c.user_id) + .join(User, User.id == recursive_cte.c.user_id, isouter=True) .order_by(recursive_cte.c.i.desc()) ).all() + print(db_result) + result: list[User | None] = [] - for user_count, user in db_result: - result.extend([user] * user_count) + none_count = 0 + + # We are moving backwards through history, but this is the order we want to return the list + # There are 3 cases: + # User is not none -> add user product_count times + # User is none, and product_count is not 0 -> add None product_count times + # User is none, and product_count is 0 -> check how much products are left to account for, + + for products_left_to_account_for, product_count, user in db_result: + if user is not None: + if products_left_to_account_for < 0: + result.extend([user] * (product_count + products_left_to_account_for)) + else: + result.extend([user] * product_count) + elif product_count != 0: + if products_left_to_account_for < 0: + none_count += product_count + products_left_to_account_for + else: + none_count += product_count + else: + pass + + # none_count += user_count + # else: + + result.extend([None] * none_count) + + # # NOTE: if the last line exeeds the product count, we need to truncate it + # result.extend([user] * min(user_count, products_left_to_account_for)) # redistribute the user counts to a list of users diff --git a/tests/queries/test_product_owners.py b/tests/queries/test_product_owners.py index 9a6ba5a..ab3eea6 100644 --- a/tests/queries/test_product_owners.py +++ b/tests/queries/test_product_owners.py @@ -1,8 +1,10 @@ +from pprint import pprint + from sqlalchemy.orm import Session from dibbler.models import Product, User from dibbler.models.Transaction import Transaction -from dibbler.queries import product_owners +from dibbler.queries import product_owners, product_owners_log def insert_test_data(sql_session: Session) -> tuple[Product, User]: @@ -20,8 +22,9 @@ def insert_test_data(sql_session: Session) -> tuple[Product, User]: def test_product_owners_no_transactions(sql_session: Session) -> None: product, _ = insert_test_data(sql_session) - owners = product_owners(sql_session, product) + pprint(product_owners_log(sql_session, product)) + owners = product_owners(sql_session, product) assert owners == [] @@ -40,8 +43,9 @@ def test_product_owners_add_products(sql_session: Session) -> None: sql_session.add_all(transactions) sql_session.commit() - owners = product_owners(sql_session, product) + pprint(product_owners_log(sql_session, product)) + owners = product_owners(sql_session, product) assert owners == [user, user, user] @@ -65,6 +69,8 @@ def test_product_owners_add_and_buy_products(sql_session: Session) -> None: sql_session.add_all(transactions) sql_session.commit() + pprint(product_owners_log(sql_session, product)) + owners = product_owners(sql_session, product) assert owners == [user, user] @@ -89,6 +95,8 @@ def test_product_owners_add_and_throw_products(sql_session: Session) -> None: sql_session.add_all(transactions) sql_session.commit() + pprint(product_owners_log(sql_session, product)) + owners = product_owners(sql_session, product) assert owners == [user, user] @@ -118,6 +126,8 @@ def test_product_owners_multiple_users(sql_session: Session) -> None: sql_session.add_all(transactions) sql_session.commit() + pprint(product_owners_log(sql_session, product)) + owners = product_owners(sql_session, product) assert owners == [user2, user2, user2, user1, user1] @@ -142,8 +152,9 @@ def test_product_owners_adjust_stock_down(sql_session: Session) -> None: sql_session.add_all(transactions) sql_session.commit() - owners = product_owners(sql_session, product) + pprint(product_owners_log(sql_session, product)) + owners = product_owners(sql_session, product) assert owners == [user, user, user] @@ -167,8 +178,9 @@ def test_product_owners_adjust_stock_up(sql_session: Session) -> None: sql_session.add_all(transactions) sql_session.commit() - owners = product_owners(sql_session, product) + pprint(product_owners_log(sql_session, product)) + owners = product_owners(sql_session, product) assert owners == [user, user, None, None, None] @@ -193,9 +205,9 @@ def test_product_owners_negative_stock(sql_session: Session) -> None: sql_session.commit() owners = product_owners(sql_session, product) - assert owners == [] + def test_product_owners_add_products_from_negative_stock(sql_session: Session) -> None: product, user = insert_test_data(sql_session) @@ -216,10 +228,12 @@ def test_product_owners_add_products_from_negative_stock(sql_session: Session) - sql_session.add_all(transactions) sql_session.commit() - owners = product_owners(sql_session, product) + pprint(product_owners_log(sql_session, product)) + owners = product_owners(sql_session, product) assert owners == [user] + def test_product_owners_interleaved_users(sql_session: Session) -> None: product, user1 = insert_test_data(sql_session) user2 = User("testuser2") @@ -257,5 +271,7 @@ def test_product_owners_interleaved_users(sql_session: Session) -> None: sql_session.add_all(transactions) sql_session.commit() + pprint(product_owners_log(sql_session, product)) + owners = product_owners(sql_session, product) assert owners == [user1, user2, user2, user1, user1]