fixup! WIP

This commit is contained in:
2025-12-09 21:21:36 +09:00
parent 2a9ace4263
commit fead6257c7
3 changed files with 128 additions and 20 deletions

View File

@@ -7,6 +7,7 @@ __all__ = [
"current_penalty",
"joint_buy_product",
"product_owners",
"product_owners_log",
"product_price",
"product_price_log",
"product_stock",
@@ -25,7 +26,7 @@ from .adjust_penalty import adjust_penalty
from .current_interest import current_interest
from .current_penalty import current_penalty
from .joint_buy_product import joint_buy_product
from .product_owners import product_owners
from .product_owners import product_owners, product_owners_log
from .product_price import product_price, product_price_log
from .product_stock import product_stock

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from dataclasses import dataclass
from sqlalchemy import (
CTE,
@@ -42,12 +43,13 @@ def _product_owners_query(
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=asc(Transaction.time)).label("i"),
Transaction.time,
Transaction.id,
Transaction.type_,
Transaction.user_id,
Transaction.product_count, )
func.row_number().over(order_by=asc(Transaction.time)).label("i"),
Transaction.time,
Transaction.id,
Transaction.type_,
Transaction.user_id,
Transaction.product_count,
)
.where(
Transaction.type_.in_(
[
@@ -100,7 +102,7 @@ def _product_owners_query(
& (trx_subset.c.product_count > 0),
trx_subset.c.product_count,
),
else_=None,
else_=0,
).label("product_count"),
# How many products left to account for
case(
@@ -124,11 +126,13 @@ def _product_owners_query(
# If adjusted upwards -> products owned by nobody, decrease products left to account for
# If adjusted downwards -> products taken away from owners, decrease products left to account for
(
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK) and (trx_subset.c.product_count > 0),
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK)
and (trx_subset.c.product_count > 0),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
),
(
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK) and (trx_subset.c.product_count < 0),
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK)
and (trx_subset.c.product_count < 0),
recursive_cte.c.products_left_to_account_for + trx_subset.c.product_count,
),
else_=recursive_cte.c.products_left_to_account_for,
@@ -146,6 +150,63 @@ def _product_owners_query(
return recursive_cte.union_all(recursive_elements)
@dataclass
class ProductOwnersLogEntry:
transaction: Transaction
user: User | None
def product_owners_log(
sql_session: Session,
product: Product,
use_cache: bool = True,
until: Transaction | None = None,
) -> list[ProductOwnersLogEntry]:
"""
Returns a log of the product ownership calculation for the given product.
If 'until' is given, only transactions up to that time are considered.
"""
recursive_cte = _product_owners_query(
product_id=product.id,
use_cache=use_cache,
until=until.time if until else None,
)
result = sql_session.execute(
select(
Transaction,
User,
)
.select_from(recursive_cte)
.join(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.join(
User,
onclause=User.id == recursive_cte.c.user_id,
isouter=True,
)
.order_by(recursive_cte.c.i.desc())
).all()
if result is None:
# If there are no transactions for this product, the query should return an empty list, not None.
raise RuntimeError(
f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id})."
)
return [
ProductOwnersLogEntry(
transaction=row[0],
user=row[1],
)
for row in result
]
def product_owners(
sql_session: Session,
product: Product,
@@ -166,16 +227,46 @@ def product_owners(
db_result = sql_session.execute(
select(
recursive_cte.c.products_left_to_account_for,
recursive_cte.c.product_count,
User,
)
.join(User, User.id == recursive_cte.c.user_id)
.join(User, User.id == recursive_cte.c.user_id, isouter=True)
.order_by(recursive_cte.c.i.desc())
).all()
print(db_result)
result: list[User | None] = []
for user_count, user in db_result:
result.extend([user] * user_count)
none_count = 0
# We are moving backwards through history, but this is the order we want to return the list
# There are 3 cases:
# User is not none -> add user product_count times
# User is none, and product_count is not 0 -> add None product_count times
# User is none, and product_count is 0 -> check how much products are left to account for,
for products_left_to_account_for, product_count, user in db_result:
if user is not None:
if products_left_to_account_for < 0:
result.extend([user] * (product_count + products_left_to_account_for))
else:
result.extend([user] * product_count)
elif product_count != 0:
if products_left_to_account_for < 0:
none_count += product_count + products_left_to_account_for
else:
none_count += product_count
else:
pass
# none_count += user_count
# else:
result.extend([None] * none_count)
# # NOTE: if the last line exeeds the product count, we need to truncate it
# result.extend([user] * min(user_count, products_left_to_account_for))
# redistribute the user counts to a list of users

View File

@@ -1,8 +1,10 @@
from pprint import pprint
from sqlalchemy.orm import Session
from dibbler.models import Product, User
from dibbler.models.Transaction import Transaction
from dibbler.queries import product_owners
from dibbler.queries import product_owners, product_owners_log
def insert_test_data(sql_session: Session) -> tuple[Product, User]:
@@ -20,8 +22,9 @@ def insert_test_data(sql_session: Session) -> tuple[Product, User]:
def test_product_owners_no_transactions(sql_session: Session) -> None:
product, _ = insert_test_data(sql_session)
owners = product_owners(sql_session, product)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == []
@@ -40,8 +43,9 @@ def test_product_owners_add_products(sql_session: Session) -> None:
sql_session.add_all(transactions)
sql_session.commit()
owners = product_owners(sql_session, product)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user, user, user]
@@ -65,6 +69,8 @@ def test_product_owners_add_and_buy_products(sql_session: Session) -> None:
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user, user]
@@ -89,6 +95,8 @@ def test_product_owners_add_and_throw_products(sql_session: Session) -> None:
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user, user]
@@ -118,6 +126,8 @@ def test_product_owners_multiple_users(sql_session: Session) -> None:
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user2, user2, user2, user1, user1]
@@ -142,8 +152,9 @@ def test_product_owners_adjust_stock_down(sql_session: Session) -> None:
sql_session.add_all(transactions)
sql_session.commit()
owners = product_owners(sql_session, product)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user, user, user]
@@ -167,8 +178,9 @@ def test_product_owners_adjust_stock_up(sql_session: Session) -> None:
sql_session.add_all(transactions)
sql_session.commit()
owners = product_owners(sql_session, product)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user, user, None, None, None]
@@ -193,9 +205,9 @@ def test_product_owners_negative_stock(sql_session: Session) -> None:
sql_session.commit()
owners = product_owners(sql_session, product)
assert owners == []
def test_product_owners_add_products_from_negative_stock(sql_session: Session) -> None:
product, user = insert_test_data(sql_session)
@@ -216,10 +228,12 @@ def test_product_owners_add_products_from_negative_stock(sql_session: Session) -
sql_session.add_all(transactions)
sql_session.commit()
owners = product_owners(sql_session, product)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user]
def test_product_owners_interleaved_users(sql_session: Session) -> None:
product, user1 = insert_test_data(sql_session)
user2 = User("testuser2")
@@ -257,5 +271,7 @@ def test_product_owners_interleaved_users(sql_session: Session) -> None:
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user1, user2, user2, user1, user1]