This commit is contained in:
2025-06-10 20:59:38 +02:00
parent 9f5999854f
commit 745db277ec
20 changed files with 1324 additions and 254 deletions

View File

@@ -10,12 +10,16 @@ from sqlalchemy.orm.collections import (
)
def _pascal_case_to_snake_case(name: str) -> str:
return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
class Base(DeclarativeBase):
metadata = MetaData(
naming_convention={
"ix": "ix_%(column_0_label)s",
"ix": "ix_%(table_name)s_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_`%(constraint_name)s`",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
@@ -23,7 +27,7 @@ class Base(DeclarativeBase):
@declared_attr.directive
def __tablename__(cls) -> str:
return cls.__name__
return _pascal_case_to_snake_case(cls.__name__)
def __repr__(self) -> str:
columns = ", ".join(

View File

@@ -0,0 +1,10 @@
from datetime import datetime
from sqlalchemy import Integer, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from dibbler.models import Base
class InterestRate(Base):
timestamp: Mapped[datetime] = mapped_column(DateTime)
percentage: Mapped[int] = mapped_column(Integer)

View File

@@ -1,47 +1,129 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Self
from sqlalchemy import (
Boolean,
Integer,
String,
case,
func,
select,
)
from sqlalchemy.orm import (
Mapped,
Session,
mapped_column,
relationship,
)
from .Base import Base
import dibbler.models.User as user
if TYPE_CHECKING:
from .PurchaseEntry import PurchaseEntry
from .UserProducts import UserProducts
from .Base import Base
from .Transaction import Transaction
from .TransactionType import TransactionType
# if TYPE_CHECKING:
# from .PurchaseEntry import PurchaseEntry
# from .UserProducts import UserProducts
class Product(Base):
__tablename__ = "products"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
product_id: Mapped[int] = mapped_column(Integer, primary_key=True)
bar_code: Mapped[str] = mapped_column(String(13))
bar_code: Mapped[str] = mapped_column(String(13), unique=True)
name: Mapped[str] = mapped_column(String(45))
price: Mapped[int] = mapped_column(Integer)
stock: Mapped[int] = mapped_column(Integer)
hidden: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# price: Mapped[int] = mapped_column(Integer)
# stock: Mapped[int] = mapped_column(Integer)
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
purchases: Mapped[set[PurchaseEntry]] = relationship(back_populates="product")
users: Mapped[set[UserProducts]] = relationship(back_populates="product")
bar_code_re = r"[0-9]+"
name_re = r".+"
name_length = 45
def __init__(self, bar_code, name, price, stock=0, hidden=False):
self.name = name
def __init__(
self: Self,
bar_code: str,
name: str,
hidden: bool = False,
) -> None:
self.bar_code = bar_code
self.price = price
self.stock = stock
self.name = name
self.hidden = hidden
def __str__(self):
return self.name
# - count (virtual)
def stock(self: Self, sql_session: Session) -> int:
"""
Returns the number of products in stock.
"""
result = sql_session.scalars(
select(
func.sum(
case(
(
Transaction.type_ == TransactionType.ADD_PRODUCT,
Transaction.product_count,
),
(
Transaction.type_ == TransactionType.BUY_PRODUCT,
-Transaction.product_count,
),
(
Transaction.type_ == TransactionType.ADJUST_STOCK,
Transaction.product_count,
),
else_=0,
)
)
).where(
Transaction.type_.in_(
[
TransactionType.BUY_PRODUCT,
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_STOCK,
]
),
Transaction.product_id == self.id,
)
).one_or_none()
return result or 0
def remaining_with_exact_price(self: Self, sql_session: Session) -> list[int]:
"""
Retrieves the remaining products with their exact price as they were bought.
"""
stock = self.stock(sql_session)
# TODO: only retrieve as many transactions as exists in the stock
last_added = sql_session.scalars(
select(
func.row_number(),
Transaction.time,
Transaction.per_product,
Transaction.product_count,
)
.where(
Transaction.type_ == TransactionType.ADD_PRODUCT,
Transaction.product_id == self.id,
)
.order_by(Transaction.time.desc())
).all()
# result = []
# while stock > 0 and last_added:
...
def price(self: Self, sql_session: Session) -> int:
"""
Returns the price of the product.
Average price over the last bought products.
"""
return Transaction.product_price(sql_session=sql_session, product=self)
def owned_by_user(self: Self, sql_session: Session) -> dict[user.User, int]:
"""
Returns an overview of how many of the remaining products are owned by which user.
"""
...

View File

@@ -0,0 +1,11 @@
from datetime import datetime
from sqlalchemy import Integer, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from dibbler.models import Base
class ProductPriceCache(Base):
product_id: Mapped[int] = mapped_column(Integer, primary_key=True)
timestamp: Mapped[datetime] = mapped_column(DateTime)
price: Mapped[int] = mapped_column(Integer)

View File

@@ -1,70 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from datetime import datetime
import math
from sqlalchemy import (
DateTime,
Integer,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
from .Transaction import Transaction
if TYPE_CHECKING:
from .PurchaseEntry import PurchaseEntry
class Purchase(Base):
__tablename__ = "purchases"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
time: Mapped[datetime] = mapped_column(DateTime)
price: Mapped[int] = mapped_column(Integer)
transactions: Mapped[set[Transaction]] = relationship(
back_populates="purchase", order_by="Transaction.user_name"
)
entries: Mapped[set[PurchaseEntry]] = relationship(back_populates="purchase")
def __init__(self):
pass
def is_complete(self):
return len(self.transactions) > 0 and len(self.entries) > 0
def price_per_transaction(self, round_up=True):
if round_up:
return int(math.ceil(float(self.price) / len(self.transactions)))
else:
return int(math.floor(float(self.price) / len(self.transactions)))
def set_price(self, round_up=True):
self.price = 0
for entry in self.entries:
self.price += entry.amount * entry.product.price
if len(self.transactions) > 0:
for t in self.transactions:
t.amount = self.price_per_transaction(round_up=round_up)
def perform_purchase(self, ignore_penalty=False, round_up=True):
self.time = datetime.datetime.now()
self.set_price(round_up=round_up)
for t in self.transactions:
t.perform_transaction(ignore_penalty=ignore_penalty)
for entry in self.entries:
entry.product.stock -= entry.amount
def perform_soft_purchase(self, price, round_up=True):
self.time = datetime.datetime.now()
self.price = price
for t in self.transactions:
t.amount = self.price_per_transaction(round_up=round_up)
for t in self.transactions:
t.perform_transaction()

View File

@@ -1,37 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import (
Integer,
ForeignKey,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
if TYPE_CHECKING:
from .Product import Product
from .Purchase import Purchase
class PurchaseEntry(Base):
__tablename__ = "purchase_entries"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
amount: Mapped[int] = mapped_column(Integer)
product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id"))
purchase_id: Mapped[int] = mapped_column(ForeignKey("purchases.id"))
product: Mapped[Product] = relationship(lazy="joined")
purchase: Mapped[Purchase] = relationship(lazy="joined")
def __init__(self, purchase, product, amount):
self.product = product
self.product_bar_code = product.bar_code
self.purchase = purchase
self.amount = amount

View File

@@ -1,52 +1,601 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from datetime import datetime
from typing import TYPE_CHECKING, Self
from sqlalchemy import (
CheckConstraint,
DateTime,
ForeignKey,
Integer,
String,
Text,
asc,
case,
cast,
func,
literal,
select,
)
from sqlalchemy import (
Enum as SQLEnum,
)
from sqlalchemy.orm import (
Mapped,
Session,
mapped_column,
relationship,
)
from sqlalchemy.sql.schema import Index
from .Base import Base
from .TransactionType import TransactionType
if TYPE_CHECKING:
from .Product import Product
from .User import User
from .Purchase import Purchase
# TODO: allow for joint transactions?
# dibbler allows joint transactions (e.g. buying more than one product at once, several people buying the same product, etc.)
# instead of having the software split the transactions up, making them hard to reconnect,
# maybe we should add some sort of joint transaction id field to allow multiple transactions to be grouped together?
_DYNAMIC_FIELDS: set[str] = {
"per_product",
"user_id",
"transfer_user_id",
"product_id",
"product_count",
}
_EXPECTED_FIELDS: dict[TransactionType, set[str]] = {
TransactionType.ADJUST_BALANCE: {"user_id"},
TransactionType.ADJUST_STOCK: {"user_id", "product_id", "product_count"},
TransactionType.TRANSFER: {"user_id", "transfer_user_id"},
TransactionType.ADD_PRODUCT: {"user_id", "product_id", "per_product", "product_count"},
TransactionType.BUY_PRODUCT: {"user_id", "product_id", "product_count"},
}
def _transaction_type_field_constraints(
transaction_type: TransactionType,
expected_fields: set[str],
) -> CheckConstraint:
unexpected_fields = _DYNAMIC_FIELDS - expected_fields
expected_constraints = ["{} IS NOT NULL".format(field) for field in expected_fields]
unexpected_constraints = ["{} IS NULL".format(field) for field in unexpected_fields]
constraints = expected_constraints + unexpected_constraints
# TODO: use sqlalchemy's `and_` and `or_` to build the constraints
return CheckConstraint(
f"type <> '{transaction_type}' OR ({' AND '.join(constraints)})",
name=f"trx_type_{transaction_type.value}_expected_fields",
)
class Transaction(Base):
__tablename__ = "transactions"
__table_args__ = (
*[
_transaction_type_field_constraints(transaction_type, expected_fields)
for transaction_type, expected_fields in _EXPECTED_FIELDS.items()
],
# Speed up product count calculation
Index("product_user_time", "product_id", "user_id", "time"),
# Speed up product owner calculation
Index("user_product_time", "user_id", "product_id", "time"),
# Speed up user transaction list / credit calculation
Index("user_time", "user_id", "time"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
time: Mapped[datetime] = mapped_column(DateTime, unique=True)
message: Mapped[str | None] = mapped_column(Text, nullable=True)
time: Mapped[datetime] = mapped_column(DateTime)
# The type of transaction
type_: Mapped[TransactionType] = mapped_column(SQLEnum(TransactionType), name="type")
# TODO: this should be inferred
# If buying products, is the user penalized for having too low credit?
# penalty: Mapped[Boolean] = mapped_column(Boolean, default=False)
# The amount of money being added or subtracted from the user's credit
# This amount means different things depending on the transaction type:
# - ADJUST_BALANCE: The amount of credit to add or subtract from the user's balance
# - ADJUST_STOCK: The amount of money which disappeared with this stock adjustment
# (i.e. current price * product_count)
# - TRANSFER: The amount of credit to transfer to another user
# - ADD_PRODUCT: The real amount spent on the products
# (i.e. not per_product * product_count, which should be rounded up)
# - BUY_PRODUCT: The amount of credit spent on the product
amount: Mapped[int] = mapped_column(Integer)
penalty: Mapped[int] = mapped_column(Integer)
description: Mapped[str | None] = mapped_column(String(50))
user_name: Mapped[str] = mapped_column(ForeignKey("users.name"))
purchase_id: Mapped[int | None] = mapped_column(ForeignKey("purchases.id"))
# If adding products, how much is each product worth
per_product: Mapped[int | None] = mapped_column(Integer)
user: Mapped[User] = relationship(lazy="joined")
purchase: Mapped[Purchase] = relationship(lazy="joined")
# The user who performs the transaction
user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"))
user: Mapped[User | None] = relationship(
lazy="joined",
foreign_keys=[user_id],
)
def __init__(self, user, amount=0, description=None, purchase=None, penalty=1):
self.user = user
# Receiving user when moving credit from one user to another
transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"))
transfer_user: Mapped[User | None] = relationship(
lazy="joined",
foreign_keys=[transfer_user_id],
)
# The product that is either being added or bought
product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id"))
product: Mapped[Product | None] = relationship(lazy="joined")
# The amount of products being added or bought
product_count: Mapped[int | None] = mapped_column(Integer)
def __init__(
self: Self,
type_: TransactionType,
user_id: int,
amount: int,
time: datetime | None = None,
message: str | None = None,
product_id: int | None = None,
transfer_user_id: int | None = None,
per_product: int | None = None,
product_count: int | None = None,
# penalty: bool = False
) -> None:
if time is None:
time = datetime.now()
self.time = time
self.message = message
self.type_ = type_
self.amount = amount
self.description = description
self.purchase = purchase
self.penalty = penalty
self.user_id = user_id
self.product_id = product_id
self.transfer_user_id = transfer_user_id
self.per_product = per_product
self.product_count = product_count
# self.penalty = penalty
def perform_transaction(self, ignore_penalty=False):
self.time = datetime.datetime.now()
if not ignore_penalty:
self.amount *= self.penalty
self.user.credit -= self.amount
self._validate_by_transaction_type()
def _validate_by_transaction_type(self: Self) -> None:
"""
Validates the transaction based on its type.
Raises ValueError if the transaction is invalid.
"""
# TODO: do we allow free products?
if self.amount == 0:
raise ValueError("Amount must not be zero.")
for field in _EXPECTED_FIELDS[self.type_]:
if getattr(self, field) is None:
raise ValueError(f"{field} must not be None for {self.type_.value} transactions.")
for field in _DYNAMIC_FIELDS - _EXPECTED_FIELDS[self.type_]:
if getattr(self, field) is not None:
raise ValueError(f"{field} must be None for {self.type_.value} transactions.")
if self.per_product is not None and self.per_product <= 0:
raise ValueError("per_product must be greater than zero.")
if (
self.per_product is not None
and self.product_count is not None
and self.amount > self.per_product * self.product_count
):
raise ValueError(
"The real amount of the transaction must be less than the total value of the products."
)
###################
# FACTORY METHODS #
###################
@classmethod
def adjust_balance(
cls: type[Self],
amount: int,
user_id: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
"""
Creates an ADJUST transaction.
"""
return cls(
time=time,
type_=TransactionType.ADJUST_BALANCE,
amount=amount,
user_id=user_id,
message=message,
)
@classmethod
def adjust_stock(
cls: type[Self],
amount: int,
user_id: int,
product_id: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
"""
Creates an ADJUST_STOCK transaction.
"""
return cls(
time=time,
type_=TransactionType.ADJUST_STOCK,
amount=amount,
user_id=user_id,
product_id=product_id,
product_count=product_count,
message=message,
)
@classmethod
def adjust_stock_auto_amount(
cls: type[Self],
sql_session: Session,
user_id: int,
product_id: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
"""
Creates an ADJUST_STOCK transaction with the amount automatically calculated based on the product's current price.
"""
from .Product import Product
product = sql_session.scalar(select(Product).where(Product.id == product_id))
if product is None:
raise ValueError(f"Product with id {product_id} does not exist.")
price = product.price(sql_session)
return cls(
time=time,
type_=TransactionType.ADJUST_STOCK,
amount=price * product_count,
user_id=user_id,
product_id=product_id,
product_count=product_count,
message=message,
)
@classmethod
def transfer(
cls: type[Self],
amount: int,
user_id: int,
transfer_user_id: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
"""
Creates a TRANSFER transaction.
"""
return cls(
time=time,
type_=TransactionType.TRANSFER,
amount=amount,
user_id=user_id,
transfer_user_id=transfer_user_id,
message=message,
)
@classmethod
def add_product(
cls: type[Self],
amount: int,
user_id: int,
product_id: int,
per_product: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
"""
Creates an ADD_PRODUCT transaction.
"""
return cls(
time=time,
type_=TransactionType.ADD_PRODUCT,
amount=amount,
user_id=user_id,
product_id=product_id,
per_product=per_product,
product_count=product_count,
message=message,
)
@classmethod
def buy_product(
cls: type[Self],
amount: int,
user_id: int,
product_id: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
"""
Creates a BUY_PRODUCT transaction.
"""
return cls(
time=time,
type_=TransactionType.BUY_PRODUCT,
amount=amount,
user_id=user_id,
product_id=product_id,
product_count=product_count,
message=message,
)
@classmethod
def buy_product_auto_amount(
cls: type[Self],
sql_session: Session,
user_id: int,
product_id: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
"""
Creates a BUY_PRODUCT transaction with the amount automatically calculated based on the product's current price.
"""
from .Product import Product
product = sql_session.scalar(select(Product).where(Product.id == product_id))
if product is None:
raise ValueError(f"Product with id {product_id} does not exist.")
price = product.price(sql_session)
return cls(
time=time,
type_=TransactionType.BUY_PRODUCT,
amount=price * product_count,
user_id=user_id,
product_id=product_id,
product_count=product_count,
message=message,
)
############################
# USER BALANCE CALCULATION #
############################
@staticmethod
def _user_balance_query(
user: User,
# until: datetime | None = None,
):
"""
The inner query for calculating the user's balance.
This is used both directly via user_balance() and in Transaction CHECK constraints.
"""
balance_adjustments = (
select(func.coalesce(func.sum(Transaction.amount).label("balance_adjustments"), 0))
.where(
Transaction.user_id == user.id,
Transaction.type_ == TransactionType.ADJUST_BALANCE,
)
.scalar_subquery()
)
transfers_to_other_users = (
select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_other_users"), 0))
.where(
Transaction.user_id == user.id,
Transaction.type_ == TransactionType.TRANSFER,
)
.scalar_subquery()
)
transfers_to_self = (
select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_self"), 0))
.where(
Transaction.transfer_user_id == user.id,
Transaction.type_ == TransactionType.TRANSFER,
)
.scalar_subquery()
)
add_products = (
select(func.coalesce(func.sum(Transaction.amount).label("add_products"), 0))
.where(
Transaction.user_id == user.id,
Transaction.type_ == TransactionType.ADD_PRODUCT,
)
.scalar_subquery()
)
buy_products = (
select(func.coalesce(func.sum(Transaction.amount).label("buy_products"), 0))
.where(
Transaction.user_id == user.id,
Transaction.type_ == TransactionType.BUY_PRODUCT,
)
.scalar_subquery()
)
query = select(
# TODO: clearly define and fix the sign of the amount
(
0
+ balance_adjustments
- transfers_to_other_users
+ transfers_to_self
+ add_products
- buy_products
).label("credit")
)
return query
@staticmethod
def user_balance(
sql_session: Session,
user: User,
# Optional: calculate the balance until a certain transaction.
# until: Transaction | None = None,
) -> int:
"""
Calculates the balance of a user.
"""
query = Transaction._user_balance_query(user) # , until=until)
result = sql_session.scalar(query)
if result is None:
# If there are no transactions for this user, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
)
return result
#############################
# PRODUCT PRICE CALCULATION #
#############################
@staticmethod
def _product_price_query(
product: Product,
# until: datetime | None = None,
):
"""
The inner query for calculating the product price.
This is used both directly via product_price() and in Transaction CHECK constraints.
"""
initial_element = select(
literal(0).label("i"),
literal(0).label("time"),
literal(0).label("price"),
literal(0).label("product_count"),
)
recursive_cte = initial_element.cte(name="rec_cte", recursive=True)
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=asc(Transaction.time)).label("i"),
Transaction.time,
Transaction.type_,
Transaction.product_count,
Transaction.per_product,
)
.where(
Transaction.type_.in_(
[
TransactionType.BUY_PRODUCT,
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_STOCK,
]
),
Transaction.product_id == product.id,
# TODO:
# If we have a transaction to limit the price calculation to, use it.
# If not, use all transactions for the product.
# (Transaction.time <= until.time) if until else True,
)
.order_by(Transaction.time.asc())
.alias("trx_subset")
)
recursive_elements = (
select(
trx_subset.c.i,
trx_subset.c.time,
case(
# Someone buys the product -> price remains the same.
(trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price),
# Someone adds the product -> price is recalculated based on
# product count, previous price, and new price.
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
cast(
func.ceil(
(trx_subset.c.per_product * trx_subset.c.product_count)
/ (
# The running product count can be negative if the accounting is bad.
# This ensures that we never end up with negative prices or zero divisions
# and other disastrous phenomena.
func.min(recursive_cte.c.product_count, 0)
+ trx_subset.c.product_count
)
),
Integer,
),
),
# Someone adjusts the stock -> price remains the same.
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK, recursive_cte.c.price),
# Should never happen
else_=recursive_cte.c.price,
).label("price"),
case(
# Someone buys the product -> product count is reduced.
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT,
recursive_cte.c.product_count - trx_subset.c.product_count,
),
# Someone adds the product -> product count is increased.
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
recursive_cte.c.product_count + trx_subset.c.product_count,
),
# Someone adjusts the stock -> product count is adjusted.
(
trx_subset.c.type_ == TransactionType.ADJUST_STOCK,
recursive_cte.c.product_count + trx_subset.c.product_count,
),
# Should never happen
else_=recursive_cte.c.product_count,
).label("product_count"),
)
.select_from(trx_subset)
.where(trx_subset.c.i == recursive_cte.c.i + 1)
)
return recursive_cte.union_all(recursive_elements)
@staticmethod
def product_price(
sql_session: Session,
product: Product,
# Optional: calculate the price until a certain transaction.
# until: Transaction | None = None,
) -> int:
"""
Calculates the price of a product.
"""
recursive_cte = Transaction._product_price_query(product) # , until=until)
# TODO: optionally verify subresults:
# - product_count should never be negative (but this happens sometimes, so just a warning)
# - price should never be negative
result = sql_session.scalar(
select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1)
)
if result is None:
# If there are no transactions for this product, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})."
)
return result

View File

@@ -0,0 +1,13 @@
from enum import Enum
class TransactionType(Enum):
"""
Enum for transaction types.
"""
ADJUST_BALANCE = "adjust_balance"
ADJUST_STOCK = "adjust_stock"
TRANSFER = "transfer"
ADD_PRODUCT = "add_product"
BUY_PRODUCT = "buy_product"

View File

@@ -1,49 +1,75 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Self
from sqlalchemy import (
Integer,
String,
select,
)
from sqlalchemy.orm import (
Mapped,
Session,
mapped_column,
relationship,
)
from .Base import Base
import dibbler.models.Product as product
if TYPE_CHECKING:
from .UserProducts import UserProducts
from .Transaction import Transaction
from .Base import Base
from .Transaction import Transaction
class User(Base):
__tablename__ = "users"
name: Mapped[str] = mapped_column(String(10), primary_key=True)
credit: Mapped[str] = mapped_column(Integer)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String(20), unique=True)
card: Mapped[str | None] = mapped_column(String(20))
rfid: Mapped[str | None] = mapped_column(String(20))
products: Mapped[set[UserProducts]] = relationship(back_populates="user")
transactions: Mapped[set[Transaction]] = relationship(back_populates="user")
# name_re = r"[a-z]+"
# card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?"
# rfid_re = r"[0-9a-fA-F]*"
name_re = r"[a-z]+"
card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?"
rfid_re = r"[0-9a-fA-F]*"
def __init__(self, name, card, rfid=None, credit=0):
def __init__(self: Self, name: str, card: str | None = None, rfid: str | None = None) -> None:
self.name = name
if card == "":
card = None
self.card = card
if rfid == "":
rfid = None
self.rfid = rfid
self.credit = credit
def __str__(self):
return self.name
# def __str__(self):
# return self.name
def is_anonymous(self):
return self.card == "11122233"
# def is_anonymous(self):
# return self.card == "11122233"
# TODO: rename to 'balance' everywhere
def credit(self, sql_session: Session) -> int:
"""
Returns the current credit of the user.
"""
result = Transaction.user_balance(
sql_session=sql_session,
user=self,
)
return result
def products(self, sql_session: Session) -> list[tuple[product.Product, int]]:
"""
Returns the products that the user has put into the system (and has not been purchased yet)
"""
...
def transactions(self, sql_session: Session) -> list[Transaction]:
"""
Returns the transactions of the user in chronological order.
"""
return list(
sql_session.scalars(
select(Transaction)
.where(Transaction.user_id == self.id)
.order_by(Transaction.time.asc())
).all()
)

View File

@@ -0,0 +1,11 @@
from datetime import datetime
from sqlalchemy import Integer, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from dibbler.models import Base
class UserBalanceCache(Base):
user_id: Mapped[int] = mapped_column(Integer, primary_key=True)
timestamp: Mapped[datetime] = mapped_column(DateTime)
balance: Mapped[int] = mapped_column(Integer)

View File

@@ -1,31 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import (
Integer,
ForeignKey,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
if TYPE_CHECKING:
from .User import User
from .Product import Product
class UserProducts(Base):
__tablename__ = "user_products"
user_name: Mapped[str] = mapped_column(ForeignKey("users.name"), primary_key=True)
product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id"), primary_key=True)
count: Mapped[int] = mapped_column(Integer)
sign: Mapped[int] = mapped_column(Integer)
user: Mapped[User] = relationship()
product: Mapped[Product] = relationship()

View File

@@ -1,17 +1,12 @@
__all__ = [
"Base",
"Product",
"Purchase",
"PurchaseEntry",
"Transaction",
"User",
"UserProducts",
]
from .Base import Base
from .Product import Product
from .Purchase import Purchase
from .PurchaseEntry import PurchaseEntry
from .Transaction import Transaction
from .TransactionType import TransactionType
from .User import User
from .UserProducts import UserProducts

View File

View File

@@ -1,48 +1,107 @@
import json
from dibbler.db import Session
from datetime import datetime
from pathlib import Path
from dibbler.models.Product import Product
from dibbler.models.User import User
from dibbler.db import Session
from dibbler.models import Product, Transaction, TransactionType, User
JSON_FILE = Path(__file__).parent.parent.parent / "mock_data.json"
def clear_db(session):
session.query(Product).delete()
session.query(User).delete()
session.commit()
# TODO: integrate this as a part of create-db, either asking interactively
# whether to seed test data, or by using command line arguments for
# automatating the answer.
def clear_db(sql_session):
sql_session.query(Product).delete()
sql_session.query(User).delete()
sql_session.commit()
def main():
session = Session()
clear_db(session)
product_items = []
user_items = []
# TODO: There is some leftover json data in the mock_data.json file.
# It should be dealt with before merging this PR, either by removing
# it or using it here.
sql_session = Session()
clear_db(sql_session)
with open(JSON_FILE) as f:
json_obj = json.load(f)
# Add users
user1 = User("Test User 1")
user2 = User("Test User 2")
for product in json_obj["products"]:
product_item = Product(
bar_code=product["bar_code"],
name=product["name"],
price=product["price"],
stock=product["stock"],
)
product_items.append(product_item)
sql_session.add(user1)
sql_session.add(user2)
sql_session.commit()
for user in json_obj["users"]:
user_item = User(
name=user["name"],
card=user["card"],
rfid=user["rfid"],
credit=user["credit"],
)
user_items.append(user_item)
# Add products
product1 = Product("1234567890123", "Test Product 1")
product2 = Product("9876543210987", "Test Product 2")
sql_session.add(product1)
sql_session.add(product2)
sql_session.commit()
session.add_all(product_items)
session.add_all(user_items)
session.commit()
# Add transactions
transactions = [
Transaction(
time=datetime(2023, 10, 1, 10, 0, 0),
type_=TransactionType.ADJUST_BALANCE,
amount=100,
user_id=user1.id,
),
Transaction(
time=datetime(2023, 10, 1, 10, 0, 1),
type_=TransactionType.ADJUST_BALANCE,
amount=50,
user_id=user2.id,
),
Transaction(
time=datetime(2023, 10, 1, 10, 0, 2),
type_=TransactionType.ADJUST_BALANCE,
amount=-50,
user_id=user1.id,
),
Transaction(
time=datetime(2023, 10, 1, 12, 0, 0),
type_=TransactionType.ADD_PRODUCT,
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user1.id,
product_id=product1.id,
),
Transaction(
time=datetime(2023, 10, 1, 12, 0, 1),
type_=TransactionType.BUY_PRODUCT,
amount=27,
product_count=1,
user_id=user2.id,
product_id=product1.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
# Note: These constructors depend on the content of the previous transactions,
# so they cannot be part of the initial transaction list.
transaction = Transaction.adjust_stock_auto_amount(
sql_session=sql_session,
time=datetime(2023, 10, 1, 12, 0, 2),
product_count=3,
user_id=user1.id,
product_id=product1.id,
)
sql_session.add(transaction)
sql_session.commit()
transaction = Transaction.adjust_stock_auto_amount(
sql_session=sql_session,
time=datetime(2023, 10, 1, 12, 0, 3),
product_count=-2,
user_id=user1.id,
product_id=product1.id,
)
sql_session.add(transaction)
sql_session.commit()

View File

@@ -6,7 +6,7 @@ input_encoding = 'utf8'
[database]
# url = "postgresql://robertem@127.0.0.1/pvvvv"
url = "sqlite:///test.db"
url = sqlite:///test.db
[limits]
low_credit_warning_limit = -100

0
tests/__init__.py Normal file
View File

27
tests/conftest.py Normal file
View File

@@ -0,0 +1,27 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from dibbler.models import Base
def pytest_addoption(parser):
parser.addoption(
"--echo",
action="store_true",
help="Enable SQLAlchemy echo mode for debugging",
)
@pytest.fixture(scope="function")
def sql_session(request):
"""Create a new SQLAlchemy session for testing."""
echo = request.config.getoption("--echo")
engine = create_engine(
"sqlite:///:memory:",
echo=echo,
)
Base.metadata.create_all(engine)
with Session(engine) as sql_session:
yield sql_session

216
tests/test_product.py Normal file
View File

@@ -0,0 +1,216 @@
from datetime import datetime
import pytest
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, TransactionType, User
def insert_test_data(sql_session: Session) -> None:
# Add users
user1 = User("Test User 1")
user2 = User("Test User 2")
sql_session.add_all([user1, user2])
sql_session.commit()
# Add products
product1 = Product("1234567890123", "Test Product 1")
product2 = Product("9876543210987", "Test Product 2")
product3 = Product("1111111111111", "Test Product 3")
sql_session.add_all([product1, product2, product3])
sql_session.commit()
# Add transactions
transactions = [
Transaction(
time=datetime(2023, 10, 1, 10, 0, 0),
type_=TransactionType.ADJUST_BALANCE,
amount=100,
user_id=user1.id,
),
Transaction(
time=datetime(2023, 10, 1, 10, 0, 1),
type_=TransactionType.ADJUST_BALANCE,
amount=50,
user_id=user2.id,
),
Transaction(
time=datetime(2023, 10, 1, 10, 0, 2),
type_=TransactionType.ADJUST_BALANCE,
amount=-50,
user_id=user1.id,
),
Transaction(
time=datetime(2023, 10, 1, 12, 0, 0),
type_=TransactionType.ADD_PRODUCT,
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user1.id,
product_id=product1.id,
),
Transaction(
time=datetime(2023, 10, 1, 12, 0, 1),
type_=TransactionType.BUY_PRODUCT,
amount=27,
product_count=1,
user_id=user2.id,
product_id=product1.id,
),
Transaction(
time=datetime(2023, 10, 1, 12, 0, 2),
type_=TransactionType.ADD_PRODUCT,
amount=50,
per_product=50,
product_count=1,
user_id=user1.id,
product_id=product3.id,
),
Transaction(
time=datetime(2023, 10, 1, 12, 0, 3),
type_=TransactionType.BUY_PRODUCT,
amount=50,
product_count=1,
user_id=user1.id,
product_id=product3.id,
),
Transaction(
time=datetime(2023, 10, 1, 12, 0, 4),
type_=TransactionType.ADJUST_BALANCE,
amount=1000,
user_id=user1.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
# Note: These constructors depend on the content of the previous transactions,
# so they cannot be part of the initial transaction list.
transaction = Transaction.adjust_stock_auto_amount(
sql_session=sql_session,
time=datetime(2023, 10, 1, 13, 0, 0),
product_count=3,
user_id=user1.id,
product_id=product1.id,
)
sql_session.add(transaction)
sql_session.commit()
transaction = Transaction.adjust_stock_auto_amount(
sql_session=sql_session,
time=datetime(2023, 10, 1, 13, 0, 1),
product_count=-2,
user_id=user1.id,
product_id=product1.id,
)
sql_session.add(transaction)
sql_session.commit()
def test_no_duplicate_products(sql_session: Session):
insert_test_data(sql_session)
product1 = Product("1234567890123", "Test Product 1")
sql_session.add(product1)
with pytest.raises(IntegrityError):
sql_session.commit()
def test_product_stock(sql_session: Session):
insert_test_data(sql_session)
product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one()
product2 = sql_session.scalars(select(Product).where(Product.name == "Test Product 2")).one()
assert product1.stock(sql_session) == 2 - 1 + 3 - 2
assert product2.stock(sql_session) == 0
def test_product_price(sql_session: Session):
insert_test_data(sql_session)
product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one()
assert product1.price(sql_session) == 27
def test_product_no_transactions_price(sql_session: Session):
insert_test_data(sql_session)
product2 = sql_session.scalars(select(Product).where(Product.name == "Test Product 2")).one()
assert product2.price(sql_session) == 0
def test_product_sold_out_price(sql_session: Session):
insert_test_data(sql_session)
product3 = sql_session.scalars(select(Product).where(Product.name == "Test Product 3")).one()
assert product3.price(sql_session) == 50
def test_allowed_to_buy_more_than_stock(sql_session: Session):
insert_test_data(sql_session)
product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one()
user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one()
transaction = Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 6),
amount = 27 * 5,
product_count=10,
user_id=user1.id,
product_id=product1.id,
)
sql_session.add(transaction)
sql_session.commit()
product1_stock = product1.stock(sql_session)
assert product1_stock < 0 # Should be negative, as we bought more than available stock
product1_price = product1.price(sql_session)
assert product1_price == 27 # Price should remain the same, as it is based on previous transactions
transaction = Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 8),
amount=22,
per_product=22,
product_count=1,
user_id=user1.id,
product_id=product1.id,
)
sql_session.add(transaction)
sql_session.commit()
product1_price = product1.price(sql_session)
assert product1_price == 22 # Price should now be updated to the new price of the added product
def test_not_allowed_to_buy_with_incorrect_amount(sql_session: Session):
insert_test_data(sql_session)
product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one()
user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one()
product1_price = product1.price(sql_session)
with pytest.raises(IntegrityError):
transaction = Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 7),
amount= product1_price * 4 + 1, # Incorrect amount
product_count=4,
user_id=user1.id,
product_id=product1.id,
)
sql_session.add(transaction)
sql_session.commit()
def test_not_allowed_to_buy_with_too_little_balance(sql_session: Session):
...

97
tests/test_transaction.py Normal file
View File

@@ -0,0 +1,97 @@
from datetime import datetime
import pytest
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, TransactionType, User
def insert_test_data(sql_session: Session) -> None:
# Add users
user1 = User("Test User 1")
user2 = User("Test User 2")
sql_session.add(user1)
sql_session.add(user2)
sql_session.commit()
# Add products
product1 = Product("1234567890123", "Test Product 1")
product2 = Product("9876543210987", "Test Product 2")
sql_session.add(product1)
sql_session.add(product2)
sql_session.commit()
# Add transactions
transactions = [
Transaction(
time=datetime(2023, 10, 1, 10, 0, 0),
type_=TransactionType.ADJUST_BALANCE,
amount=100,
user_id=user1.id,
),
Transaction(
time=datetime(2023, 10, 1, 10, 0, 1),
type_=TransactionType.ADJUST_BALANCE,
amount=50,
user_id=user2.id,
),
Transaction(
time=datetime(2023, 10, 1, 10, 0, 2),
type_=TransactionType.ADJUST_BALANCE,
amount=-50,
user_id=user1.id,
),
Transaction(
time=datetime(2023, 10, 1, 12, 0, 0),
type_=TransactionType.ADD_PRODUCT,
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user1.id,
product_id=product1.id,
),
Transaction(
time=datetime(2023, 10, 1, 12, 0, 1),
type_=TransactionType.BUY_PRODUCT,
amount=27,
product_count=1,
user_id=user2.id,
product_id=product1.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
def test_no_duplicate_timestamps(sql_session: Session):
"""
Ensure that no two transactions have the same timestamp.
"""
# Insert test data
insert_test_data(sql_session)
user1 = sql_session.scalar(
select(User).where(User.name == "Test User 1")
)
assert user1 is not None, "Test User 1 should exist"
transaction_to_duplicate = sql_session.scalar(
select(Transaction).limit(1)
)
assert transaction_to_duplicate is not None, "There should be at least one transaction"
duplicate_timestamp_transaction = Transaction.adjust_balance(
time=transaction_to_duplicate.time, # Use the same timestamp as an existing transaction
amount=50,
user_id=user1.id,
)
with pytest.raises(IntegrityError):
sql_session.add(duplicate_timestamp_transaction)
sql_session.commit()

108
tests/test_user.py Normal file
View File

@@ -0,0 +1,108 @@
from datetime import datetime
import pytest
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, TransactionType, User
def insert_test_data(sql_session: Session) -> None:
# Add users
user1 = User("Test User 1")
user2 = User("Test User 2")
sql_session.add(user1)
sql_session.add(user2)
sql_session.commit()
# Add products
product1 = Product("1234567890123", "Test Product 1")
product2 = Product("9876543210987", "Test Product 2")
sql_session.add(product1)
sql_session.add(product2)
sql_session.commit()
# Add transactions
transactions = [
Transaction(
time=datetime(2023, 10, 1, 10, 0, 0),
type_=TransactionType.ADJUST_BALANCE,
amount=100,
user_id=user1.id,
),
Transaction(
time=datetime(2023, 10, 1, 10, 0, 1),
type_=TransactionType.ADJUST_BALANCE,
amount=50,
user_id=user2.id,
),
Transaction(
time=datetime(2023, 10, 1, 10, 0, 2),
type_=TransactionType.ADJUST_BALANCE,
amount=-50,
user_id=user1.id,
),
Transaction(
time=datetime(2023, 10, 1, 12, 0, 0),
type_=TransactionType.ADD_PRODUCT,
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user1.id,
product_id=product1.id,
),
Transaction(
time=datetime(2023, 10, 1, 12, 0, 1),
type_=TransactionType.BUY_PRODUCT,
amount=27,
product_count=1,
user_id=user2.id,
product_id=product1.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
def test_ensure_no_duplicate_users(sql_session: Session):
insert_test_data(sql_session)
user1 = User("Test User 1")
sql_session.add(user1)
with pytest.raises(IntegrityError):
sql_session.commit()
def test_user_credit(sql_session: Session):
insert_test_data(sql_session)
user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one()
user2 = sql_session.scalars(select(User).where(User.name == "Test User 2")).one()
assert user1.credit(sql_session) == 100 - 50 + 27 * 2
assert user2.credit(sql_session) == 50 - 27
def test_user_transactions(sql_session: Session):
insert_test_data(sql_session)
user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one()
user2 = sql_session.scalars(select(User).where(User.name == "Test User 2")).one()
user1_transactions = user1.transactions(sql_session)
user2_transactions = user2.transactions(sql_session)
assert len(user1_transactions) == 3
assert len(user2_transactions) == 2
def test_user_not_allowed_to_transfer_to_self(sql_session: Session):
insert_test_data(sql_session)
...
# user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one()
# with pytest.raises(ValueError, match="Cannot transfer to self"):
# user1.transfer(sql_session, user1, 10) # Attempting to transfer to self