Write a set of queries to go along with the event sourcing model

This commit is contained in:
2025-12-10 16:00:36 +09:00
parent 39c05258ef
commit 243da23cde
39 changed files with 5783 additions and 74 deletions

0
dibbler/lib/__init__.py Normal file
View File

View File

@@ -1,79 +1,7 @@
import pwd
import subprocess
import os
import pwd
import signal
from sqlalchemy import or_, and_
from ..models import User, Product
def search_user(string, session, ignorethisflag=None):
string = string.lower()
exact_match = (
session.query(User)
.filter(or_(User.name == string, User.card == string, User.rfid == string))
.first()
)
if exact_match:
return exact_match
user_list = (
session.query(User)
.filter(
or_(
User.name.ilike(f"%{string}%"),
User.card.ilike(f"%{string}%"),
User.rfid.ilike(f"%{string}%"),
)
)
.all()
)
return user_list
def search_product(string, session, find_hidden_products=True):
if find_hidden_products:
exact_match = (
session.query(Product)
.filter(or_(Product.bar_code == string, Product.name == string))
.first()
)
else:
exact_match = (
session.query(Product)
.filter(
or_(
Product.bar_code == string,
and_(Product.name == string, Product.hidden is False),
)
)
.first()
)
if exact_match:
return exact_match
if find_hidden_products:
product_list = (
session.query(Product)
.filter(
or_(
Product.bar_code.ilike(f"%{string}%"),
Product.name.ilike(f"%{string}%"),
)
)
.all()
)
else:
product_list = (
session.query(Product)
.filter(
or_(
Product.bar_code.ilike(f"%{string}%"),
and_(Product.name.ilike(f"%{string}%"), Product.hidden is False),
)
)
.all()
)
return product_list
import subprocess
def system_user_exists(username):

View File

@@ -0,0 +1,37 @@
__all__ = [
# "add_product",
# "add_user",
"adjust_interest",
"adjust_penalty",
"current_interest",
"current_penalty",
"joint_buy_product",
"product_owners",
"product_owners_log",
"product_price",
"product_price_log",
"product_stock",
# "products_owned_by_user",
"search_product",
"search_user",
"transaction_log",
"user_balance",
"user_balance_log",
]
# from .add_product import add_product
# from .add_user import add_user
from .adjust_interest import adjust_interest
from .adjust_penalty import adjust_penalty
from .current_interest import current_interest
from .current_penalty import current_penalty
from .joint_buy_product import joint_buy_product
from .product_owners import product_owners, product_owners_log
from .product_price import product_price, product_price_log
from .product_stock import product_stock
# from .products_owned_by_user import products_owned_by_user
from .search_product import search_product
from .search_user import search_user
from .transaction_log import transaction_log
from .user_balance import user_balance, user_balance_log

View File

@@ -0,0 +1,51 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def add_product(
sql_session: Session,
user: User,
product: Product,
amount: int,
per_product: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if amount <= 0:
raise ValueError("Amount must be positive.")
if per_product <= 0:
raise ValueError("Per product price must be positive.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
if per_product * product_count < amount:
raise ValueError("Total per product price must be at least equal to amount.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=amount,
per_product=per_product,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,33 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
def adjust_balance(
sql_session: Session,
user: User,
amount: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if amount == 0:
raise ValueError("Amount must be non-zero.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_balance(
user_id=user.id,
amount=amount,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
# TODO: this type of transaction should be password protected.
# the password can be set as a string literal in the config file.
def adjust_interest(
sql_session: Session,
user: User,
new_interest: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if new_interest < 0:
raise ValueError("Interest rate cannot be negative")
if user.id is None:
raise ValueError("User must be persisted in the database.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_interest(
user_id=user.id,
interest_rate_percent=new_interest,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,49 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
from dibbler.queries.current_penalty import current_penalty
# TODO: this type of transaction should be password protected.
# the password can be set as a string literal in the config file.
def adjust_penalty(
sql_session: Session,
user: User,
new_penalty: int | None = None,
new_penalty_multiplier: int | None = None,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if new_penalty is None and new_penalty_multiplier is None:
raise ValueError("At least one of new_penalty or new_penalty_multiplier must be provided")
if new_penalty_multiplier is not None and new_penalty_multiplier < 100:
raise ValueError("Penalty multiplier cannot be less than 100%")
if user.id is None:
raise ValueError("User must be persisted in the database.")
if new_penalty is None or new_penalty_multiplier is None:
existing_penalty, existing_penalty_multiplier = current_penalty(sql_session)
if new_penalty is None:
new_penalty = existing_penalty
if new_penalty_multiplier is None:
new_penalty_multiplier = existing_penalty_multiplier
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_penalty(
user_id=user.id,
penalty_threshold=new_penalty,
penalty_multiplier_percent=new_penalty_multiplier,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,40 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def adjust_stock(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count == 0:
raise ValueError("Product count must be non-zero.")
# TODO: it should not be possible to reduce stock below zero.
#
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,38 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def buy_product(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.buy_product(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,25 @@
from sqlalchemy.orm import Session
from dibbler.models import Product
def create_product(
sql_session: Session,
name: str,
barcode: str,
) -> Product:
if not name:
raise ValueError("Name cannot be empty.")
if not barcode:
raise ValueError("Barcode cannot be empty.")
# TODO: check for duplicate names, barcodes
# TODO: add more validation for barcode
product = Product(barcode, name)
sql_session.add(product)
sql_session.commit()
return product

View File

@@ -0,0 +1,21 @@
from sqlalchemy.orm import Session
from dibbler.models import User
def create_user(
sql_session: Session,
name: str,
card: str | None,
rfid: str | None,
) -> User:
if not name:
raise ValueError("Name cannot be empty.")
# TODO: check for duplicate names, cards, rfids
user = User(name=name, card=card, rfid=rfid)
sql_session.add(user)
sql_session.commit()
return user

View File

@@ -0,0 +1,55 @@
from datetime import datetime
from sqlalchemy import BindParameter, bindparam, select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT
from dibbler.queries.query_helpers import until_filter
def current_interest(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
) -> int:
"""
Get the current interest rate percentage as of a given time or transaction.
Returns the interest rate percentage as an integer.
"""
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
result = sql_session.scalars(
select(Transaction)
.where(
Transaction.type_ == TransactionType.ADJUST_INTEREST,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.limit(1)
).one_or_none()
if result is None:
return DEFAULT_INTEREST_RATE_PERCENT
elif result.interest_rate_percent is None:
return DEFAULT_INTEREST_RATE_PERCENT
else:
return result.interest_rate_percent

View File

@@ -0,0 +1,59 @@
from datetime import datetime
from sqlalchemy import BindParameter, bindparam, select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import (
DEFAULT_PENALTY_MULTIPLIER_PERCENT,
DEFAULT_PENALTY_THRESHOLD,
)
from dibbler.queries.query_helpers import until_filter
def current_penalty(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
) -> tuple[int, int]:
"""
Get the current penalty settings (threshold and multiplier percentage) as of a given time or transaction.
Returns a tuple of `(penalty_threshold, penalty_multiplier_percentage)`.
"""
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
result = sql_session.scalars(
select(Transaction)
.where(
Transaction.type_ == TransactionType.ADJUST_PENALTY,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.limit(1)
).one_or_none()
if result is None:
return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENT
assert result.penalty_threshold is not None, "Penalty threshold must be set"
assert result.penalty_multiplier_percent is not None, "Penalty multiplier percent must be set"
return result.penalty_threshold, result.penalty_multiplier_percent

View File

@@ -0,0 +1,68 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
User,
)
def joint_buy_product(
sql_session: Session,
product: Product,
product_count: int,
instigator: User,
users: list[User],
time: datetime | None = None,
message: str | None = None,
) -> list[Transaction]:
"""
Create buy product transactions for multiple users at once.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if instigator.id is None:
raise ValueError("Instigator must be persisted in the database.")
if len(users) == 0:
raise ValueError("At least bying one user must be specified.")
if any(user.id is None for user in users):
raise ValueError("All users must be persisted in the database.")
if instigator not in users:
raise ValueError("Instigator must be in the list of users buying the product.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
joint_transaction = Transaction.joint(
user_id=instigator.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(joint_transaction)
sql_session.flush() # Ensure joint_transaction gets an ID
transactions = [joint_transaction]
for user in users:
buy_transaction = Transaction.joint_buy_product(
user_id=user.id,
joint_transaction_id=joint_transaction.id,
time=time,
message=message,
)
sql_session.add(buy_transaction)
transactions.append(buy_transaction)
sql_session.commit()
return transactions

View File

@@ -0,0 +1,309 @@
from dataclasses import dataclass
from datetime import datetime
from sqlalchemy import (
CTE,
BindParameter,
and_,
bindparam,
case,
func,
or_,
select,
)
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
TransactionType,
User,
)
from dibbler.queries.product_stock import _product_stock_query
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
until_filter,
)
def _product_owners_query(
product_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
) -> CTE:
"""
The inner query for inferring the owners of a given product.
"""
if use_cache:
print("WARNING: Using cache for users owning product query is not implemented yet.")
if isinstance(product_id, int):
product_id = bindparam("product_id", value=product_id)
if until_time is not None and until_transaction is not None:
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = bindparam("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
product_stock = _product_stock_query(
product_id=product_id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=Transaction.time.desc()).label("i"),
Transaction.time,
Transaction.id,
Transaction.type_,
Transaction.user_id,
Transaction.product_count,
)
.where(
or_(
Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
and_(
Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
Transaction.product_count > CONST_ZERO,
),
),
Transaction.product_id == product_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.subquery(trx_subset_name)
)
initial_element = select(
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_NONE.label("user_id"),
CONST_ZERO.label("product_count"),
product_stock.scalar_subquery().label("products_left_to_account_for"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
recursive_elements = (
select(
trx_subset.c.i,
trx_subset.c.time,
trx_subset.c.id.label("transaction_id"),
# Who added the product (if any)
case(
# Someone adds the product -> they own it
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
trx_subset.c.user_id,
),
else_=CONST_NONE,
).label("user_id"),
# How many products did they add (if any)
case(
# Someone adds the product -> they added a certain amount of products
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
trx_subset.c.product_count,
),
# Stock got adjusted upwards -> consider those products as added by nobody
(
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
and (trx_subset.c.product_count > CONST_ZERO),
trx_subset.c.product_count,
),
else_=CONST_ZERO,
).label("product_count"),
# How many products left to account for
case(
# Someone adds the product -> known owner, decrease the number of products left to account for
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
),
# Stock got adjusted upwards -> none owner, decrease the number of products left to account for
(
and_(
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
trx_subset.c.product_count > CONST_ZERO,
),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
),
else_=recursive_cte.c.products_left_to_account_for,
).label("products_left_to_account_for"),
)
.select_from(trx_subset)
.where(
and_(
trx_subset.c.i == recursive_cte.c.i + CONST_ONE,
# Base case: stop if we've accounted for all products
recursive_cte.c.products_left_to_account_for > CONST_ZERO,
)
)
)
return recursive_cte.union_all(recursive_elements)
@dataclass
class ProductOwnersLogEntry:
transaction: Transaction
user: User | None
products_left_to_account_for: int
def product_owners_log(
sql_session: Session,
product: Product,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[ProductOwnersLogEntry]:
"""
Returns a log of the product ownership calculation for the given product.
If 'until' is given, only transactions up to that time are considered.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
recursive_cte = _product_owners_query(
product_id=product.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.execute(
select(
Transaction,
User,
recursive_cte.c.products_left_to_account_for,
)
.select_from(recursive_cte)
.join(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.join(
User,
onclause=User.id == recursive_cte.c.user_id,
isouter=True,
)
.order_by(recursive_cte.c.time.desc())
).all()
if result is None:
# If there are no transactions for this product, the query should return an empty list, not None.
raise RuntimeError(
f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id})."
)
return [
ProductOwnersLogEntry(
transaction=row[0],
user=row[1],
products_left_to_account_for=row[2],
)
for row in result
]
def product_owners(
sql_session: Session,
product: Product,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[User | None]:
"""
Returns an ordered list of users owning the given product.
If 'until' is given, only transactions up to that time are considered.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
recursive_cte = _product_owners_query(
product_id=product.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
db_result = sql_session.execute(
select(
recursive_cte.c.products_left_to_account_for,
recursive_cte.c.product_count,
User,
)
.join(User, User.id == recursive_cte.c.user_id, isouter=True)
.order_by(recursive_cte.c.time.desc())
).all()
print(db_result)
result: list[User | None] = []
none_count = 0
# We are moving backwards through history, but this is the order we want to return the list
# There are 3 cases:
# User is not none -> add user product_count times
# User is none, and product_count is not 0 -> add None product_count times
# User is none, and product_count is 0 -> check how much products are left to account for,
# TODO: embed this into the query itself?
for products_left_to_account_for, product_count, user in db_result:
if user is not None:
if products_left_to_account_for < 0:
result.extend([user] * (product_count + products_left_to_account_for))
else:
result.extend([user] * product_count)
elif product_count != 0:
if products_left_to_account_for < 0:
none_count += product_count + products_left_to_account_for
else:
none_count += product_count
else:
pass
# none_count += user_count
# else:
result.extend([None] * none_count)
# # NOTE: if the last line exeeds the product count, we need to truncate it
# result.extend([user] * min(user_count, products_left_to_account_for))
# redistribute the user counts to a list of users
return list(result)

View File

@@ -0,0 +1,309 @@
import math
from dataclasses import dataclass
from datetime import datetime
from sqlalchemy import (
BindParameter,
ColumnElement,
Integer,
bindparam,
case,
cast,
func,
select,
)
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
TransactionType,
)
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
until_filter,
)
def _product_price_query(
product_id: int | ColumnElement[int],
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
):
"""
The inner query for calculating the product price.
"""
if use_cache:
print("WARNING: Using cache for product price query is not implemented yet.")
if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id)
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
initial_element = select(
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("price"),
CONST_ZERO.label("product_count"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=Transaction.time.asc()).label("i"),
Transaction.id,
Transaction.time,
Transaction.type_,
Transaction.product_count,
Transaction.per_product,
)
.where(
Transaction.type_.in_(
[
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.JOINT.as_literal_column(),
]
),
Transaction.product_id == product_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.asc())
.subquery(trx_subset_name)
)
recursive_elements = (
select(
trx_subset.c.i,
trx_subset.c.time,
trx_subset.c.id.label("transaction_id"),
case(
# Someone buys the product -> price remains the same.
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.price,
),
# Someone adds the product -> price is recalculated based on
# product count, previous price, and new price.
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
cast(
func.ceil(
(
recursive_cte.c.price
* func.max(recursive_cte.c.product_count, CONST_ZERO)
+ trx_subset.c.per_product * trx_subset.c.product_count
)
/ (
# The running product count can be negative if the accounting is bad.
# This ensures that we never end up with negative prices or zero divisions
# and other disastrous phenomena.
func.max(recursive_cte.c.product_count, CONST_ZERO)
+ trx_subset.c.product_count
)
),
Integer,
),
),
# Someone adjusts the stock -> price remains the same.
(
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
recursive_cte.c.price,
),
# Should never happen
else_=recursive_cte.c.price,
).label("price"),
case(
# Someone buys the product -> product count is reduced.
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.product_count - trx_subset.c.product_count,
),
(
trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(),
recursive_cte.c.product_count - trx_subset.c.product_count,
),
# Someone adds the product -> product count is increased.
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.product_count + trx_subset.c.product_count,
),
# Someone adjusts the stock -> product count is adjusted.
(
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
recursive_cte.c.product_count + trx_subset.c.product_count,
),
# Should never happen
else_=recursive_cte.c.product_count,
).label("product_count"),
)
.select_from(trx_subset)
.where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
)
return recursive_cte.union_all(recursive_elements)
# TODO: create a function for the log that pretty prints the log entries
# for debugging purposes
@dataclass
class ProductPriceLogEntry:
transaction: Transaction
price: int
product_count: int
def product_price_log(
sql_session: Session,
product: Product,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[ProductPriceLogEntry]:
"""
Calculates the price of a product and returns a log of the price changes.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
recursive_cte = _product_price_query(
product.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.execute(
select(
Transaction,
recursive_cte.c.price,
recursive_cte.c.product_count,
)
.select_from(recursive_cte)
.join(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.order_by(recursive_cte.c.i.asc())
).all()
if result is None:
# If there are no transactions for this product, the query should return an empty list, not None.
raise RuntimeError(
f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})."
)
return [
ProductPriceLogEntry(
transaction=row[0],
price=row.price,
product_count=row.product_count,
)
for row in result
]
def product_price(
sql_session: Session,
product: Product,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
include_interest: bool = False,
) -> int:
"""
Calculates the price of a product.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
recursive_cte = _product_price_query(
product.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
# TODO: optionally verify subresults:
# - product_count should never be negative (but this happens sometimes, so just a warning)
# - price should never be negative
result = sql_session.scalars(
select(recursive_cte.c.price)
.order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
).one_or_none()
if result is None:
# If there are no transactions for this product, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})."
)
if include_interest:
interest_rate = (
sql_session.scalar(
select(Transaction.interest_rate_percent)
.where(
Transaction.type_ == TransactionType.ADJUST_INTEREST,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.limit(CONST_ONE)
)
or DEFAULT_INTEREST_RATE_PERCENT
)
result = math.ceil(result * interest_rate / 100)
return result

View File

@@ -0,0 +1,126 @@
from datetime import datetime
from typing import Tuple
from sqlalchemy import (
BindParameter,
Select,
bindparam,
case,
func,
select,
)
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
TransactionType,
)
from dibbler.queries.query_helpers import until_filter
def _product_stock_query(
product_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> Select[Tuple[int]]:
"""
The inner query for calculating the product stock.
"""
if use_cache:
print("WARNING: Using cache for product stock query is not implemented yet.")
if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id)
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
query = select(
func.sum(
case(
(
Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
Transaction.product_count,
),
(
Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
Transaction.product_count,
),
(
Transaction.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
-Transaction.product_count,
),
(
Transaction.type_ == TransactionType.JOINT.as_literal_column(),
-Transaction.product_count,
),
(
Transaction.type_ == TransactionType.THROW_PRODUCT.as_literal_column(),
-Transaction.product_count,
),
else_=0,
)
).label("stock")
).where(
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.JOINT.as_literal_column(),
TransactionType.THROW_PRODUCT.as_literal_column(),
]
),
Transaction.product_id == product_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
return query
def product_stock(
sql_session: Session,
product: Product,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> int:
"""
Returns the number of products in stock.
If 'until' is given, only transactions up to that time are considered.
"""
if product.id is None:
raise ValueError("Product must be persisted in the database.")
query = _product_stock_query(
product_id=product.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.scalars(query).one_or_none()
return result or 0

View File

@@ -0,0 +1,80 @@
from datetime import datetime
from typing import TypeVar
from sqlalchemy import (
BindParameter,
ColumnExpressionArgument,
literal,
select,
)
from sqlalchemy.orm import QueryableAttribute
from dibbler.models import Transaction
T = TypeVar("T")
def const(value: T) -> BindParameter[T]:
"""
Create a constant SQL literal bind parameter.
This is useful to avoid too many `?` bind parameters in SQL queries,
when the input value is known to be safe.
"""
return literal(value, literal_execute=True)
CONST_ZERO: BindParameter[int] = const(0)
"""A constant SQL expression `0`. This will render as a literal `0` in SQL queries."""
CONST_ONE: BindParameter[int] = const(1)
"""A constant SQL expression `1`. This will render as a literal `1` in SQL queries."""
CONST_TRUE: BindParameter[bool] = const(True)
"""A constant SQL expression `TRUE`. This will render as a literal `TRUE` in SQL queries."""
CONST_FALSE: BindParameter[bool] = const(False)
"""A constant SQL expression `FALSE`. This will render as a literal `FALSE` in SQL queries."""
CONST_NONE: BindParameter[None] = const(None)
"""A constant SQL expression `NULL`. This will render as a literal `NULL` in SQL queries."""
def until_filter(
until_time: BindParameter[datetime] | None = None,
until_transaction_id: BindParameter[int] | None = None,
until_inclusive: bool = True,
transaction_time: QueryableAttribute = Transaction.time,
) -> ColumnExpressionArgument[bool]:
"""
Create a filter condition for transactions up to a given time or transaction.
Only one of `until_time` or `until_transaction_id` may be specified.
"""
assert not (until_time is not None and until_transaction_id is not None), (
"Cannot filter by both until_time and until_transaction_id."
)
match (until_time, until_transaction_id, until_inclusive):
case (BindParameter(), None, True):
return transaction_time <= until_time
case (BindParameter(), None, False):
return transaction_time < until_time
case (None, BindParameter(), True):
return (
transaction_time
<= select(Transaction.time)
.where(Transaction.id == until_transaction_id)
.scalar_subquery()
)
case (None, BindParameter(), False):
return (
transaction_time
< select(Transaction.time)
.where(Transaction.id == until_transaction_id)
.scalar_subquery()
)
return CONST_TRUE

View File

@@ -0,0 +1,42 @@
from sqlalchemy import and_, literal, not_, or_, select
from sqlalchemy.orm import Session
from dibbler.models import Product
def search_product(
string: str,
sql_session: Session,
find_hidden_products=False,
) -> Product | list[Product]:
if not string:
raise ValueError("Search string cannot be empty.")
exact_match = sql_session.scalars(
select(Product).where(
or_(
Product.bar_code == string,
and_(
Product.name == string,
literal(True) if find_hidden_products else not_(Product.hidden),
),
)
)
).first()
if exact_match:
return exact_match
product_list = sql_session.scalars(
select(Product).where(
or_(
Product.bar_code.ilike(f"%{string}%"),
and_(
Product.name.ilike(f"%{string}%"),
literal(True) if find_hidden_products else not_(Product.hidden),
),
)
)
).all()
return list(product_list)

View File

@@ -0,0 +1,39 @@
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from dibbler.models import User
def search_user(
string: str,
sql_session: Session,
) -> User | list[User]:
if not string:
raise ValueError("Search string cannot be empty.")
string = string.lower()
exact_match = sql_session.scalars(
select(User).where(
or_(
User.name == string,
User.card == string,
User.rfid == string,
)
)
).first()
if exact_match:
return exact_match
user_list = sql_session.scalars(
select(User).where(
or_(
User.name.ilike(f"%{string}%"),
User.card.ilike(f"%{string}%"),
User.rfid.ilike(f"%{string}%"),
)
)
).all()
return list(user_list)

View File

@@ -0,0 +1,42 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def throw_product(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
raise NotImplementedError(
"Please don't use this function until relevant calculations have been added to user_balance."
)
transaction = Transaction.throw_product(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,142 @@
from datetime import datetime
from sqlalchemy import BindParameter, select
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
TransactionType,
User,
)
# TODO: should this include full joint transactions that involve a user?
# TODO: should this involve throw-away transactions that affects a user?
def transaction_log(
sql_session: Session,
user: User | None = None,
product: Product | None = None,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
after_time: BindParameter[datetime] | datetime | None = None,
after_transaction: Transaction | None = None,
after_inlcusive: bool = True,
transaction_type: list[TransactionType] | None = None,
negate_transaction_type_filter: bool = False,
limit: int | None = None,
) -> list[Transaction]:
"""
Retrieve the transaction log, optionally filtered.
Only one of `user` or `product` may be specified.
Only one of `until_time` or `until_transaction_id` may be specified.
Only one of `after_time` or `after_transaction_id` may be specified.
The after and after filters are inclusive by default.
"""
if not (user is None or product is None):
raise ValueError("Cannot filter by both user and product.")
if isinstance(user, User):
if user.id is None:
raise ValueError("User must be persisted in the database.")
user_id = BindParameter("user_id", value=user.id)
else:
user_id = None
if isinstance(product, Product):
if product.id is None:
raise ValueError("Product must be persisted in the database.")
product_id = BindParameter("product_id", value=product.id)
else:
product_id = None
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if not (after_time is None or after_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(after_time, datetime):
after_time = BindParameter("after_time", value=after_time)
if isinstance(after_transaction, Transaction):
if after_transaction.id is None:
raise ValueError("after_transaction must be persisted in the database.")
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
else:
after_transaction_id = None
if after_time is not None and until_time is not None:
assert isinstance(after_time.value, datetime)
assert isinstance(until_time.value, datetime)
if after_time.value > until_time.value:
raise ValueError("after_time cannot be after until_time.")
if after_transaction is not None and until_transaction is not None:
assert after_transaction.time is not None
assert until_transaction.time is not None
if after_transaction.time > until_transaction.time:
raise ValueError("after_transaction cannot be after until_transaction.")
if limit is not None and limit <= 0:
raise ValueError("Limit must be positive.")
query = select(Transaction)
if user is not None:
query = query.where(Transaction.user_id == user_id)
if product is not None:
query = query.where(Transaction.product_id == product_id)
match (until_time, until_transaction_id, until_inclusive):
case (BindParameter(), None, True):
query = query.where(Transaction.time <= until_time)
case (BindParameter(), None, False):
query = query.where(Transaction.time < until_time)
case (None, BindParameter(), True):
query = query.where(Transaction.id <= until_transaction_id)
case (None, BindParameter(), False):
query = query.where(Transaction.id < until_transaction_id)
case _:
pass
match (after_time, after_transaction_id, after_inlcusive):
case (BindParameter(), None, True):
query = query.where(Transaction.time >= after_time)
case (BindParameter(), None, False):
query = query.where(Transaction.time > after_time)
case (None, BindParameter(), True):
query = query.where(Transaction.id >= after_transaction_id)
case (None, BindParameter(), False):
query = query.where(Transaction.id > after_transaction_id)
case _:
pass
if transaction_type is not None:
if negate_transaction_type_filter:
query = query.where(~Transaction.type_.in_(transaction_type))
else:
query = query.where(Transaction.type_.in_(transaction_type))
if limit is not None:
query = query.limit(limit)
query = query.order_by(Transaction.time.asc(), Transaction.id.asc())
result = sql_session.scalars(query).all()
return list(result)

View File

@@ -0,0 +1,38 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
def transfer(
sql_session: Session,
from_user: User,
to_user: User,
amount: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if from_user.id is None:
raise ValueError("From user must be persisted in the database.")
if to_user.id is None:
raise ValueError("To user must be persisted in the database.")
if amount <= 0:
raise ValueError("Amount must be positive.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.transfer(
user_id=from_user.id,
transfer_user_id=to_user.id,
amount=amount,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,567 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Tuple
from sqlalchemy import (
CTE,
BindParameter,
Float,
Integer,
Select,
and_,
bindparam,
case,
cast,
column,
func,
or_,
select,
)
from sqlalchemy.orm import Session, aliased
from sqlalchemy.sql.elements import KeyedColumnElement
from dibbler.models import (
Transaction,
TransactionType,
User,
)
from dibbler.models.Transaction import (
DEFAULT_INTEREST_RATE_PERCENT,
DEFAULT_PENALTY_MULTIPLIER_PERCENT,
DEFAULT_PENALTY_THRESHOLD,
)
from dibbler.queries.product_price import _product_price_query
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
const,
until_filter,
)
def _joint_transaction_query(
user_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> Select[Tuple[int, int, int]]:
"""
The inner query for getting joint transactions relevant to a user.
This scans for JOINT_BUY_PRODUCT transactions made by the user,
then finds the corresponding JOINT transactions, and counts how many "shares"
of the joint transaction the user has, as well as the total number of shares.
"""
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
# First, select all joint buy product transactions for the given user
# sub_joint_transaction = aliased(Transaction, name="right_trx")
sub_joint_transaction = (
select(Transaction.joint_transaction_id.distinct().label("joint_transaction_id"))
.where(
Transaction.type_ == TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
Transaction.user_id == user_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
transaction_time=Transaction.time,
),
)
.subquery("sub_joint_transaction")
)
# Join those with their main joint transaction
# (just use Transaction)
# Then, count how many users are involved in each joint transaction
joint_transaction_count = aliased(Transaction, name="count_trx")
joint_transaction = (
select(
Transaction.id,
# Shares the user has in the transaction,
func.sum(
case(
(joint_transaction_count.user_id == user_id, CONST_ONE),
else_=CONST_ZERO,
)
).label("user_shares"),
# The total number of shares in the transaction,
func.count(joint_transaction_count.id).label("user_count"),
)
.select_from(sub_joint_transaction)
.join(
Transaction,
onclause=Transaction.id == sub_joint_transaction.c.joint_transaction_id,
)
.join(
joint_transaction_count,
onclause=joint_transaction_count.joint_transaction_id == Transaction.id,
)
.group_by(joint_transaction_count.joint_transaction_id)
)
return joint_transaction
def _non_joint_transaction_query(
user_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> Select[Tuple[int, None, None]]:
"""
The inner query for getting non-joint transactions relevant to a user.
"""
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
query = select(
Transaction.id,
CONST_NONE.label("user_shares"),
CONST_NONE.label("user_count"),
).where(
or_(
and_(
Transaction.user_id == user_id,
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_BALANCE.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.TRANSFER.as_literal_column(),
]
),
),
and_(
Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
Transaction.transfer_user_id == user_id,
),
Transaction.type_.in_(
[
TransactionType.THROW_PRODUCT.as_literal_column(),
TransactionType.ADJUST_INTEREST.as_literal_column(),
TransactionType.ADJUST_PENALTY.as_literal_column(),
]
),
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
return query
def _product_cost_expression(
product_count_column: KeyedColumnElement[int],
product_id_column: KeyedColumnElement[int],
interest_rate_percent_column: KeyedColumnElement[int],
user_balance_column: KeyedColumnElement[int],
penalty_threshold_column: KeyedColumnElement[int],
penalty_multiplier_percent_column: KeyedColumnElement[int],
joint_user_shares_column: KeyedColumnElement[int],
joint_user_count_column: KeyedColumnElement[int],
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "product_price_cte",
trx_subset_name: str = "product_price_trx_subset",
):
# TODO: This can get quite expensive real quick, so we should do some caching of the
# product prices somehow.
expression = (
select(
cast(
func.ceil(
# Base price
(
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
)
# Interest
+ (
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
* cast(interest_rate_percent_column - const(100), Float)
/ const(100.0)
)
# Penalty
+ (
(
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
)
* cast(penalty_multiplier_percent_column - const(100), Float)
/ const(100.0)
* cast(user_balance_column < penalty_threshold_column, Integer)
)
),
Integer,
)
)
.select_from(
_product_price_query(
product_id_column,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
cte_name=cte_name,
trx_subset_name=trx_subset_name,
)
)
.order_by(column("i").desc())
.limit(CONST_ONE)
.scalar_subquery()
)
return expression
def _user_balance_query(
user_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
) -> CTE:
"""
The inner query for calculating the user's balance.
"""
if use_cache:
print("WARNING: Using cache for user balance query is not implemented yet.")
if isinstance(user_id, int):
user_id = BindParameter("user_id", value=user_id)
initial_element = select(
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("balance"),
const(DEFAULT_INTEREST_RATE_PERCENT).label("interest_rate_percent"),
const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
const(DEFAULT_PENALTY_MULTIPLIER_PERCENT).label("penalty_multiplier_percent"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
trx_subset_subset = (
_non_joint_transaction_query(
user_id=user_id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
.union_all(
_joint_transaction_query(
user_id=user_id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
)
.subquery(f"{trx_subset_name}_subset")
)
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=Transaction.time.asc()).label("i"),
Transaction.id,
Transaction.amount,
Transaction.interest_rate_percent,
Transaction.penalty_multiplier_percent,
Transaction.penalty_threshold,
Transaction.product_count,
Transaction.product_id,
Transaction.time,
Transaction.transfer_user_id,
Transaction.type_,
trx_subset_subset.c.user_shares,
trx_subset_subset.c.user_count,
)
.select_from(trx_subset_subset)
.join(
Transaction,
onclause=Transaction.id == trx_subset_subset.c.id,
)
.order_by(Transaction.time.asc())
.subquery(trx_subset_name)
)
recursive_elements = (
select(
trx_subset.c.i,
trx_subset.c.time,
trx_subset.c.id.label("transaction_id"),
case(
# Adjusts balance -> balance gets adjusted
(
trx_subset.c.type_ == TransactionType.ADJUST_BALANCE.as_literal_column(),
recursive_cte.c.balance + trx_subset.c.amount,
),
# Adds a product -> balance increases
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.balance + trx_subset.c.amount,
),
# Buys a product -> balance decreases
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.balance
- _product_cost_expression(
product_count_column=trx_subset.c.product_count,
product_id_column=trx_subset.c.product_id,
interest_rate_percent_column=recursive_cte.c.interest_rate_percent,
user_balance_column=recursive_cte.c.balance,
penalty_threshold_column=recursive_cte.c.penalty_threshold,
penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent,
joint_user_shares_column=CONST_ONE,
joint_user_count_column=CONST_ONE,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
cte_name=f"{cte_name}_price",
trx_subset_name=f"{trx_subset_name}_price",
).label("product_cost"),
),
# Joint transaction -> balance decreases proportionally
(
trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(),
recursive_cte.c.balance
- _product_cost_expression(
product_count_column=trx_subset.c.product_count,
product_id_column=trx_subset.c.product_id,
interest_rate_percent_column=recursive_cte.c.interest_rate_percent,
user_balance_column=recursive_cte.c.balance,
penalty_threshold_column=recursive_cte.c.penalty_threshold,
penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent,
joint_user_shares_column=trx_subset.c.user_shares,
joint_user_count_column=trx_subset.c.user_count,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
cte_name=f"{cte_name}_joint_price",
trx_subset_name=f"{trx_subset_name}_joint_price",
).label("joint_product_cost"),
),
# Transfers money to self -> balance increases
(
and_(
trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
trx_subset.c.transfer_user_id == user_id,
),
recursive_cte.c.balance + trx_subset.c.amount,
),
# Transfers money from self -> balance decreases
(
and_(
trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
trx_subset.c.transfer_user_id != user_id,
),
recursive_cte.c.balance - trx_subset.c.amount,
),
# Throws a product -> if the user is considered to have bought it, balance increases
# TODO: # (
# trx_subset.c.type_ == TransactionType.THROW_PRODUCT,
# recursive_cte.c.balance + trx_subset.c.amount,
# ),
# Interest adjustment -> balance stays the same
# Penalty adjustment -> balance stays the same
else_=recursive_cte.c.balance,
).label("balance"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_INTEREST.as_literal_column(),
trx_subset.c.interest_rate_percent,
),
else_=recursive_cte.c.interest_rate_percent,
).label("interest_rate_percent"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
trx_subset.c.penalty_threshold,
),
else_=recursive_cte.c.penalty_threshold,
).label("penalty_threshold"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
trx_subset.c.penalty_multiplier_percent,
),
else_=recursive_cte.c.penalty_multiplier_percent,
).label("penalty_multiplier_percent"),
)
.select_from(trx_subset)
.where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
)
return recursive_cte.union_all(recursive_elements)
# TODO: create a function for the log that pretty prints the log entries
# for debugging purposes
@dataclass
class UserBalanceLogEntry:
transaction: Transaction
balance: int
interest_rate_percent: int
penalty_threshold: int
penalty_multiplier_percent: int
def is_penalized(self) -> bool:
"""
Returns whether this exact transaction is penalized.
"""
raise NotImplementedError("is_penalized is not implemented yet.")
def user_balance_log(
sql_session: Session,
user: User,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[UserBalanceLogEntry]:
"""
Returns a log of the user's balance over time, including interest and penalty adjustments.
If 'until' is given, only transactions up to that time are considered.
"""
if user.id is None:
raise ValueError("User must be persisted in the database.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
recursive_cte = _user_balance_query(
user.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.execute(
select(
Transaction,
recursive_cte.c.balance,
recursive_cte.c.interest_rate_percent,
recursive_cte.c.penalty_threshold,
recursive_cte.c.penalty_multiplier_percent,
)
.select_from(recursive_cte)
.join(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.order_by(recursive_cte.c.i.asc())
).all()
if result is None:
# If there are no transactions for this user, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
)
return [
UserBalanceLogEntry(
transaction=row[0],
balance=row.balance,
interest_rate_percent=row.interest_rate_percent,
penalty_threshold=row.penalty_threshold,
penalty_multiplier_percent=row.penalty_multiplier_percent,
)
for row in result
]
def user_balance(
sql_session: Session,
user: User,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> int:
"""
Calculates the balance of a user.
If 'until' is given, only transactions up to that time are considered.
"""
if user.id is None:
raise ValueError("User must be persisted in the database.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
recursive_cte = _user_balance_query(
user.id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.scalar(
select(recursive_cte.c.balance)
.order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
)
if result is None:
# If there are no transactions for this user, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
)
return result

View File

@@ -0,0 +1,48 @@
from datetime import datetime
from sqlalchemy import BindParameter, bindparam
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
# NOTE: This absoulutely needs a cache, else we can't stop recursing until we know all owners for all products...
#
# Since we know that the non-owned products will not get renowned by the user by other means,
# we can just check for ownership on the products that have an ADD_PRODUCT transaction for the user.
# between now and the cached time.
#
# However, the opposite way is more difficult. The cache will store which products are owned by which users,
# but we still need to check if the user passes out of ownership for the item, without needing to check past
# the cache time. Maybe we also need to store the queue number(s) per user/product combo in the cache? What if
# a user has products multiple places in the queue, interleaved with other users?
def user_products(
sql_session: Session,
user: User,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[tuple[Product, int]]:
"""
Returns the list of products owned by the user, along with how many of each product they own.
"""
if user.id is None:
raise ValueError("User must be persisted in the database.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
raise NotImplementedError("Not implemented yet, needs caching system first.")

28
tests/helpers.py Normal file
View File

@@ -0,0 +1,28 @@
from datetime import datetime, timedelta
from dibbler.models import Transaction
def assign_times(
transactions: list[Transaction],
start_time: datetime = datetime(2024, 1, 1, 0, 0, 0),
delta: timedelta = timedelta(minutes=1),
) -> None:
"""Assigns datetimes to a list of transactions starting from start_time and incrementing by delta."""
current_time = start_time
for transaction in transactions:
transaction.time = current_time
current_time += delta
def assert_id_order_similar_to_time_order(transactions: list[Transaction]) -> None:
"""Asserts that the order of transaction IDs is similar to the order of their timestamps."""
sorted_by_time = sorted(transactions, key=lambda t: t.time)
sorted_by_id = sorted(transactions, key=lambda t: t.id)
for t1, t2 in zip(sorted_by_time, sorted_by_id):
assert t1.id == t2.id or t1.time == t2.time, (
f"Transaction ID order does not match time order:\n"
f"ID {t1.id} at time {t1.time}\n"
f"ID {t2.id} at time {t2.time}"
)

View File

View File

@@ -0,0 +1,85 @@
from datetime import datetime, timedelta
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
from dibbler.queries import adjust_interest, current_interest
def insert_test_data(sql_session: Session) -> User:
user = User("Test User")
sql_session.add(user)
sql_session.commit()
return user
def test_adjust_interest_unitialized_user(sql_session: Session) -> None:
user = User("Uninitialized User")
with pytest.raises(ValueError, match="User must be persisted in the database."):
adjust_interest(
sql_session,
user=user,
new_interest=4,
message="Attempting to adjust interest for uninitialized user",
)
def test_adjust_interest_no_history(sql_session: Session) -> None:
user = insert_test_data(sql_session)
adjust_interest(
sql_session,
user=user,
new_interest=3,
message="Setting initial interest rate",
)
sql_session.commit()
current_interest_rate = current_interest(sql_session)
assert current_interest_rate == 3
def test_adjust_interest_existing_history(sql_session: Session) -> None:
user = insert_test_data(sql_session)
transactions = [
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 9, 0, 0),
user_id=user.id,
interest_rate_percent=5,
message="Initial interest rate",
),
]
sql_session.add_all(transactions)
sql_session.commit()
current_interest_rate = current_interest(sql_session)
assert current_interest_rate == 5
adjust_interest(
sql_session,
user=user,
new_interest=2,
message="Adjusting interest rate",
time=transactions[-1].time + timedelta(days=1),
)
sql_session.commit()
current_interest_rate = current_interest(sql_session)
assert current_interest_rate == 2
def test_adjust_interest_negative_failure(sql_session: Session) -> None:
user = insert_test_data(sql_session)
with pytest.raises(ValueError, match="Interest rate cannot be negative"):
adjust_interest(
sql_session,
user=user,
new_interest=-1,
message="Attempting to set negative interest rate",
)

View File

@@ -0,0 +1,179 @@
from datetime import datetime, timedelta
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
from dibbler.models.Transaction import (
DEFAULT_PENALTY_MULTIPLIER_PERCENT,
DEFAULT_PENALTY_THRESHOLD,
)
from dibbler.queries import adjust_penalty, current_penalty
def insert_test_data(sql_session: Session) -> User:
user = User("Test User")
sql_session.add(user)
sql_session.commit()
return user
def test_adjust_penalty_empty_not_allowed(sql_session: Session) -> None:
user = insert_test_data(sql_session)
with pytest.raises(ValueError):
adjust_penalty(
sql_session,
user=user,
message="No penalty or multiplier provided",
)
def test_adjust_penalty_unitialized_user(sql_session: Session) -> None:
user = User("Uninitialized User")
with pytest.raises(ValueError):
adjust_penalty(
sql_session,
user=user,
new_penalty=-100,
new_penalty_multiplier=110,
message="Attempting to adjust penalty for uninitialized user",
)
def test_adjust_penalty_no_history(sql_session: Session) -> None:
user = insert_test_data(sql_session)
adjust_penalty(
sql_session,
user=user,
new_penalty=-200,
message="Setting initial interest rate",
)
sql_session.commit()
(penalty, multiplier) = current_penalty(sql_session)
assert penalty == -200
assert multiplier == DEFAULT_PENALTY_MULTIPLIER_PERCENT
def test_adjust_penalty_multiplier_no_history(sql_session: Session) -> None:
user = insert_test_data(sql_session)
adjust_penalty(
sql_session,
user=user,
new_penalty_multiplier=125,
message="Setting initial interest rate",
)
sql_session.commit()
(penalty, multiplier) = current_penalty(sql_session)
assert penalty == DEFAULT_PENALTY_THRESHOLD
assert multiplier == 125
def test_adjust_penalty_multiplier_less_than_100_fail(sql_session: Session) -> None:
user = insert_test_data(sql_session)
adjust_penalty(
sql_session,
user=user,
new_penalty_multiplier=100,
message="Setting initial interest rate",
)
sql_session.commit()
(_, multiplier) = current_penalty(sql_session)
assert multiplier == 100
with pytest.raises(ValueError, match="Penalty multiplier cannot be less than 100%"):
adjust_penalty(
sql_session,
user=user,
new_penalty_multiplier=99,
message="Setting initial interest rate",
)
def test_adjust_penalty_existing_history(sql_session: Session) -> None:
user = insert_test_data(sql_session)
transactions = [
Transaction.adjust_penalty(
time=datetime(2024, 1, 1, 10, 0, 0),
user_id=user.id,
penalty_threshold=-150,
penalty_multiplier_percent=110,
message="Initial penalty settings",
),
]
sql_session.add_all(transactions)
sql_session.commit()
(penalty, _) = current_penalty(sql_session)
assert penalty == -150
adjust_penalty(
sql_session,
user=user,
new_penalty=-250,
message="Adjusting penalty threshold",
time=transactions[-1].time + timedelta(days=1),
)
sql_session.commit()
(penalty, _) = current_penalty(sql_session)
assert penalty == -250
def test_adjust_penalty_multiplier_existing_history(sql_session: Session) -> None:
user = insert_test_data(sql_session)
transactions = [
Transaction.adjust_penalty(
time=datetime(2024, 1, 1, 10, 0, 0),
user_id=user.id,
penalty_threshold=-150,
penalty_multiplier_percent=110,
message="Initial penalty settings",
),
]
sql_session.add_all(transactions)
sql_session.commit()
(_, multiplier) = current_penalty(sql_session)
assert multiplier == 110
adjust_penalty(
sql_session,
user=user,
new_penalty_multiplier=130,
message="Adjusting penalty multiplier",
time=transactions[-1].time + timedelta(days=1),
)
sql_session.commit()
(_, multiplier) = current_penalty(sql_session)
assert multiplier == 130
def test_adjust_penalty_and_multiplier(sql_session: Session) -> None:
user = insert_test_data(sql_session)
adjust_penalty(
sql_session,
user=user,
new_penalty=-300,
new_penalty_multiplier=150,
message="Setting both penalty and multiplier",
)
sql_session.commit()
(penalty, multiplier) = current_penalty(sql_session)
assert penalty == -300
assert multiplier == 150

View File

@@ -0,0 +1,38 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT
from dibbler.models import Transaction, User
from dibbler.queries import current_interest
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
def test_current_interest_no_history(sql_session: Session) -> None:
assert current_interest(sql_session) == DEFAULT_INTEREST_RATE_PERCENT
def test_current_interest_with_history(sql_session: Session) -> None:
user = User("Admin User")
sql_session.add(user)
sql_session.commit()
transactions = [
Transaction.adjust_interest(
interest_rate_percent=5,
user_id=user.id,
),
Transaction.adjust_interest(
interest_rate_percent=7,
user_id=user.id,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
assert current_interest(sql_session) == 7

View File

@@ -0,0 +1,46 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
from dibbler.models.Transaction import (
DEFAULT_PENALTY_MULTIPLIER_PERCENT,
DEFAULT_PENALTY_THRESHOLD,
)
from dibbler.queries import current_penalty
from tests.helpers import assign_times, assert_id_order_similar_to_time_order
def test_current_penalty_no_history(sql_session: Session) -> None:
assert current_penalty(sql_session) == (
DEFAULT_PENALTY_THRESHOLD,
DEFAULT_PENALTY_MULTIPLIER_PERCENT,
)
def test_current_penalty_with_history(sql_session: Session) -> None:
user = User("Admin User")
sql_session.add(user)
sql_session.commit()
transactions = [
Transaction.adjust_penalty(
penalty_threshold=-200,
penalty_multiplier_percent=150,
user_id=user.id,
),
Transaction.adjust_penalty(
penalty_threshold=-300,
penalty_multiplier_percent=200,
user_id=user.id,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
assert current_penalty(sql_session) == (-300, 200)

View File

@@ -0,0 +1,199 @@
from datetime import datetime, timedelta
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries import joint_buy_product
def insert_test_data(sql_session: Session) -> tuple[User, User, User, Product]:
user1 = User("Test User 1")
user2 = User("Test User 2")
user3 = User("Test User 3")
product = Product("1234567890123", "Test Product")
sql_session.add_all([user1, user2, user3, product])
sql_session.commit()
transactions = [
Transaction.add_product(
user_id=user1.id,
product_id=product.id,
amount=30,
per_product=10,
product_count=3,
time=datetime(2024, 1, 1, 10, 0, 0),
)
]
sql_session.add_all(transactions)
sql_session.commit()
return user1, user2, user3, product
def test_joint_buy_product_uninitialized_product(sql_session: Session) -> None:
user = User("Test User 1")
sql_session.add(user)
sql_session.commit()
product = Product("1234567890123", "Uninitialized Product")
with pytest.raises(ValueError):
joint_buy_product(
sql_session,
instigator=user,
users=[user],
product=product,
product_count=1,
)
def test_joint_buy_product_no_users(sql_session: Session) -> None:
user, _, _, product = insert_test_data(sql_session)
with pytest.raises(ValueError):
joint_buy_product(
sql_session,
instigator=user,
users=[],
product=product,
product_count=1,
)
def test_joint_buy_product_uninitialized_instigator(sql_session: Session) -> None:
user, user2, _, product = insert_test_data(sql_session)
uninitialized_user = User("Uninitialized User")
with pytest.raises(ValueError):
joint_buy_product(
sql_session,
instigator=uninitialized_user,
users=[user, user2],
product=product,
product_count=1,
)
def test_joint_buy_product_uninitialized_user_in_list(sql_session: Session) -> None:
user, _, _, product = insert_test_data(sql_session)
uninitialized_user = User("Uninitialized User")
with pytest.raises(ValueError):
joint_buy_product(
sql_session,
instigator=user,
users=[user, uninitialized_user],
product=product,
product_count=1,
)
def test_joint_buy_product_invalid_product_count(sql_session: Session) -> None:
user, _, _, product = insert_test_data(sql_session)
with pytest.raises(ValueError):
joint_buy_product(
sql_session,
instigator=user,
users=[user],
product=product,
product_count=0,
)
with pytest.raises(ValueError):
joint_buy_product(
sql_session,
instigator=user,
users=[user],
product=product,
product_count=-1,
)
def test_joint_single_user(sql_session: Session) -> None:
user, _, _, product = insert_test_data(sql_session)
joint_buy_product(
sql_session,
instigator=user,
users=[user],
product=product,
product_count=1,
)
def test_joint_buy_product(sql_session: Session) -> None:
user, user2, user3, product = insert_test_data(sql_session)
joint_buy_product(
sql_session,
instigator=user,
users=[user, user2, user3],
product=product,
product_count=1,
)
def test_joint_buy_product_more_than_in_stock(sql_session: Session) -> None:
user, user2, user3, product = insert_test_data(sql_session)
joint_buy_product(
sql_session,
instigator=user,
users=[user, user2, user3],
product=product,
product_count=5,
)
def test_joint_buy_product_out_of_stock(sql_session: Session) -> None:
user, user2, user3, product = insert_test_data(sql_session)
transactions = [
Transaction.buy_product(
user_id=user.id,
product_id=product.id,
product_count=3,
time=datetime(2024, 1, 2, 10, 0, 0),
)
]
sql_session.add_all(transactions)
sql_session.commit()
joint_buy_product(
sql_session,
instigator=user,
users=[user, user2, user3],
product=product,
product_count=10,
time=transactions[-1].time + timedelta(days=1),
)
def test_joint_buy_product_duplicate_user(sql_session: Session) -> None:
user, user2, _, product = insert_test_data(sql_session)
joint_buy_product(
sql_session,
instigator=user,
users=[user, user, user2],
product=product,
product_count=1,
)
def test_joint_buy_product_non_involved_instigator(sql_session: Session) -> None:
user, user2, user3, product = insert_test_data(sql_session)
with pytest.raises(ValueError):
joint_buy_product(
sql_session,
instigator=user,
users=[user2, user3],
product=product,
product_count=1,
)

View File

@@ -0,0 +1,335 @@
from datetime import datetime
from pprint import pprint
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, User
from dibbler.models.Transaction import Transaction
from dibbler.queries import product_owners, product_owners_log, product_stock
from tests.helpers import assign_times, assert_id_order_similar_to_time_order
def insert_test_data(sql_session: Session) -> tuple[Product, User]:
user = User("testuser")
product = Product("1234567890123", "Test Product")
sql_session.add(user)
sql_session.add(product)
sql_session.commit()
return product, user
def test_product_owners_unitilialized_product(sql_session: Session) -> None:
user = User("testuser")
sql_session.add(user)
sql_session.commit()
product = Product("1234567890123", "Uninitialized Product")
with pytest.raises(ValueError):
product_owners(sql_session, product)
def test_product_owners_no_transactions(sql_session: Session) -> None:
product, _ = insert_test_data(sql_session)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == []
def test_product_owners_add_products(sql_session: Session) -> None:
product, user = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=30,
per_product=10,
product_count=3,
)
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user, user, user]
def test_product_owners_add_and_buy_products(sql_session: Session) -> None:
product, user = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=30,
per_product=10,
product_count=3,
),
Transaction.buy_product(
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user, user]
def test_product_owners_add_and_throw_products(sql_session: Session) -> None:
product, user = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=40,
per_product=10,
product_count=4,
),
Transaction.throw_product(
user_id=user.id,
product_id=product.id,
product_count=2,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user, user]
def test_product_owners_multiple_users(sql_session: Session) -> None:
product, user1 = insert_test_data(sql_session)
user2 = User("testuser2")
sql_session.add(user2)
sql_session.commit()
transactions = [
Transaction.add_product(
user_id=user1.id,
product_id=product.id,
amount=20,
per_product=10,
product_count=2,
),
Transaction.add_product(
user_id=user2.id,
product_id=product.id,
amount=30,
per_product=10,
product_count=3,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user2, user2, user2, user1, user1]
def test_product_owners_adjust_stock_down(sql_session: Session) -> None:
product, user = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=50,
per_product=10,
product_count=5,
),
Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=-2,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_owners_log(sql_session, product))
assert product_stock(sql_session, product) == 3
owners = product_owners(sql_session, product)
assert owners == [user, user, user]
def test_product_owners_adjust_stock_up(sql_session: Session) -> None:
product, user = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=20,
per_product=10,
product_count=2,
),
Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=3,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user, user, None, None, None]
def test_product_owners_negative_stock(sql_session: Session) -> None:
product, user = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=10,
per_product=10,
product_count=1,
),
Transaction.buy_product(
user_id=user.id,
product_id=product.id,
product_count=2,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
owners = product_owners(sql_session, product)
assert owners == []
def test_product_owners_add_products_from_negative_stock(sql_session: Session) -> None:
product, user = insert_test_data(sql_session)
transactions = [
Transaction.buy_product(
user_id=user.id,
product_id=product.id,
product_count=2,
),
Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=30,
per_product=10,
product_count=3,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user]
def test_product_owners_interleaved_users(sql_session: Session) -> None:
product, user1 = insert_test_data(sql_session)
user2 = User("testuser2")
sql_session.add(user2)
sql_session.commit()
transactions = [
Transaction.add_product(
user_id=user1.id,
product_id=product.id,
amount=20,
per_product=10,
product_count=2,
),
Transaction.add_product(
user_id=user2.id,
product_id=product.id,
amount=30,
per_product=10,
product_count=3,
),
Transaction.buy_product(
user_id=user1.id,
product_id=product.id,
product_count=1,
),
Transaction.add_product(
user_id=user1.id,
product_id=product.id,
amount=10,
per_product=10,
product_count=1,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_owners_log(sql_session, product))
owners = product_owners(sql_session, product)
assert owners == [user1, user2, user2, user2, user1]

View File

@@ -0,0 +1,447 @@
import math
from datetime import datetime, timedelta
from pprint import pprint
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries import product_price, product_price_log, joint_buy_product
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
# TODO: see if we can use pytest_runtest_makereport to print the "product_price_log"s
# only on failures instead of inlining it in every test function
def insert_test_data(sql_session: Session) -> tuple[User, Product]:
user = User("Test User")
product = Product("1234567890123", "Test Product")
sql_session.add(user)
sql_session.add(product)
sql_session.commit()
return user, product
def test_product_price_uninitialized_product(sql_session: Session) -> None:
user = User("Test User")
sql_session.add(user)
sql_session.commit()
product = Product("1234567890123", "Uninitialized Product")
with pytest.raises(ValueError, match="Product must be persisted in the database."):
product_price(sql_session, product)
def test_product_price_no_transactions(sql_session: Session) -> None:
_, product = insert_test_data(sql_session)
pprint(product_price_log(sql_session, product))
assert product_price(sql_session, product) == 0
def test_product_price_basic_history(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
amount=27 * 2 - 1,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
assert product_price(sql_session, product) == 27
def test_product_price_sold_out(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=27 * 2 - 1,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
product_count=2,
user_id=user.id,
product_id=product.id,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_price_log(sql_session, product))
assert product_price(sql_session, product) == 27
def test_product_price_interest(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.adjust_interest(
interest_rate_percent=110,
user_id=user.id,
),
Transaction.add_product(
amount=27 * 2 - 1,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_price_log(sql_session, product))
product_price_ = product_price(sql_session, product)
product_price_interest = product_price(sql_session, product, include_interest=True)
assert product_price_ == 27
assert product_price_interest == math.ceil(27 * 1.1)
def test_product_price_changing_interest(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.adjust_interest(
interest_rate_percent=110,
user_id=user.id,
),
Transaction.add_product(
amount=27 * 2 - 1,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
Transaction.adjust_interest(
interest_rate_percent=120,
user_id=user.id,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_price_log(sql_session, product))
product_price_interest = product_price(sql_session, product, include_interest=True)
assert product_price_interest == math.ceil(27 * 1.2)
def test_product_price_old_transaction(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
# Price should be 27
Transaction.add_product(
amount=38 * 3,
per_product=38,
product_count=3,
user_id=user.id,
product_id=product.id,
),
# price should be averaged upwards
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
until_transaction = transactions[0]
pprint(
product_price_log(
sql_session,
product,
until_transaction=until_transaction,
)
)
product_price_ = product_price(
sql_session,
product,
until_transaction=until_transaction,
)
assert product_price_ == 27
# Price goes up and gets rounded up to the next integer
def test_product_price_round_up_from_below(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
# Price should be 27
Transaction.add_product(
amount=38 * 3,
per_product=38,
product_count=3,
user_id=user.id,
product_id=product.id,
),
# price should be averaged upwards
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_price_log(sql_session, product))
product_price_ = product_price(sql_session, product)
assert product_price_ == math.ceil((27 * 2 + 38 * 3) / (2 + 3))
# Price goes down and gets rounded up to the next integer
def test_product_price_round_up_from_above(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
# Price should be 27
Transaction.add_product(
amount=20 * 3,
per_product=20,
product_count=3,
user_id=user.id,
product_id=product.id,
),
# price should be averaged downwards
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_price_log(sql_session, product))
product_price_ = product_price(sql_session, product)
assert product_price_ == math.ceil((27 * 2 + 20 * 3) / (2 + 3))
def test_product_price_with_negative_stock_single_addition(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=1,
per_product=10,
product_count=1,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
product_count=10,
user_id=user.id,
product_id=product.id,
),
Transaction.add_product(
amount=22,
per_product=22,
product_count=1,
user_id=user.id,
product_id=product.id,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_price_log(sql_session, product))
# Stock went subzero, price should be the last added product price
product1_price = product_price(sql_session, product)
assert product1_price == 22
def test_product_price_with_negative_stock_multiple_additions(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=1,
per_product=10,
product_count=1,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
product_count=10,
user_id=user.id,
product_id=product.id,
),
Transaction.add_product(
amount=22,
per_product=22,
product_count=1,
user_id=user.id,
product_id=product.id,
),
Transaction.add_product(
amount=29,
per_product=29,
product_count=2,
user_id=user.id,
product_id=product.id,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
pprint(product_price_log(sql_session, product))
# Stock went subzero, price should be the last added product price
product1_price = product_price(sql_session, product)
assert product1_price == math.ceil(29)
def test_product_price_joint_transactions(sql_session: Session) -> None:
user1, product = insert_test_data(sql_session)
user2 = User("Test User 2")
sql_session.add(user2)
sql_session.commit()
transactions = [
Transaction.add_product(
amount=30 * 3,
per_product=30,
product_count=3,
user_id=user1.id,
product_id=product.id,
),
Transaction.add_product(
amount=20 * 2,
per_product=20,
product_count=2,
user_id=user2.id,
product_id=product.id,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
product_price_ = product_price(sql_session, product)
assert product_price_ == math.ceil((30 * 3 + 20 * 2) / (3 + 2))
transactions += joint_buy_product(
sql_session,
instigator=user1,
users=[user1, user2],
product=product,
product_count=2,
time=transactions[-1].time + timedelta(seconds=1),
)
pprint(product_price_log(sql_session, product))
old_product_price = product_price_
product_price_ = product_price(sql_session, product)
assert product_price_ == old_product_price, (
"Joint buy transactions should not affect product price"
)
transactions = [
Transaction.add_product(
amount=25 * 4,
per_product=25,
product_count=4,
user_id=user1.id,
product_id=product.id,
time=transactions[-1].time + timedelta(seconds=1),
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
product_price_ = product_price(sql_session, product)
# Expected state:
# Added products:
# Count: 3 + 2 = 5, Price: (30 * 3 + 20 * 2) / 5 = 26
# Joint bought products:
# Count: 5 - 2 = 3, Price: n/a (should not affect price)
# Added products:
# Count: 3 + 4 = 7, Price: (26 * 3 + 25 * 4) / (3 + 4) = 25.57 -> 26
assert product_price_ == math.ceil((26 * 3 + 25 * 4) / (3 + 4))
def test_product_price_until(sql_session: Session) -> None: ...

View File

@@ -0,0 +1,285 @@
from datetime import datetime, timedelta
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries import joint_buy_product, product_stock
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
def insert_test_data(sql_session: Session) -> tuple[User, Product]:
user = User("Test User 1")
product = Product("1234567890123", "Test Product")
sql_session.add(user)
sql_session.add(product)
sql_session.commit()
return user, product
def test_product_stock_uninitialized_product(sql_session: Session) -> None:
user = User("Test User 1")
sql_session.add(user)
sql_session.commit()
product = Product("1234567890123", "Uninitialized Product")
with pytest.raises(ValueError):
product_stock(sql_session, product)
def test_product_stock_until_datetime_and_transaction_id_not_allowed(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transaction = Transaction.add_product(
amount=10,
per_product=10,
user_id=user.id,
product_id=product.id,
product_count=1,
)
with pytest.raises(ValueError):
product_stock(
sql_session,
product,
until_time=datetime.now(),
until_transaction=transaction,
)
def test_product_stock_basic_history(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
sql_session.commit()
transactions = [
Transaction.add_product(
amount=10,
per_product=10,
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
assert product_stock(sql_session, product) == 1
def test_product_stock_adjust_stock_up(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=50,
per_product=10,
product_count=5,
),
Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=2,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
assert product_stock(sql_session, product) == 5 + 2
def test_product_stock_adjust_stock_down(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=50,
per_product=10,
product_count=5,
),
Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=-2,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
assert product_stock(sql_session, product) == 5 - 2
def test_product_stock_complex_history(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=27 * 2,
per_product=27,
user_id=user.id,
product_id=product.id,
product_count=2,
),
Transaction.buy_product(
user_id=user.id,
product_id=product.id,
product_count=3,
),
Transaction.add_product(
amount=50 * 4,
per_product=50,
user_id=user.id,
product_id=product.id,
product_count=4,
),
Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=3,
),
Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=-2,
),
Transaction.throw_product(
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
assert product_stock(sql_session, product) == 2 - 3 + 4 + 3 - 2 - 1
def test_product_stock_no_transactions(sql_session: Session) -> None:
_, product = insert_test_data(sql_session)
assert product_stock(sql_session, product) == 0
def test_negative_product_stock(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=50,
per_product=50,
user_id=user.id,
product_id=product.id,
product_count=1,
),
Transaction.buy_product(
user_id=user.id,
product_id=product.id,
product_count=2,
),
Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=-1,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
# The stock should be negative because we added and bought the product
assert product_stock(sql_session, product) == 1 - 2 - 1
def test_product_stock_joint_transaction(sql_session: Session) -> None:
user1, product = insert_test_data(sql_session)
user2 = User("Test User 2")
sql_session.add(user2)
sql_session.commit()
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 17, 0, 0),
amount=100,
per_product=100,
user_id=user1.id,
product_id=product.id,
product_count=5,
),
]
sql_session.add_all(transactions)
sql_session.commit()
joint_buy_product(
sql_session,
time=transactions[0].time + timedelta(seconds=1),
instigator=user1,
users=[user1, user2],
product=product,
product_count=3,
)
assert product_stock(sql_session, product) == 5 - 3
def test_product_stock_until_time(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=10,
per_product=10,
user_id=user.id,
product_id=product.id,
product_count=1,
),
Transaction.add_product(
amount=20,
per_product=10,
user_id=user.id,
product_id=product.id,
product_count=2,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
assert (
product_stock(
sql_session,
product,
until_time=transactions[-1].time - timedelta(seconds=1),
)
== 1
)

View File

@@ -0,0 +1,96 @@
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product
from dibbler.queries import search_product
def insert_test_data(sql_session: Session) -> list[Product]:
products = [
Product("1234567890123", "Test Product A"),
Product("2345678901234", "Test Product B"),
Product("3456789012345", "Another Product"),
Product("4567890123456", "Hidden Product", hidden=True),
]
sql_session.add_all(products)
sql_session.commit()
return products
def test_search_product_empty_not_allowed(sql_session: Session) -> None:
insert_test_data(sql_session)
with pytest.raises(ValueError):
search_product("", sql_session)
def test_search_product_no_products(sql_session: Session) -> None:
result = search_product("Nonexistent Product", sql_session)
assert isinstance(result, list)
assert len(result) == 0
def test_search_product_name_exact_match(sql_session: Session) -> None:
insert_test_data(sql_session)
result = search_product("Test Product A", sql_session)
assert isinstance(result, Product)
assert result.bar_code == "1234567890123"
def test_search_product_name_partial_match(sql_session: Session) -> None:
insert_test_data(sql_session)
result = search_product("Test Product", sql_session)
assert isinstance(result, list)
assert len(result) == 2
names = {product.name for product in result}
assert names == {"Test Product A", "Test Product B"}
def test_search_product_name_no_match(sql_session: Session) -> None:
insert_test_data(sql_session)
result = search_product("Nonexistent", sql_session)
assert isinstance(result, list)
assert len(result) == 0
def test_search_product_barcode_exact_match(sql_session: Session) -> None:
products = insert_test_data(sql_session)
product = products[1] # Test Product B
result = search_product(product.bar_code, sql_session)
assert isinstance(result, Product)
assert result.name == product.name
# Should not be able to find hidden products
def test_search_product_hidden_products(sql_session: Session) -> None:
insert_test_data(sql_session)
result = search_product("Hidden Product", sql_session)
assert isinstance(result, list)
assert len(result) == 0
# Should be able to find hidden products if specified
def test_search_product_find_hidden_products(sql_session: Session) -> None:
insert_test_data(sql_session)
result = search_product("Hidden Product", sql_session, find_hidden_products=True)
assert isinstance(result, Product)
assert result.name == "Hidden Product"
# Should be able to find hidden products by barcode despite not specified
def test_search_product_hidden_products_by_barcode(sql_session: Session) -> None:
products = insert_test_data(sql_session)
hidden_product = products[3] # Hidden Product
result = search_product(hidden_product.bar_code, sql_session)
assert isinstance(result, Product)
assert result.name == "Hidden Product"

View File

@@ -0,0 +1,86 @@
from sqlalchemy.orm import Session
import pytest
from dibbler.models import User
from dibbler.queries import search_user
USER = [
("alice", 123),
("bob", 125),
("charlie", 126),
("david", 127),
("eve", 128),
("evey", 129),
("evy", 130),
("-symbol-man", 131),
("user_123", 132),
]
def setup_users(sql_session: Session) -> None:
for username, rfid in USER:
user = User(name=username, rfid=str(rfid))
sql_session.add(user)
sql_session.commit()
def test_search_user_empty_not_allowed(sql_session: Session) -> None:
setup_users(sql_session)
with pytest.raises(ValueError):
search_user("", sql_session)
def test_search_user_exact_match(sql_session: Session) -> None:
setup_users(sql_session)
user = search_user("alice", sql_session)
assert user is not None
assert isinstance(user, User)
assert user.name == "alice"
user = search_user("125", sql_session)
assert user is not None
assert isinstance(user, User)
assert user.name == "bob"
def test_search_user_partial_match(sql_session: Session) -> None:
setup_users(sql_session)
users = search_user("ev", sql_session)
assert isinstance(users, list)
assert len(users) == 3
names = {user.name for user in users}
assert names == {"eve", "evey", "evy"}
users = search_user("user", sql_session)
assert isinstance(users, list)
assert len(users) == 1
assert users[0].name == "user_123"
def test_search_user_no_match(sql_session: Session) -> None:
setup_users(sql_session)
result = search_user("nonexistent", sql_session)
assert isinstance(result, list)
assert len(result) == 0
def test_search_user_special_characters(sql_session: Session) -> None:
setup_users(sql_session)
user = search_user("-symbol-man", sql_session)
assert user is not None
assert isinstance(user, User)
assert user.name == "-symbol-man"
def test_search_by_rfid(sql_session: Session) -> None:
setup_users(sql_session)
user = search_user("130", sql_session)
assert user is not None
assert isinstance(user, User)
assert user.name == "evy"

View File

@@ -0,0 +1,687 @@
from datetime import datetime, timedelta
import pytest
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
TransactionType,
User,
)
from dibbler.queries import transaction_log
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
def insert_test_data(sql_session: Session) -> tuple[User, User, Product, Product]:
user1 = User("Test User 1")
user2 = User("Test User 2")
product1 = Product("1234567890123", "Test Product 1")
product2 = Product("9876543210987", "Test Product 2")
sql_session.add_all([user1, user2, product1, product2])
sql_session.commit()
return user1, user2, product1, product2
def insert_default_test_transactions(
sql_session: Session,
user1: User,
user2: User,
product1: Product,
product2: Product,
) -> list[Transaction]:
transactions = [
Transaction.adjust_balance(
amount=100,
user_id=user1.id,
),
Transaction.adjust_balance(
amount=50,
user_id=user2.id,
),
Transaction.adjust_balance(
amount=-50,
user_id=user1.id,
),
Transaction.add_product(
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user1.id,
product_id=product1.id,
),
Transaction.buy_product(
product_count=1,
user_id=user2.id,
product_id=product2.id,
),
Transaction.add_product(
amount=15 * 1,
per_product=15,
product_count=1,
user_id=user2.id,
product_id=product2.id,
),
Transaction.transfer(
amount=30,
user_id=user1.id,
transfer_user_id=user2.id,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
return transactions
def test_transaction_log_invalid_limit(sql_session: Session) -> None:
with pytest.raises(ValueError):
transaction_log(sql_session, limit=0)
with pytest.raises(ValueError):
transaction_log(sql_session, limit=-1)
def test_transaction_log_uninitialized_user(sql_session: Session) -> None:
user = User("Uninitialized User")
with pytest.raises(ValueError):
transaction_log(sql_session, user=user)
def test_transaction_log_uninitialized_product(sql_session: Session) -> None:
product = Product("1234567890123", "Uninitialized Product")
with pytest.raises(ValueError):
transaction_log(sql_session, product=product)
def test_transaction_log_uninitialized_after_until_transaction(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
insert_default_test_transactions(sql_session, user, user2, product, product2)
uninitialized_transaction = Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0),
amount=100,
user_id=user.id,
)
with pytest.raises(ValueError):
transaction_log(
sql_session,
user=user,
after_transaction=uninitialized_transaction,
)
with pytest.raises(ValueError):
transaction_log(
sql_session,
user=user,
until_transaction=uninitialized_transaction,
)
def test_transaction_log_product_and_user_not_allowed(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
insert_default_test_transactions(sql_session, user, user2, product, product2)
with pytest.raises(ValueError):
transaction_log(
sql_session,
user=user,
product=product,
)
def test_transaction_log_until_datetime_and_transaction_id_not_allowed(
sql_session: Session,
) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
insert_default_test_transactions(sql_session, user, user2, product, product2)
trx = Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0),
amount=100,
user_id=user.id,
)
sql_session.add(trx)
sql_session.commit()
with pytest.raises(ValueError):
transaction_log(
sql_session,
user=user,
until_time=datetime(2023, 10, 1, 11, 0, 0),
until_transaction=trx,
)
def test_transaction_log_after_datetime_and_transaction_id_not_allowed(
sql_session: Session,
) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
insert_default_test_transactions(sql_session, user, user2, product, product2)
trx = Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0),
amount=100,
user_id=user.id,
)
sql_session.add(trx)
sql_session.commit()
with pytest.raises(ValueError):
transaction_log(
sql_session,
user=user,
after_time=datetime(2023, 10, 1, 15, 0, 0),
after_transaction=trx,
)
def test_user_transactions_no_transactions(sql_session: Session) -> None:
insert_test_data(sql_session)
transactions = transaction_log(sql_session)
assert len(transactions) == 0
def test_transaction_log_basic(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
insert_default_test_transactions(sql_session, user, user2, product, product2)
assert len(transaction_log(sql_session)) == 7
def test_transaction_log_filtered_by_user(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
insert_default_test_transactions(sql_session, user, user2, product, product2)
assert len(transaction_log(sql_session, user=user)) == 4
assert len(transaction_log(sql_session, user=user2)) == 3
def test_transaction_log_filtered_by_product(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
insert_default_test_transactions(sql_session, user, user2, product, product2)
assert len(transaction_log(sql_session, product=product)) == 1
assert len(transaction_log(sql_session, product=product2)) == 2
def test_transaction_log_after_datetime(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
assert (
len(
transaction_log(
sql_session,
after_time=transactions[2].time,
)
)
== len(transactions) - 2
)
def test_transaction_log_after_datetime_no_transactions(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
assert (
len(
transaction_log(
sql_session,
after_time=transactions[-1].time + timedelta(seconds=1),
)
)
== 0
)
def test_transaction_log_after_datetime_exclusive(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
assert (
len(
transaction_log(
sql_session,
after_time=transactions[2].time,
after_inlcusive=False,
)
)
== len(transactions) - 3
)
def test_transaction_log_after_transaction_id(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
first_transaction = transactions[0]
assert len(
transaction_log(
sql_session,
after_transaction=first_transaction,
)
) == len(transactions)
def test_transaction_log_after_transaction_id_one_transaction(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
last_transaction = transactions[-1]
assert (
len(
transaction_log(
sql_session,
after_transaction=last_transaction,
)
)
== 1
)
def test_transaction_log_after_transaction_id_exclusive(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
third_transaction = transactions[2]
assert (
len(
transaction_log(
sql_session,
after_transaction=third_transaction,
after_inlcusive=False,
)
)
== len(transactions) - 3
)
def test_transaction_log_until_datetime(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
assert (
len(
transaction_log(
sql_session,
until_time=transactions[-3].time,
)
)
== len(transactions) - 2
)
def test_transaction_log_until_datetime_no_transactions(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
assert (
len(
transaction_log(
sql_session,
until_time=transactions[0].time - timedelta(seconds=1),
)
)
== 0
)
def test_transaction_log_until_datetime_exclusive(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
assert (
len(
transaction_log(
sql_session,
until_time=transactions[-3].time,
until_inclusive=False,
)
)
== len(transactions) - 3
)
def test_transaction_log_until_transaction(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
last_transaction = transactions[-3]
assert (
len(
transaction_log(
sql_session,
until_transaction=last_transaction,
)
)
== len(transactions) - 2
)
def test_transaction_log_until_transaction_one_transaction(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
first_transaction = transactions[0]
assert (
len(
transaction_log(
sql_session,
until_transaction=first_transaction,
)
)
== 1
)
def test_transaction_log_until_transaction_exclusive(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
last_transaction = transactions[-3]
assert (
len(
transaction_log(
sql_session,
until_transaction=last_transaction,
until_inclusive=False,
)
)
== len(transactions) - 3
)
def test_transaction_log_after_until_datetime_illegal_order(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
second_transaction = transactions[1]
fifth_transaction = transactions[4]
with pytest.raises(ValueError):
transaction_log(
sql_session,
after_time=fifth_transaction.time,
until_time=second_transaction.time,
)
def test_transaction_log_after_until_datetime_combined(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
second_transaction = transactions[1]
fifth_transaction = transactions[4]
assert (
len(
transaction_log(
sql_session,
after_time=second_transaction.time,
until_time=fifth_transaction.time,
)
)
== 4
)
def test_transaction_log_after_until_transaction_illegal_order(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
second_transaction = transactions[1]
fifth_transaction = transactions[4]
with pytest.raises(ValueError):
transaction_log(
sql_session,
after_transaction=fifth_transaction,
until_transaction=second_transaction,
)
def test_transaction_log_after_until_transaction_combined(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
second_transaction = transactions[1]
fifth_transaction = transactions[4]
assert (
len(
transaction_log(
sql_session,
after_transaction=second_transaction,
until_transaction=fifth_transaction,
)
)
== 4
)
def test_transaction_log_after_date_until_transaction(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
second_transaction = transactions[1]
fifth_transaction = transactions[4]
assert (
len(
transaction_log(
sql_session,
after_time=second_transaction.time,
until_transaction=fifth_transaction,
)
)
== 4
)
def test_transaction_log_after_transaction_until_date(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
second_transaction = transactions[1]
fifth_transaction = transactions[4]
assert (
len(
transaction_log(
sql_session,
after_transaction=second_transaction,
until_time=fifth_transaction.time,
)
)
== 4
)
def test_transaction_log_limit(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
assert len(transaction_log(sql_session, limit=3)) == 3
assert len(transaction_log(sql_session, limit=len(transactions) + 3)) == len(transactions)
def test_transaction_log_filtered_by_transaction_type(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
insert_default_test_transactions(sql_session, user, user2, product, product2)
assert (
len(
transaction_log(
sql_session,
transaction_type=[TransactionType.ADJUST_BALANCE],
)
)
== 3
)
assert (
len(
transaction_log(
sql_session,
transaction_type=[TransactionType.ADD_PRODUCT],
)
)
== 2
)
assert (
len(
transaction_log(
sql_session,
transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT],
)
)
== 3
)
def test_transaction_log_filtered_by_transaction_type_negated(sql_session: Session) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
assert (
len(
transaction_log(
sql_session,
transaction_type=[TransactionType.ADJUST_BALANCE],
negate_transaction_type_filter=True,
)
)
== len(transactions) - 3
)
assert (
len(
transaction_log(
sql_session,
transaction_type=[TransactionType.ADD_PRODUCT],
negate_transaction_type_filter=True,
)
)
== len(transactions) - 2
)
assert (
len(
transaction_log(
sql_session,
transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT],
negate_transaction_type_filter=True,
)
)
== len(transactions) - 3
)
def test_transaction_log_combined_filter_user_datetime_transaction_type_limit(
sql_session: Session,
) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
second_transaction = transactions[1]
sixth_transaction = transactions[5]
result = transaction_log(
sql_session,
user=user,
after_time=second_transaction.time,
until_time=sixth_transaction.time,
transaction_type=[TransactionType.ADJUST_BALANCE, TransactionType.ADD_PRODUCT],
limit=2,
)
assert len(result) == 2
def test_transaction_log_combined_filter_user_transaction_transaction_type_limit(
sql_session: Session,
) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
second_transaction = transactions[1]
sixth_transaction = transactions[5]
result = transaction_log(
sql_session,
user=user,
after_transaction=second_transaction,
until_transaction=sixth_transaction,
transaction_type=[TransactionType.ADJUST_BALANCE, TransactionType.ADD_PRODUCT],
limit=2,
)
assert len(result) == 2
def test_transaction_log_combined_filter_product_datetime_transaction_type_limit(
sql_session: Session,
) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
second_transaction = transactions[1]
sixth_transaction = transactions[5]
result = transaction_log(
sql_session,
product=product2,
after_time=second_transaction.time,
until_time=sixth_transaction.time,
transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT],
limit=2,
)
assert len(result) == 2
def test_transaction_log_combined_filter_product_transaction_transaction_type_limit(
sql_session: Session,
) -> None:
user, user2, product, product2 = insert_test_data(sql_session)
transactions = insert_default_test_transactions(sql_session, user, user2, product, product2)
second_transaction = transactions[1]
sixth_transaction = transactions[5]
result = transaction_log(
sql_session,
product=product2,
after_transaction=second_transaction,
until_transaction=sixth_transaction,
transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT],
limit=2,
)
assert len(result) == 2
# NOTE: see the corresponding TODO's above the function definition
@pytest.mark.skip(reason="Not yet implemented")
def test_transaction_log_filtered_by_user_joint_transactions(sql_session: Session) -> None: ...
@pytest.mark.skip(reason="Not yet implemented")
def test_transaction_log_filtered_by_user_throw_away_transactions(sql_session: Session) -> None: ...

File diff suppressed because it is too large Load Diff