23 Commits

Author SHA1 Message Date
3e4c3a44d2 fixup! WIP 2025-06-13 23:15:50 +02:00
f5c4959e51 fixup! WIP 2025-06-12 23:29:38 +02:00
b7df9a8640 fixup! WIP 2025-06-12 23:19:05 +02:00
d4511981ff fixup! WIP 2025-06-12 22:22:11 +02:00
3c9bca8b55 fixup! WIP 2025-06-12 21:47:05 +02:00
745db277ec WIP 2025-06-11 18:00:49 +02:00
9f5999854f .gitignore: add pytest-cov data 2025-06-10 20:59:10 +02:00
042bb58fbd {nix,pyproject.toml}: add pytest, pytest-cov 2025-06-10 20:58:49 +02:00
4a4f0e6947 module.nix: config -> settings 2025-05-05 14:52:28 +02:00
a4d10ad0c7 Merge pull request 'Seed test data' (#16) from seed_test into master
Reviewed-on: #16
Reviewed-by: Oystein Kristoffer Tveit <oysteikt@pvv.ntnu.no>
2025-03-30 21:45:37 +02:00
a654baba11 ruff format 2025-03-30 21:44:37 +02:00
e69d04dcd0 mock script, og mock data. 2025-03-29 22:48:30 +01:00
b2a6384f31 la tilbake uv, en project manager 2025-03-29 22:46:30 +01:00
4f89765070 ignorer bifiler fra hatchling 2025-03-29 22:42:36 +01:00
914e5b4e50 fjerner __pyachce__, fra repo tracking 2025-03-29 22:37:11 +01:00
de20bad7dd remove conf.py 2025-03-19 18:47:23 +01:00
4bab5e7e21 treewide: fix brother-ql usage 2025-03-19 18:47:16 +01:00
b85a6535fe shell.nix: add python with all packages 2025-03-19 18:14:42 +01:00
22a09b4177 README: add more information 2025-03-19 18:06:40 +01:00
c39b15d1a8 .envrc: init 2025-03-19 17:50:48 +01:00
122ac2ab18 treewide: update everything nix 2025-03-19 17:50:14 +01:00
28228beccd pyproject.toml: remove invalid license
This license field was added without any of the earlier contributors
consent on accident. It is not valid
2025-03-17 21:03:55 +01:00
8a6a0c12ba Merge pull request #3 from Programvareverkstedet/restructure-project
Restructure project
2023-09-02 21:18:04 +02:00
58 changed files with 2827 additions and 409 deletions

1
.envrc Normal file
View File

@@ -0,0 +1 @@
use flake

7
.gitignore vendored
View File

@@ -1,8 +1,11 @@
result
result-*
**/__pycache__
dibbler.egg-info
dist
test.db
.ruff_cache
.ruff_cache
.coverage

View File

@@ -2,13 +2,31 @@
EDB-system for PVVVV
## Hva er dette?
Dibbler er et system laget av PVVere for PVVere for å byttelåne både matvarer og godis.
Det er designet for en gammeldags VT terminal, og er laget for å være enkelt både å bruke og å hacke på.
Programmet er skrevet i Python, og bruker en sql database for å lagre data.
Samlespleiseboden er satt opp slik at folk kjøper inn varer, og får dibblerkreditt, og så kan man bruke
denne kreditten til å kjøpe ut andre varer. Det er ikke noen form for authentisering, så hele systemet er basert på tillit.
Det er anbefalt å koble en barkodeleser til systemet for å gjøre det enklere å både legge til og kjøpe varer.
## Kom i gang
Installer python, og lag og aktiver et venv. Installer så avhengighetene med `pip install`.
Deretter kan du kjøre programmet med
```console
python -m dibbler -c example-config.ini create-db
python -m dibbler -c example-config.ini loop
```
## Nix
### Hvordan kjøre
`nix run github:Prograrmvarverkstedet/dibbler`
### Bygge nytt image
### Bygge nytt image
For å bygge et image trenger du en builder som takler å bygge for arkitekturen du skal lage et image for.
@@ -16,16 +34,16 @@ For å bygge et image trenger du en builder som takler å bygge for arkitekturen
Flaket exposer en modul som autologger inn med en bruker som automatisk kjører dibbler, og setter opp et minimalistisk miljø.
Før du bygger imaget burde du endre conf.py lokalt til å inneholde instillingene dine. **NB: Denne kommer til å ligge i nix storen.**
Før du bygger imaget burde du kopiere og endre `example-config.ini` lokalt til å inneholde instillingene dine. **NB: Denne kommer til å ligge i nix storen, ikke si noe her som du ikke vil at moren din skal høre.**
Du kan også endre hvilken conf.py som blir brukt direkte i pakken eller i modulen.
Du kan også endre hvilken config-fil som blir brukt direkte i pakken eller i modulen.
Se eksempelet for hvordan skrot er satt opp i flake.nix
Se eksempelet for hvordan skrot er satt opp i `flake.nix` og `nix/skrott.nix`
### Bygge image for skrot
Skrot har et image definert i flake.nix:
1. endre conf.py
1. endre `example-config.ini`
2. `nix build .#images.skrot`
3. ???
4. non-profit
4. non-profit

13
conf.py
View File

@@ -1,13 +0,0 @@
db_url = "postgresql://robertem@127.0.0.1/pvvvv"
quit_allowed = True
stop_allowed = False
show_tracebacks = True
input_encoding = "utf8"
low_credit_warning_limit = -100
user_recent_transaction_limit = 100
# See https://pypi.org/project/brother_ql/ for label types
# Set rotate to False for endless labels
label_type = "62"
label_rotate = False

View File

@@ -1,4 +0,0 @@
{ pkgs ? import <nixos-unstable> { } }:
{
dibbler = pkgs.callPackage ./nix/dibbler.nix { };
}

View File

@@ -2,7 +2,7 @@ import os
from PIL import ImageFont
from barcode.writer import ImageWriter, mm2px
from brother_ql.devicedependent import label_type_specs
from brother_ql.labels import ALL_LABELS
def px2mm(px, dpi=300):
@@ -12,14 +12,15 @@ def px2mm(px, dpi=300):
class BrotherLabelWriter(ImageWriter):
def __init__(self, typ="62", max_height=350, rot=False, text=None):
super(BrotherLabelWriter, self).__init__()
assert typ in label_type_specs
label = next([l for l in ALL_LABELS if l.identifier == typ])
assert label is not None
self.rot = rot
if self.rot:
self._h, self._w = label_type_specs[typ]["dots_printable"]
self._h, self._w = label.dots_printable
if self._w == 0 or self._w > max_height:
self._w = min(max_height, self._h / 2)
else:
self._w, self._h = label_type_specs[typ]["dots_printable"]
self._w, self._h = label.dots_printable
if self._h == 0 or self._h > max_height:
self._h = min(max_height, self._w / 2)
self._xo = 0.0

View File

@@ -1,79 +1,7 @@
import pwd
import subprocess
import os
import pwd
import signal
from sqlalchemy import or_, and_
from ..models import User, Product
def search_user(string, session, ignorethisflag=None):
string = string.lower()
exact_match = (
session.query(User)
.filter(or_(User.name == string, User.card == string, User.rfid == string))
.first()
)
if exact_match:
return exact_match
user_list = (
session.query(User)
.filter(
or_(
User.name.ilike(f"%{string}%"),
User.card.ilike(f"%{string}%"),
User.rfid.ilike(f"%{string}%"),
)
)
.all()
)
return user_list
def search_product(string, session, find_hidden_products=True):
if find_hidden_products:
exact_match = (
session.query(Product)
.filter(or_(Product.bar_code == string, Product.name == string))
.first()
)
else:
exact_match = (
session.query(Product)
.filter(
or_(
Product.bar_code == string,
and_(Product.name == string, Product.hidden is False),
)
)
.first()
)
if exact_match:
return exact_match
if find_hidden_products:
product_list = (
session.query(Product)
.filter(
or_(
Product.bar_code.ilike(f"%{string}%"),
Product.name.ilike(f"%{string}%"),
)
)
.all()
)
else:
product_list = (
session.query(Product)
.filter(
or_(
Product.bar_code.ilike(f"%{string}%"),
and_(Product.name.ilike(f"%{string}%"), Product.hidden is False),
)
)
.all()
)
return product_list
import subprocess
def system_user_exists(username):

View File

@@ -2,9 +2,10 @@ import os
import datetime
import barcode
from brother_ql import BrotherQLRaster, create_label
from brother_ql.brother_ql_create import create_label
from brother_ql.raster import BrotherQLRaster
from brother_ql.backends import backend_factory
from brother_ql.devicedependent import label_type_specs
from brother_ql.labels import ALL_LABELS
from PIL import Image, ImageDraw, ImageFont
from .barcode_helpers import BrotherLabelWriter
@@ -17,10 +18,11 @@ def print_name_label(
label_type="62",
printer_type="QL-700",
):
label = next([l for l in ALL_LABELS if l.identifier == label_type])
if not rotate:
width, height = label_type_specs[label_type]["dots_printable"]
width, height = label.dots_printable
else:
height, width = label_type_specs[label_type]["dots_printable"]
height, width = label.dots_printable
font_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ChopinScript.ttf")
fs = 2000

View File

@@ -76,12 +76,8 @@ class Database:
personDatoVerdi = defaultdict(list) # dict->array
personUkedagVerdi = defaultdict(list)
# for global
personPosTransactions = (
{}
) # personPosTransactions[trygvrad] == 100 #trygvrad har lagt 100kr i boksen
personNegTransactions = (
{}
) # personNegTransactions[trygvrad» == 70 #trygvrad har tatt 70kr fra boksen
personPosTransactions = {} # personPosTransactions[trygvrad] == 100 #trygvrad har lagt 100kr i boksen
personNegTransactions = {} # personNegTransactions[trygvrad» == 70 #trygvrad har tatt 70kr fra boksen
globalVareAntall = {} # globalVareAntall[Oreo] == 3
globalVareVerdi = {} # globalVareVerdi[Oreo] == 30 #[kr]
globalPersonAntall = {} # globalPersonAntall[trygvrad] == 3

View File

@@ -20,6 +20,7 @@ subparsers = parser.add_subparsers(
subparsers.add_parser("loop", help="Run the dibbler loop")
subparsers.add_parser("create-db", help="Create the database")
subparsers.add_parser("slabbedasker", help="Find out who is slabbedasker")
subparsers.add_parser("seed-data", help="Fill with mock data")
def main():
@@ -41,6 +42,11 @@ def main():
slabbedasker.main()
elif args.subcommand == "seed-data":
import dibbler.subcommands.seed_test_data as seed_test_data
seed_test_data.main()
if __name__ == "__main__":
main()

View File

@@ -180,7 +180,7 @@ When finished, write an empty line to confirm the purchase.\n"""
print(f"User {t.user.name}'s credit is now {t.user.credit:d} kr")
if t.user.credit < config.getint("limits", "low_credit_warning_limit"):
print(
f'USER {t.user.name} HAS LOWER CREDIT THAN {config.getint("limits", "low_credit_warning_limit"):d},',
f"USER {t.user.name} HAS LOWER CREDIT THAN {config.getint('limits', 'low_credit_warning_limit'):d},",
"AND SHOULD CONSIDER PUTTING SOME MONEY IN THE BOX.",
)

View File

@@ -10,12 +10,16 @@ from sqlalchemy.orm.collections import (
)
def _pascal_case_to_snake_case(name: str) -> str:
return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
class Base(DeclarativeBase):
metadata = MetaData(
naming_convention={
"ix": "ix_%(column_0_label)s",
"ix": "ix_%(table_name)s_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_`%(constraint_name)s`",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
@@ -23,7 +27,7 @@ class Base(DeclarativeBase):
@declared_attr.directive
def __tablename__(cls) -> str:
return cls.__name__
return _pascal_case_to_snake_case(cls.__name__)
def __repr__(self) -> str:
columns = ", ".join(

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Self
from sqlalchemy import (
Boolean,
@@ -9,39 +10,44 @@ from sqlalchemy import (
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
if TYPE_CHECKING:
from .PurchaseEntry import PurchaseEntry
from .UserProducts import UserProducts
class Product(Base):
__tablename__ = "products"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""Internal database ID"""
product_id: Mapped[int] = mapped_column(Integer, primary_key=True)
bar_code: Mapped[str] = mapped_column(String(13))
name: Mapped[str] = mapped_column(String(45))
price: Mapped[int] = mapped_column(Integer)
stock: Mapped[int] = mapped_column(Integer)
hidden: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
bar_code: Mapped[str] = mapped_column(String(13), unique=True)
"""
The bar code of the product.
purchases: Mapped[set[PurchaseEntry]] = relationship(back_populates="product")
users: Mapped[set[UserProducts]] = relationship(back_populates="product")
This is a unique identifier for the product, typically a 13-digit
EAN-13 code.
"""
bar_code_re = r"[0-9]+"
name_re = r".+"
name_length = 45
name: Mapped[str] = mapped_column(String(45), unique=True)
"""
The name of the product.
def __init__(self, bar_code, name, price, stock=0, hidden=False):
self.name = name
Please don't write fanfics here, this is not a place for that.
"""
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
"""
Whether the product is hidden from the user interface.
Hidden products are not shown in the product list, but can still be
used in transactions.
"""
def __init__(
self: Self,
bar_code: str,
name: str,
hidden: bool = False,
) -> None:
self.bar_code = bar_code
self.price = price
self.stock = stock
self.name = name
self.hidden = hidden
def __str__(self):
return self.name

View File

@@ -0,0 +1,15 @@
from datetime import datetime
from sqlalchemy import Integer, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from dibbler.models import Base
class ProductCache(Base):
product_id: Mapped[int] = mapped_column(Integer, primary_key=True)
price: Mapped[int] = mapped_column(Integer)
price_timestamp: Mapped[datetime] = mapped_column(DateTime)
stock: Mapped[int] = mapped_column(Integer)
stock_timestamp: Mapped[datetime] = mapped_column(DateTime)

View File

@@ -1,70 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from datetime import datetime
import math
from sqlalchemy import (
DateTime,
Integer,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
from .Transaction import Transaction
if TYPE_CHECKING:
from .PurchaseEntry import PurchaseEntry
class Purchase(Base):
__tablename__ = "purchases"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
time: Mapped[datetime] = mapped_column(DateTime)
price: Mapped[int] = mapped_column(Integer)
transactions: Mapped[set[Transaction]] = relationship(
back_populates="purchase", order_by="Transaction.user_name"
)
entries: Mapped[set[PurchaseEntry]] = relationship(back_populates="purchase")
def __init__(self):
pass
def is_complete(self):
return len(self.transactions) > 0 and len(self.entries) > 0
def price_per_transaction(self, round_up=True):
if round_up:
return int(math.ceil(float(self.price) / len(self.transactions)))
else:
return int(math.floor(float(self.price) / len(self.transactions)))
def set_price(self, round_up=True):
self.price = 0
for entry in self.entries:
self.price += entry.amount * entry.product.price
if len(self.transactions) > 0:
for t in self.transactions:
t.amount = self.price_per_transaction(round_up=round_up)
def perform_purchase(self, ignore_penalty=False, round_up=True):
self.time = datetime.datetime.now()
self.set_price(round_up=round_up)
for t in self.transactions:
t.perform_transaction(ignore_penalty=ignore_penalty)
for entry in self.entries:
entry.product.stock -= entry.amount
def perform_soft_purchase(self, price, round_up=True):
self.time = datetime.datetime.now()
self.price = price
for t in self.transactions:
t.amount = self.price_per_transaction(round_up=round_up)
for t in self.transactions:
t.perform_transaction()

View File

@@ -1,37 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import (
Integer,
ForeignKey,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
if TYPE_CHECKING:
from .Product import Product
from .Purchase import Purchase
class PurchaseEntry(Base):
__tablename__ = "purchase_entries"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
amount: Mapped[int] = mapped_column(Integer)
product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id"))
purchase_id: Mapped[int] = mapped_column(ForeignKey("purchases.id"))
product: Mapped[Product] = relationship(lazy="joined")
purchase: Mapped[Purchase] = relationship(lazy="joined")
def __init__(self, purchase, product, amount):
self.product = product
self.product_bar_code = product.bar_code
self.purchase = purchase
self.amount = amount

View File

@@ -1,52 +1,463 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from datetime import datetime
from typing import TYPE_CHECKING, Self
from sqlalchemy import (
CheckConstraint,
DateTime,
ForeignKey,
Integer,
String,
Text,
and_,
column,
or_,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from sqlalchemy.orm.collections import (
InstrumentedDict,
InstrumentedList,
InstrumentedSet,
)
from sqlalchemy.sql.schema import Index
from .Base import Base
from .TransactionType import TransactionType, TransactionTypeSQL
if TYPE_CHECKING:
from .Product import Product
from .User import User
from .Purchase import Purchase
# TODO: rename to *_PERCENT
# NOTE: these only matter when there are no adjustments made in the database.
DEFAULT_INTEREST_RATE_PERCENTAGE = 100
DEFAULT_PENALTY_THRESHOLD = -100
DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE = 200
# TODO: allow for joint transactions?
# dibbler allows joint transactions (e.g. buying more than one product at once, several people buying the same product, etc.)
# instead of having the software split the transactions up, making them hard to reconnect,
# maybe we should add some sort of joint transaction id field to allow multiple transactions to be grouped together?
_DYNAMIC_FIELDS: set[str] = {
"amount",
"interest_rate_percent",
"penalty_multiplier_percent",
"penalty_threshold",
"per_product",
"product_count",
"product_id",
"transfer_user_id",
}
_EXPECTED_FIELDS: dict[TransactionType, set[str]] = {
TransactionType.ADD_PRODUCT: {"amount", "per_product", "product_count", "product_id"},
TransactionType.ADJUST_BALANCE: {"amount"},
TransactionType.ADJUST_INTEREST: {"interest_rate_percent"},
TransactionType.ADJUST_PENALTY: {"penalty_multiplier_percent", "penalty_threshold"},
TransactionType.ADJUST_STOCK: {"product_count", "product_id"},
TransactionType.BUY_PRODUCT: {"product_count", "product_id"},
TransactionType.TRANSFER: {"amount", "transfer_user_id"},
}
assert all(x <= _DYNAMIC_FIELDS for x in _EXPECTED_FIELDS.values()), (
"All expected fields must be part of _DYNAMIC_FIELDS."
)
def _transaction_type_field_constraints(
transaction_type: TransactionType,
expected_fields: set[str],
) -> CheckConstraint:
unexpected_fields = _DYNAMIC_FIELDS - expected_fields
return CheckConstraint(
or_(
column("type") != transaction_type.value,
and_(
*[column(field) != None for field in expected_fields],
*[column(field) == None for field in unexpected_fields],
),
),
name=f"trx_type_{transaction_type.value}_expected_fields",
)
class Transaction(Base):
__tablename__ = "transactions"
__table_args__ = (
*[
_transaction_type_field_constraints(transaction_type, expected_fields)
for transaction_type, expected_fields in _EXPECTED_FIELDS.items()
],
CheckConstraint(
or_(
column("type") != TransactionType.TRANSFER.value,
column("user_id") != column("transfer_user_id"),
),
name="trx_type_transfer_no_self_transfers",
),
# Speed up product count calculation
Index("product_user_time", "product_id", "user_id", "time"),
# Speed up product owner calculation
Index("user_product_time", "user_id", "product_id", "time"),
# Speed up user transaction list / credit calculation
Index("user_time", "user_id", "time"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""
A unique identifier for the transaction.
time: Mapped[datetime] = mapped_column(DateTime)
amount: Mapped[int] = mapped_column(Integer)
penalty: Mapped[int] = mapped_column(Integer)
description: Mapped[str | None] = mapped_column(String(50))
Not used for anything else than identifying the transaction in the database.
"""
user_name: Mapped[str] = mapped_column(ForeignKey("users.name"))
purchase_id: Mapped[int | None] = mapped_column(ForeignKey("purchases.id"))
time: Mapped[datetime] = mapped_column(DateTime, unique=True)
"""
The time when the transaction took place.
user: Mapped[User] = relationship(lazy="joined")
purchase: Mapped[Purchase] = relationship(lazy="joined")
This is used to order transactions chronologically, and to calculate
all kinds of state.
"""
def __init__(self, user, amount=0, description=None, purchase=None, penalty=1):
self.user = user
message: Mapped[str | None] = mapped_column(Text, nullable=True)
"""
A message that can be set by the user to describe the reason
behind the transaction (or potentially a place to write som fan fiction).
This is not used for any calculations, but can be useful for debugging.
"""
type_: Mapped[TransactionType] = mapped_column(TransactionTypeSQL, name="type")
"""
Which type of transaction this is.
The type determines which fields are expected to be set.
"""
amount: Mapped[int | None] = mapped_column(Integer)
"""
This field means different things depending on the transaction type:
- `ADD_PRODUCT`: The real amount spent on the products.
- `ADJUST_BALANCE`: The amount of credit to add or subtract from the user's balance.
- `BUY_PRODUCT`: The amount of credit spent on the product.
Note that this includes any penalties and interest that the user
had to pay as well.
- `TRANSFER`: The amount of balance to transfer to another user.
"""
per_product: Mapped[int | None] = mapped_column(Integer)
"""
If adding products, how much is each product worth
Note that this is distinct from the total amount of the transaction,
because this gets rounded up to the nearest integer, while the total amount
that the user paid in the store would be stored in the `amount` field.
"""
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
"""The user who performs the transaction. See `user` for more details."""
user: Mapped[User] = relationship(
lazy="joined",
foreign_keys=[user_id],
)
"""
The user who performs the transaction.
For some transaction types, like `TRANSFER` and `ADD_PRODUCT`, this is a
functional field with "real world consequences" for price calculations.
For others, like `ADJUST_PENALTY` and `ADJUST_STOCK`, this is just a record of who
performed the transaction, and does not affect any state calculations.
"""
# Receiving user when moving credit from one user to another
transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"))
"""The user who receives money in a `TRANSFER` transaction."""
transfer_user: Mapped[User | None] = relationship(
lazy="joined",
foreign_keys=[transfer_user_id],
)
"""The user who receives money in a `TRANSFER` transaction."""
# The product that is either being added or bought
product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id"))
"""The product being added or bought."""
product: Mapped[Product | None] = relationship(lazy="joined")
"""The product being added or bought."""
# The amount of products being added or bought
product_count: Mapped[int | None] = mapped_column(Integer)
"""
The amount of products being added or bought.
"""
penalty_threshold: Mapped[int | None] = mapped_column(Integer, nullable=True)
"""
On `ADJUST_PENALTY` transactions, this is the threshold in krs for when the user
should start getting penalized for low credit.
See also `penalty_multiplier`.
"""
penalty_multiplier_percent: Mapped[int | None] = mapped_column(Integer, nullable=True)
"""
On `ADJUST_PENALTY` transactions, this is the multiplier for the amount of
money the user has to pay when they have too low credit.
The multiplier is a percentage, so `100` means the user has to pay the full
price of the product, `200` means they have to pay double, etc.
See also `penalty_threshold`.
"""
# TODO: this should be inferred
# Assuming this is a BUY_PRODUCT transaction, was the user penalized for having
# too low credit in this transaction?
# is_penalized: Mapped[Boolean] = mapped_column(Boolean, default=False)
interest_rate_percent: Mapped[int | None] = mapped_column(Integer, nullable=True)
"""
On `ADJUST_INTEREST` transactions, this is the interest rate in percent
that the user has to pay on their balance.
The interest rate is a percentage, so `100` means the user has to pay the full
price of the product, `200` means they have to pay double, etc.
"""
def __init__(
self: Self,
type_: TransactionType,
user_id: int,
amount: int | None = None,
time: datetime | None = None,
message: str | None = None,
product_id: int | None = None,
transfer_user_id: int | None = None,
per_product: int | None = None,
product_count: int | None = None,
penalty_threshold: int | None = None,
penalty_multiplier_percent: int | None = None,
interest_rate_percent: int | None = None,
) -> None:
"""
Please do not call this constructor directly, use the factory methods instead.
"""
if time is None:
time = datetime.now()
self.time = time
self.message = message
self.type_ = type_
self.amount = amount
self.description = description
self.purchase = purchase
self.penalty = penalty
self.user_id = user_id
self.product_id = product_id
self.transfer_user_id = transfer_user_id
self.per_product = per_product
self.product_count = product_count
self.penalty_threshold = penalty_threshold
self.penalty_multiplier_percent = penalty_multiplier_percent
self.interest_rate_percent = interest_rate_percent
def perform_transaction(self, ignore_penalty=False):
self.time = datetime.datetime.now()
if not ignore_penalty:
self.amount *= self.penalty
self.user.credit -= self.amount
self._validate_by_transaction_type()
def _validate_by_transaction_type(self: Self) -> None:
"""
Validates the transaction's fields based on its type.
Raises `ValueError` if the transaction is invalid.
"""
# TODO: do we allow free products?
if self.amount == 0:
raise ValueError("Amount must not be zero.")
for field in _EXPECTED_FIELDS[self.type_]:
if getattr(self, field) is None:
raise ValueError(f"{field} must not be None for {self.type_.value} transactions.")
for field in _DYNAMIC_FIELDS - _EXPECTED_FIELDS[self.type_]:
if getattr(self, field) is not None:
raise ValueError(f"{field} must be None for {self.type_.value} transactions.")
if self.per_product is not None and self.per_product <= 0:
raise ValueError("per_product must be greater than zero.")
if (
self.per_product is not None
and self.product_count is not None
and self.amount is not None
and self.amount > self.per_product * self.product_count
):
raise ValueError(
"The real amount of the transaction must be less than the total value of the products."
)
# TODO: improve printing further
def __repr__(self) -> str:
sort_order = [
"id",
"time",
]
columns = ", ".join(
f"{k}={repr(v)}"
for k, v in sorted(
self.__dict__.items(),
key=lambda item: chr(sort_order.index(item[0]))
if item[0] in sort_order
else item[0],
)
if not any(
[
k == "type_",
(k == "message" and v is None),
k.startswith("_"),
# Ensure that we don't try to print out the entire list of
# relationships, which could create an infinite loop
isinstance(v, Base),
isinstance(v, InstrumentedList),
isinstance(v, InstrumentedSet),
isinstance(v, InstrumentedDict),
*[k in (_DYNAMIC_FIELDS - _EXPECTED_FIELDS[self.type_])],
]
)
)
return f"{self.type_.upper()}({columns})"
###################
# FACTORY METHODS #
###################
@classmethod
def adjust_balance(
cls: type[Self],
amount: int,
user_id: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
return cls(
time=time,
type_=TransactionType.ADJUST_BALANCE,
amount=amount,
user_id=user_id,
message=message,
)
@classmethod
def adjust_interest(
cls: type[Self],
interest_rate_percent: int,
user_id: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
return cls(
time=time,
type_=TransactionType.ADJUST_INTEREST,
interest_rate_percent=interest_rate_percent,
user_id=user_id,
message=message,
)
@classmethod
def adjust_penalty(
cls: type[Self],
penalty_multiplier_percent: int,
penalty_threshold: int,
user_id: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
return cls(
time=time,
type_=TransactionType.ADJUST_PENALTY,
penalty_multiplier_percent=penalty_multiplier_percent,
penalty_threshold=penalty_threshold,
user_id=user_id,
message=message,
)
@classmethod
def adjust_stock(
cls: type[Self],
user_id: int,
product_id: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
return cls(
time=time,
type_=TransactionType.ADJUST_STOCK,
user_id=user_id,
product_id=product_id,
product_count=product_count,
message=message,
)
@classmethod
def add_product(
cls: type[Self],
amount: int,
user_id: int,
product_id: int,
per_product: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
return cls(
time=time,
type_=TransactionType.ADD_PRODUCT,
amount=amount,
user_id=user_id,
product_id=product_id,
per_product=per_product,
product_count=product_count,
message=message,
)
@classmethod
def buy_product(
cls: type[Self],
user_id: int,
product_id: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
return cls(
time=time,
type_=TransactionType.BUY_PRODUCT,
user_id=user_id,
product_id=product_id,
product_count=product_count,
message=message,
)
@classmethod
def transfer(
cls: type[Self],
amount: int,
user_id: int,
transfer_user_id: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
return cls(
time=time,
type_=TransactionType.TRANSFER,
amount=amount,
user_id=user_id,
transfer_user_id=transfer_user_id,
message=message,
)

View File

@@ -0,0 +1,26 @@
from enum import StrEnum, auto
from sqlalchemy import Enum as SQLEnum
class TransactionType(StrEnum):
"""
Enum for transaction types.
"""
ADD_PRODUCT = auto()
ADJUST_BALANCE = auto()
ADJUST_INTEREST = auto()
ADJUST_PENALTY = auto()
ADJUST_STOCK = auto()
BUY_PRODUCT = auto()
TRANSFER = auto()
TransactionTypeSQL = SQLEnum(
TransactionType,
native_enum=True,
create_constraint=True,
validate_strings=True,
values_callable=lambda x: [i.value for i in x],
)

View File

@@ -1,49 +1,47 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self
from sqlalchemy import (
Integer,
String,
select,
)
from sqlalchemy.orm import (
Mapped,
Session,
mapped_column,
relationship,
)
from .Base import Base
if TYPE_CHECKING:
from .UserProducts import UserProducts
from .Transaction import Transaction
class User(Base):
__tablename__ = "users"
name: Mapped[str] = mapped_column(String(10), primary_key=True)
credit: Mapped[str] = mapped_column(Integer)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""Internal database ID"""
name: Mapped[str] = mapped_column(String(20), unique=True)
"""
The PVV username of the user.
"""
card: Mapped[str | None] = mapped_column(String(20))
rfid: Mapped[str | None] = mapped_column(String(20))
products: Mapped[set[UserProducts]] = relationship(back_populates="user")
transactions: Mapped[set[Transaction]] = relationship(back_populates="user")
# name_re = r"[a-z]+"
# card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?"
# rfid_re = r"[0-9a-fA-F]*"
name_re = r"[a-z]+"
card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?"
rfid_re = r"[0-9a-fA-F]*"
def __init__(self, name, card, rfid=None, credit=0):
def __init__(self: Self, name: str, card: str | None = None, rfid: str | None = None) -> None:
self.name = name
if card == "":
card = None
self.card = card
if rfid == "":
rfid = None
self.rfid = rfid
self.credit = credit
def __str__(self):
return self.name
# def __str__(self):
# return self.name
def is_anonymous(self):
return self.card == "11122233"
# def is_anonymous(self):
# return self.card == "11122233"

View File

@@ -0,0 +1,13 @@
from datetime import datetime
from sqlalchemy import Integer, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from dibbler.models import Base
# More like user balance cash money flow, amirite?
class UserBalanceCache(Base):
user_id: Mapped[int] = mapped_column(Integer, primary_key=True)
balance: Mapped[int] = mapped_column(Integer)
timestamp: Mapped[datetime] = mapped_column(DateTime)

View File

@@ -1,31 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import (
Integer,
ForeignKey,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
if TYPE_CHECKING:
from .User import User
from .Product import Product
class UserProducts(Base):
__tablename__ = "user_products"
user_name: Mapped[str] = mapped_column(ForeignKey("users.name"), primary_key=True)
product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id"), primary_key=True)
count: Mapped[int] = mapped_column(Integer)
sign: Mapped[int] = mapped_column(Integer)
user: Mapped[User] = relationship()
product: Mapped[Product] = relationship()

View File

@@ -1,17 +1,12 @@
__all__ = [
'Base',
'Product',
'Purchase',
'PurchaseEntry',
'Transaction',
'User',
'UserProducts',
"Base",
"Product",
"Transaction",
"User",
]
from .Base import Base
from .Product import Product
from .Purchase import Purchase
from .PurchaseEntry import PurchaseEntry
from .Transaction import Transaction
from .TransactionType import TransactionType
from .User import User
from .UserProducts import UserProducts

View File

View File

View File

@@ -0,0 +1,2 @@
# NOTE: this type of transaction should be password protected.
# the password can be set as a string literal in the config file.

View File

@@ -0,0 +1,2 @@
# NOTE: this type of transaction should be password protected.
# the password can be set as a string literal in the config file.

View File

@@ -0,0 +1,19 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
def current_interest(sql_session: Session) -> int:
result = sql_session.scalars(
select(Transaction)
.where(Transaction.type_ == TransactionType.ADJUST_INTEREST)
.order_by(Transaction.time.desc())
.limit(1)
).one_or_none()
if result is None:
return DEFAULT_INTEREST_RATE_PERCENTAGE
return result.interest_rate_percent

View File

@@ -0,0 +1,25 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import (
DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE,
DEFAULT_PENALTY_THRESHOLD,
)
def current_penalty(sql_session: Session) -> tuple[int, int]:
result = sql_session.scalars(
select(Transaction)
.where(Transaction.type_ == TransactionType.ADJUST_PENALTY)
.order_by(Transaction.time.desc())
.limit(1)
).one_or_none()
if result is None:
return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE
assert result.penalty_threshold is not None, "Penalty threshold must be set"
assert result.penalty_multiplier_percent is not None, "Penalty multiplier percent must be set"
return result.penalty_threshold, result.penalty_multiplier_percent

View File

@@ -0,0 +1,245 @@
import math
from dataclasses import dataclass
from datetime import datetime
from sqlalchemy import (
ColumnElement,
Integer,
SQLColumnExpression,
asc,
case,
cast,
func,
literal,
select,
)
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
TransactionType,
)
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
def _product_price_query(
product_id: int | ColumnElement[int],
use_cache: bool = True,
until: datetime | SQLColumnExpression[datetime] | None = None,
until_including: bool = True,
cte_name: str = "rec_cte",
):
"""
The inner query for calculating the product price.
"""
if use_cache:
print("WARNING: Using cache for product price query is not implemented yet.")
initial_element = select(
literal(0).label("i"),
literal(0).label("time"),
literal(None).label("transaction_id"),
literal(0).label("price"),
literal(0).label("product_count"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=asc(Transaction.time)).label("i"),
Transaction.id,
Transaction.time,
Transaction.type_,
Transaction.product_count,
Transaction.per_product,
)
.where(
Transaction.type_.in_(
[
TransactionType.BUY_PRODUCT,
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_STOCK,
]
),
Transaction.product_id == product_id,
case(
(literal(until_including), Transaction.time <= until),
else_=Transaction.time < until,
)
if until is not None
else literal(True),
)
.order_by(Transaction.time.asc())
.alias("trx_subset")
)
recursive_elements = (
select(
trx_subset.c.i,
trx_subset.c.time,
trx_subset.c.id.label("transaction_id"),
case(
# Someone buys the product -> price remains the same.
(trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price),
# Someone adds the product -> price is recalculated based on
# product count, previous price, and new price.
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
cast(
func.ceil(
(
recursive_cte.c.price * func.max(recursive_cte.c.product_count, 0)
+ trx_subset.c.per_product * trx_subset.c.product_count
)
/ (
# The running product count can be negative if the accounting is bad.
# This ensures that we never end up with negative prices or zero divisions
# and other disastrous phenomena.
func.max(recursive_cte.c.product_count, 0)
+ trx_subset.c.product_count
)
),
Integer,
),
),
# Someone adjusts the stock -> price remains the same.
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK, recursive_cte.c.price),
# Should never happen
else_=recursive_cte.c.price,
).label("price"),
case(
# Someone buys the product -> product count is reduced.
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT,
recursive_cte.c.product_count - trx_subset.c.product_count,
),
# Someone adds the product -> product count is increased.
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
recursive_cte.c.product_count + trx_subset.c.product_count,
),
# Someone adjusts the stock -> product count is adjusted.
(
trx_subset.c.type_ == TransactionType.ADJUST_STOCK,
recursive_cte.c.product_count + trx_subset.c.product_count,
),
# Should never happen
else_=recursive_cte.c.product_count,
).label("product_count"),
)
.select_from(trx_subset)
.where(trx_subset.c.i == recursive_cte.c.i + 1)
)
return recursive_cte.union_all(recursive_elements)
# TODO: create a function for the log that pretty prints the log entries
# for debugging purposes
@dataclass
class ProductPriceLogEntry:
transaction: Transaction
price: int
product_count: int
def product_price_log(
sql_session: Session,
product: Product,
use_cache: bool = True,
until: Transaction | None = None,
) -> list[ProductPriceLogEntry]:
"""
Calculates the price of a product and returns a log of the price changes.
"""
recursive_cte = _product_price_query(
product.id,
use_cache=use_cache,
until=until.time if until else None,
)
result = sql_session.execute(
select(
Transaction,
recursive_cte.c.price,
recursive_cte.c.product_count,
)
.select_from(recursive_cte)
.join(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.order_by(recursive_cte.c.i.asc())
).all()
if result is None:
# If there are no transactions for this product, the query should return an empty list, not None.
raise RuntimeError(
f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})."
)
return [
ProductPriceLogEntry(
transaction=row[0],
price=row.price,
product_count=row.product_count,
)
for row in result
]
@staticmethod
def product_price(
sql_session: Session,
product: Product,
use_cache: bool = True,
until: Transaction | None = None,
include_interest: bool = False,
) -> int:
"""
Calculates the price of a product.
"""
recursive_cte = _product_price_query(
product.id,
use_cache=use_cache,
until=until.time if until else None,
)
# TODO: optionally verify subresults:
# - product_count should never be negative (but this happens sometimes, so just a warning)
# - price should never be negative
result = sql_session.scalars(
select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1)
).one_or_none()
if result is None:
# If there are no transactions for this product, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})."
)
if include_interest:
interest_rate = (
sql_session.scalar(
select(Transaction.interest_rate_percent)
.where(
Transaction.type_ == TransactionType.ADJUST_INTEREST,
literal(True) if until is None else Transaction.time <= until.time,
)
.order_by(Transaction.time.desc())
.limit(1)
)
or DEFAULT_INTEREST_RATE_PERCENTAGE
)
result = math.ceil(result * interest_rate / 100)
return result

View File

@@ -0,0 +1,76 @@
from datetime import datetime
from sqlalchemy import case, func, literal, select
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
Transaction,
TransactionType,
)
def _product_stock_query(
product_id: int,
use_cache: bool = True,
until: datetime | None = None,
):
"""
The inner query for calculating the product stock.
"""
if use_cache:
print("WARNING: Using cache for product stock query is not implemented yet.")
query = select(
func.sum(
case(
(
Transaction.type_ == TransactionType.ADD_PRODUCT,
Transaction.product_count,
),
(
Transaction.type_ == TransactionType.BUY_PRODUCT,
-Transaction.product_count,
),
(
Transaction.type_ == TransactionType.ADJUST_STOCK,
Transaction.product_count,
),
else_=0,
)
)
).where(
Transaction.type_.in_(
[
TransactionType.BUY_PRODUCT,
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_STOCK,
]
),
Transaction.product_id == product_id,
Transaction.time <= until if until is not None else literal(True),
)
return query
def product_stock(
sql_session: Session,
product: Product,
use_cache: bool = True,
until: datetime | None = None,
) -> int:
"""
Returns the number of products in stock.
"""
query = _product_stock_query(
product_id=product.id,
use_cache=use_cache,
until=until,
)
result = sql_session.scalars(query).one_or_none()
return result or 0

View File

@@ -0,0 +1,39 @@
from sqlalchemy import and_, literal, or_, select
from sqlalchemy.orm import Session
from dibbler.models import Product
def search_product(
string: str,
sql_session: Session,
find_hidden_products=True,
) -> Product | list[Product]:
exact_match = sql_session.scalars(
select(Product).where(
or_(
Product.bar_code == string,
and_(
Product.name == string,
literal(True) if find_hidden_products else not Product.hidden,
),
)
)
).first()
if exact_match:
return exact_match
product_list = sql_session.scalars(
select(Product).where(
or_(
Product.bar_code.ilike(f"%{string}%"),
and_(
Product.name.ilike(f"%{string}%"),
literal(True) if find_hidden_products else not Product.hidden,
),
)
)
).all()
return list(product_list)

View File

@@ -0,0 +1,37 @@
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from dibbler.models import User
def search_user(
string: str,
sql_session: Session,
ignorethisflag=None,
) -> User | list[User]:
string = string.lower()
exact_match = sql_session.scalars(
select(User).where(
or_(
User.name == string,
User.card == string,
User.rfid == string,
)
)
).first()
if exact_match:
return exact_match
user_list = sql_session.scalars(
select(User).where(
or_(
User.name.ilike(f"%{string}%"),
User.card.ilike(f"%{string}%"),
User.rfid.ilike(f"%{string}%"),
)
)
).all()
return list(user_list)

View File

@@ -0,0 +1,319 @@
from dataclasses import dataclass
from datetime import datetime
from sqlalchemy import (
Float,
Integer,
and_,
asc,
case,
cast,
column,
func,
literal,
or_,
select,
)
from sqlalchemy.orm import Session
from dibbler.models import (
Transaction,
TransactionType,
User,
)
from dibbler.models.Transaction import (
DEFAULT_INTEREST_RATE_PERCENTAGE,
DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE,
DEFAULT_PENALTY_THRESHOLD,
)
from dibbler.queries.product_price import _product_price_query
def _user_balance_query(
user_id: int,
use_cache: bool = True,
until: datetime | None = None,
until_including: bool = True,
cte_name: str = "rec_cte",
):
"""
The inner query for calculating the user's balance.
"""
if use_cache:
print("WARNING: Using cache for user balance query is not implemented yet.")
initial_element = select(
literal(0).label("i"),
literal(0).label("time"),
literal(None).label("transaction_id"),
literal(0).label("balance"),
literal(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
literal(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
literal(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=asc(Transaction.time)).label("i"),
Transaction.amount,
Transaction.id,
Transaction.interest_rate_percent,
Transaction.penalty_multiplier_percent,
Transaction.penalty_threshold,
Transaction.product_count,
Transaction.product_id,
Transaction.time,
Transaction.transfer_user_id,
Transaction.type_,
)
.where(
or_(
and_(
Transaction.user_id == user_id,
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_BALANCE,
TransactionType.BUY_PRODUCT,
TransactionType.TRANSFER,
]
),
),
and_(
Transaction.type_ == TransactionType.TRANSFER,
Transaction.transfer_user_id == user_id,
),
Transaction.type_.in_(
[
TransactionType.ADJUST_INTEREST,
TransactionType.ADJUST_PENALTY,
]
),
),
case(
(literal(until_including), Transaction.time <= until),
else_=Transaction.time < until,
)
if until is not None
else literal(True),
)
.order_by(Transaction.time.asc())
.alias("trx_subset")
)
recursive_elements = (
select(
trx_subset.c.i,
trx_subset.c.time,
trx_subset.c.id.label("transaction_id"),
case(
# Adjusts balance -> balance gets adjusted
(
trx_subset.c.type_ == TransactionType.ADJUST_BALANCE,
recursive_cte.c.balance + trx_subset.c.amount,
),
# Adds a product -> balance increases
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
recursive_cte.c.balance + trx_subset.c.amount,
),
# Buys a product -> balance decreases
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT,
recursive_cte.c.balance
- (
trx_subset.c.product_count
# Price of a single product, accounted for penalties and interest.
* cast(
func.ceil(
# TODO: This can get quite expensive real quick, so we should do some caching of the
# product prices somehow.
# Base price
(
# FIXME: this always returns 0 for some reason...
select(cast(column("price"), Float))
.select_from(
_product_price_query(
trx_subset.c.product_id,
use_cache=use_cache,
until=trx_subset.c.time,
until_including=False,
cte_name="product_price_cte",
)
)
.order_by(column("i").desc())
.limit(1)
).scalar_subquery()
# TODO: should interest be applied before or after the penalty multiplier?
# at the moment of writing, after sound right, but maybe ask someone?
# Interest
* (cast(recursive_cte.c.interest_rate_percent, Float) / 100)
# Penalty
* case(
(
# TODO: should this be <= or <?
recursive_cte.c.balance < recursive_cte.c.penalty_threshold,
(
cast(recursive_cte.c.penalty_multiplier_percent, Float)
/ 100
),
),
else_=1.0,
)
),
Integer,
)
),
),
# Transfers money to self -> balance increases
(
and_(
trx_subset.c.type_ == TransactionType.TRANSFER,
trx_subset.c.transfer_user_id == user_id,
),
recursive_cte.c.balance + trx_subset.c.amount,
),
# Transfers money from self -> balance decreases
(
and_(
trx_subset.c.type_ == TransactionType.TRANSFER,
trx_subset.c.transfer_user_id != user_id,
),
recursive_cte.c.balance - trx_subset.c.amount,
),
# Interest adjustment -> balance stays the same
# Penalty adjustment -> balance stays the same
else_=recursive_cte.c.balance,
).label("balance"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_INTEREST,
trx_subset.c.interest_rate_percent,
),
else_=recursive_cte.c.interest_rate_percent,
).label("interest_rate_percent"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY,
trx_subset.c.penalty_threshold,
),
else_=recursive_cte.c.penalty_threshold,
).label("penalty_threshold"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY,
trx_subset.c.penalty_multiplier_percent,
),
else_=recursive_cte.c.penalty_multiplier_percent,
).label("penalty_multiplier_percent"),
)
.select_from(trx_subset)
.where(trx_subset.c.i == recursive_cte.c.i + 1)
)
return recursive_cte.union_all(recursive_elements)
# TODO: create a function for the log that pretty prints the log entries
# for debugging purposes
@dataclass
class UserBalanceLogEntry:
transaction: Transaction
balance: int
interest_rate_percent: int
penalty_threshold: int
penalty_multiplier_percent: int
def is_penalized(self) -> bool:
"""
Returns whether this exact transaction is penalized.
"""
return False
# return self.transaction.type_ == TransactionType.BUY_PRODUCT and prev?
def user_balance_log(
sql_session: Session,
user: User,
use_cache: bool = True,
until: Transaction | None = None,
) -> list[UserBalanceLogEntry]:
"""
Returns a log of the user's balance over time, including interest and penalty adjustments.
"""
recursive_cte = _user_balance_query(
user.id,
use_cache=use_cache,
until=until.time if until else None,
)
result = sql_session.execute(
select(
Transaction,
recursive_cte.c.balance,
recursive_cte.c.interest_rate_percent,
recursive_cte.c.penalty_threshold,
recursive_cte.c.penalty_multiplier_percent,
)
.select_from(recursive_cte)
.join(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.order_by(recursive_cte.c.i.asc())
).all()
if result is None:
# If there are no transactions for this user, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
)
return [
UserBalanceLogEntry(
transaction=row[0],
balance=row.balance,
interest_rate_percent=row.interest_rate_percent,
penalty_threshold=row.penalty_threshold,
penalty_multiplier_percent=row.penalty_multiplier_percent,
)
for row in result
]
def user_balance(
sql_session: Session,
user: User,
use_cache: bool = True,
until: Transaction | None = None,
) -> int:
"""
Calculates the balance of a user.
"""
recursive_cte = _user_balance_query(
user.id,
use_cache=use_cache,
until=until.time if until else None,
)
result = sql_session.scalar(
select(recursive_cte.c.balance).order_by(recursive_cte.c.i.desc()).limit(1)
)
if result is None:
# If there are no transactions for this user, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
)
return result

View File

@@ -0,0 +1,20 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
# TODO: allow filtering out 'special transactions' like 'ADJUST_INTEREST' and 'ADJUST_PENALTY'
def user_transactions(sql_session: Session, user: User) -> list[Transaction]:
"""
Returns the transactions of the user in chronological order.
"""
return list(
sql_session.scalars(
select(Transaction)
.where(Transaction.user_id == user.id)
.order_by(Transaction.time.asc())
).all()
)

View File

View File

View File

@@ -0,0 +1,77 @@
from datetime import datetime
from pathlib import Path
from dibbler.db import Session
from dibbler.models import Product, Transaction, User
JSON_FILE = Path(__file__).parent.parent.parent / "mock_data.json"
# TODO: integrate this as a part of create-db, either asking interactively
# whether to seed test data, or by using command line arguments for
# automatating the answer.
def clear_db(sql_session):
sql_session.query(Product).delete()
sql_session.query(User).delete()
sql_session.commit()
def main():
# TODO: There is some leftover json data in the mock_data.json file.
# It should be dealt with before merging this PR, either by removing
# it or using it here.
sql_session = Session()
clear_db(sql_session)
# Add users
user1 = User("Test User 1")
user2 = User("Test User 2")
sql_session.add(user1)
sql_session.add(user2)
sql_session.commit()
# Add products
product1 = Product("1234567890123", "Test Product 1")
product2 = Product("9876543210987", "Test Product 2")
sql_session.add(product1)
sql_session.add(product2)
sql_session.commit()
# Add transactions
transactions = [
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0),
amount=100,
user_id=user1.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 1),
amount=50,
user_id=user2.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 2),
amount=-50,
user_id=user1.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user1.id,
product_id=product1.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 1),
product_count=1,
user_id=user2.id,
product_id=product1.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()

View File

@@ -5,7 +5,7 @@ show_tracebacks = true
input_encoding = 'utf8'
[database]
; url = postgresql://robertem@127.0.0.1/pvvvv
# url = "postgresql://robertem@127.0.0.1/pvvvv"
url = sqlite:///test.db
[limits]

23
flake.lock generated
View File

@@ -5,31 +5,32 @@
"systems": "systems"
},
"locked": {
"lastModified": 1692799911,
"narHash": "sha256-3eihraek4qL744EvQXsK1Ha6C3CR7nnT8X2qWap4RNk=",
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "f9e7cf818399d17d347f847525c5a5a8032e4e44",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
"id": "flake-utils",
"type": "indirect"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1693145325,
"narHash": "sha256-Gat9xskErH1zOcLjYMhSDBo0JTBZKfGS0xJlIRnj6Rc=",
"lastModified": 1749285348,
"narHash": "sha256-frdhQvPbmDYaScPFiCnfdh3B/Vh81Uuoo0w5TkWmmjU=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "cddebdb60de376c1bdb7a4e6ee3d98355453fe56",
"rev": "3e3afe5174c561dee0df6f2c2b2236990146329f",
"type": "github"
},
"original": {
"id": "nixpkgs",
"type": "indirect"
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {

View File

@@ -1,77 +1,65 @@
{
description = "Dibbler samspleisebod";
inputs.flake-utils.url = "github:numtide/flake-utils";
inputs.nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
outputs = { self, nixpkgs, flake-utils }:
flake-utils.lib.eachDefaultSystem (system: let
outputs = { self, nixpkgs, flake-utils }: let
inherit (nixpkgs) lib;
systems = [
"x86_64-linux"
"aarch64-linux"
"x86_64-darwin"
"aarch64-darwin"
];
forAllSystems = f: lib.genAttrs systems (system: let
pkgs = nixpkgs.legacyPackages.${system};
in {
packages = {
in f system pkgs);
in {
packages = forAllSystems (system: pkgs: {
default = self.packages.${system}.dibbler;
dibbler = pkgs.callPackage ./nix/dibbler.nix {
python3Packages = pkgs.python311Packages;
python3Packages = pkgs.python312Packages;
};
};
skrot = self.nixosConfigurations.skrot.config.system.build.sdImage;
});
apps = {
apps = forAllSystems (system: pkgs: {
default = self.apps.${system}.dibbler;
dibbler = flake-utils.lib.mkApp {
drv = self.packages.${system}.dibbler;
};
};
});
devShells = {
default = self.devShells.${system}.dibbler;
dibbler = pkgs.mkShell {
packages = with pkgs; [
python311Packages.black
ruff
];
overlays = {
default = self.overlays.dibbler;
dibbler = final: prev: {
inherit (self.packages.${prev.system}) dibbler;
};
};
})
//
devShells = forAllSystems (system: pkgs: {
default = self.devShells.${system}.dibbler;
dibbler = pkgs.callPackage ./nix/shell.nix {
python = pkgs.python312;
};
});
{
# Note: using the module requires that you have applied the
# overlay first
# Note: using the module requires that you have applied the overlay first
nixosModules.default = import ./nix/module.nix;
images.skrot = self.nixosConfigurations.skrot.config.system.build.sdImage;
nixosConfigurations.skrot = nixpkgs.lib.nixosSystem {
nixosConfigurations.skrot = nixpkgs.lib.nixosSystem (rec {
system = "aarch64-linux";
pkgs = import nixpkgs {
inherit system;
overlays = [ self.overlays.dibbler ];
};
modules = [
(nixpkgs + "/nixos/modules/installer/sd-card/sd-image-aarch64.nix")
self.nixosModules.default
({...}: {
system.stateVersion = "22.05";
networking = {
hostName = "skrot";
domain = "pvv.ntnu.no";
nameservers = [ "129.241.0.200" "129.241.0.201" ];
defaultGateway = "129.241.210.129";
interfaces.eth0 = {
useDHCP = false;
ipv4.addresses = [{
address = "129.241.210.235";
prefixLength = 25;
}];
};
};
# services.resolved.enable = true;
# systemd.network.enable = true;
# systemd.network.networks."30-network" = {
# matchConfig.Name = "*";
# DHCP = "no";
# address = [ "129.241.210.235/25" ];
# gateway = [ "129.241.210.129" ];
# };
})
./nix/skrott.nix
];
};
});
};
}

76
mock_data.json Normal file
View File

@@ -0,0 +1,76 @@
{
"products": [
{
"product_id": 1,
"bar_code": "1234567890123",
"name": "Wireless Mouse",
"price": 2999,
"stock": 150,
"hidden": false
},
{
"product_id": 2,
"bar_code": "9876543210987",
"name": "Mechanical Keyboard",
"price": 5999,
"stock": 75,
"hidden": false
},
{
"product_id": 3,
"bar_code": "1112223334445",
"name": "Gaming Monitor",
"price": 19999,
"stock": 20,
"hidden": false
},
{
"product_id": 4,
"bar_code": "5556667778889",
"name": "USB-C Docking Station",
"price": 8999,
"stock": 50,
"hidden": true
},
{
"product_id": 5,
"bar_code": "4445556667771",
"name": "Noise Cancelling Headphones",
"price": 12999,
"stock": 30,
"hidden": true
}
],
"users": [
{
"name": "Albert",
"credit": 42069,
"card": "NTU12345678",
"rfid": "a1b2c3d4e5"
},
{
"name": "lorem",
"credit": 2000,
"card": "9876543210",
"rfid": "f6e7d8c9b0"
},
{
"name": "ibsum",
"credit": 1000,
"card": "11122233",
"rfid": ""
},
{
"name": "dave",
"credit": 7500,
"card": "NTU56789012",
"rfid": "1234abcd5678"
},
{
"name": "eve",
"credit": 3000,
"card": null,
"rfid": "deadbeef1234"
}
]
}

View File

@@ -4,11 +4,23 @@
}:
python3Packages.buildPythonApplication {
pname = "dibbler";
version = "unstable-2021-09-07";
version = "unstable";
src = lib.cleanSource ../.;
format = "pyproject";
# brother-ql is breaky breaky
# https://github.com/NixOS/nixpkgs/issues/285234
dontCheckRuntimeDeps = true;
pythonImportsCheck = [];
doCheck = true;
nativeCheckInputs = with python3Packages; [
pytest
pytestCheckHook
];
nativeBuildInputs = with python3Packages; [ setuptools ];
propagatedBuildInputs = with python3Packages; [
brother-ql

View File

@@ -1,16 +1,31 @@
{ config, pkgs, lib, ... }: let
cfg = config.services.dibbler;
format = pkgs.formats.ini { };
in {
options.services.dibbler = {
enable = lib.mkEnableOption "dibbler, the little kiosk computer";
package = lib.mkPackageOption pkgs "dibbler" { };
config = lib.mkOption {
default = ../conf.py;
settings = lib.mkOption {
description = "Configuration for dibbler";
default = { };
type = lib.types.submodule {
freeformType = format.type;
};
};
};
config = let
screen = "${pkgs.screen}/bin/screen";
in {
in lib.mkIf cfg.enable {
services.dibbler.settings = lib.pipe ../example-config.ini [
builtins.readFile
builtins.fromTOML
(lib.mapAttrsRecursive (_: lib.mkDefault))
];
boot = {
consoleLogLevel = 0;
enableContainers = false;
@@ -23,7 +38,7 @@ in {
group = "dibbler";
extraGroups = [ "lp" ];
isNormalUser = true;
shell = ((pkgs.writeShellScriptBin "login-shell" "${screen} -x dibbler") // {shellPath = "/bin/login-shell";});
shell = (pkgs.writeShellScriptBin "login-shell" "${screen} -x dibbler") // {shellPath = "/bin/login-shell";};
};
};
@@ -32,7 +47,9 @@ in {
wantedBy = [ "default.target" ];
serviceConfig = {
ExecStartPre = "-${screen} -X -S dibbler kill";
ExecStart = "${screen} -dmS dibbler -O -l ${cfg.package}/bin/dibbler --config ${cfg.config} loop";
ExecStart = let
config = format.generate "dibbler-config.ini" cfg.settings;
in "${screen} -dmS dibbler -O -l ${cfg.package}/bin/dibbler --config ${config} loop";
ExecStartPost = "${screen} -X -S dibbler width 42 80";
User = "dibbler";
Group = "dibbler";
@@ -69,7 +86,7 @@ in {
console.keyMap = "no";
programs.command-not-found.enable = false;
i18n.supportedLocales = [ "en_US.UTF-8/UTF-8" ];
environment.noXlibs = true;
# environment.noXlibs = true;
documentation = {
info.enable = false;

23
nix/shell.nix Normal file
View File

@@ -0,0 +1,23 @@
{
mkShell,
python,
ruff,
uv,
}:
mkShell {
packages = [
ruff
uv
(python.withPackages (ps: with ps; [
brother-ql
matplotlib
psycopg2
python-barcode
sqlalchemy
pytest
pytest-cov
]))
];
}

27
nix/skrott.nix Normal file
View File

@@ -0,0 +1,27 @@
{...}: {
system.stateVersion = "25.05";
services.dibbler.enable = true;
networking = {
hostName = "skrot";
domain = "pvv.ntnu.no";
nameservers = [ "129.241.0.200" "129.241.0.201" ];
defaultGateway = "129.241.210.129";
interfaces.eth0 = {
useDHCP = false;
ipv4.addresses = [{
address = "129.241.210.235";
prefixLength = 25;
}];
};
};
# services.resolved.enable = true;
# systemd.network.enable = true;
# systemd.network.networks."30-network" = {
# matchConfig.Name = "*";
# DHCP = "no";
# address = [ "129.241.210.235/25" ];
# gateway = [ "129.241.210.129" ];
# };
}

View File

@@ -8,7 +8,6 @@ authors = []
description = "EDB-system for PVV"
readme = "README.md"
requires-python = ">=3.11"
license = {text = "BSD-3-Clause"}
classifiers = [
"Programming Language :: Python :: 3",
]
@@ -21,6 +20,12 @@ dependencies = [
]
dynamic = ["version"]
[project.optional-dependencies]
dev = [
"pytest",
"pytest-cov",
]
[tool.setuptools.packages.find]
include = ["dibbler*"]
@@ -32,4 +37,3 @@ line-length = 100
[tool.ruff]
line-length = 100

0
tests/__init__.py Normal file
View File

36
tests/conftest.py Normal file
View File

@@ -0,0 +1,36 @@
import pytest
from sqlalchemy import create_engine, event
from sqlalchemy.orm import Session
from dibbler.models import Base
def pytest_addoption(parser):
parser.addoption(
"--echo",
action="store_true",
help="Enable SQLAlchemy echo mode for debugging",
)
@pytest.fixture(scope="function")
def sql_session(request):
"""Create a new SQLAlchemy session for testing."""
echo = request.config.getoption("--echo")
engine = create_engine(
"sqlite:///:memory:",
echo=echo,
)
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, _connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
Base.metadata.create_all(engine)
with Session(engine) as sql_session:
yield sql_session

0
tests/models/__init__.py Normal file
View File

View File

@@ -0,0 +1,32 @@
import pytest
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from dibbler.models import Product
def insert_test_data(sql_session: Session) -> Product:
product = Product("1234567890123", "Test Product")
sql_session.add(product)
sql_session.commit()
return product
def test_product_no_duplicate_barcodes(sql_session: Session):
product = insert_test_data(sql_session)
duplicate_product = Product(product.bar_code, "Hehe >:)")
sql_session.add(duplicate_product)
with pytest.raises(IntegrityError):
sql_session.commit()
def test_product_no_duplicate_names(sql_session: Session):
product = insert_test_data(sql_session)
duplicate_product = Product("1918238911928", product.name)
sql_session.add(duplicate_product)
with pytest.raises(IntegrityError):
sql_session.commit()

View File

@@ -0,0 +1,199 @@
from datetime import datetime
import pytest
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries.product_stock import product_stock
def insert_test_data(sql_session: Session) -> tuple[User, Product]:
user = User("Test User")
product = Product("1234567890123", "Test Product")
sql_session.add(user)
sql_session.add(product)
sql_session.commit()
return user, product
def test_transaction_no_duplicate_timestamps(sql_session: Session):
user, _ = insert_test_data(sql_session)
transaction1 = Transaction.adjust_balance(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
amount=100,
)
sql_session.add(transaction1)
sql_session.commit()
transaction2 = Transaction.adjust_balance(
time=transaction1.time,
user_id=user.id,
amount=-50,
)
sql_session.add(transaction2)
with pytest.raises(IntegrityError):
sql_session.commit()
def test_user_not_allowed_to_transfer_to_self(sql_session: Session) -> None:
user, _ = insert_test_data(sql_session)
transaction = Transaction.transfer(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
transfer_user_id=user.id,
amount=50,
)
sql_session.add(transaction)
with pytest.raises(IntegrityError):
sql_session.commit()
def test_product_foreign_key_constraint(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transaction = Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
)
sql_session.add(transaction)
sql_session.commit()
# Attempt to add a transaction with a non-existent product
invalid_transaction = Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
user_id=user.id,
product_id=9999, # Non-existent product ID
amount=27,
per_product=27,
product_count=1,
)
sql_session.add(invalid_transaction)
with pytest.raises(IntegrityError):
sql_session.commit()
def test_user_foreign_key_constraint(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transaction = Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
)
sql_session.add(transaction)
sql_session.commit()
# Attempt to add a transaction with a non-existent user
invalid_transaction = Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
user_id=9999, # Non-existent user ID
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
)
sql_session.add(invalid_transaction)
with pytest.raises(IntegrityError):
sql_session.commit()
def test_transaction_buy_product_more_than_stock(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 13, 0, 0),
product_count=10,
user_id=user.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
assert product_stock(sql_session, product) == 1 - 10
def test_transaction_buy_product_dont_allow_no_add_product_transactions(
sql_session: Session,
) -> None:
user, product = insert_test_data(sql_session)
transaction = Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
product_count=1,
user_id=user.id,
product_id=product.id,
)
sql_session.add(transaction)
with pytest.raises(ValueError):
sql_session.commit()
def test_transaction_add_product_deny_amount_over_per_product_times_product_count(
sql_session: Session,
) -> None:
user, product = insert_test_data(sql_session)
with pytest.raises(ValueError):
_transaction = Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27 * 2 + 1, # Invalid amount
per_product=27,
product_count=2,
)
def test_transaction_add_product_allow_amount_under_per_product_times_product_count(
sql_session: Session,
) -> None:
user, product = insert_test_data(sql_session)
transaction = Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27 * 2 - 1, # Valid amount
per_product=27,
product_count=2,
)
sql_session.add(transaction)
sql_session.commit()

25
tests/models/test_user.py Normal file
View File

@@ -0,0 +1,25 @@
from datetime import datetime
import pytest
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def insert_test_data(sql_session: Session) -> User:
user = User("Test User")
sql_session.add(user)
sql_session.commit()
return user
def test_ensure_no_duplicate_user_names(sql_session: Session):
user = insert_test_data(sql_session)
user2 = User(user.name)
sql_session.add(user2)
with pytest.raises(IntegrityError):
sql_session.commit()

View File

View File

@@ -0,0 +1,342 @@
import math
from datetime import datetime
from pprint import pprint
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries.product_price import product_price, product_price_log
# TODO: see if we can use pytest_runtest_makereport to print the "product_price_log"s
# only on failures instead of inlining it in every test function
def insert_test_data(sql_session: Session) -> tuple[User, Product]:
user = User("Test User")
product = Product("1234567890123", "Test Product")
sql_session.add(user)
sql_session.add(product)
sql_session.commit()
return user, product
def test_product_price_no_transactions(sql_session: Session) -> None:
_, product = insert_test_data(sql_session)
pprint(product_price_log(sql_session, product))
assert product_price(sql_session, product) == 0
def test_product_price_basic_history(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
amount=27 * 2 - 1,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
assert product_price(sql_session, product) == 27
def test_product_price_sold_out(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
amount=27 * 2 - 1,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 1),
product_count=2,
user_id=user.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
assert product_price(sql_session, product) == 27
def test_product_price_interest(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 12, 0, 0),
interest_rate_percent=110,
user_id=user.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
amount=27 * 2 - 1,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
product_price_ = product_price(sql_session, product)
product_price_interest = product_price(sql_session, product, include_interest=True)
assert product_price_ == 27
assert product_price_interest == math.ceil(27 * 1.1)
def test_product_price_changing_interest(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 12, 0, 0),
interest_rate_percent=110,
user_id=user.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
amount=27 * 2 - 1,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 12, 0, 2),
interest_rate_percent=120,
user_id=user.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
product_price_interest = product_price(sql_session, product, include_interest=True)
assert product_price_interest == math.ceil(27 * 1.2)
def test_product_price_old_transaction(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
# Price should be 27
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 2),
amount=38 * 3,
per_product=38,
product_count=3,
user_id=user.id,
product_id=product.id,
),
# price should be averaged upwards
]
sql_session.add_all(transactions)
sql_session.commit()
until_transaction = transactions[0]
pprint(
product_price_log(
sql_session,
product,
until=until_transaction,
)
)
product_price_ = product_price(
sql_session,
product,
until=until_transaction,
)
assert product_price_ == 27
# Price goes up and gets rounded up to the next integer
def test_product_price_round_up_from_below(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
# Price should be 27
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 2),
amount=38 * 3,
per_product=38,
product_count=3,
user_id=user.id,
product_id=product.id,
),
# price should be averaged upwards
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
product_price_ = product_price(sql_session, product)
assert product_price_ == math.ceil((27 * 2 + 38 * 3) / (2 + 3))
# Price goes down and gets rounded up to the next integer
def test_product_price_round_up_from_above(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
# Price should be 27
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 2),
amount=20 * 3,
per_product=20,
product_count=3,
user_id=user.id,
product_id=product.id,
),
# price should be averaged downwards
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
product_price_ = product_price(sql_session, product)
assert product_price_ == math.ceil((27 * 2 + 20 * 3) / (2 + 3))
def test_product_price_with_negative_stock_single_addition(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 0),
amount=1,
per_product=10,
product_count=1,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 13, 0, 1),
product_count=10,
user_id=user.id,
product_id=product.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 2),
amount=22,
per_product=22,
product_count=1,
user_id=user.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
# Stock went subzero, price should be the last added product price
product1_price = product_price(sql_session, product)
assert product1_price == 22
# TODO: what happens when stock is still negative and yet new products are added?
def test_product_price_with_negative_stock_multiple_additions(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 0),
amount=1,
per_product=10,
product_count=1,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 13, 0, 1),
product_count=10,
user_id=user.id,
product_id=product.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 2),
amount=22,
per_product=22,
product_count=1,
user_id=user.id,
product_id=product.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 3),
amount=29,
per_product=29,
product_count=2,
user_id=user.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
# Stock went subzero, price should be the ceiled average of the last added products
product1_price = product_price(sql_session, product)
assert product1_price == math.ceil((22 + 29 * 2) / (1 + 2))

View File

@@ -0,0 +1,141 @@
from datetime import datetime
from sqlalchemy import select
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries.product_stock import product_stock
def insert_test_data(sql_session: Session) -> None:
user1 = User("Test User 1")
sql_session.add(user1)
sql_session.commit()
def test_product_stock_basic_history(sql_session: Session) -> None:
insert_test_data(sql_session)
user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one()
product = Product("1234567890123", "Test Product")
sql_session.add(product)
sql_session.commit()
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
amount=10,
per_product=10,
user_id=user1.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
assert product_stock(sql_session, product) == 1
def test_product_stock_complex_history(sql_session: Session) -> None:
insert_test_data(sql_session)
user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one()
product = Product("1234567890123", "Test Product")
sql_session.add(product)
sql_session.commit()
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 0),
amount=27 * 2,
per_product=27,
user_id=user1.id,
product_id=product.id,
product_count=2,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 13, 0, 1),
user_id=user1.id,
product_id=product.id,
product_count=3,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 2),
amount=50 * 4,
per_product=50,
user_id=user1.id,
product_id=product.id,
product_count=4,
),
Transaction.adjust_stock(
time=datetime(2023, 10, 1, 15, 0, 0),
user_id=user1.id,
product_id=product.id,
product_count=3,
),
Transaction.adjust_stock(
time=datetime(2023, 10, 1, 15, 0, 1),
user_id=user1.id,
product_id=product.id,
product_count=-2,
),
]
sql_session.add_all(transactions)
sql_session.commit()
assert product_stock(sql_session, product) == 2 - 3 + 4 + 3 - 2
def test_product_stock_no_transactions(sql_session: Session) -> None:
insert_test_data(sql_session)
product = Product("1234567890123", "Test Product")
sql_session.add(product)
sql_session.commit()
assert product_stock(sql_session, product) == 0
def test_negative_product_stock(sql_session: Session) -> None:
insert_test_data(sql_session)
user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one()
product = Product("1234567890123", "Test Product")
sql_session.add(product)
sql_session.commit()
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 14, 0, 0),
amount=50,
per_product=50,
user_id=user1.id,
product_id=product.id,
product_count=1,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 14, 0, 1),
user_id=user1.id,
product_id=product.id,
product_count=2,
),
Transaction.adjust_stock(
time=datetime(2023, 10, 1, 16, 0, 0),
user_id=user1.id,
product_id=product.id,
product_count=-1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
# The stock should be negative because we added and bought the product
assert product_stock(sql_session, product) == 1 - 2 - 1

View File

@@ -0,0 +1,306 @@
import math
from datetime import datetime
from pprint import pprint
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries.user_balance import user_balance, user_balance_log
# TODO: see if we can use pytest_runtest_makereport to print the "user_balance_log"s
# only on failures instead of inlining it in every test function
def insert_test_data(sql_session: Session) -> tuple[User, Product]:
user = User("Test User")
product = Product("1234567890123", "Test Product")
sql_session.add(user)
sql_session.add(product)
sql_session.commit()
return user, product
def test_user_balance_no_transactions(sql_session: Session) -> None:
user, _ = insert_test_data(sql_session)
pprint(user_balance_log(sql_session, user))
balance = user_balance(sql_session, user)
assert balance == 0
def test_user_balance_basic_history(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
amount=100,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 1),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
balance = user_balance(sql_session, user)
assert balance == 100 + 27
def test_user_balance_with_transfers(sql_session: Session) -> None:
user1, product = insert_test_data(sql_session)
user2 = User("Test User 2")
sql_session.add(user2)
sql_session.commit()
transactions = [
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user1.id,
amount=100,
),
Transaction.transfer(
time=datetime(2023, 10, 1, 10, 0, 1),
user_id=user1.id,
transfer_user_id=user2.id,
amount=50,
),
Transaction.transfer(
time=datetime(2023, 10, 1, 10, 0, 2),
user_id=user2.id,
transfer_user_id=user1.id,
amount=30,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user1))
user1_balance = user_balance(sql_session, user1)
assert user1_balance == 100 - 50 + 30
pprint(user_balance_log(sql_session, user2))
user2_balance = user_balance(sql_session, user2)
assert user2_balance == 50 - 30
def test_user_balance_complex_history(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.")
def test_user_balance_penalty(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
amount=-200,
),
# Penalized, pays 2x the price (default penalty)
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
assert user_balance(sql_session, user) == 27 - 200 - (27 * 2)
def test_user_balance_changing_penalty(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
amount=-200,
),
# Penalized, pays 2x the price (default penalty)
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
Transaction.adjust_penalty(
time=datetime(2023, 10, 1, 13, 0, 0),
user_id=user.id,
penalty_multiplier_percent=300,
penalty_threshold=-100,
),
# Penalized, pays 3x the price
Transaction.buy_product(
time=datetime(2023, 10, 1, 14, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
assert user_balance(sql_session, user) == 27 - 200 - (27 * 2) - (27 * 3)
def test_user_balance_interest(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
interest_rate_percent=110,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
assert user_balance(sql_session, user) == 27 - math.ceil(27 * 1.1)
def test_user_balance_changing_interest(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27 * 3,
per_product=27,
product_count=3,
),
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
interest_rate_percent=110,
),
# Pays 1.1x the price
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 13, 0, 0),
user_id=user.id,
interest_rate_percent=120,
),
# Pays 1.2x the price
Transaction.buy_product(
time=datetime(2023, 10, 1, 14, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
assert user_balance(sql_session, user) == 27 * 3 - math.ceil(27 * 1.1) - math.ceil(27 * 1.2)
def test_user_balance_penalty_interest_combined(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
interest_rate_percent=110,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
amount=-200,
),
# Penalized, pays 2x the price (default penalty)
# Pays 1.1x the price
Transaction.buy_product(
time=datetime(2023, 10, 1, 13, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
assert user_balance(sql_session, user) == (
27
- 200
- math.ceil(27 * 2 * 1.1)
)

View File

@@ -0,0 +1,60 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries.user_transactions import user_transactions
def insert_test_data(sql_session: Session) -> User:
user = User("Test User")
sql_session.add(user)
sql_session.commit()
return user
def test_user_transactions(sql_session: Session):
user = insert_test_data(sql_session)
product = Product("1234567890123", "Test Product")
user2 = User("Test User 2")
sql_session.add_all([product, user2])
sql_session.commit()
transactions = [
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0),
amount=100,
user_id=user.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 1),
amount=50,
user_id=user2.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 2),
amount=-50,
user_id=user.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 1),
product_count=1,
user_id=user2.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
assert len(user_transactions(sql_session, user)) == 3
assert len(user_transactions(sql_session, user2)) == 2