Write a set of queries to go along with the event sourcing model

This commit is contained in:
2025-12-10 16:00:36 +09:00
parent d2890c23cc
commit 6ec094b43f
39 changed files with 5836 additions and 74 deletions

0
dibbler/lib/__init__.py Normal file
View File

View File

@@ -1,79 +1,7 @@
import pwd
import subprocess
import os
import pwd
import signal
from sqlalchemy import or_, and_
from ..models import User, Product
def search_user(string, session, ignorethisflag=None):
string = string.lower()
exact_match = (
session.query(User)
.filter(or_(User.name == string, User.card == string, User.rfid == string))
.first()
)
if exact_match:
return exact_match
user_list = (
session.query(User)
.filter(
or_(
User.name.ilike(f"%{string}%"),
User.card.ilike(f"%{string}%"),
User.rfid.ilike(f"%{string}%"),
)
)
.all()
)
return user_list
def search_product(string, session, find_hidden_products=True):
if find_hidden_products:
exact_match = (
session.query(Product)
.filter(or_(Product.bar_code == string, Product.name == string))
.first()
)
else:
exact_match = (
session.query(Product)
.filter(
or_(
Product.bar_code == string,
and_(Product.name == string, Product.hidden is False),
)
)
.first()
)
if exact_match:
return exact_match
if find_hidden_products:
product_list = (
session.query(Product)
.filter(
or_(
Product.bar_code.ilike(f"%{string}%"),
Product.name.ilike(f"%{string}%"),
)
)
.all()
)
else:
product_list = (
session.query(Product)
.filter(
or_(
Product.bar_code.ilike(f"%{string}%"),
and_(Product.name.ilike(f"%{string}%"), Product.hidden is False),
)
)
.all()
)
return product_list
import subprocess
def system_user_exists(username):

View File

@@ -0,0 +1,46 @@
__all__ = [
"add_product",
"adjust_balance",
"adjust_interest",
"adjust_penalty",
"adjust_stock",
"create_product",
"create_user",
"current_interest",
"current_penalty",
"joint_buy_product",
"product_owners",
"product_owners_log",
"product_price",
"product_price_log",
"product_stock",
"search_product",
"search_user",
"throw_product",
"transaction_log",
"transfer",
"user_balance",
"user_balance_log",
"user_products",
]
from .add_product import add_product
from .adjust_balance import adjust_balance
from .adjust_interest import adjust_interest
from .adjust_penalty import adjust_penalty
from .adjust_stock import adjust_stock
from .create_product import create_product
from .create_user import create_user
from .current_interest import current_interest
from .current_penalty import current_penalty
from .joint_buy_product import joint_buy_product
from .product_owners import product_owners, product_owners_log
from .product_price import product_price, product_price_log
from .product_stock import product_stock
from .search_product import search_product
from .search_user import search_user
from .throw_product import throw_product
from .transaction_log import transaction_log
from .transfer import transfer
from .user_balance import user_balance, user_balance_log
from .user_products import user_products

View File

@@ -0,0 +1,51 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def add_product(
sql_session: Session,
user: User,
product: Product,
amount: int,
per_product: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if amount <= 0:
raise ValueError("Amount must be positive.")
if per_product <= 0:
raise ValueError("Per product price must be positive.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
if per_product * product_count < amount:
raise ValueError("Total per product price must be at least equal to amount.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=amount,
per_product=per_product,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,33 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
def adjust_balance(
sql_session: Session,
user: User,
amount: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if amount == 0:
raise ValueError("Amount must be non-zero.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_balance(
user_id=user.id,
amount=amount,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
# TODO: this type of transaction should be password protected.
# the password can be set as a string literal in the config file.
def adjust_interest(
sql_session: Session,
user: User,
new_interest: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if new_interest < 0:
raise ValueError("Interest rate cannot be negative")
if user.id is None:
raise ValueError("User must be persisted in the database.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_interest(
user_id=user.id,
interest_rate_percent=new_interest,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,49 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
from dibbler.queries.current_penalty import current_penalty
# TODO: this type of transaction should be password protected.
# the password can be set as a string literal in the config file.
def adjust_penalty(
sql_session: Session,
user: User,
new_penalty: int | None = None,
new_penalty_multiplier: int | None = None,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if new_penalty is None and new_penalty_multiplier is None:
raise ValueError("At least one of new_penalty or new_penalty_multiplier must be provided")
if new_penalty_multiplier is not None and new_penalty_multiplier < 100:
raise ValueError("Penalty multiplier cannot be less than 100%")
if user.id is None:
raise ValueError("User must be persisted in the database.")
if new_penalty is None or new_penalty_multiplier is None:
existing_penalty, existing_penalty_multiplier = current_penalty(sql_session)
if new_penalty is None:
new_penalty = existing_penalty
if new_penalty_multiplier is None:
new_penalty_multiplier = existing_penalty_multiplier
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_penalty(
user_id=user.id,
penalty_threshold=new_penalty,
penalty_multiplier_percent=new_penalty_multiplier,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,40 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def adjust_stock(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count == 0:
raise ValueError("Product count must be non-zero.")
# TODO: it should not be possible to reduce stock below zero.
#
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,38 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def buy_product(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.buy_product(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,25 @@
from sqlalchemy.orm import Session
from dibbler.models import Product
def create_product(
sql_session: Session,
name: str,
barcode: str,
) -> Product:
if not name:
raise ValueError("Name cannot be empty.")
if not barcode:
raise ValueError("Barcode cannot be empty.")
# TODO: check for duplicate names, barcodes
# TODO: add more validation for barcode
product = Product(barcode, name)
sql_session.add(product)
sql_session.commit()
return product

View File

@@ -0,0 +1,21 @@
from sqlalchemy.orm import Session
from dibbler.models import User
def create_user(
sql_session: Session,
name: str,
card: str | None,
rfid: str | None,
) -> User:
if not name:
raise ValueError("Name cannot be empty.")
# TODO: check for duplicate names, cards, rfids
user = User(name=name, card=card, rfid=rfid)
sql_session.add(user)
sql_session.commit()
return user

View File

@@ -0,0 +1,55 @@
from datetime import datetime
from sqlalchemy import BindParameter, bindparam, select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT
from dibbler.queries.query_helpers import until_filter
def current_interest(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
) -> int:
"""
Get the current interest rate percentage as of a given time or transaction.
Returns the interest rate percentage as an integer.
"""
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
result = sql_session.scalars(
select(Transaction)
.where(
Transaction.type_ == TransactionType.ADJUST_INTEREST,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.limit(1)
).one_or_none()
if result is None:
return DEFAULT_INTEREST_RATE_PERCENT
elif result.interest_rate_percent is None:
return DEFAULT_INTEREST_RATE_PERCENT
else:
return result.interest_rate_percent

View File

@@ -0,0 +1,59 @@
from datetime import datetime
from sqlalchemy import BindParameter, bindparam, select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import (
DEFAULT_PENALTY_MULTIPLIER_PERCENT,
DEFAULT_PENALTY_THRESHOLD,
)
from dibbler.queries.query_helpers import until_filter
def current_penalty(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
) -> tuple[int, int]:
"""
Get the current penalty settings (threshold and multiplier percentage) as of a given time or transaction.
Returns a tuple of `(penalty_threshold, penalty_multiplier_percentage)`.
"""
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
result = sql_session.scalars(
select(Transaction)
.where(
Transaction.type_ == TransactionType.ADJUST_PENALTY,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.limit(1)
).one_or_none()
if result is None:
return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENT
assert result.penalty_threshold is not None, "Penalty threshold must be set"
assert result.penalty_multiplier_percent is not None, "Penalty multiplier percent must be set"
return result.penalty_threshold, result.penalty_multiplier_percent

View File

@@ -0,0 +1,68 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
User,
)
def joint_buy_product(
sql_session: Session,
product: Product,
product_count: int,
instigator: User,
users: list[User],
time: datetime | None = None,
message: str | None = None,
) -> list[Transaction]:
"""
Create buy product transactions for multiple users at once.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if instigator.id is None:
raise ValueError("Instigator must be persisted in the database.")
if len(users) == 0:
raise ValueError("At least bying one user must be specified.")
if any(user.id is None for user in users):
raise ValueError("All users must be persisted in the database.")
if instigator not in users:
raise ValueError("Instigator must be in the list of users buying the product.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
joint_transaction = Transaction.joint(
user_id=instigator.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(joint_transaction)
sql_session.flush() # Ensure joint_transaction gets an ID
transactions = [joint_transaction]
for user in users:
buy_transaction = Transaction.joint_buy_product(
user_id=user.id,
joint_transaction_id=joint_transaction.id,
time=time,
message=message,
)
sql_session.add(buy_transaction)
transactions.append(buy_transaction)
sql_session.commit()
return transactions

View File

@@ -0,0 +1,309 @@
from dataclasses import dataclass
from datetime import datetime
from sqlalchemy import (
CTE,
BindParameter,
and_,
bindparam,
case,
func,
or_,
select,
)
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
TransactionType,
User,
)
from dibbler.queries.product_stock import _product_stock_query
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
until_filter,
)
def _product_owners_query(
product_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
) -> CTE:
"""
The inner query for inferring the owners of a given product.
"""
if use_cache:
print("WARNING: Using cache for users owning product query is not implemented yet.")
if isinstance(product_id, int):
product_id = bindparam("product_id", value=product_id)
if until_time is not None and until_transaction is not None:
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = bindparam("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
product_stock = _product_stock_query(
product_id=product_id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=Transaction.time.desc()).label("i"),
Transaction.time,
Transaction.id,
Transaction.type_,
Transaction.user_id,
Transaction.product_count,
)
.where(
or_(
Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
and_(
Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
Transaction.product_count > CONST_ZERO,
),
),
Transaction.product_id == product_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.subquery(trx_subset_name)
)
initial_element = select(
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_NONE.label("user_id"),
CONST_ZERO.label("product_count"),
product_stock.scalar_subquery().label("products_left_to_account_for"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
recursive_elements = (
select(
trx_subset.c.i,
trx_subset.c.time,
trx_subset.c.id.label("transaction_id"),
# Who added the product (if any)
case(
# Someone adds the product -> they own it
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
trx_subset.c.user_id,
),
else_=CONST_NONE,
).label("user_id"),
# How many products did they add (if any)
case(
# Someone adds the product -> they added a certain amount of products
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
trx_subset.c.product_count,
),
# Stock got adjusted upwards -> consider those products as added by nobody
(
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
and (trx_subset.c.product_count > CONST_ZERO),
trx_subset.c.product_count,
),
else_=CONST_ZERO,
).label("product_count"),
# How many products left to account for
case(
# Someone adds the product -> known owner, decrease the number of products left to account for
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
),
# Stock got adjusted upwards -> none owner, decrease the number of products left to account for
(
and_(
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
trx_subset.c.product_count > CONST_ZERO,
),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
),
else_=recursive_cte.c.products_left_to_account_for,
).label("products_left_to_account_for"),
)
.select_from(trx_subset)
.where(
and_(
trx_subset.c.i == recursive_cte.c.i + CONST_ONE,
# Base case: stop if we've accounted for all products
recursive_cte.c.products_left_to_account_for > CONST_ZERO,
)
)
)
return recursive_cte.union_all(recursive_elements)
@dataclass
class ProductOwnersLogEntry:
transaction: Transaction
user: User | None
products_left_to_account_for: int
def product_owners_log(
sql_session: Session,
product: Product,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[ProductOwnersLogEntry]:
"""
Returns a log of the product ownership calculation for the given product.
If 'until' is given, only transactions up to that time are considered.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
recursive_cte = _product_owners_query(
product_id=product.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.execute(
select(
Transaction,
User,
recursive_cte.c.products_left_to_account_for,
)
.select_from(recursive_cte)
.join(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.join(
User,
onclause=User.id == recursive_cte.c.user_id,
isouter=True,
)
.order_by(recursive_cte.c.time.desc())
).all()
if result is None:
# If there are no transactions for this product, the query should return an empty list, not None.
raise RuntimeError(
f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id})."
)
return [
ProductOwnersLogEntry(
transaction=row[0],
user=row[1],
products_left_to_account_for=row[2],
)
for row in result
]
def product_owners(
sql_session: Session,
product: Product,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[User | None]:
"""
Returns an ordered list of users owning the given product.
If 'until' is given, only transactions up to that time are considered.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
recursive_cte = _product_owners_query(
product_id=product.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
db_result = sql_session.execute(
select(
recursive_cte.c.products_left_to_account_for,
recursive_cte.c.product_count,
User,
)
.join(User, User.id == recursive_cte.c.user_id, isouter=True)
.order_by(recursive_cte.c.time.desc())
).all()
print(db_result)
result: list[User | None] = []
none_count = 0
# We are moving backwards through history, but this is the order we want to return the list
# There are 3 cases:
# User is not none -> add user product_count times
# User is none, and product_count is not 0 -> add None product_count times
# User is none, and product_count is 0 -> check how much products are left to account for,
# TODO: embed this into the query itself?
for products_left_to_account_for, product_count, user in db_result:
if user is not None:
if products_left_to_account_for < 0:
result.extend([user] * (product_count + products_left_to_account_for))
else:
result.extend([user] * product_count)
elif product_count != 0:
if products_left_to_account_for < 0:
none_count += product_count + products_left_to_account_for
else:
none_count += product_count
else:
pass
# none_count += user_count
# else:
result.extend([None] * none_count)
# # NOTE: if the last line exceeds the product count, we need to truncate it
# result.extend([user] * min(user_count, products_left_to_account_for))
# redistribute the user counts to a list of users
return list(result)

View File

@@ -0,0 +1,358 @@
import math
from dataclasses import dataclass
from datetime import datetime
from sqlalchemy import (
BindParameter,
ColumnElement,
Integer,
bindparam,
case,
cast,
func,
select,
)
from sqlalchemy.orm import Session
from dibbler.models import (
LastCacheTransaction,
Product,
ProductCache,
Transaction,
TransactionType,
)
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
until_filter, after_filter,
)
def _product_price_query(
product_id: int | ColumnElement[int],
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
):
"""
The inner query for calculating the product price.
"""
if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id)
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if use_cache:
initial_element_fields = (
select(
Transaction.time.label("time"),
Transaction.id.label("transaction_id"),
ProductCache.price.label("price"),
ProductCache.stock.label("product_count"),
)
.select_from(ProductCache)
.join(
LastCacheTransaction,
ProductCache.last_cache_transaction_id == LastCacheTransaction.id,
)
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
.where(
ProductCache.product_id == product_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.union(
select(
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("price"),
CONST_ZERO.label("product_count"),
)
)
.order_by(Transaction.time.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
.subquery()
.alias("initial_element_fields")
)
initial_element = select(
CONST_ZERO.label("i"),
initial_element_fields.c.time,
initial_element_fields.c.transaction_id,
initial_element_fields.c.price,
initial_element_fields.c.product_count,
).select_from(initial_element_fields)
else:
initial_element = select(
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("price"),
CONST_ZERO.label("product_count"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=Transaction.time.asc()).label("i"),
Transaction.id,
Transaction.time,
Transaction.type_,
Transaction.product_count,
Transaction.per_product,
)
.where(
Transaction.type_.in_(
[
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.JOINT.as_literal_column(),
]
),
Transaction.product_id == product_id,
after_filter(
after_time=None,
after_transaction_id=recursive_cte.c.transaction_id,
after_inclusive=False,
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.asc())
.subquery(trx_subset_name)
)
recursive_elements = (
select(
trx_subset.c.i,
trx_subset.c.time,
trx_subset.c.id.label("transaction_id"),
case(
# Someone buys the product -> price remains the same.
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.price,
),
# Someone adds the product -> price is recalculated based on
# product count, previous price, and new price.
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
cast(
func.ceil(
(
recursive_cte.c.price
* func.max(recursive_cte.c.product_count, CONST_ZERO)
+ trx_subset.c.per_product * trx_subset.c.product_count
)
/ (
# The running product count can be negative if the accounting is bad.
# This ensures that we never end up with negative prices or zero divisions
# and other disastrous phenomena.
func.max(recursive_cte.c.product_count, CONST_ZERO)
+ trx_subset.c.product_count
)
),
Integer,
),
),
# Someone adjusts the stock -> price remains the same.
(
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
recursive_cte.c.price,
),
# Should never happen
else_=recursive_cte.c.price,
).label("price"),
case(
# Someone buys the product -> product count is reduced.
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.product_count - trx_subset.c.product_count,
),
(
trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(),
recursive_cte.c.product_count - trx_subset.c.product_count,
),
# Someone adds the product -> product count is increased.
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.product_count + trx_subset.c.product_count,
),
# Someone adjusts the stock -> product count is adjusted.
(
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
recursive_cte.c.product_count + trx_subset.c.product_count,
),
# Should never happen
else_=recursive_cte.c.product_count,
).label("product_count"),
)
.select_from(trx_subset)
.where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
)
return recursive_cte.union_all(recursive_elements)
# TODO: create a function for the log that pretty prints the log entries
# for debugging purposes
@dataclass
class ProductPriceLogEntry:
transaction: Transaction
price: int
product_count: int
def product_price_log(
sql_session: Session,
product: Product,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[ProductPriceLogEntry]:
"""
Calculates the price of a product and returns a log of the price changes.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
recursive_cte = _product_price_query(
product.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.execute(
select(
Transaction,
recursive_cte.c.price,
recursive_cte.c.product_count,
)
.select_from(recursive_cte)
.join(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.order_by(recursive_cte.c.i.asc())
).all()
if result is None:
# If there are no transactions for this product, the query should return an empty list, not None.
raise RuntimeError(
f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})."
)
return [
ProductPriceLogEntry(
transaction=row[0],
price=row.price,
product_count=row.product_count,
)
for row in result
]
def product_price(
sql_session: Session,
product: Product,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
include_interest: bool = False,
) -> int:
"""
Calculates the price of a product.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
recursive_cte = _product_price_query(
product.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
# TODO: optionally verify subresults:
# - product_count should never be negative (but this happens sometimes, so just a warning)
# - price should never be negative
result = sql_session.scalars(
select(recursive_cte.c.price)
.order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
).one_or_none()
if result is None:
# If there are no transactions for this product, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})."
)
if include_interest:
interest_rate = (
sql_session.scalar(
select(Transaction.interest_rate_percent)
.where(
Transaction.type_ == TransactionType.ADJUST_INTEREST,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.limit(CONST_ONE)
)
or DEFAULT_INTEREST_RATE_PERCENT
)
result = math.ceil(result * interest_rate / 100)
return result

View File

@@ -0,0 +1,126 @@
from datetime import datetime
from typing import Tuple
from sqlalchemy import (
BindParameter,
Select,
bindparam,
case,
func,
select,
)
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
TransactionType,
)
from dibbler.queries.query_helpers import until_filter
def _product_stock_query(
product_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> Select[Tuple[int]]:
"""
The inner query for calculating the product stock.
"""
if use_cache:
print("WARNING: Using cache for product stock query is not implemented yet.")
if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id)
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
query = select(
func.sum(
case(
(
Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
Transaction.product_count,
),
(
Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
Transaction.product_count,
),
(
Transaction.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
-Transaction.product_count,
),
(
Transaction.type_ == TransactionType.JOINT.as_literal_column(),
-Transaction.product_count,
),
(
Transaction.type_ == TransactionType.THROW_PRODUCT.as_literal_column(),
-Transaction.product_count,
),
else_=0,
)
).label("stock")
).where(
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.JOINT.as_literal_column(),
TransactionType.THROW_PRODUCT.as_literal_column(),
]
),
Transaction.product_id == product_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
return query
def product_stock(
sql_session: Session,
product: Product,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> int:
"""
Returns the number of products in stock.
If 'until' is given, only transactions up to that time are considered.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
query = _product_stock_query(
product_id=product.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.scalars(query).one_or_none()
return result or 0

View File

@@ -0,0 +1,80 @@
from datetime import datetime
from typing import TypeVar
from sqlalchemy import (
BindParameter,
ColumnExpressionArgument,
literal,
select,
)
from sqlalchemy.orm import QueryableAttribute
from dibbler.models import Transaction
T = TypeVar("T")
def const(value: T) -> BindParameter[T]:
"""
Create a constant SQL literal bind parameter.
This is useful to avoid too many `?` bind parameters in SQL queries,
when the input value is known to be safe.
"""
return literal(value, literal_execute=True)
CONST_ZERO: BindParameter[int] = const(0)
"""A constant SQL expression `0`. This will render as a literal `0` in SQL queries."""
CONST_ONE: BindParameter[int] = const(1)
"""A constant SQL expression `1`. This will render as a literal `1` in SQL queries."""
CONST_TRUE: BindParameter[bool] = const(True)
"""A constant SQL expression `TRUE`. This will render as a literal `TRUE` in SQL queries."""
CONST_FALSE: BindParameter[bool] = const(False)
"""A constant SQL expression `FALSE`. This will render as a literal `FALSE` in SQL queries."""
CONST_NONE: BindParameter[None] = const(None)
"""A constant SQL expression `NULL`. This will render as a literal `NULL` in SQL queries."""
def until_filter(
until_time: BindParameter[datetime] | None = None,
until_transaction_id: BindParameter[int] | None = None,
until_inclusive: bool = True,
transaction_time: QueryableAttribute = Transaction.time,
) -> ColumnExpressionArgument[bool]:
"""
Create a filter condition for transactions up to a given time or transaction.
Only one of `until_time` or `until_transaction_id` may be specified.
"""
assert not (until_time is not None and until_transaction_id is not None), (
"Cannot filter by both until_time and until_transaction_id."
)
match (until_time, until_transaction_id, until_inclusive):
case (BindParameter(), None, True):
return transaction_time <= until_time
case (BindParameter(), None, False):
return transaction_time < until_time
case (None, BindParameter(), True):
return (
transaction_time
<= select(Transaction.time)
.where(Transaction.id == until_transaction_id)
.scalar_subquery()
)
case (None, BindParameter(), False):
return (
transaction_time
< select(Transaction.time)
.where(Transaction.id == until_transaction_id)
.scalar_subquery()
)
return CONST_TRUE

View File

@@ -0,0 +1,42 @@
from sqlalchemy import and_, literal, not_, or_, select
from sqlalchemy.orm import Session
from dibbler.models import Product
def search_product(
string: str,
sql_session: Session,
find_hidden_products=False,
) -> Product | list[Product]:
if not string:
raise ValueError("Search string cannot be empty.")
exact_match = sql_session.scalars(
select(Product).where(
or_(
Product.bar_code == string,
and_(
Product.name == string,
literal(True) if find_hidden_products else not_(Product.hidden),
),
)
)
).first()
if exact_match:
return exact_match
product_list = sql_session.scalars(
select(Product).where(
or_(
Product.bar_code.ilike(f"%{string}%"),
and_(
Product.name.ilike(f"%{string}%"),
literal(True) if find_hidden_products else not_(Product.hidden),
),
)
)
).all()
return list(product_list)

View File

@@ -0,0 +1,39 @@
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from dibbler.models import User
def search_user(
string: str,
sql_session: Session,
) -> User | list[User]:
if not string:
raise ValueError("Search string cannot be empty.")
string = string.lower()
exact_match = sql_session.scalars(
select(User).where(
or_(
User.name == string,
User.card == string,
User.rfid == string,
)
)
).first()
if exact_match:
return exact_match
user_list = sql_session.scalars(
select(User).where(
or_(
User.name.ilike(f"%{string}%"),
User.card.ilike(f"%{string}%"),
User.rfid.ilike(f"%{string}%"),
)
)
).all()
return list(user_list)

View File

@@ -0,0 +1,42 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def throw_product(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
raise NotImplementedError(
"Please don't use this function until relevant calculations have been added to user_balance."
)
transaction = Transaction.throw_product(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,142 @@
from datetime import datetime
from sqlalchemy import BindParameter, select
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
TransactionType,
User,
)
# TODO: should this include full joint transactions that involve a user?
# TODO: should this involve throw-away transactions that affects a user?
def transaction_log(
sql_session: Session,
user: User | None = None,
product: Product | None = None,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
after_time: BindParameter[datetime] | datetime | None = None,
after_transaction: Transaction | None = None,
after_inclusive: bool = True,
transaction_type: list[TransactionType] | None = None,
negate_transaction_type_filter: bool = False,
limit: int | None = None,
) -> list[Transaction]:
"""
Retrieve the transaction log, optionally filtered.
Only one of `user` or `product` may be specified.
Only one of `until_time` or `until_transaction_id` may be specified.
Only one of `after_time` or `after_transaction_id` may be specified.
The after and after filters are inclusive by default.
"""
if not (user is None or product is None):
raise ValueError("Cannot filter by both user and product.")
if isinstance(user, User):
if user.id is None:
raise ValueError("User must be persisted in the database.")
user_id = BindParameter("user_id", value=user.id)
else:
user_id = None
if isinstance(product, Product):
if product.id is None:
raise ValueError("Product must be persisted in the database.")
product_id = BindParameter("product_id", value=product.id)
else:
product_id = None
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if not (after_time is None or after_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(after_time, datetime):
after_time = BindParameter("after_time", value=after_time)
if isinstance(after_transaction, Transaction):
if after_transaction.id is None:
raise ValueError("after_transaction must be persisted in the database.")
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
else:
after_transaction_id = None
if after_time is not None and until_time is not None:
assert isinstance(after_time.value, datetime)
assert isinstance(until_time.value, datetime)
if after_time.value > until_time.value:
raise ValueError("after_time cannot be after until_time.")
if after_transaction is not None and until_transaction is not None:
assert after_transaction.time is not None
assert until_transaction.time is not None
if after_transaction.time > until_transaction.time:
raise ValueError("after_transaction cannot be after until_transaction.")
if limit is not None and limit <= 0:
raise ValueError("Limit must be positive.")
query = select(Transaction)
if user is not None:
query = query.where(Transaction.user_id == user_id)
if product is not None:
query = query.where(Transaction.product_id == product_id)
match (until_time, until_transaction_id, until_inclusive):
case (BindParameter(), None, True):
query = query.where(Transaction.time <= until_time)
case (BindParameter(), None, False):
query = query.where(Transaction.time < until_time)
case (None, BindParameter(), True):
query = query.where(Transaction.id <= until_transaction_id)
case (None, BindParameter(), False):
query = query.where(Transaction.id < until_transaction_id)
case _:
pass
match (after_time, after_transaction_id, after_inclusive):
case (BindParameter(), None, True):
query = query.where(Transaction.time >= after_time)
case (BindParameter(), None, False):
query = query.where(Transaction.time > after_time)
case (None, BindParameter(), True):
query = query.where(Transaction.id >= after_transaction_id)
case (None, BindParameter(), False):
query = query.where(Transaction.id > after_transaction_id)
case _:
pass
if transaction_type is not None:
if negate_transaction_type_filter:
query = query.where(~Transaction.type_.in_(transaction_type))
else:
query = query.where(Transaction.type_.in_(transaction_type))
if limit is not None:
query = query.limit(limit)
query = query.order_by(Transaction.time.asc(), Transaction.id.asc())
result = sql_session.scalars(query).all()
return list(result)

View File

@@ -0,0 +1,38 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
def transfer(
sql_session: Session,
from_user: User,
to_user: User,
amount: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if from_user.id is None:
raise ValueError("From user must be persisted in the database.")
if to_user.id is None:
raise ValueError("To user must be persisted in the database.")
if amount <= 0:
raise ValueError("Amount must be positive.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.transfer(
user_id=from_user.id,
transfer_user_id=to_user.id,
amount=amount,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,567 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Tuple
from sqlalchemy import (
CTE,
BindParameter,
Float,
Integer,
Select,
and_,
bindparam,
case,
cast,
column,
func,
or_,
select,
)
from sqlalchemy.orm import Session, aliased
from sqlalchemy.sql.elements import KeyedColumnElement
from dibbler.models import (
Transaction,
TransactionType,
User,
)
from dibbler.models.Transaction import (
DEFAULT_INTEREST_RATE_PERCENT,
DEFAULT_PENALTY_MULTIPLIER_PERCENT,
DEFAULT_PENALTY_THRESHOLD,
)
from dibbler.queries.product_price import _product_price_query
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
const,
until_filter,
)
def _joint_transaction_query(
user_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> Select[Tuple[int, int, int]]:
"""
The inner query for getting joint transactions relevant to a user.
This scans for JOINT_BUY_PRODUCT transactions made by the user,
then finds the corresponding JOINT transactions, and counts how many "shares"
of the joint transaction the user has, as well as the total number of shares.
"""
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
# First, select all joint buy product transactions for the given user
# sub_joint_transaction = aliased(Transaction, name="right_trx")
sub_joint_transaction = (
select(Transaction.joint_transaction_id.distinct().label("joint_transaction_id"))
.where(
Transaction.type_ == TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
Transaction.user_id == user_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
transaction_time=Transaction.time,
),
)
.subquery("sub_joint_transaction")
)
# Join those with their main joint transaction
# (just use Transaction)
# Then, count how many users are involved in each joint transaction
joint_transaction_count = aliased(Transaction, name="count_trx")
joint_transaction = (
select(
Transaction.id,
# Shares the user has in the transaction,
func.sum(
case(
(joint_transaction_count.user_id == user_id, CONST_ONE),
else_=CONST_ZERO,
)
).label("user_shares"),
# The total number of shares in the transaction,
func.count(joint_transaction_count.id).label("user_count"),
)
.select_from(sub_joint_transaction)
.join(
Transaction,
onclause=Transaction.id == sub_joint_transaction.c.joint_transaction_id,
)
.join(
joint_transaction_count,
onclause=joint_transaction_count.joint_transaction_id == Transaction.id,
)
.group_by(joint_transaction_count.joint_transaction_id)
)
return joint_transaction
def _non_joint_transaction_query(
user_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> Select[Tuple[int, None, None]]:
"""
The inner query for getting non-joint transactions relevant to a user.
"""
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
query = select(
Transaction.id,
CONST_NONE.label("user_shares"),
CONST_NONE.label("user_count"),
).where(
or_(
and_(
Transaction.user_id == user_id,
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_BALANCE.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.TRANSFER.as_literal_column(),
]
),
),
and_(
Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
Transaction.transfer_user_id == user_id,
),
Transaction.type_.in_(
[
TransactionType.THROW_PRODUCT.as_literal_column(),
TransactionType.ADJUST_INTEREST.as_literal_column(),
TransactionType.ADJUST_PENALTY.as_literal_column(),
]
),
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
return query
def _product_cost_expression(
product_count_column: KeyedColumnElement[int],
product_id_column: KeyedColumnElement[int],
interest_rate_percent_column: KeyedColumnElement[int],
user_balance_column: KeyedColumnElement[int],
penalty_threshold_column: KeyedColumnElement[int],
penalty_multiplier_percent_column: KeyedColumnElement[int],
joint_user_shares_column: KeyedColumnElement[int],
joint_user_count_column: KeyedColumnElement[int],
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "product_price_cte",
trx_subset_name: str = "product_price_trx_subset",
):
# TODO: This can get quite expensive real quick, so we should do some caching of the
# product prices somehow.
expression = (
select(
cast(
func.ceil(
# Base price
(
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
)
# Interest
+ (
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
* cast(interest_rate_percent_column - const(100), Float)
/ const(100.0)
)
# Penalty
+ (
(
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
)
* cast(penalty_multiplier_percent_column - const(100), Float)
/ const(100.0)
* cast(user_balance_column < penalty_threshold_column, Integer)
)
),
Integer,
)
)
.select_from(
_product_price_query(
product_id_column,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
cte_name=cte_name,
trx_subset_name=trx_subset_name,
)
)
.order_by(column("i").desc())
.limit(CONST_ONE)
.scalar_subquery()
)
return expression
def _user_balance_query(
user_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
) -> CTE:
"""
The inner query for calculating the user's balance.
"""
if use_cache:
print("WARNING: Using cache for user balance query is not implemented yet.")
if isinstance(user_id, int):
user_id = BindParameter("user_id", value=user_id)
initial_element = select(
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("balance"),
const(DEFAULT_INTEREST_RATE_PERCENT).label("interest_rate_percent"),
const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
const(DEFAULT_PENALTY_MULTIPLIER_PERCENT).label("penalty_multiplier_percent"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
trx_subset_subset = (
_non_joint_transaction_query(
user_id=user_id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
.union_all(
_joint_transaction_query(
user_id=user_id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
)
.subquery(f"{trx_subset_name}_subset")
)
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=Transaction.time.asc()).label("i"),
Transaction.id,
Transaction.amount,
Transaction.interest_rate_percent,
Transaction.penalty_multiplier_percent,
Transaction.penalty_threshold,
Transaction.product_count,
Transaction.product_id,
Transaction.time,
Transaction.transfer_user_id,
Transaction.type_,
trx_subset_subset.c.user_shares,
trx_subset_subset.c.user_count,
)
.select_from(trx_subset_subset)
.join(
Transaction,
onclause=Transaction.id == trx_subset_subset.c.id,
)
.order_by(Transaction.time.asc())
.subquery(trx_subset_name)
)
recursive_elements = (
select(
trx_subset.c.i,
trx_subset.c.time,
trx_subset.c.id.label("transaction_id"),
case(
# Adjusts balance -> balance gets adjusted
(
trx_subset.c.type_ == TransactionType.ADJUST_BALANCE.as_literal_column(),
recursive_cte.c.balance + trx_subset.c.amount,
),
# Adds a product -> balance increases
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.balance + trx_subset.c.amount,
),
# Buys a product -> balance decreases
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.balance
- _product_cost_expression(
product_count_column=trx_subset.c.product_count,
product_id_column=trx_subset.c.product_id,
interest_rate_percent_column=recursive_cte.c.interest_rate_percent,
user_balance_column=recursive_cte.c.balance,
penalty_threshold_column=recursive_cte.c.penalty_threshold,
penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent,
joint_user_shares_column=CONST_ONE,
joint_user_count_column=CONST_ONE,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
cte_name=f"{cte_name}_price",
trx_subset_name=f"{trx_subset_name}_price",
).label("product_cost"),
),
# Joint transaction -> balance decreases proportionally
(
trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(),
recursive_cte.c.balance
- _product_cost_expression(
product_count_column=trx_subset.c.product_count,
product_id_column=trx_subset.c.product_id,
interest_rate_percent_column=recursive_cte.c.interest_rate_percent,
user_balance_column=recursive_cte.c.balance,
penalty_threshold_column=recursive_cte.c.penalty_threshold,
penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent,
joint_user_shares_column=trx_subset.c.user_shares,
joint_user_count_column=trx_subset.c.user_count,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
cte_name=f"{cte_name}_joint_price",
trx_subset_name=f"{trx_subset_name}_joint_price",
).label("joint_product_cost"),
),
# Transfers money to self -> balance increases
(
and_(
trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
trx_subset.c.transfer_user_id == user_id,
),
recursive_cte.c.balance + trx_subset.c.amount,
),
# Transfers money from self -> balance decreases
(
and_(
trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
trx_subset.c.transfer_user_id != user_id,
),
recursive_cte.c.balance - trx_subset.c.amount,
),
# Throws a product -> if the user is considered to have bought it, balance increases
# TODO: # (
# trx_subset.c.type_ == TransactionType.THROW_PRODUCT,
# recursive_cte.c.balance + trx_subset.c.amount,
# ),
# Interest adjustment -> balance stays the same
# Penalty adjustment -> balance stays the same
else_=recursive_cte.c.balance,
).label("balance"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_INTEREST.as_literal_column(),
trx_subset.c.interest_rate_percent,
),
else_=recursive_cte.c.interest_rate_percent,
).label("interest_rate_percent"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
trx_subset.c.penalty_threshold,
),
else_=recursive_cte.c.penalty_threshold,
).label("penalty_threshold"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
trx_subset.c.penalty_multiplier_percent,
),
else_=recursive_cte.c.penalty_multiplier_percent,
).label("penalty_multiplier_percent"),
)
.select_from(trx_subset)
.where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
)
return recursive_cte.union_all(recursive_elements)
# TODO: create a function for the log that pretty prints the log entries
# for debugging purposes
@dataclass
class UserBalanceLogEntry:
transaction: Transaction
balance: int
interest_rate_percent: int
penalty_threshold: int
penalty_multiplier_percent: int
def is_penalized(self) -> bool:
"""
Returns whether this exact transaction is penalized.
"""
raise NotImplementedError("is_penalized is not implemented yet.")
def user_balance_log(
sql_session: Session,
user: User,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[UserBalanceLogEntry]:
"""
Returns a log of the user's balance over time, including interest and penalty adjustments.
If 'until' is given, only transactions up to that time are considered.
"""
if user.id is None:
raise ValueError("User must be persisted in the database.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
recursive_cte = _user_balance_query(
user.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.execute(
select(
Transaction,
recursive_cte.c.balance,
recursive_cte.c.interest_rate_percent,
recursive_cte.c.penalty_threshold,
recursive_cte.c.penalty_multiplier_percent,
)
.select_from(recursive_cte)
.join(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.order_by(recursive_cte.c.i.asc())
).all()
if result is None:
# If there are no transactions for this user, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
)
return [
UserBalanceLogEntry(
transaction=row[0],
balance=row.balance,
interest_rate_percent=row.interest_rate_percent,
penalty_threshold=row.penalty_threshold,
penalty_multiplier_percent=row.penalty_multiplier_percent,
)
for row in result
]
def user_balance(
sql_session: Session,
user: User,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> int:
"""
Calculates the balance of a user.
If 'until' is given, only transactions up to that time are considered.
"""
if user.id is None:
raise ValueError("User must be persisted in the database.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
recursive_cte = _user_balance_query(
user.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.scalar(
select(recursive_cte.c.balance)
.order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
)
if result is None:
# If there are no transactions for this user, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
)
return result

View File

@@ -0,0 +1,48 @@
from datetime import datetime
from sqlalchemy import BindParameter, bindparam
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
# NOTE: This absolutely needs a cache, else we can't stop recursing until we know all owners for all products...
#
# Since we know that the non-owned products will not get renowned by the user by other means,
# we can just check for ownership on the products that have an ADD_PRODUCT transaction for the user.
# between now and the cached time.
#
# However, the opposite way is more difficult. The cache will store which products are owned by which users,
# but we still need to check if the user passes out of ownership for the item, without needing to check past
# the cache time. Maybe we also need to store the queue number(s) per user/product combo in the cache? What if
# a user has products multiple places in the queue, interleaved with other users?
def user_products(
sql_session: Session,
user: User,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[tuple[Product, int]]:
"""
Returns the list of products owned by the user, along with how many of each product they own.
"""
if user.id is None:
raise ValueError("User must be persisted in the database.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
raise NotImplementedError("Not implemented yet, needs caching system first.")