52 Commits

Author SHA1 Message Date
a265fb921c WIP: caching
Some checks failed
Run tests / run-tests (push) Failing after 3m2s
Run benchmarks / run-tests (push) Successful in 34m40s
2026-02-12 17:02:40 +09:00
aed85b4a06 uv.lock: update 2026-02-12 17:02:38 +09:00
3fc7d78c1c nix/package: add missing test deps 2026-02-12 16:55:51 +09:00
38e9066300 Add benchmarks 2026-02-12 16:55:50 +09:00
a9070fc680 Add subcommand for displaying the transaction log 2026-02-12 16:55:50 +09:00
2ac7d26bcd seed_test_data: use new queries 2026-02-12 16:55:50 +09:00
c85a11eb89 README: add overview of project structure 2026-02-12 16:55:49 +09:00
57f7d25cdf Write a set of queries to go along with the event sourcing model 2026-02-12 16:55:49 +09:00
2a05bd7a58 Update models for event sourcing 2026-02-12 16:54:08 +09:00
00afede3d9 .gitea/workflows: init test pipeline 2026-02-12 16:54:04 +09:00
19ee9bebc2 Set up testing with pytest and co 2026-02-12 16:54:03 +09:00
acb31992f8 Write specification for economy 2026-02-12 16:37:37 +09:00
fb0f24cb67 fix database verification for views 2026-02-05 02:50:11 +09:00
3d555ca9d1 menus/faq: fix indentation 2026-02-05 01:52:56 +09:00
af5710d663 verify database connection before starting 2026-02-05 01:39:40 +09:00
4d88409e97 helpers: fix search_user 2026-02-05 01:39:12 +09:00
72cd066414 example-config.toml: fix sqlite default path 2026-02-05 01:38:35 +09:00
b1bb1e556b Add --version flag to cli 2026-02-05 00:41:06 +09:00
70b04c0c45 Fix a bunch more lints 2026-02-04 22:59:18 +09:00
7bea5b0b96 Remove need for clear 2026-02-04 22:16:45 +09:00
3123b8b474 loop: disable autocommits, reset db session on looping 2026-02-04 00:38:40 +09:00
9091adedad stats: fix balance stat when missing database rows 2026-02-04 00:34:45 +09:00
94955cb706 treewide: fix a bunch more typing issues 2026-02-04 00:28:29 +09:00
3b6cd1d354 buymenu: fix warning message escapes 2026-02-04 00:24:36 +09:00
c2ee66c394 treewide: format 2026-02-04 00:01:38 +09:00
b5b2706085 helpermenus: add some more types 2026-02-04 00:01:17 +09:00
bf9cea7dfc loop: disable autoflushing, don't expire session on commit 2026-02-03 23:32:19 +09:00
cf945143ba treewide: fix a bunch of lints 2026-02-03 23:24:37 +09:00
e84b43e2a0 pyproject.toml: add ruff linting rules 2026-02-03 23:24:19 +09:00
17fc23ba97 menus/Menu: never unset sql_session 2026-02-03 23:02:11 +09:00
45179a9c43 loop: don't overload main name 2026-02-03 23:01:36 +09:00
dfaa818f46 treewide: rollback if commit was unsuccessful 2026-02-03 22:52:43 +09:00
ec43f67e58 flake.nix: fix nix run 2026-01-27 19:42:21 +09:00
1b09a904cb menus/mainmenu: register sql session in menu 2026-01-27 19:40:17 +09:00
8e84669d9b Temporarily disable brother-ql + friends, update to python 3.13 2026-01-26 13:02:34 +09:00
1d01e1b2cb package.nix: add clear to $PATH 2026-01-26 02:30:10 +09:00
019f419b12 models: a bit of back population 2026-01-25 22:54:01 +09:00
3bab62b3ac treewide: types, types and more types 2026-01-25 22:53:45 +09:00
e771fb0240 Propagate sql_session through constructors 2026-01-25 18:38:22 +09:00
2331e53795 config: structured database config 2026-01-25 18:08:50 +09:00
2ae651a1fa README: add link to wiki docs 2026-01-25 18:08:49 +09:00
76f07841be module.nix: fix lib.getExe warnings 2026-01-25 18:08:49 +09:00
ecaec99212 Replace configparser with tomllib 2026-01-25 18:08:49 +09:00
cb385097dc README: add note about vm 2026-01-11 22:36:51 +09:00
b86962ef0e flake.nix: system -> stdenv.hostPlatform.system 2026-01-09 06:14:35 +09:00
9c0bd54be6 parse config file argument as Path 2026-01-09 05:45:43 +09:00
919d7a5afe assert database_url is present 2026-01-09 05:45:42 +09:00
ddca959ad6 pyproject.toml: psycopg2 -> psycopg2-binary 2026-01-09 05:45:40 +09:00
1733843b77 pyproject.toml: set authors 2026-01-06 17:33:15 +09:00
4ed68ff05c nix: yeet skrott, massive module modifications tm, wrap package and more
Sorry for the kinda big commit that does everything at once

This change does the following:
- yeets skrott and skrot-specific settings from the NixOS module,
- adds a bunch more settings and generalizations to the NixOS module,
- adds two VM NixOS configurations for interactive testing
- wraps the nix package so that `less` is always present in `$PATH`
- yeah, that's about it

kthxbye
2026-01-06 17:01:21 +09:00
78161a96be Try to read config from /etc/dibbler/dibbler.conf 2026-01-06 16:01:34 +09:00
f4b5e1d6d4 pyproject.toml: set package version, fix nix package 2026-01-06 14:09:38 +09:00
109 changed files with 6241 additions and 2702 deletions

View File

@@ -0,0 +1,71 @@
name: Run benchmarks
on:
workflow_dispatch:
# TODO: make this only workflow_dispatch when merged into main
push:
jobs:
run-tests:
runs-on: debian-latest
steps:
- uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install dependencies
run: uv sync --locked --group test
- name: Run benchmarks
continue-on-error: true
run: |
set -euo pipefail
set -x
PYTEST_ARGS=(
-vv
--benchmark-only
-k test_benchmark
)
uv run -- pytest "${PYTEST_ARGS[@]}"
- name: Upload benchmark JSON report
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v2
with:
source: ./benchmark/*/*.json
quote-source: false
target: ${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/benchmark.json
username: gitea-web
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
host: pages.pvv.ntnu.no
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
- name: Upload histograms
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v2
with:
source: ./benchmark/*.svg
quote-source: false
target: ${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/
username: gitea-web
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
host: pages.pvv.ntnu.no
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
# NOTE: $GITHUB_STEP_SUMMARY when...
- name: Run information
run: |
echo "Benchmark run ID: ${{ github.run_id }}"
echo "Benchmark JSON: https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/benchmark.json"
echo "Histograms: https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-product_owners.svg"
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-product_price.svg"
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-product_stock.svg"
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-transaction_log.svg"
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-user_balance.svg"
- name: Check failure
if: failure()
run: |
echo "Tests failed"
exit 1

View File

@@ -16,66 +16,57 @@ jobs:
run-tests:
runs-on: debian-latest
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install dependencies
run: uv sync --locked --group test
- name: Install dependencies
run: uv sync --locked --group test
- name: Run tests
continue-on-error: true
run: |
set -euo pipefail
set -x
- name: Run tests
continue-on-error: true
run: |
set -euo pipefail
set -x
PYTEST_ARGS=(
-vv
--cov=dibbler.lib
--cov=dibbler.models
--cov=dibbler.queries
--cov-report=html
--cov-branch
--self-contained-html
--html=./test-report/index.html
)
if [ "$DEBUG_SQL" == "true" ]; then
PYTEST_ARGS+=(
--debug-sql
PYTEST_ARGS=(
-vv
)
fi
uv run -- pytest "${PYTEST_ARGS[@]}"
if [ "$DEBUG_SQL" == "true" ]; then
PYTEST_ARGS+=(
--debug-sql
)
fi
- name: Generate badge
run: uv run -- coverage-badge -o htmlcov/badge.svg
uv run -- pytest "${PYTEST_ARGS[@]}"
- name: Upload test report
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1
with:
source: ./test-report/
target: ${{ gitea.ref_name }}/test-report/
username: gitea-web
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
host: pages.pvv.ntnu.no
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
- name: Generate badge
run: uv run -- coverage-badge -o htmlcov/badge.svg
- name: Upload coverage report
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1
with:
source: ./htmlcov/
target: ${{ gitea.ref_name }}/coverage/
username: gitea-web
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
host: pages.pvv.ntnu.no
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
- name: Upload test report
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1
with:
source: ./test-report/
target: ${{ gitea.ref_name }}/test-report/
username: gitea-web
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
host: pages.pvv.ntnu.no
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
- name: Check failure
if: failure()
run: |
echo "Tests failed"
exit 1
- name: Upload coverage report
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1
with:
source: ./htmlcov/
target: ${{ gitea.ref_name }}/coverage/
username: gitea-web
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
host: pages.pvv.ntnu.no
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
- name: Check failure
if: failure()
run: |
echo "Tests failed"
exit 1

6
.gitignore vendored
View File

@@ -8,6 +8,12 @@ test.db
.ruff_cache
*.qcow2
dibbler/_version.py
.coverage
.coverage.*
htmlcov
test-report
/benchmark

View File

@@ -23,8 +23,9 @@ Installer python, og lag og aktiver et venv. Installer så avhengighetene med `p
Deretter kan du kjøre programmet med
```console
python -m dibbler -c example-config.ini create-db
python -m dibbler -c example-config.ini loop
python -m dibbler -c example-config.toml create-db
python -m dibbler -c example-config.toml seed-data
python -m dibbler -c example-config.toml loop
```
## Prosjektstruktur
@@ -61,25 +62,30 @@ Her ligger enhetstester for prosjektet. Testene bruker `pytest` som testløper.
## Nix
### Bygge nytt image
> [!NOTE]
> Vi har skrevet nix-kode for å generere en QEMU-VM med tilnærmet produksjonsoppsett.
> Det kjører ikke nødvendigvis noen VM-er i produksjon, og ihvertfall ikke denne VM-en.
> Den er hovedsakelig laget for enkel interaktiv testing, og for å teste NixOS modulen.
For å bygge et image trenger du en builder som takler å bygge for arkitekturen du skal lage et image for.
Du kan enklest komme i gang med nix-utvikling ved å kjøre test VM-en:
(Eller be til gudene om at cross compile funker)
```console
nix run .#vm
Flaket exposer en modul som autologger inn med en bruker som automatisk kjører dibbler, og setter opp et minimalistisk miljø.
# Eller hvis du trenger tilgang til terminalen i VM-en også:
nix run .#vm-non-kiosk
```
Før du bygger imaget burde du kopiere og endre `example-config.ini` lokalt til å inneholde instillingene dine. **NB: Denne kommer til å ligge i nix storen, ikke si noe her som du ikke vil at moren din skal høre.**
Du kan også bygge pakken manuelt, eller kjøre den direkte:
Du kan også endre hvilken config-fil som blir brukt direkte i pakken eller i modulen.
```console
nix build .#dibbler
Se eksempelet for hvordan skrot er satt opp i `flake.nix` og `nix/skrott.nix`
nix run .# -- --config example-config.toml create-db
nix run .# -- --config example-config.toml seed-data
nix run .# -- --config example-config.toml loop
```
### Bygge image for skrot
## Produksjonssetting
Skrot har et image definert i flake.nix:
1. endre `example-config.ini`
2. `nix build .#images.skrot`
3. ???
4. non-profit
Se https://wiki.pvv.ntnu.no/wiki/Drift/Dibbler

View File

@@ -1,6 +1,56 @@
# This module is supposed to act as a singleton and be filled
# with config variables by cli.py
import os
import sys
import tomllib
from pathlib import Path
from typing import Any
import configparser
from dibbler.lib.helpers import file_is_submissive_and_readable
config = configparser.ConfigParser()
DEFAULT_CONFIG_PATH = Path("/etc/dibbler/dibbler.toml")
config: dict[str, dict[str, Any]] = {}
def load_config(config_path: Path | None = None) -> None:
global config
if config_path is not None:
with Path(config_path).open("rb") as file:
config = tomllib.load(file)
elif file_is_submissive_and_readable(DEFAULT_CONFIG_PATH):
with DEFAULT_CONFIG_PATH.open("rb") as file:
config = tomllib.load(file)
else:
print(
"Could not read config file, it was neither provided nor readable in default location",
file=sys.stderr,
)
sys.exit(1)
def config_db_string() -> str:
db_type = config["database"]["type"]
if db_type == "sqlite":
path = Path(config["database"]["sqlite"]["path"])
return f"sqlite:///{path.absolute()}"
if db_type == "postgresql":
host = config["database"]["postgresql"]["host"]
port = config["database"]["postgresql"].get("port", 5432)
username = config["database"]["postgresql"].get("username", "dibbler")
dbname = config["database"]["postgresql"].get("dbname", "dibbler")
if "password_file" in config["database"]["postgresql"]:
with Path(config["database"]["postgresql"]["password_file"]).open("r") as f:
password = f.read().strip()
elif "password" in config["database"]["postgresql"]:
password = config["database"]["postgresql"]["password"]
else:
password = ""
if host.startswith("/"):
return f"postgresql+psycopg2://{username}:{password}@/{dbname}?host={host}"
return f"postgresql+psycopg2://{username}:{password}@{host}:{port}/{dbname}"
print(f"Error: unknown database type '{db_type}'")
exit(1)

View File

@@ -1,19 +0,0 @@
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from dibbler.conf import config
database_url: str | None = None
if (url := config.get("database", "url")) is not None:
database_url = url
elif (url_file := config.get("database", "url_file")) is not None:
with Path(url_file).open() as file:
database_url = file.read().strip()
assert database_url is not None, "Database URL must be specified in config"
engine = create_engine(database_url)
Session = sessionmaker(bind=engine)

View File

@@ -1,71 +1,71 @@
import os
# import os
from PIL import ImageFont
from barcode.writer import ImageWriter, mm2px
from brother_ql.labels import ALL_LABELS
# from PIL import ImageFont
# from barcode.writer import ImageWriter, mm2px
# from brother_ql.labels import ALL_LABELS
def px2mm(px, dpi=300):
return (25.4 * px) / dpi
# def px2mm(px, dpi=300):
# return (25.4 * px) / dpi
class BrotherLabelWriter(ImageWriter):
def __init__(self, typ="62", max_height=350, rot=False, text=None):
super(BrotherLabelWriter, self).__init__()
label = next([l for l in ALL_LABELS if l.identifier == typ])
assert label is not None
self.rot = rot
if self.rot:
self._h, self._w = label.dots_printable
if self._w == 0 or self._w > max_height:
self._w = min(max_height, self._h / 2)
else:
self._w, self._h = label.dots_printable
if self._h == 0 or self._h > max_height:
self._h = min(max_height, self._w / 2)
self._xo = 0.0
self._yo = 0.0
self._title = text
# class BrotherLabelWriter(ImageWriter):
# def __init__(self, typ="62", max_height=350, rot=False, text=None):
# super(BrotherLabelWriter, self).__init__()
# label = next([l for l in ALL_LABELS if l.identifier == typ])
# assert label is not None
# self.rot = rot
# if self.rot:
# self._h, self._w = label.dots_printable
# if self._w == 0 or self._w > max_height:
# self._w = min(max_height, self._h / 2)
# else:
# self._w, self._h = label.dots_printable
# if self._h == 0 or self._h > max_height:
# self._h = min(max_height, self._w / 2)
# self._xo = 0.0
# self._yo = 0.0
# self._title = text
def _init(self, code):
self.text = None
super(BrotherLabelWriter, self)._init(code)
# def _init(self, code):
# self.text = None
# super(BrotherLabelWriter, self)._init(code)
def calculate_size(self, modules_per_line, number_of_lines, dpi=300):
x, y = super(BrotherLabelWriter, self).calculate_size(
modules_per_line, number_of_lines, dpi
)
# def calculate_size(self, modules_per_line, number_of_lines, dpi=300):
# x, y = super(BrotherLabelWriter, self).calculate_size(
# modules_per_line, number_of_lines, dpi
# )
self._xo = (px2mm(self._w) - px2mm(x)) / 2
self._yo = px2mm(self._h) - px2mm(y)
assert self._xo >= 0
assert self._yo >= 0
# self._xo = (px2mm(self._w) - px2mm(x)) / 2
# self._yo = px2mm(self._h) - px2mm(y)
# assert self._xo >= 0
# assert self._yo >= 0
return int(self._w), int(self._h)
# return int(self._w), int(self._h)
def _paint_module(self, xpos, ypos, width, color):
super(BrotherLabelWriter, self)._paint_module(
xpos + self._xo, ypos + self._yo, width, color
)
# def _paint_module(self, xpos, ypos, width, color):
# super(BrotherLabelWriter, self)._paint_module(
# xpos + self._xo, ypos + self._yo, width, color
# )
def _paint_text(self, xpos, ypos):
super(BrotherLabelWriter, self)._paint_text(xpos + self._xo, ypos + self._yo)
# def _paint_text(self, xpos, ypos):
# super(BrotherLabelWriter, self)._paint_text(xpos + self._xo, ypos + self._yo)
def _finish(self):
if self._title:
width = self._w + 1
height = 0
max_h = self._h - mm2px(self._yo, self.dpi)
fs = int(max_h / 1.2)
font_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"Stranger back in the Night.ttf",
)
font = ImageFont.truetype(font_path, 10)
while width > self._w or height > max_h:
font = ImageFont.truetype(font_path, fs)
width, height = font.getsize(self._title)
fs -= 1
pos = ((self._w - width) // 2, 0 - (height // 8))
self._draw.text(pos, self._title, font=font, fill=self.foreground)
return self._image
# def _finish(self):
# if self._title:
# width = self._w + 1
# height = 0
# max_h = self._h - mm2px(self._yo, self.dpi)
# fs = int(max_h / 1.2)
# font_path = os.path.join(
# os.path.dirname(os.path.realpath(__file__)),
# "Stranger back in the Night.ttf",
# )
# font = ImageFont.truetype(font_path, 10)
# while width > self._w or height > max_h:
# font = ImageFont.truetype(font_path, fs)
# width, height = font.getsize(self._title)
# fs -= 1
# pos = ((self._w - width) // 2, 0 - (height // 8))
# self._draw.text(pos, self._title, font=font, fill=self.foreground)
# return self._image

View File

@@ -0,0 +1,108 @@
import sys
from pathlib import Path
from sqlalchemy import Engine, create_engine, inspect, select
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.clsregistry import _ModuleMarker
from dibbler.lib.helpers import file_is_submissive_and_readable
from dibbler.models import Base
def check_db_health(engine: Engine, verify_table_existence: bool = False) -> None:
dialect_name = getattr(engine.dialect, "name", "").lower()
if "postgres" in dialect_name:
check_postgres_ping(engine)
elif dialect_name == "sqlite":
check_sqlite_file(engine)
if verify_table_existence:
verify_tables_and_columns(engine)
def check_postgres_ping(engine: Engine) -> None:
try:
with engine.connect() as conn:
result = conn.execute(select(1))
scalar = result.scalar()
if scalar != 1 and scalar is not None:
print(
"Unexpected response from Postgres when running 'SELECT 1'",
file=sys.stderr,
)
sys.exit(1)
except (OperationalError, DBAPIError) as exc:
print(f"Failed to connect to Postgres database: {exc}", file=sys.stderr)
sys.exit(1)
def check_sqlite_file(engine: Engine) -> None:
db_path = engine.url.database
# Don't verify in-memory databases or empty paths
if db_path in (None, "", ":memory:"):
return
db_path = db_path.removeprefix("file:").removeprefix("sqlite:")
# Strip query parameters
if "?" in db_path:
db_path = db_path.split("?", 1)[0]
path = Path(db_path)
if not path.exists():
print(f"SQLite database file does not exist: {path}", file=sys.stderr)
sys.exit(1)
if not path.is_file():
print(f"SQLite database path is not a file: {path}", file=sys.stderr)
sys.exit(1)
if not file_is_submissive_and_readable(path):
print(f"SQLite database file is not submissive and readable: {path}", file=sys.stderr)
sys.exit(1)
return
def verify_tables_and_columns(engine: Engine) -> None:
iengine = inspect(engine)
errors = False
tables = iengine.get_table_names()
views = iengine.get_view_names()
tables.extend(views)
for _name, klass in Base.registry._class_registry.items():
if isinstance(klass, _ModuleMarker):
continue
table = klass.__tablename__
if table in tables:
columns = [c["name"] for c in iengine.get_columns(table)]
mapper = inspect(klass)
for column_prop in mapper.attrs:
if isinstance(column_prop, RelationshipProperty):
pass
else:
for column in column_prop.columns:
if not column.key in columns:
print(
f"Model '{klass}' declares column '{column.key}' which does not exist in database {engine}",
file=sys.stderr,
)
errors = True
else:
print(
f"Model '{klass}' declares table '{table}' which does not exist in database {engine}",
file=sys.stderr,
)
errors = True
if errors:
print("Have you remembered to run `dibbler create-db?", file=sys.stderr)
sys.exit(1)

View File

@@ -2,9 +2,12 @@ import os
import pwd
import signal
import subprocess
from collections.abc import Callable
from pathlib import Path
from typing import Any, Literal
def system_user_exists(username):
def system_user_exists(username: str) -> bool:
try:
pwd.getpwnam(username)
except KeyError:
@@ -15,7 +18,7 @@ def system_user_exists(username):
return True
def guess_data_type(string):
def guess_data_type(string: str) -> Literal["card", "rfid", "bar_code", "username"] | None:
if string.startswith("ntnu") and string[4:].isdigit():
return "card"
if string.isdigit() and len(string) == 10:
@@ -29,7 +32,11 @@ def guess_data_type(string):
return None
def argmax(d, all=False, value=None):
def argmax(
d: dict[Any, Any],
all_: bool = False,
value: Callable[[Any], Any] | None = None,
) -> Any | list[Any] | None:
maxarg = None
if value is not None:
dd = d
@@ -39,12 +46,12 @@ def argmax(d, all=False, value=None):
for key in list(d.keys()):
if maxarg is None or d[key] > d[maxarg]:
maxarg = key
if all:
if all_:
return [k for k in list(d.keys()) if d[k] == d[maxarg]]
return maxarg
def less(string):
def less(string: str) -> None:
"""
Run less with string as input; wait until it finishes.
"""
@@ -56,3 +63,13 @@ def less(string):
proc = subprocess.Popen("less", env=env, encoding="utf-8", stdin=subprocess.PIPE)
proc.communicate(string)
signal.signal(signal.SIGINT, int_handler)
def file_is_submissive_and_readable(file: Path) -> bool:
return file.is_file() and any(
[
file.stat().st_mode & 0o400 and file.stat().st_uid == os.getuid(),
file.stat().st_mode & 0o040 and file.stat().st_gid == os.getgid(),
file.stat().st_mode & 0o004,
],
)

View File

@@ -1,98 +1,95 @@
import os
import datetime
# import barcode
# from brother_ql.brother_ql_create import create_label
# from brother_ql.raster import BrotherQLRaster
# from brother_ql.backends import backend_factory
# from brother_ql.labels import ALL_LABELS
# from PIL import Image, ImageDraw, ImageFont
import barcode
from brother_ql.brother_ql_create import create_label
from brother_ql.raster import BrotherQLRaster
from brother_ql.backends import backend_factory
from brother_ql.labels import ALL_LABELS
from PIL import Image, ImageDraw, ImageFont
from .barcode_helpers import BrotherLabelWriter
# from .barcode_helpers import BrotherLabelWriter
def print_name_label(
text,
margin=10,
rotate=False,
label_type="62",
printer_type="QL-700",
):
label = next([l for l in ALL_LABELS if l.identifier == label_type])
if not rotate:
width, height = label.dots_printable
else:
height, width = label.dots_printable
# def print_name_label(
# text,
# margin=10,
# rotate=False,
# label_type="62",
# printer_type="QL-700",
# ):
# label = next([l for l in ALL_LABELS if l.identifier == label_type])
# if not rotate:
# width, height = label.dots_printable
# else:
# height, width = label.dots_printable
font_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ChopinScript.ttf")
fs = 2000
tw, th = width, height
if width == 0:
while th + 2 * margin > height:
font = ImageFont.truetype(font_path, fs)
tw, th = font.getsize(text)
fs -= 1
width = tw + 2 * margin
elif height == 0:
while tw + 2 * margin > width:
font = ImageFont.truetype(font_path, fs)
tw, th = font.getsize(text)
fs -= 1
height = th + 2 * margin
else:
while tw + 2 * margin > width or th + 2 * margin > height:
font = ImageFont.truetype(font_path, fs)
tw, th = font.getsize(text)
fs -= 1
# font_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ChopinScript.ttf")
# fs = 2000
# tw, th = width, height
# if width == 0:
# while th + 2 * margin > height:
# font = ImageFont.truetype(font_path, fs)
# tw, th = font.getsize(text)
# fs -= 1
# width = tw + 2 * margin
# elif height == 0:
# while tw + 2 * margin > width:
# font = ImageFont.truetype(font_path, fs)
# tw, th = font.getsize(text)
# fs -= 1
# height = th + 2 * margin
# else:
# while tw + 2 * margin > width or th + 2 * margin > height:
# font = ImageFont.truetype(font_path, fs)
# tw, th = font.getsize(text)
# fs -= 1
xp = (width // 2) - (tw // 2)
yp = (height // 2) - (th // 2)
# xp = (width // 2) - (tw // 2)
# yp = (height // 2) - (th // 2)
im = Image.new("RGB", (width, height), (255, 255, 255))
dr = ImageDraw.Draw(im)
# im = Image.new("RGB", (width, height), (255, 255, 255))
# dr = ImageDraw.Draw(im)
dr.text((xp, yp), text, fill=(0, 0, 0), font=font)
now = datetime.datetime.now()
date = now.strftime("%Y-%m-%d")
dr.text((0, 0), date, fill=(0, 0, 0))
# dr.text((xp, yp), text, fill=(0, 0, 0), font=font)
# now = datetime.datetime.now()
# date = now.strftime("%Y-%m-%d")
# dr.text((0, 0), date, fill=(0, 0, 0))
base_path = os.path.dirname(os.path.realpath(__file__))
fn = os.path.join(base_path, "bar_codes", text + ".png")
# base_path = os.path.dirname(os.path.realpath(__file__))
# fn = os.path.join(base_path, "bar_codes", text + ".png")
im.save(fn, "PNG")
print_image(fn, printer_type, label_type)
# im.save(fn, "PNG")
# print_image(fn, printer_type, label_type)
def print_bar_code(
barcode_value,
barcode_text,
barcode_type="ean13",
rotate=False,
printer_type="QL-700",
label_type="62",
):
bar_coder = barcode.get_barcode_class(barcode_type)
wr = BrotherLabelWriter(typ=label_type, rot=rotate, text=barcode_text, max_height=1000)
# def print_bar_code(
# barcode_value,
# barcode_text,
# barcode_type="ean13",
# rotate=False,
# printer_type="QL-700",
# label_type="62",
# ):
# bar_coder = barcode.get_barcode_class(barcode_type)
# wr = BrotherLabelWriter(typ=label_type, rot=rotate, text=barcode_text, max_height=1000)
test = bar_coder(barcode_value, writer=wr)
base_path = os.path.dirname(os.path.realpath(__file__))
fn = test.save(os.path.join(base_path, "bar_codes", barcode_value))
print_image(fn, printer_type, label_type)
# test = bar_coder(barcode_value, writer=wr)
# base_path = os.path.dirname(os.path.realpath(__file__))
# fn = test.save(os.path.join(base_path, "bar_codes", barcode_value))
# print_image(fn, printer_type, label_type)
def print_image(fn, printer_type="QL-700", label_type="62"):
qlr = BrotherQLRaster(printer_type)
qlr.exception_on_warning = True
create_label(qlr, fn, label_type, threshold=70, cut=True)
# def print_image(fn, printer_type="QL-700", label_type="62"):
# qlr = BrotherQLRaster(printer_type)
# qlr.exception_on_warning = True
# create_label(qlr, fn, label_type, threshold=70, cut=True)
be = backend_factory("pyusb")
list_available_devices = be["list_available_devices"]
BrotherQLBackend = be["backend_class"]
# be = backend_factory("pyusb")
# list_available_devices = be["list_available_devices"]
# BrotherQLBackend = be["backend_class"]
ad = list_available_devices()
assert ad
string_descr = ad[0]["string_descr"]
# ad = list_available_devices()
# assert ad
# string_descr = ad[0]["string_descr"]
printer = BrotherQLBackend(string_descr)
# printer = BrotherQLBackend(string_descr)
printer.write(qlr.data)
# printer.write(qlr.data)

View File

@@ -1,22 +0,0 @@
from typing import TypeVar
from sqlalchemy import BindParameter, literal
T = TypeVar("T")
def const(value: T) -> BindParameter[T]:
"""
Create a constant SQL literal bind parameter.
This is useful to avoid too many `?` bind parameters in SQL queries,
when the input value is known to be safe.
"""
return literal(value, literal_execute=True)
CONST_ZERO: BindParameter[int] = const(0)
CONST_ONE: BindParameter[int] = const(1)
CONST_TRUE: BindParameter[bool] = const(True)
CONST_FALSE: BindParameter[bool] = const(False)
CONST_NONE: BindParameter[None] = const(None)

View File

@@ -1,3 +1,4 @@
from dibbler.lib.render_tree import render_tree
from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import EXPECTED_FIELDS
@@ -10,23 +11,19 @@ def render_transaction_log(transaction_log: list[Transaction]) -> str:
aggregated_log = _aggregate_joint_transactions(transaction_log)
lines = []
for i, transaction in enumerate(aggregated_log):
for transaction in aggregated_log:
if isinstance(transaction, list):
inner_lines = []
is_last = i == len(aggregated_log) - 1
lines.append(_render_transaction(transaction[0], is_last))
for j, sub_transaction in enumerate(transaction[1:]):
is_last_inner = j == len(transaction) - 2
line = _render_transaction(sub_transaction, is_last_inner)
lines.append(_render_transaction(transaction[0]))
for sub_transaction in transaction[1:]:
line = _render_transaction(sub_transaction)
inner_lines.append(line)
indented_inner_lines = _indent_lines(inner_lines, is_last=is_last)
lines.extend(indented_inner_lines)
lines.append(inner_lines)
else:
is_last = i == len(aggregated_log) - 1
line = _render_transaction(transaction, is_last)
line = _render_transaction(transaction)
lines.append(line)
return "\n".join(lines)
return render_tree(lines)
def _aggregate_joint_transactions(
@@ -61,17 +58,7 @@ def _aggregate_joint_transactions(
return aggregated
def _indent_lines(lines: list[str], is_last: bool = False) -> list[str]:
indented_lines = []
for line in lines:
if is_last:
indented_lines.append(" " + line)
else:
indented_lines.append("│ " + line)
return indented_lines
def _render_transaction(transaction: Transaction, is_last: bool) -> str:
def _render_transaction(transaction: Transaction) -> str:
match transaction.type_:
case TransactionType.ADD_PRODUCT:
line = f"ADD_PRODUCT({transaction.id}, {transaction.user.name}"
@@ -125,5 +112,4 @@ def _render_transaction(transaction: Transaction, is_last: bool) -> str:
line = (
f"UNKNOWN[{transaction.type_}](id={transaction.id}, user_id={transaction.user_id})"
)
return "└─ " + line if is_last else "├─ " + line
return line

115
dibbler/lib/render_tree.py Normal file
View File

@@ -0,0 +1,115 @@
_TREE_CHARS = {
"normal": {
"vertical": "│ ",
"branch": "├─ ",
"last": "└─ ",
"empty": " ",
},
"ascii": {
"vertical": "| ",
"branch": "|-- ",
"last": "`-- ",
"empty": " ",
},
}
assert set(_TREE_CHARS["normal"].keys()) == set(_TREE_CHARS["ascii"].keys())
assert all(len(v) == 3 for v in _TREE_CHARS["normal"].values())
assert all(len(v) == 4 for v in _TREE_CHARS["ascii"].values())
def render_tree(
tree: list[str | list],
ascii_only: bool = False,
) -> str:
"""
Render a tree structure as a string.
Each item in the `tree` list can be either a string (a leaf node)
or another list (a subtree).
When `ascii_only` is `True`, only ASCII characters are used for drawing the tree.
Example:
```python
tree = [
"root",
[
"child1",
[
"grandchild1",
"grandchild2",
],
"child2",
],
"root2",
]
print(render_tree(tree, ascii_only=False))
```
Output:
```
├─ root
│ ├─ child1
│ │ ├─ grandchild1
│ │ └─ grandchild2
│ └─ child2
└─ root2
```
Example with ASCII only:
```python
print(render_tree(tree, ascii_only=True))
```
Output:
```
|-- root
| |-- child1
| | |-- grandchild1
| | `-- grandchild2
| `-- child2
`-- root2
```
"""
result: list[str] = []
for index, item in enumerate(tree):
is_last = index == len(tree) - 1
item_lines = _render_tree_line(item, is_last, ascii_only)
result.extend(item_lines)
return "\n".join(result)
def _render_tree_line(
item: str | list,
is_last: bool,
ascii_only: bool,
prefix: str = "",
) -> list[str]:
chars = _TREE_CHARS["ascii"] if ascii_only else _TREE_CHARS["normal"]
lines: list[str] = []
if isinstance(item, str):
line_prefix = chars["last"] if is_last else chars["branch"]
item_lines = item.splitlines()
for line_index, line in enumerate(item_lines):
if line_index == 0:
lines.append(f"{prefix}{line_prefix}{line}")
else:
lines.append(f"{prefix}{chars['vertical']}{line}")
elif isinstance(item, list):
new_prefix = prefix + (chars["empty"] if is_last else chars["vertical"])
for sub_index, sub_item in enumerate(item):
sub_is_last = sub_index == len(item) - 1
sub_lines = _render_tree_line(sub_item, sub_is_last, ascii_only, new_prefix)
lines.extend(sub_lines)
else:
raise ValueError("Item must be either a string or a list.")
return lines

View File

@@ -3,18 +3,20 @@
import datetime
from collections import defaultdict
from pathlib import Path
from sqlalchemy.orm import Session
from .helpers import *
from ..models import Transaction
from ..db import Session
from .helpers import *
def getUser():
def getUser(sql_session: Session) -> str:
assert sql_session is not None
while 1:
string = input("user? ")
session = Session()
user = search_user(string, session)
session.close()
user = search_user(string, sql_session)
sql_session.close()
if not isinstance(user, list):
return user.name
i = 0
@@ -37,12 +39,11 @@ def getUser():
return user[n].name
def getProduct():
def getProduct(sql_session: Session) -> str:
assert sql_session is not None
while 1:
string = input("product? ")
session = Session()
product = search_product(string, session)
session.close()
product = search_product(string, sql_session)
if not isinstance(product, list):
return product.name
i = 0
@@ -89,7 +90,7 @@ class Database:
class InputLine:
def __init__(self, u, p, t):
def __init__(self, u, p, t) -> None:
self.inputUser = u
self.inputProduct = p
self.inputType = t
@@ -122,17 +123,17 @@ def getInputType():
return int(inp)
def getProducts(products):
def getProducts(products: str) -> list[tuple[str]]:
product = []
products = products.partition("¤")
split_products = products.partition("¤")
product.append(products[0])
while products[1] == "¤":
products = products[2].partition("¤")
split_products = split_products[2].partition("¤")
product.append(products[0])
return product
def getDateFile(date, inp):
def getDateFile(date: str, inp: str) -> datetime.date:
try:
year = inp.partition("-")
month = year[2].partition("-")
@@ -176,7 +177,7 @@ def addLineToDatabase(database, inputLine):
if abs(inputLine.price) > 90000:
return database
# fyller inn for varer
if (not inputLine.product == "") and (
if (inputLine.product != "") and (
(inputLine.inputProduct == "") or (inputLine.inputProduct == inputLine.product)
):
database.varePersonAntall[inputLine.product][inputLine.user] = (
@@ -190,7 +191,7 @@ def addLineToDatabase(database, inputLine):
database.vareUkedagAntall[inputLine.product][inputLine.weekday] += 1
# fyller inn for personer
if (inputLine.inputUser == "") or (inputLine.inputUser == inputLine.user):
if not inputLine.product == "":
if inputLine.product != "":
database.personVareAntall[inputLine.user][inputLine.product] = (
database.personVareAntall[inputLine.user].setdefault(inputLine.product, 0) + 1
)
@@ -214,7 +215,7 @@ def addLineToDatabase(database, inputLine):
database.personNegTransactions[inputLine.user] = (
database.personNegTransactions.setdefault(inputLine.user, 0) + inputLine.price
)
elif not (inputLine.inputType == 1):
elif inputLine.inputType != 1:
database.globalVareAntall[inputLine.product] = (
database.globalVareAntall.setdefault(inputLine.product, 0) + 1
)
@@ -225,7 +226,7 @@ def addLineToDatabase(database, inputLine):
# fyller inn for global statistikk
if (inputLine.inputType == 3) or (inputLine.inputType == 4):
database.pengebeholdning[inputLine.dateNum] += inputLine.price
if not (inputLine.product == ""):
if inputLine.product != "":
database.globalPersonAntall[inputLine.user] = (
database.globalPersonAntall.setdefault(inputLine.user, 0) + 1
)
@@ -238,12 +239,12 @@ def addLineToDatabase(database, inputLine):
return database
def buildDatabaseFromDb(inputType, inputProduct, inputUser):
def buildDatabaseFromDb(inputType, inputProduct, inputUser, sql_session: Session):
assert sql_session is not None
sdate = input("enter start date (yyyy-mm-dd)? ")
edate = input("enter end date (yyyy-mm-dd)? ")
print("building database...")
session = Session()
transaction_list = session.query(Transaction).all()
transaction_list = sql_session.query(Transaction).all()
inputLine = InputLine(inputUser, inputProduct, inputType)
startDate = getDateDb(transaction_list[0].time, sdate)
endDate = getDateDb(transaction_list[-1].time, edate)
@@ -273,9 +274,9 @@ def buildDatabaseFromDb(inputType, inputProduct, inputUser):
inputLine.price = 0
print("saving as default.dibblerlog...", end=" ")
f = open("default.dibblerlog", "w")
f = Path.open("default.dibblerlog", "w")
line_format = "%s|%s|%s|%s|%s|%s\n"
transaction_list = session.query(Transaction).all()
transaction_list = sql_session.query(Transaction).all()
for transaction in transaction_list:
if transaction.purchase:
products = "¤".join([ent.product.name for ent in transaction.purchase.entries])
@@ -290,8 +291,7 @@ def buildDatabaseFromDb(inputType, inputProduct, inputUser):
transaction.description,
)
f.write(line.encode("utf8"))
session.close()
f.close
f.close()
# bygg database.pengebeholdning
if (inputType == 3) or (inputType == 4):
for i in range(inputLine.numberOfDays + 1):
@@ -311,7 +311,7 @@ def buildDatabaseFromFile(inputFile, inputType, inputProduct, inputUser):
sdate = input("enter start date (yyyy-mm-dd)? ")
edate = input("enter end date (yyyy-mm-dd)? ")
f = open(inputFile)
f = Path.open(inputFile)
try:
fileLines = f.readlines()
finally:
@@ -329,7 +329,7 @@ def buildDatabaseFromFile(inputFile, inputType, inputProduct, inputUser):
database.globalUkedagForbruk = [0] * 7
database.pengebeholdning = [0] * (inputLine.numberOfDays + 1)
for linje in fileLines:
if not (linje[0] == "#") and not (linje == "\n"):
if linje[0] != "#" and linje != "\n":
# henter dateNum, products, user, price
restDel = linje.partition("|")
restDel = restDel[2].partition(" ")
@@ -359,7 +359,7 @@ def buildDatabaseFromFile(inputFile, inputType, inputProduct, inputUser):
return database, dateLine
def printTopDict(dictionary, n, k):
def printTopDict(dictionary: dict[str, Any], n: int, k: bool) -> None:
i = 0
for key in sorted(dictionary, key=dictionary.get, reverse=k):
print(key, ": ", dictionary[key])
@@ -369,7 +369,7 @@ def printTopDict(dictionary, n, k):
break
def printTopDict2(dictionary, dictionary2, n):
def printTopDict2(dictionary, dictionary2, n) -> None:
print("")
print("product : price[kr] ( number )")
i = 0
@@ -381,7 +381,7 @@ def printTopDict2(dictionary, dictionary2, n):
break
def printWeekdays(week, days):
def printWeekdays(week, days) -> None:
if week == [] or days == 0:
return
print(
@@ -404,10 +404,10 @@ def printWeekdays(week, days):
print("")
def printBalance(database, user):
def printBalance(database, user) -> None:
forbruk = 0
if user in database.personVareVerdi:
forbruk = sum([i for i in list(database.personVareVerdi[user].values())])
forbruk = sum(database.personVareVerdi[user].values())
print("totalt kjøpt for: ", forbruk, end=" ")
if user in database.personNegTransactions:
print("kr, totalt lagt til: ", -database.personNegTransactions[user], end=" ")
@@ -419,14 +419,14 @@ def printBalance(database, user):
print("")
def printUser(database, dateLine, user, n):
def printUser(database, dateLine, user, n) -> None:
printTopDict2(database.personVareVerdi[user], database.personVareAntall[user], n)
print("\nforbruk per ukedag [kr/dag],", end=" ")
printWeekdays(database.personUkedagVerdi[user], len(dateLine))
printBalance(database, user)
def printProduct(database, dateLine, product, n):
def printProduct(database, dateLine, product, n) -> None:
printTopDict(database.varePersonAntall[product], n, 1)
print("\nforbruk per ukedag [antall/dag],", end=" ")
printWeekdays(database.vareUkedagAntall[product], len(dateLine))
@@ -440,7 +440,7 @@ def printProduct(database, dateLine, product, n):
)
def printGlobal(database, dateLine, n):
def printGlobal(database, dateLine, n) -> None:
print("\nmest lagt til: ")
printTopDict(database.personNegTransactions, n, 0)
print("\nmest tatt fra:")
@@ -454,9 +454,9 @@ def printGlobal(database, dateLine, n):
"Det er solgt varer til en verdi av: ",
sum(database.globalDatoForbruk),
"kr, det er lagt til",
-sum([i for i in list(database.personNegTransactions.values())]),
-sum(database.personNegTransactions.values()),
"og tatt fra",
sum([i for i in list(database.personPosTransactions.values())]),
sum(database.personPosTransactions.values()),
end=" ",
)
print(
@@ -466,23 +466,24 @@ def printGlobal(database, dateLine, n):
)
def alt4menuTextOnly(database, dateLine):
def alt4menuTextOnly(database, dateLine, sql_session: Session) -> None:
assert sql_session is not None
n = 10
while 1:
print(
"\n1: user-statistics, 2: product-statistics, 3:global-statistics, n: adjust amount of data shown q:quit"
"\n1: user-statistics, 2: product-statistics, 3:global-statistics, n: adjust amount of data shown q:quit",
)
inp = input("")
if inp == "q":
break
elif inp == "1":
if inp == "1":
try:
printUser(database, dateLine, getUser(), n)
printUser(database, dateLine, getUser(sql_session), n)
except:
print("\n\nSomething is not right, (last date prior to first date?)")
elif inp == "2":
try:
printProduct(database, dateLine, getProduct(), n)
printProduct(database, dateLine, getProduct(sql_session), n)
except:
print("\n\nSomething is not right, (last date prior to first date?)")
elif inp == "3":
@@ -494,15 +495,16 @@ def alt4menuTextOnly(database, dateLine):
n = int(input("set number to show "))
def statisticsTextOnly():
def statisticsTextOnly(sql_session: Session) -> None:
assert sql_session is not None
inputType = 4
product = ""
user = ""
print("\n0: from file, 1: from database, q:quit")
inp = input("")
if inp == "1":
database, dateLine = buildDatabaseFromDb(inputType, product, user)
database, dateLine = buildDatabaseFromDb(inputType, product, user, sql_session)
elif inp == "0" or inp == "":
database, dateLine = buildDatabaseFromFile("default.dibblerlog", inputType, product, user)
if not inp == "q":
alt4menuTextOnly(database, dateLine)
if inp != "q":
alt4menuTextOnly(database, dateLine, sql_session)

View File

@@ -1,7 +1,12 @@
import argparse
import sys
from pathlib import Path
from dibbler.conf import config
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from dibbler.conf import config_db_string, load_config
from dibbler.lib.check_db_health import check_db_health
parser = argparse.ArgumentParser()
@@ -11,13 +16,20 @@ parser.add_argument(
help="Path to the config file",
type=Path,
metavar="FILE",
default="config.ini",
required=False,
)
parser.add_argument(
"-V",
"--version",
help="Show program version",
action="store_true",
default=False,
)
subparsers = parser.add_subparsers(
title="subcommands",
dest="subcommand",
required=True,
)
subparsers.add_parser("loop", help="Run the dibbler loop")
subparsers.add_parser("create-db", help="Create the database")
@@ -26,29 +38,55 @@ subparsers.add_parser("seed-data", help="Fill with mock data")
subparsers.add_parser("transaction-log", help="Print transaction log")
def main():
def main() -> None:
args = parser.parse_args()
config.read(args.config)
if args.version:
from ._version import commit_id, version
print(f"Dibbler version {version}, commit {commit_id if commit_id else '<unknown>'}")
return
if not args.subcommand:
parser.print_help()
sys.exit(1)
load_config(args.config)
engine = create_engine(config_db_string())
sql_session = Session(
engine,
expire_on_commit=False,
autocommit=False,
autoflush=False,
close_resets_only=True,
)
check_db_health(
engine,
verify_table_existence=args.subcommand != "create-db",
)
if args.subcommand == "loop":
import dibbler.subcommands.loop as loop
loop.main()
loop.main(sql_session)
elif args.subcommand == "create-db":
import dibbler.subcommands.makedb as makedb
makedb.main()
makedb.main(engine)
elif args.subcommand == "slabbedasker":
import dibbler.subcommands.slabbedasker as slabbedasker
slabbedasker.main()
slabbedasker.main(sql_session)
elif args.subcommand == "seed-data":
import dibbler.subcommands.seed_test_data as seed_test_data
seed_test_data.main()
seed_test_data.main(sql_session)
elif args.subcommand == "transaction-log":
import dibbler.subcommands.transaction_log as transaction_log

View File

@@ -26,28 +26,28 @@ __all__ = [
from .addstock import AddStockMenu
from .buymenu import BuyMenu
from .editing import (
AddUserMenu,
EditUserMenu,
AddProductMenu,
EditProductMenu,
AddUserMenu,
AdjustStockMenu,
CleanupStockMenu,
EditProductMenu,
EditUserMenu,
)
from .faq import FAQMenu
from .helpermenus import Menu
from .mainmenu import MainMenu
from .miscmenus import (
ProductSearchMenu,
TransferMenu,
AdjustCreditMenu,
UserListMenu,
ShowUserMenu,
ProductListMenu,
ProductSearchMenu,
ShowUserMenu,
TransferMenu,
UserListMenu,
)
from .printermenu import PrintLabelMenu
from .stats import (
ProductPopularityMenu,
ProductRevenueMenu,
BalanceMenu,
LoggedStatisticsMenu,
ProductPopularityMenu,
ProductRevenueMenu,
)

View File

@@ -1,6 +1,7 @@
from math import ceil
import sqlalchemy
from sqlalchemy.orm import Session
from dibbler.models import (
Product,
@@ -9,12 +10,13 @@ from dibbler.models import (
Transaction,
User,
)
from .helpermenus import Menu
class AddStockMenu(Menu):
def __init__(self):
Menu.__init__(self, "Add stock and adjust credit", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Add stock and adjust credit", sql_session)
self.help_text = """
Enter what you have bought for PVVVV here, along with your user name and how
much money you're due in credits for the purchase when prompted.\n"""
@@ -23,7 +25,7 @@ much money you're due in credits for the purchase when prompted.\n"""
self.products = {}
self.price = 0
def _execute(self):
def _execute(self, **_kwargs) -> bool | None:
questions = {
(
False,
@@ -86,10 +88,10 @@ much money you're due in credits for the purchase when prompted.\n"""
self.perform_transaction()
def complete_input(self):
return bool(self.users) and len(self.products) and self.price
def complete_input(self) -> bool:
return self.users is not None and len(self.products) > 0 and self.price > 0
def print_info(self):
def print_info(self) -> None:
width = 6 + Product.name_length
print()
print(width * "-")
@@ -109,7 +111,12 @@ much money you're due in credits for the purchase when prompted.\n"""
print(f"{self.products[product][0]}".rjust(width - len(product.name)))
print(width * "-")
def add_thing_to_pending(self, thing, amount, price):
def add_thing_to_pending(
self,
thing: User | Product,
amount: int,
price: int,
) -> None:
if isinstance(thing, User):
self.users.append(thing)
elif thing in list(self.products.keys()):
@@ -119,7 +126,7 @@ much money you're due in credits for the purchase when prompted.\n"""
else:
self.products[thing] = [amount, price]
def perform_transaction(self):
def perform_transaction(self) -> None:
print("Did you pay a different price?")
if self.confirm(">", default=False):
self.price = self.input_int("How much did you pay?", 0, self.price, default=self.price)
@@ -132,10 +139,11 @@ much money you're due in credits for the purchase when prompted.\n"""
old_price = product.price
old_hidden = product.hidden
product.price = int(
ceil(float(value) / (max(product.stock, 0) + self.products[product][0]))
ceil(float(value) / (max(product.stock, 0) + self.products[product][0])),
)
product.stock = max(
self.products[product][0], product.stock + self.products[product][0]
self.products[product][0],
product.stock + self.products[product][0],
)
product.hidden = False
print(
@@ -151,13 +159,14 @@ much money you're due in credits for the purchase when prompted.\n"""
PurchaseEntry(purchase, product, -self.products[product][0])
purchase.perform_soft_purchase(-self.price, round_up=False)
self.session.add(purchase)
self.sql_session.add(purchase)
try:
self.session.commit()
self.sql_session.commit()
print("Success! Transaction performed:")
# self.print_info()
for user in self.users:
print(f"User {user.name}'s credit is now {user.credit:d}")
except sqlalchemy.exc.SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not perform transaction: {e}")

View File

@@ -1,4 +1,8 @@
from typing import Any
import sqlalchemy
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from dibbler.conf import config
from dibbler.models import (
@@ -13,10 +17,11 @@ from .helpermenus import Menu
class BuyMenu(Menu):
def __init__(self, session=None):
Menu.__init__(self, "Buy", uses_db=True)
if session:
self.session = session
superfast_mode: bool
purchase: Purchase
def __init__(self, sql_session: Session) -> None:
super().__init__("Buy", sql_session)
self.superfast_mode = False
self.help_text = """
Each purchase may contain one or more products and one or more buyers.
@@ -28,7 +33,7 @@ addition, and you can type 'what' at any time to redisplay it.
When finished, write an empty line to confirm the purchase.\n"""
@staticmethod
def credit_check(user):
def credit_check(user: User) -> bool:
"""
:param user:
@@ -37,28 +42,32 @@ When finished, write an empty line to confirm the purchase.\n"""
"""
assert isinstance(user, User)
return user.credit > config.getint("limits", "low_credit_warning_limit")
return user.credit > config["limits"]["low_credit_warning_limit"]
def low_credit_warning(self, user, timeout=False):
def low_credit_warning(
self,
user: User,
timeout: bool = False,
) -> bool:
assert isinstance(user, User)
print("***********************************************************************")
print("***********************************************************************")
print("")
print("$$\ $$\ $$$$$$\ $$$$$$$\ $$\ $$\ $$$$$$\ $$\ $$\ $$$$$$\\")
print("$$ | $\ $$ |$$ __$$\ $$ __$$\ $$$\ $$ |\_$$ _|$$$\ $$ |$$ __$$\\")
print("$$ |$$$\ $$ |$$ / $$ |$$ | $$ |$$$$\ $$ | $$ | $$$$\ $$ |$$ / \__|")
print("$$ $$ $$\$$ |$$$$$$$$ |$$$$$$$ |$$ $$\$$ | $$ | $$ $$\$$ |$$ |$$$$\\")
print("$$$$ _$$$$ |$$ __$$ |$$ __$$< $$ \$$$$ | $$ | $$ \$$$$ |$$ |\_$$ |")
print("$$$ / \$$$ |$$ | $$ |$$ | $$ |$$ |\$$$ | $$ | $$ |\$$$ |$$ | $$ |")
print("$$ / \$$ |$$ | $$ |$$ | $$ |$$ | \$$ |$$$$$$\ $$ | \$$ |\$$$$$$ |")
print("\__/ \__|\__| \__|\__| \__|\__| \__|\______|\__| \__| \______/")
print("")
print("***********************************************************************")
print("***********************************************************************")
print("")
print(r"***********************************************************************")
print(r"***********************************************************************")
print(r"")
print(r"$$\ $$\ $$$$$$\ $$$$$$$\ $$\ $$\ $$$$$$\ $$\ $$\ $$$$$$\\")
print(r"$$ | $\ $$ |$$ __$$\ $$ __$$\ $$$\ $$ |\_$$ _|$$$\ $$ |$$ __$$\\")
print(r"$$ |$$$\ $$ |$$ / $$ |$$ | $$ |$$$$\ $$ | $$ | $$$$\ $$ |$$ / \__|")
print(r"$$ $$ $$\$$ |$$$$$$$$ |$$$$$$$ |$$ $$\$$ | $$ | $$ $$\$$ |$$ |$$$$\\")
print(r"$$$$ _$$$$ |$$ __$$ |$$ __$$< $$ \$$$$ | $$ | $$ \$$$$ |$$ |\_$$ |")
print(r"$$$ / \$$$ |$$ | $$ |$$ | $$ |$$ |\$$$ | $$ | $$ |\$$$ |$$ | $$ |")
print(r"$$ / \$$ |$$ | $$ |$$ | $$ |$$ | \$$ |$$$$$$\ $$ | \$$ |\$$$$$$ |")
print(r"\__/ \__|\__| \__|\__| \__|\__| \__|\______|\__| \__| \______/")
print(r"")
print(r"***********************************************************************")
print(r"***********************************************************************")
print(r"")
print(
f"USER {user.name} HAS LOWER CREDIT THAN {config.getint('limits', 'low_credit_warning_limit'):d}."
f"USER {user.name} HAS LOWER CREDIT THAN {config['limits']['low_credit_warning_limit']:d}.",
)
print("THIS PURCHASE WILL CHARGE YOUR CREDIT TWICE AS MUCH.")
print("CONSIDER PUTTING MONEY IN THE BOX TO AVOID THIS.")
@@ -68,10 +77,13 @@ When finished, write an empty line to confirm the purchase.\n"""
if timeout:
print("THIS PURCHASE WILL AUTOMATICALLY BE PERFORMED IN 3 MINUTES!")
return self.confirm(prompt=">", default=True, timeout=180)
else:
return self.confirm(prompt=">", default=True)
return self.confirm(prompt=">", default=True)
def add_thing_to_purchase(self, thing, amount=1):
def add_thing_to_purchase(
self,
thing: User | Product,
amount: int = 1,
) -> bool:
if isinstance(thing, User):
if thing.is_anonymous():
print("---------------------------------------------")
@@ -80,7 +92,10 @@ When finished, write an empty line to confirm the purchase.\n"""
print("---------------------------------------------")
if not self.credit_check(thing):
if self.low_credit_warning(user=thing, timeout=self.superfast_mode):
if self.low_credit_warning(
user=thing,
timeout=self.superfast_mode,
):
Transaction(thing, purchase=self.purchase, penalty=2)
else:
return False
@@ -95,7 +110,11 @@ When finished, write an empty line to confirm the purchase.\n"""
PurchaseEntry(self.purchase, thing, amount)
return True
def _execute(self, initial_contents=None):
def _execute(
self,
initial_contents: list[tuple[User | Product, int]] | None = None,
**_kwargs,
) -> bool:
self.print_header()
self.purchase = Purchase()
self.exit_confirm_msg = None
@@ -107,7 +126,7 @@ When finished, write an empty line to confirm the purchase.\n"""
for thing, num in initial_contents:
self.add_thing_to_purchase(thing, num)
def is_product(candidate):
def is_product(candidate: Any) -> bool:
return isinstance(candidate[0], Product)
if len(initial_contents) > 0 and all(map(is_product, initial_contents)):
@@ -129,7 +148,7 @@ When finished, write an empty line to confirm the purchase.\n"""
True,
True,
): "Enter more products or users, or an empty line to confirm",
}[(len(self.purchase.transactions) > 0, len(self.purchase.entries) > 0)]
}[(len(self.purchase.transactions) > 0, len(self.purchase.entries) > 0)],
)
# Read in a 'thing' (product or user):
@@ -147,16 +166,16 @@ When finished, write an empty line to confirm the purchase.\n"""
if thing is None:
if not self.complete_input():
if self.confirm(
"Not enough information entered. Abort purchase?", default=True
"Not enough information entered. Abort purchase?",
default=True,
):
return False
continue
break
else:
# once we get something in the
# purchase, we want to protect the
# user from accidentally killing it
self.exit_confirm_msg = "Abort purchase?"
# once we get something in the
# purchase, we want to protect the
# user from accidentally killing it
self.exit_confirm_msg = "Abort purchase?"
# Add the thing to our purchase object:
if not self.add_thing_to_purchase(thing, amount=num):
@@ -167,10 +186,11 @@ When finished, write an empty line to confirm the purchase.\n"""
break
self.purchase.perform_purchase()
self.session.add(self.purchase)
self.sql_session.add(self.purchase)
try:
self.session.commit()
except sqlalchemy.exc.SQLAlchemyError as e:
self.sql_session.commit()
except SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store purchase: {e}")
else:
print("Purchase stored.")
@@ -178,9 +198,9 @@ When finished, write an empty line to confirm the purchase.\n"""
for t in self.purchase.transactions:
if not t.user.is_anonymous():
print(f"User {t.user.name}'s credit is now {t.user.credit:d} kr")
if t.user.credit < config.getint("limits", "low_credit_warning_limit"):
if t.user.credit < config["limits"]["low_credit_warning_limit"]:
print(
f"USER {t.user.name} HAS LOWER CREDIT THAN {config.getint('limits', 'low_credit_warning_limit'):d},",
f"USER {t.user.name} HAS LOWER CREDIT THAN {config['limits']['low_credit_warning_limit']:d},",
"AND SHOULD CONSIDER PUTTING SOME MONEY IN THE BOX.",
)
@@ -189,10 +209,10 @@ When finished, write an empty line to confirm the purchase.\n"""
print("")
return True
def complete_input(self):
def complete_input(self) -> bool:
return self.purchase.is_complete()
def format_purchase(self):
def format_purchase(self) -> str | None:
self.purchase.set_price()
transactions = self.purchase.transactions
entries = self.purchase.entries
@@ -204,7 +224,10 @@ When finished, write an empty line to confirm the purchase.\n"""
string += "(empty)"
else:
string += ", ".join(
[t.user.name + ("*" if not self.credit_check(t.user) else "") for t in transactions]
[
t.user.name + ("*" if not self.credit_check(t.user) else "")
for t in transactions
],
)
string += "\n products: "
if len(entries) == 0:
@@ -212,7 +235,7 @@ When finished, write an empty line to confirm the purchase.\n"""
else:
string += "\n "
string += "\n ".join(
[f"{e.amount:d}x {e.product.name} ({e.product.price:d} kr)" for e in entries]
[f"{e.amount:d}x {e.product.name} ({e.product.price:d} kr)" for e in entries],
)
if len(transactions) > 1:
string += f"\n price per person: {self.purchase.price_per_transaction():d} kr"
@@ -228,7 +251,7 @@ When finished, write an empty line to confirm the purchase.\n"""
return string
def print_purchase(self):
def print_purchase(self) -> None:
info = self.format_purchase()
if info is not None:
self.set_context(info)

View File

@@ -1,6 +1,9 @@
import sqlalchemy
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.orm import Session
from dibbler.models import Product, User
from dibbler.models import User, Product
from .helpermenus import Menu, Selector
__all__ = [
@@ -14,32 +17,48 @@ __all__ = [
class AddUserMenu(Menu):
def __init__(self):
Menu.__init__(self, "Add user", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Add user", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
username = self.input_str(
"Username (should be same as PVV username)",
regex=User.name_re,
length_range=(1, 10),
)
cardnum = self.input_str("Card number (optional)", regex=User.card_re, length_range=(0, 10))
cardnum = cardnum.lower()
rfid = self.input_str("RFID (optional)", regex=User.rfid_re, length_range=(0, 10))
assert username is not None
cardnum = self.input_str(
"Card number (optional)",
regex=User.card_re,
length_range=(0, 10),
empty_string_is_none=True,
)
if cardnum is not None:
cardnum = cardnum.lower()
rfid = self.input_str(
"RFID (optional)",
regex=User.rfid_re,
length_range=(0, 10),
empty_string_is_none=True,
)
user = User(username, cardnum, rfid)
self.session.add(user)
self.sql_session.add(user)
try:
self.session.commit()
self.sql_session.commit()
print(f"User {username} stored")
except sqlalchemy.exc.IntegrityError as e:
except IntegrityError as e:
self.sql_session.rollback()
print(f"Could not store user {username}: {e}")
self.pause()
class EditUserMenu(Menu):
def __init__(self):
Menu.__init__(self, "Edit user", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Edit user", sql_session)
self.help_text = """
The only editable part of a user is its card number and rfid.
@@ -47,7 +66,7 @@ First select an existing user, then enter a new card number for that
user, then rfid (write an empty line to remove the card number or rfid).
"""
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
user = self.input_user("User")
self.printc(f"Editing user {user.name}")
@@ -69,43 +88,50 @@ user, then rfid (write an empty line to remove the card number or rfid).
empty_string_is_none=True,
)
try:
self.session.commit()
self.sql_session.commit()
print(f"User {user.name} stored")
except sqlalchemy.exc.SQLAlchemyError as e:
except SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store user {user.name}: {e}")
self.pause()
class AddProductMenu(Menu):
def __init__(self):
Menu.__init__(self, "Add product", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Add product", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
bar_code = self.input_str("Bar code", regex=Product.bar_code_re, length_range=(8, 13))
assert bar_code is not None
name = self.input_str("Name", regex=Product.name_re, length_range=(1, Product.name_length))
assert name is not None
price = self.input_int("Price", 1, 100000)
product = Product(bar_code, name, price)
self.session.add(product)
self.sql_session.add(product)
try:
self.session.commit()
self.sql_session.commit()
print(f"Product {name} stored")
except sqlalchemy.exc.SQLAlchemyError as e:
except SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store product {name}: {e}")
self.pause()
class EditProductMenu(Menu):
def __init__(self):
Menu.__init__(self, "Edit product", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Edit product", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
product = self.input_product("Product")
self.printc(f"Editing product {product.name}")
while True:
selector = Selector(
f"Do what with {product.name}?",
sql_session=self.sql_session,
items=[
("name", "Edit name"),
("price", "Edit price"),
@@ -135,9 +161,10 @@ class EditProductMenu(Menu):
product.hidden = self.confirm(f"Hidden(currently {product.hidden})", default=False)
elif what == "store":
try:
self.session.commit()
self.sql_session.commit()
print(f"Product {product.name} stored")
except sqlalchemy.exc.SQLAlchemyError as e:
except SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store product {product.name}: {e}")
self.pause()
return
@@ -149,10 +176,10 @@ class EditProductMenu(Menu):
class AdjustStockMenu(Menu):
def __init__(self):
Menu.__init__(self, "Adjust stock", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Adjust stock", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
product = self.input_product("Product")
@@ -168,10 +195,11 @@ class AdjustStockMenu(Menu):
product.stock += add_stock
try:
self.session.commit()
self.sql_session.commit()
print("Stock is now stored")
self.pause()
except sqlalchemy.exc.SQLAlchemyError as e:
except SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store stock: {e}")
self.pause()
return
@@ -179,13 +207,13 @@ class AdjustStockMenu(Menu):
class CleanupStockMenu(Menu):
def __init__(self):
Menu.__init__(self, "Stock Cleanup", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Stock Cleanup", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
products = self.session.query(Product).filter(Product.stock != 0).all()
products = self.sql_session.query(Product).filter(Product.stock != 0).all()
print("Every product in stock will be printed.")
print("Entering no value will keep current stock or set it to 0 if it is negative.")
@@ -199,15 +227,16 @@ class CleanupStockMenu(Menu):
for product in products:
oldstock = product.stock
product.stock = self.input_int(product.name, 0, 10000, default=max(0, oldstock))
self.session.add(product)
self.sql_session.add(product)
if oldstock != product.stock:
changed_products.append((product, oldstock))
try:
self.session.commit()
self.sql_session.commit()
print("New stocks are now stored.")
self.pause()
except sqlalchemy.exc.SQLAlchemyError as e:
except SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store stock: {e}")
self.pause()
return

View File

@@ -1,129 +1,146 @@
# -*- coding: utf-8 -*-
from textwrap import dedent
from .helpermenus import MessageMenu, Menu
from sqlalchemy.orm import Session
from .helpermenus import Menu, MessageMenu
class FAQMenu(Menu):
def __init__(self):
Menu.__init__(self, "Frequently Asked Questions")
def __init__(self, sql_session: Session) -> None:
super().__init__("Frequently Asked Questions", sql_session)
self.items = [
MessageMenu(
"What is the meaning with this program?",
"""
We want to avoid keeping lots of cash in PVVVV's money box and to
make it easy to pay for stuff without using money. (Without using
money each time, that is. You do of course have to pay for the things
you buy eventually).
dedent("""
We want to avoid keeping lots of cash in PVVVV's money box and to
make it easy to pay for stuff without using money. (Without using
money each time, that is. You do of course have to pay for the things
you buy eventually).
Dibbler stores a "credit" amount for each user. When you register a
purchase in Dibbler, this amount is decreased. To increase your
credit, purchase products for dibbler, and register them using "Add
stock and adjust credit".
Alternatively, add money to the money box and use "Adjust credit" to
tell Dibbler about it.
""",
Dibbler stores a "credit" amount for each user. When you register a
purchase in Dibbler, this amount is decreased. To increase your
credit, purchase products for dibbler, and register them using "Add
stock and adjust credit".
Alternatively, add money to the money box and use "Adjust credit" to
tell Dibbler about it.
"""),
sql_session,
),
MessageMenu(
"Can I still pay for stuff using cash?",
"""
Please put money in the money box and use "Adjust Credit" so that
dibbler can keep track of credit and purchases.""",
dedent("""
Please put money in the money box and use "Adjust Credit" so that
dibbler can keep track of credit and purchases.
"""),
sql_session,
),
MessageMenu(
"How do I exit from a submenu/dialog/thing?",
'Type "exit", "q", or ^d.',
sql_session,
),
MessageMenu("How do I exit from a submenu/dialog/thing?", 'Type "exit", "q", or ^d.'),
MessageMenu(
'What does "." mean?',
"""
The "." character, known as "full stop" or "period", is most often
used to indicate the end of a sentence.
dedent("""
The "." character, known as "full stop" or "period", is most often
used to indicate the end of a sentence.
It is also used by Dibbler to indicate that the program wants you to
read some text before continuing. Whenever some output ends with a
line containing only a period, you should read the lines above and
then press enter to continue.
""",
It is also used by Dibbler to indicate that the program wants you to
read some text before continuing. Whenever some output ends with a
line containing only a period, you should read the lines above and
then press enter to continue.
"""),
sql_session,
),
MessageMenu(
"Why is the user interface so terribly unintuitive?",
"""
Answer #1: It is not.
dedent("""
Answer #1: It is not.
Answer #2: We are trying to compete with PVV's microwave oven in
userfriendliness.
Answer #2: We are trying to compete with PVV's microwave oven in
userfriendliness.
Answer #3: YOU are unintuitive.
""",
Answer #3: YOU are unintuitive.
"""),
sql_session,
),
MessageMenu(
"Why is there no help command?",
'There is. Have you tried typing "help"?',
sql_session,
),
MessageMenu(
'Where are the easter eggs? I tried saying "moo", but nothing happened.',
'Don\'t say "moo".',
sql_session,
),
MessageMenu(
"Why does the program speak English when all the users are Norwegians?",
"Godt spørsmål. Det virket sikkert som en god idé der og da.",
sql_session,
),
MessageMenu(
"Why does the screen have strange colours?",
"""
Type "c" on the main menu to change the colours of the display, or
"cs" if you are a boring person.
""",
dedent("""
Type "c" on the main menu to change the colours of the display, or
"cs" if you are a boring person.
"""),
sql_session,
),
MessageMenu(
"I found a bug; is there a reward?",
"""
No.
dedent("""
No.
But if you are certain that it is a bug, not a feature, then you
should fix it (or better: force someone else to do it).
But if you are certain that it is a bug, not a feature, then you
should fix it (or better: force someone else to do it).
Follow this procedure:
Follow this procedure:
1. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler
1. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler
2. Fix the bug.
2. Fix the bug.
3. Check that the program still runs (and, preferably, that the bug is
in fact fixed).
3. Check that the program still runs (and, preferably, that the bug is
in fact fixed).
4. Commit.
4. Commit.
5. Update the running copy from svn:
5. Update the running copy from svn:
$ su -
# su -l -s /bin/bash pvvvv
$ cd dibbler
$ git pull
$ su -
# su -l -s /bin/bash pvvvv
$ cd dibbler
$ git pull
6. Type "restart" in Dibbler to replace the running process by a new
one using the updated files.
""",
6. Type "restart" in Dibbler to replace the running process by a new
one using the updated files.
"""),
sql_session,
),
MessageMenu(
"My question isn't listed here; what do I do?",
"""
DON'T PANIC.
dedent("""
DON'T PANIC.
Follow this procedure:
Follow this procedure:
1. Ask someone (or read the source code) and get an answer.
1. Ask someone (or read the source code) and get an answer.
2. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler
2. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler
3. Add your question (with answer) to the FAQ and commit.
3. Add your question (with answer) to the FAQ and commit.
4. Update the running copy from svn:
4. Update the running copy from svn:
$ su -
# su -l -s /bin/bash pvvvv
$ cd dibbler
$ git pull
$ su -
# su -l -s /bin/bash pvvvv
$ cd dibbler
$ git pull
5. Type "restart" in Dibbler to replace the running process by a new
one using the updated files.
""",
5. Type "restart" in Dibbler to replace the running process by a new
one using the updated files.
"""),
sql_session,
),
]

View File

@@ -1,44 +1,64 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
import sys
from select import select
from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar
from dibbler.db import Session
from dibbler.models import User
from dibbler.lib.helpers import (
search_user,
search_product,
guess_data_type,
argmax,
guess_data_type,
search_product,
search_user,
)
from dibbler.models import Product, User
exit_commands = ["exit", "abort", "quit", "bye", "eat flaming death", "q"]
help_commands = ["help", "?"]
context_commands = ["what", "??"]
local_help_commands = ["help!", "???"]
if TYPE_CHECKING:
from collections.abc import Callable, Iterable
from sqlalchemy.orm import Session
exit_commands: list[str] = ["exit", "abort", "quit", "bye", "eat flaming death", "q"]
help_commands: list[str] = ["help", "?"]
context_commands: list[str] = ["what", "??"]
local_help_commands: list[str] = ["help!", "???"]
class ExitMenu(Exception):
class ExitMenuException(Exception):
pass
class Menu(object):
MenuItemType = TypeVar("MenuItemType", bound="Menu")
class Menu:
name: str
sql_session: Session
items: list[Menu | tuple[MenuItemType, str] | str]
prompt: str | None
end_prompt: str | None
return_index: bool
exit_msg: str | None
exit_confirm_msg: str | None
exit_disallowed_msg: str | None
help_text: str | None
context: str | None
def __init__(
self,
name,
items=None,
prompt=None,
end_prompt="> ",
return_index=True,
exit_msg=None,
exit_confirm_msg=None,
exit_disallowed_msg=None,
help_text=None,
uses_db=False,
):
self.name = name
name: str,
sql_session: Session,
items: list[Self | tuple[MenuItemType, str] | str] | None = None,
prompt: str | None = None,
end_prompt: str | None = "> ",
return_index: bool = True,
exit_msg: str | None = None,
exit_confirm_msg: str | None = None,
exit_disallowed_msg: str | None = None,
help_text: str | None = None,
) -> None:
self.name: str = name
self.sql_session: Session = sql_session
self.items = items if items is not None else []
self.prompt = prompt
self.end_prompt = end_prompt
@@ -48,54 +68,61 @@ class Menu(object):
self.exit_disallowed_msg = exit_disallowed_msg
self.help_text = help_text
self.context = None
self.uses_db = uses_db
self.session = None
def exit_menu(self):
assert name is not None
assert self.sql_session is not None
def exit_menu(self) -> None:
if self.exit_disallowed_msg is not None:
print(self.exit_disallowed_msg)
return
if self.exit_confirm_msg is not None:
if not self.confirm(self.exit_confirm_msg, default=True):
return
raise ExitMenu()
raise ExitMenuException()
def at_exit(self):
def at_exit(self) -> None:
if self.exit_msg:
print(self.exit_msg)
def set_context(self, string, display=True):
def set_context(
self,
string: str | None,
display: bool = True,
) -> None:
self.context = string
if self.context is not None and display:
print(self.context)
def add_to_context(self, string):
self.context += string
def add_to_context(self, string: str) -> None:
if self.context is not None:
self.context += string
else:
self.context = string
def printc(self, string):
def printc(self, string: str) -> None:
print(string)
if self.context is None:
self.context = string
else:
self.context += "\n" + string
def show_context(self):
def show_context(self) -> None:
print(self.header())
if self.context is not None:
print(self.context)
def item_is_submenu(self, i):
def item_is_submenu(self, i: int) -> bool:
return isinstance(self.items[i], Menu)
def item_name(self, i):
def item_name(self, i: int) -> str:
if self.item_is_submenu(i):
return self.items[i].name
elif isinstance(self.items[i], tuple):
if isinstance(self.items[i], tuple):
return self.items[i][1]
else:
return self.items[i]
return self.items[i]
def item_value(self, i):
def item_value(self, i: int) -> MenuItemType | int:
if isinstance(self.items[i], tuple):
return self.items[i][0]
if self.return_index:
@@ -104,14 +131,14 @@ class Menu(object):
def input_str(
self,
prompt=None,
end_prompt=None,
regex=None,
length_range=(None, None),
empty_string_is_none=False,
timeout=None,
default=None,
):
prompt: str | None = None,
end_prompt: str | None = None,
regex: str | None = None,
length_range: tuple[int | None, int | None] = (None, None),
empty_string_is_none: bool = False,
timeout: int | None = None,
default: str | None = None,
) -> str | None:
if prompt is None:
prompt = self.prompt if self.prompt is not None else ""
if default is not None:
@@ -168,7 +195,7 @@ class Menu(object):
):
if length_range[0] and length_range[1]:
print(
f"Value must have length in range [{length_range[0]:d}, {length_range[1]:d}]"
f"Value must have length in range [{length_range[0]:d}, {length_range[1]:d}]",
)
elif length_range[0]:
print(f"Value must have length at least {length_range[0]:d}")
@@ -177,7 +204,7 @@ class Menu(object):
continue
return result
def special_input_options(self, result):
def special_input_options(self, result) -> bool:
"""
Handles special, magic input for input_str
@@ -187,7 +214,7 @@ class Menu(object):
"""
return False
def special_input_choice(self, in_str):
def special_input_choice(self, in_str: str) -> bool:
"""
Handle choices which are not simply menu items.
@@ -197,33 +224,39 @@ class Menu(object):
"""
return False
def input_choice(self, number_of_choices, prompt=None, end_prompt=None):
def input_choice(
self,
number_of_choices: int,
prompt: str | None = None,
end_prompt: str | None = None,
) -> int:
while True:
result = self.input_str(prompt, end_prompt)
assert result is not None
if result == "":
print("Please enter something")
else:
if result.isdigit():
choice = int(result)
if choice == 0 and 10 <= number_of_choices:
if choice == 0 and number_of_choices >= 10:
return 10
if 0 < choice <= number_of_choices:
return choice
if not self.special_input_choice(result):
self.invalid_menu_choice(result)
def invalid_menu_choice(self, in_str):
def invalid_menu_choice(self, in_str: str) -> None:
print("Please enter a valid choice.")
def input_int(
self,
prompt=None,
minimum=None,
maximum=None,
null_allowed=False,
zero_allowed=True,
default=None,
):
prompt: str,
minimum: int | None = None,
maximum: int | None = None,
null_allowed: bool = False,
zero_allowed: bool = True,
default: int | None = None,
) -> int | Literal[False]:
if minimum is not None and maximum is not None:
end_prompt = f"({minimum}-{maximum})>"
elif minimum is not None:
@@ -234,7 +267,11 @@ class Menu(object):
end_prompt = ""
while True:
result = self.input_str(prompt + end_prompt, default=default)
result = self.input_str(
prompt + end_prompt,
default=str(default) if default is not None else None,
)
assert result is not None
if result == "" and null_allowed:
return False
try:
@@ -252,93 +289,115 @@ class Menu(object):
except ValueError:
print("Please enter an integer")
def input_user(self, prompt=None, end_prompt=None):
def input_user(
self,
prompt: str | None = None,
end_prompt: str | None = None,
) -> User:
user = None
while user is None:
user = self.retrieve_user(self.input_str(prompt, end_prompt))
search_string = self.input_str(prompt, end_prompt)
assert search_string is not None
user = self.retrieve_user(search_string)
return user
def retrieve_user(self, search_str):
def retrieve_user(self, search_str: str) -> User | None:
return self.search_ui(search_user, search_str, "user")
def input_product(self, prompt=None, end_prompt=None):
def input_product(
self,
prompt: str | None = None,
end_prompt: str | None = None,
) -> Product:
product = None
while product is None:
product = self.retrieve_product(self.input_str(prompt, end_prompt))
search_string = self.input_str(prompt, end_prompt)
assert search_string is not None
product = self.retrieve_product(search_string)
return product
def retrieve_product(self, search_str):
def retrieve_product(self, search_str: str) -> Product | None:
return self.search_ui(search_product, search_str, "product")
def input_thing(
self,
prompt=None,
end_prompt=None,
permitted_things=("user", "product"),
add_nonexisting=(),
empty_input_permitted=False,
find_hidden_products=True,
):
prompt: str | None = None,
end_prompt: str | None = None,
permitted_things: Iterable[str] = ("user", "product"),
add_nonexisting: Iterable[str] = (),
empty_input_permitted: bool = False,
find_hidden_products: bool = True,
) -> User | Product | None:
result = None
while result is None:
search_str = self.input_str(prompt, end_prompt)
assert search_str is not None
if search_str == "" and empty_input_permitted:
return None
result = self.search_for_thing(
search_str, permitted_things, add_nonexisting, find_hidden_products
search_str,
permitted_things,
add_nonexisting,
find_hidden_products,
)
return result
def input_multiple(
self,
prompt=None,
end_prompt=None,
permitted_things=("user", "product"),
add_nonexisting=(),
empty_input_permitted=False,
find_hidden_products=True,
):
prompt: str | None = None,
end_prompt: str | None = None,
permitted_things: Iterable[str] = ("user", "product"),
add_nonexisting: Iterable[str] = (),
empty_input_permitted: bool = False,
find_hidden_products: bool = True,
) -> tuple[User | Product, int] | None:
result = None
num = 0
while result is None:
search_str = self.input_str(prompt, end_prompt)
assert search_str is not None
search_lst = search_str.split(" ")
if search_str == "" and empty_input_permitted:
return None
else:
result = self.search_for_thing(
search_str, permitted_things, add_nonexisting, find_hidden_products
)
num = 1
result = self.search_for_thing(
search_str,
permitted_things,
add_nonexisting,
find_hidden_products,
)
num = 1
if (result is None) and (len(search_lst) > 1):
print('Interpreting input as "<number> <product>"')
try:
num = int(search_lst[0])
result = self.search_for_thing(
" ".join(search_lst[1:]),
permitted_things,
add_nonexisting,
find_hidden_products,
)
# Her kan det legges inn en except ValueError,
# men da blir det fort mye plaging av brukeren
except Exception as e:
print(e)
if (result is None) and (len(search_lst) > 1):
print('Interpreting input as "<number> <product>"')
try:
num = int(search_lst[0])
result = self.search_for_thing(
" ".join(search_lst[1:]),
permitted_things,
add_nonexisting,
find_hidden_products,
)
# Her kan det legges inn en except ValueError,
# men da blir det fort mye plaging av brukeren
except Exception as e:
print(e)
return result, num
def search_for_thing(
self,
search_str,
permitted_things=("user", "product"),
add_non_existing=(),
find_hidden_products=True,
):
search_fun = {"user": search_user, "product": search_product}
search_str: str,
permitted_things: Iterable[str] = ("user", "product"),
add_non_existing: Iterable[str] = (),
find_hidden_products: bool = True,
) -> User | Product | None:
search_fun = {
"user": search_user,
"product": search_product,
}
results = {}
result_values = {}
for thing in permitted_things:
results[thing] = search_fun[thing](search_str, self.session, find_hidden_products)
results[thing] = search_fun[thing](search_str, self.sql_session, find_hidden_products)
result_values[thing] = self.search_result_value(results[thing])
selected_thing = argmax(result_values)
if not results[selected_thing]:
@@ -353,10 +412,14 @@ class Menu(object):
return self.search_add(search_str)
# print('No match found for "%s".' % search_str)
return None
return self.search_ui2(search_str, results[selected_thing], selected_thing)
return self.search_ui2(
search_str,
results[selected_thing],
selected_thing,
)
@staticmethod
def search_result_value(result):
def search_result_value(result) -> Literal[0, 1, 2, 3]:
if result is None:
return 0
if not isinstance(result, list):
@@ -367,18 +430,19 @@ class Menu(object):
return 2
return 1
def search_add(self, string):
def search_add(self, string: str) -> User | None:
type_guess = guess_data_type(string)
if type_guess == "username":
print(f'"{string}" looks like a username, but no such user exists.')
if self.confirm(f"Create user {string}?"):
user = User(string, None)
self.session.add(user)
self.sql_session.add(user)
return user
return None
if type_guess == "card":
selector = Selector(
f'"{string}" looks like a card number, but no user with that card number exists.',
self.sql_session,
[
("create", f"Create user with card number {string}"),
("set", f"Set card number of an existing user to {string}"),
@@ -387,12 +451,14 @@ class Menu(object):
selection = selector.execute()
if selection == "create":
username = self.input_str(
"Username for new user (should be same as PVV username)",
User.name_re,
(1, 10),
prompt="Username for new user (should be same as PVV username)",
end_prompt=None,
regex=User.name_re,
length_range=(1, 10),
)
assert username is not None
user = User(username, string)
self.session.add(user)
self.sql_session.add(user)
return user
if selection == "set":
user = self.input_user("User to set card number for")
@@ -405,11 +471,21 @@ class Menu(object):
print(f'"{string}" looks like the bar code for a product, but no such product exists.')
return None
def search_ui(self, search_fun, search_str, thing):
result = search_fun(search_str, self.session)
def search_ui(
self,
search_fun: Callable[[str, Session], list[Any] | Any],
search_str: str,
thing: str,
) -> Any:
result = search_fun(search_str, self.sql_session)
return self.search_ui2(search_str, result, thing)
def search_ui2(self, search_str, result, thing):
def search_ui2(
self,
search_str: str,
result: list[Any] | Any,
thing: str,
) -> Any:
if not isinstance(result, list):
return result
if len(result) == 0:
@@ -429,25 +505,41 @@ class Menu(object):
else:
select_header = f'{len(result):d} {thing}s matching "{search_str}"'
select_items = result
selector = Selector(select_header, items=select_items, return_index=False)
selector = Selector(
select_header,
self.sql_session,
items=select_items,
return_index=False,
)
return selector.execute()
@staticmethod
def confirm(prompt, end_prompt=None, default=None, timeout=None):
return ConfirmMenu(prompt, end_prompt=None, default=default, timeout=timeout).execute()
def confirm(
self,
prompt: str,
end_prompt: str | None = None,
default: bool | None = None,
timeout: int | None = None,
) -> bool:
return ConfirmMenu(
self.sql_session,
prompt,
end_prompt=None,
default=default,
timeout=timeout,
).execute()
def header(self):
def header(self) -> str:
return f"[{self.name}]"
def print_header(self):
def print_header(self) -> None:
print("")
print(self.header())
def pause(self):
def pause(self) -> None:
self.input_str(".", end_prompt="")
@staticmethod
def general_help():
def general_help() -> None:
print(
"""
DIBBLER HELP
@@ -470,10 +562,10 @@ class Menu(object):
of money PVVVV owes the user. This value decreases with the
appropriate amount when you register a purchase, and you may increase
it by putting money in the box and using the "Adjust credit" menu.
"""
""",
)
def local_help(self):
def local_help(self) -> None:
if self.help_text is None:
print("no help here")
else:
@@ -481,21 +573,15 @@ class Menu(object):
print(f"Help for {self.header()}:")
print(self.help_text)
def execute(self, **kwargs):
def execute(self, **_kwargs) -> MenuItemType | int | None:
self.set_context(None)
try:
if self.uses_db and not self.session:
self.session = Session()
return self._execute(**kwargs)
except ExitMenu:
return self._execute(**_kwargs)
except ExitMenuException:
self.at_exit()
return None
finally:
if self.session is not None:
self.session.close()
self.session = None
def _execute(self, **kwargs):
def _execute(self, **_kwargs) -> MenuItemType | int | None:
while True:
self.print_header()
self.set_context(None)
@@ -514,12 +600,21 @@ class Menu(object):
class MessageMenu(Menu):
def __init__(self, name, message, pause_after_message=True):
Menu.__init__(self, name)
message: str
pause_after_message: bool
def __init__(
self,
name: str,
message: str,
sql_session: Session,
pause_after_message: bool = True,
) -> None:
super().__init__(name, sql_session)
self.message = message.strip()
self.pause_after_message = pause_after_message
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
print("")
print(self.message)
@@ -528,10 +623,17 @@ class MessageMenu(Menu):
class ConfirmMenu(Menu):
def __init__(self, prompt="confirm? ", end_prompt=": ", default=None, timeout=0):
Menu.__init__(
self,
def __init__(
self,
sql_session: Session,
prompt: str = "confirm? ",
end_prompt: str | None = ": ",
default: bool | None = None,
timeout: int | None = 0,
) -> None:
super().__init__(
"question",
sql_session,
prompt=prompt,
end_prompt=end_prompt,
exit_disallowed_msg="Please answer yes or no",
@@ -539,45 +641,55 @@ class ConfirmMenu(Menu):
self.default = default
self.timeout = timeout
def _execute(self):
def _execute(self, **_kwargs) -> bool:
options = {True: "[y]/n", False: "y/[n]", None: "y/n"}[self.default]
while True:
result = self.input_str(
f"{self.prompt} ({options})", end_prompt=": ", timeout=self.timeout
f"{self.prompt} ({options})",
end_prompt=": ",
timeout=self.timeout,
)
result = result.lower().strip()
if result in ["y", "yes"]:
return True
elif result in ["n", "no"]:
if result in ["n", "no"]:
return False
elif self.default is not None and result == "":
if self.default is not None and result == "":
return self.default
else:
print("Please answer yes or no")
print("Please answer yes or no")
class Selector(Menu):
def __init__(
self,
name,
items=None,
prompt="select",
return_index=True,
exit_msg=None,
exit_confirm_msg=None,
help_text=None,
):
name: str,
sql_session: Session,
items: list[Self | tuple[MenuItemType, str] | str] | None = None,
prompt: str | None = "select",
return_index: bool = True,
exit_msg: str | None = None,
exit_confirm_msg: str | None = None,
help_text: str | None = None,
) -> None:
if items is None:
items = []
Menu.__init__(self, name, items, prompt, return_index=return_index, exit_msg=exit_msg)
super().__init__(
name,
sql_session,
items,
prompt,
return_index=return_index,
exit_msg=exit_msg,
help_text=help_text,
)
def header(self):
def header(self) -> str:
return self.name
def print_header(self):
def print_header(self) -> None:
print(self.header())
def local_help(self):
def local_help(self) -> None:
if self.help_text is None:
print("This is a selection menu. Enter one of the listed numbers, or")
print("'exit' to go out and do something else.")

View File

@@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-
import os
import random
import sys
from dibbler.db import Session
from sqlalchemy.orm import Session
from .buymenu import BuyMenu
from .faq import FAQMenu
@@ -13,14 +12,17 @@ faq_commands = ["faq"]
restart_commands = ["restart"]
def restart():
def restart() -> None:
# Does not work if the script is not executable, or if it was
# started by searching $PATH.
os.execv(sys.argv[0], sys.argv)
class MainMenu(Menu):
def special_input_choice(self, in_str):
def __init__(self, sql_session: Session, **_kwargs) -> None:
super().__init__("Dibbler main menu", sql_session, **_kwargs)
def special_input_choice(self, in_str: str) -> bool:
mv = in_str.split()
if len(mv) == 2 and mv[0].isdigit():
num = int(mv[0])
@@ -28,7 +30,7 @@ class MainMenu(Menu):
else:
num = 1
item_name = in_str
buy_menu = BuyMenu(Session())
buy_menu = BuyMenu(self.sql_session)
thing = buy_menu.search_for_thing(item_name, find_hidden_products=False)
if thing:
buy_menu.execute(initial_contents=[(thing, num)])
@@ -36,32 +38,26 @@ class MainMenu(Menu):
return True
return False
def special_input_options(self, result):
def special_input_options(self, result: str) -> bool:
if result in faq_commands:
FAQMenu().execute()
FAQMenu(self.sql_session).execute()
return True
if result in restart_commands:
if self.confirm("Restart Dibbler?"):
restart()
pass
return True
elif result == "c":
os.system(
'echo -e "\033['
+ str(random.randint(40, 49))
+ ";"
+ str(random.randint(30, 37))
+ ';5m"'
)
os.system("clear")
if result == "c":
print(f"\033[{random.randint(40, 49)};{random.randint(30, 37)};5m")
print("\033[2J")
self.show_context()
return True
elif result == "cs":
os.system('echo -e "\033[0m"')
os.system("clear")
if result == "cs":
print("\033[0m")
print("\033[2J")
self.show_context()
return True
return False
def invalid_menu_choice(self, in_str):
def invalid_menu_choice(self, in_str: str) -> None:
print(self.show_context())

View File

@@ -1,17 +1,19 @@
import sqlalchemy
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from dibbler.conf import config
from dibbler.models import Transaction, Product, User
from dibbler.lib.helpers import less
from dibbler.models import Product, Transaction, User
from .helpermenus import Menu, Selector
class TransferMenu(Menu):
def __init__(self):
Menu.__init__(self, "Transfer credit between users", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Transfer credit between users", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
amount = self.input_int("Transfer amount", 1, 100000)
self.set_context(f"Transferring {amount:d} kr", display=False)
@@ -26,24 +28,25 @@ class TransferMenu(Menu):
t2 = Transaction(user2, -amount, f'transfer from {user1.name} "{comment}"')
t1.perform_transaction()
t2.perform_transaction()
self.session.add(t1)
self.session.add(t2)
self.sql_session.add(t1)
self.sql_session.add(t2)
try:
self.session.commit()
self.sql_session.commit()
print(f"Transferred {amount:d} kr from {user1} to {user2}")
print(f"User {user1}'s credit is now {user1.credit:d} kr")
print(f"User {user2}'s credit is now {user2.credit:d} kr")
print(f"Comment: {comment}")
except sqlalchemy.exc.SQLAlchemyError as e:
except SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not perform transfer: {e}")
# self.pause()
class ShowUserMenu(Menu):
def __init__(self):
Menu.__init__(self, "Show user", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Show user", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
user = self.input_user("User name, card number or RFID")
print(f"User name: {user.name}")
@@ -52,11 +55,12 @@ class ShowUserMenu(Menu):
print(f"Credit: {user.credit} kr")
selector = Selector(
f"What do you want to know about {user.name}?",
self.sql_session,
items=[
(
"transactions",
"Recent transactions (List of last "
+ str(config.getint("limits", "user_recent_transaction_limit"))
+ str(config["limits"]["user_recent_transaction_limit"])
+ ")",
),
("products", f"Which products {user.name} has bought, and how many"),
@@ -65,7 +69,7 @@ class ShowUserMenu(Menu):
)
what = selector.execute()
if what == "transactions":
self.print_transactions(user, config.getint("limits", "user_recent_transaction_limit"))
self.print_transactions(user, config["limits"]["user_recent_transaction_limit"])
elif what == "products":
self.print_purchased_products(user)
elif what == "transactions-all":
@@ -74,7 +78,7 @@ class ShowUserMenu(Menu):
print("What what?")
@staticmethod
def print_transactions(user, limit=None):
def print_transactions(user: User, limit: int | None = None) -> None:
num_trans = len(user.transactions)
if limit is None:
limit = num_trans
@@ -87,10 +91,7 @@ class ShowUserMenu(Menu):
if t.purchase:
products = []
for entry in t.purchase.entries:
if abs(entry.amount) != 1:
amount = f"{abs(entry.amount)}x "
else:
amount = ""
amount = f"{abs(entry.amount)}x " if abs(entry.amount) != 1 else ""
product = f"{amount}{entry.product.name}"
products.append(product)
string += "purchase ("
@@ -98,13 +99,13 @@ class ShowUserMenu(Menu):
string += ")"
if t.penalty > 1:
string += f" * {t.penalty:d}x penalty applied"
else:
elif t.description is not None:
string += t.description
string += "\n"
less(string)
@staticmethod
def print_purchased_products(user):
def print_purchased_products(user: User) -> None:
products = []
for ref in user.products:
product = ref.product
@@ -123,13 +124,13 @@ class ShowUserMenu(Menu):
class UserListMenu(Menu):
def __init__(self):
Menu.__init__(self, "User list", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("User list", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
user_list = self.session.query(User).all()
total_credit = self.session.query(sqlalchemy.func.sum(User.credit)).first()[0]
user_list = self.sql_session.query(User).all()
total_credit = self.sql_session.query(sqlalchemy.func.sum(User.credit)).first()[0]
line_format = "%-12s | %6s\n"
hline = "---------------------\n"
@@ -144,10 +145,10 @@ class UserListMenu(Menu):
class AdjustCreditMenu(Menu):
def __init__(self):
Menu.__init__(self, "Adjust credit", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Adjust credit", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
user = self.input_user("User")
print(f"User {user.name}'s credit is {user.credit:d} kr")
@@ -164,24 +165,25 @@ class AdjustCreditMenu(Menu):
description = "manually adjusted credit"
transaction = Transaction(user, -amount, description)
transaction.perform_transaction()
self.session.add(transaction)
self.sql_session.add(transaction)
try:
self.session.commit()
self.sql_session.commit()
print(f"User {user.name}'s credit is now {user.credit:d} kr")
except sqlalchemy.exc.SQLAlchemyError as e:
except SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store transaction: {e}")
# self.pause()
class ProductListMenu(Menu):
def __init__(self):
Menu.__init__(self, "Product list", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Product list", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
text = ""
product_list = (
self.session.query(Product)
self.sql_session.query(Product)
.filter(Product.hidden.is_(False))
.order_by(Product.stock.desc())
)
@@ -204,21 +206,22 @@ class ProductListMenu(Menu):
class ProductSearchMenu(Menu):
def __init__(self):
Menu.__init__(self, "Product search", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Product search", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
self.set_context("Enter (part of) product name or bar code")
product = self.input_product()
print(
"Result: %s, price: %d kr, bar code: %s, stock: %d, hidden: %s"
% (
product.name,
product.price,
product.bar_code,
product.stock,
("Y" if product.hidden else "N"),
)
", ".join(
[
f"Result: {product.name}",
f"price: {product.price} kr",
f"bar code: {product.bar_code}",
f"stock: {product.stock}",
f"hidden: {'Y' if product.hidden else 'N'}",
],
),
)
# self.pause()

View File

@@ -1,45 +1,46 @@
import re
from dibbler.conf import config
from dibbler.models import Product, User
from dibbler.lib.printer_helpers import print_bar_code, print_name_label
from sqlalchemy.orm import Session
# from dibbler.lib.printer_helpers import print_bar_code, print_name_label
from .helpermenus import Menu
class PrintLabelMenu(Menu):
def __init__(self):
Menu.__init__(self, "Print a label", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Print a label", sql_session)
self.help_text = """
Prints out a product bar code on the printer
Put it up somewhere in the vicinity.
"""
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
thing = self.input_thing("Product/User")
print("Printer menu is under renovation, please be patient")
if isinstance(thing, Product):
if re.match(r"^[0-9]{13}$", thing.bar_code):
bar_type = "ean13"
elif re.match(r"^[0-9]{8}$", thing.bar_code):
bar_type = "ean8"
else:
bar_type = "code39"
print_bar_code(
thing.bar_code,
thing.name,
barcode_type=bar_type,
rotate=config.getboolean("printer", "rotate"),
printer_type="QL-700",
label_type=config.get("printer", "label_type"),
)
elif isinstance(thing, User):
print_name_label(
text=thing.name,
label_type=config.get("printer", "label_type"),
rotate=config.getboolean("printer", "rotate"),
printer_type="QL-700",
)
return
# thing = self.input_thing("Product/User")
# if isinstance(thing, Product):
# if re.match(r"^[0-9]{13}$", thing.bar_code):
# bar_type = "ean13"
# elif re.match(r"^[0-9]{8}$", thing.bar_code):
# bar_type = "ean8"
# else:
# bar_type = "code39"
# print_bar_code(
# thing.bar_code,
# thing.name,
# barcode_type=bar_type,
# rotate=config["printer"]["rotate"],
# printer_type="QL-700",
# label_type=config.get("printer", "label_type"),
# )
# elif isinstance(thing, User):
# print_name_label(
# text=thing.name,
# label_type=config["printer"]["label_type"],
# rotate=config["printer"]["rotate"],
# printer_type="QL-700",
# )

View File

@@ -1,8 +1,9 @@
from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from dibbler.lib.helpers import less
from dibbler.models import PurchaseEntry, Product, User
from dibbler.lib.statistikkHelpers import statisticsTextOnly
from dibbler.models import Product, PurchaseEntry, User
from .helpermenus import Menu
@@ -15,14 +16,14 @@ __all__ = [
class ProductPopularityMenu(Menu):
def __init__(self):
Menu.__init__(self, "Products by popularity", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Products by popularity", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
text = ""
sub = (
self.session.query(
self.sql_session.query(
PurchaseEntry.product_id,
func.sum(PurchaseEntry.amount).label("purchase_count"),
)
@@ -31,8 +32,8 @@ class ProductPopularityMenu(Menu):
.subquery()
)
product_list = (
self.session.query(Product, sub.c.purchase_count)
.outerjoin((sub, Product.product_id == sub.c.product_id))
self.sql_session.query(Product, sub.c.purchase_count)
.outerjoin(sub, Product.product_id == sub.c.product_id)
.order_by(desc(sub.c.purchase_count))
.filter(sub.c.purchase_count is not None)
.all()
@@ -48,14 +49,14 @@ class ProductPopularityMenu(Menu):
class ProductRevenueMenu(Menu):
def __init__(self):
Menu.__init__(self, "Products by revenue", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Products by revenue", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
text = ""
sub = (
self.session.query(
self.sql_session.query(
PurchaseEntry.product_id,
func.sum(PurchaseEntry.amount).label("purchase_count"),
)
@@ -64,8 +65,8 @@ class ProductRevenueMenu(Menu):
.subquery()
)
product_list = (
self.session.query(Product, sub.c.purchase_count)
.outerjoin((sub, Product.product_id == sub.c.product_id))
self.sql_session.query(Product, sub.c.purchase_count)
.outerjoin(sub, Product.product_id == sub.c.product_id)
.order_by(desc(sub.c.purchase_count * Product.price))
.filter(sub.c.purchase_count is not None)
.all()
@@ -86,22 +87,26 @@ class ProductRevenueMenu(Menu):
class BalanceMenu(Menu):
def __init__(self):
Menu.__init__(self, "Total balance of PVVVV", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Total balance of PVVVV", sql_session)
def _execute(self):
def _execute(self, **_kwargs) -> None:
self.print_header()
text = ""
total_value = 0
product_list = self.session.query(Product).filter(Product.stock > 0).all()
product_list = self.sql_session.query(Product).filter(Product.stock > 0).all()
for p in product_list:
total_value += p.stock * p.price
total_positive_credit = (
self.session.query(func.sum(User.credit)).filter(User.credit > 0).first()[0]
self.sql_session.query(func.coalesce(func.sum(User.credit), 0))
.filter(User.credit > 0)
.first()[0]
)
total_negative_credit = (
self.session.query(func.sum(User.credit)).filter(User.credit < 0).first()[0]
self.sql_session.query(func.coalesce(func.sum(User.credit), 0))
.filter(User.credit < 0)
.first()[0]
)
total_credit = total_positive_credit + total_negative_credit
@@ -119,8 +124,8 @@ class BalanceMenu(Menu):
class LoggedStatisticsMenu(Menu):
def __init__(self):
Menu.__init__(self, "Statistics from log", uses_db=True)
def __init__(self, sql_session: Session) -> None:
super().__init__("Statistics from log", sql_session)
def _execute(self):
statisticsTextOnly()
def _execute(self, **_kwargs) -> None:
statisticsTextOnly(self.sql_session)

View File

@@ -17,16 +17,19 @@ def _pascal_case_to_snake_case(name: str) -> str:
class Base(DeclarativeBase):
metadata = MetaData(
naming_convention={
"ix": "ix_%(table_name)s_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
"ix": "ix__%(table_name)s__%(column_0_label)s",
"uq": "uq__%(table_name)s__%(column_0_name)s",
"ck": "ck__%(table_name)s__%(constraint_name)s",
"fk": "fk__%(table_name)s__%(column_0_name)s_%(referred_table_name)s",
"pk": "pk__%(table_name)s",
},
)
@declared_attr.directive
def __tablename__(cls) -> str:
if hasattr(cls, "__table_name__"):
assert isinstance(cls.__table_name__, str)
return cls.__table_name__
return _pascal_case_to_snake_case(cls.__name__)
# NOTE: This is the default implementation of __repr__ for all tables,
@@ -46,7 +49,7 @@ class Base(DeclarativeBase):
isinstance(v, InstrumentedList),
isinstance(v, InstrumentedSet),
isinstance(v, InstrumentedDict),
]
],
)
)
return f"<{self.__class__.__name__}({columns})>"

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from dibbler.models import Base
if TYPE_CHECKING:
from dibbler.models import Transaction
class LastCacheTransaction(Base):
"""Tracks the last transaction that affected various caches."""
id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""Internal database ID"""
transaction_id: Mapped[int | None] = mapped_column(ForeignKey("trx.id"), index=True)
"""The ID of the last transaction that affected the cache(s)."""
transaction: Mapped[Transaction | None] = relationship(
lazy="joined",
foreign_keys=[transaction_id],
)
"""The last transaction that affected the cache(s)."""

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Self
from typing import TYPE_CHECKING, Self
from sqlalchemy import (
Boolean,
@@ -19,6 +19,7 @@ class Product(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""Internal database ID"""
# TODO: add more validation for barcode
bar_code: Mapped[str] = mapped_column(String(13), unique=True)
"""
The bar code of the product.

View File

@@ -1,16 +1,33 @@
from datetime import datetime
from __future__ import annotations
from sqlalchemy import Integer, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from dibbler.models import Base
if TYPE_CHECKING:
from dibbler.models import LastCacheTransaction, Product
class ProductCache(Base):
product_id: Mapped[int] = mapped_column(Integer, primary_key=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""Internal database ID"""
product_id: Mapped[int] = mapped_column(ForeignKey("product.id"))
product: Mapped[Product] = relationship(
lazy="joined",
foreign_keys=[product_id],
)
price: Mapped[int] = mapped_column(Integer)
price_timestamp: Mapped[datetime] = mapped_column(DateTime)
stock: Mapped[int] = mapped_column(Integer)
stock_timestamp: Mapped[datetime] = mapped_column(DateTime)
last_cache_transaction_id: Mapped[int | None] = mapped_column(
ForeignKey("last_cache_transaction.id"), nullable=True,
)
last_cache_transaction: Mapped[LastCacheTransaction | None] = relationship(
lazy="joined",
foreign_keys=[last_cache_transaction_id],
)

View File

@@ -33,11 +33,10 @@ if TYPE_CHECKING:
from .Product import Product
from .User import User
# TODO: rename to *_PERCENT
# NOTE: these only matter when there are no adjustments made in the database.
DEFAULT_INTEREST_RATE_PERCENTAGE = 100
DEFAULT_INTEREST_RATE_PERCENT = 100
DEFAULT_PENALTY_THRESHOLD = -100
DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE = 200
DEFAULT_PENALTY_MULTIPLIER_PERCENT = 200
_DYNAMIC_FIELDS: set[str] = {
"amount",
@@ -88,6 +87,7 @@ def _transaction_type_field_constraints(
class Transaction(Base):
__tablename__ = "trx"
__table_args__ = (
*[
_transaction_type_field_constraints(transaction_type, expected_fields)
@@ -131,12 +131,12 @@ class Transaction(Base):
),
name="trx_joint_transaction_id_not_self",
),
# Speed up product count calculation
Index("product_user_time", "product_id", "user_id", "time"),
# Speed up product stock calculation
Index("ix__transaction__product_id_type_time", "product_id", "type", "time"),
# Speed up product owner calculation
Index("user_product_time", "user_id", "product_id", "time"),
Index("ix__transaction__user_id_product_time", "user_id", "product_id", "time"),
# Speed up user transaction list / credit calculation
Index("user_time", "user_id", "time"),
Index("ix__transaction__user_id_time", "user_id", "time"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
@@ -146,7 +146,7 @@ class Transaction(Base):
Not used for anything else than identifying the transaction in the database.
"""
time: Mapped[datetime] = mapped_column(DateTime)
time: Mapped[datetime] = mapped_column(DateTime, index=True)
"""
The time when the transaction took place.
@@ -162,7 +162,7 @@ class Transaction(Base):
This is not used for any calculations, but can be useful for debugging.
"""
type_: Mapped[TransactionType] = mapped_column(TransactionTypeSQL, name="type")
type_: Mapped[TransactionType] = mapped_column(TransactionTypeSQL, name="type", index=True)
"""
Which type of transaction this is.
@@ -189,7 +189,7 @@ class Transaction(Base):
that the user paid in the store would be stored in the `amount` field.
"""
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), index=True)
"""The user who performs the transaction. See `user` for more details."""
user: Mapped[User] = relationship(
lazy="joined",
@@ -207,7 +207,10 @@ class Transaction(Base):
In the case of `JOINT` transactions, this is the user who initiated the joint transaction.
"""
joint_transaction_id: Mapped[int | None] = mapped_column(ForeignKey("transaction.id"))
joint_transaction_id: Mapped[int | None] = mapped_column(
ForeignKey("trx.id"),
index=True,
)
"""
An optional ID to group multiple transactions together as part of a joint transaction.
@@ -223,7 +226,7 @@ class Transaction(Base):
"""
# Receiving user when moving credit from one user to another
transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"))
transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"), index=True)
"""The user who receives money in a `TRANSFER` transaction."""
transfer_user: Mapped[User | None] = relationship(
lazy="joined",
@@ -232,7 +235,7 @@ class Transaction(Base):
"""The user who receives money in a `TRANSFER` transaction."""
# The product that is either being added or bought
product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id"))
product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id"), index=True)
"""The product being added or bought."""
product: Mapped[Product | None] = relationship(lazy="joined")
"""The product being added or bought."""
@@ -330,7 +333,6 @@ class Transaction(Base):
Validates the transaction's fields based on its type.
Raises `ValueError` if the transaction is invalid.
"""
# TODO: do we allow free products?
if self.amount == 0:
raise ValueError("Amount must not be zero.")
@@ -352,7 +354,7 @@ class Transaction(Base):
and self.amount > self.per_product * self.product_count
):
raise ValueError(
"The real amount of the transaction must be less than the total value of the products."
"The real amount of the transaction must be less than the total value of the products.",
)
# TODO: improve printing further
@@ -383,7 +385,7 @@ class Transaction(Base):
isinstance(v, InstrumentedSet),
isinstance(v, InstrumentedDict),
*[k in (_DYNAMIC_FIELDS - EXPECTED_FIELDS[self.type_])],
]
],
)
)
return f"{self.type_.upper()}({columns})"
@@ -400,6 +402,11 @@ class Transaction(Base):
time: datetime | None = None,
message: str | None = None,
) -> Self:
"""
Convenience constructor for creating an `ADJUST_BALANCE` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls(
time=time,
type_=TransactionType.ADJUST_BALANCE,
@@ -416,6 +423,14 @@ class Transaction(Base):
time: datetime | None = None,
message: str | None = None,
) -> Self:
"""
Convenience constructor for creating an `ADJUST_INTEREST` transaction.
Note that the `interest_rate_percent` is absolute, not relative to the previous interest rate.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls(
time=time,
type_=TransactionType.ADJUST_INTEREST,
@@ -433,6 +448,14 @@ class Transaction(Base):
time: datetime | None = None,
message: str | None = None,
) -> Self:
"""
Convenience constructor for creating an `ADJUST_PENALTY` transaction.
Note that both `penalty_multiplier_percent` and `penalty_threshold` are absolute,
not relative to the previous settings.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls(
time=time,
type_=TransactionType.ADJUST_PENALTY,
@@ -451,6 +474,11 @@ class Transaction(Base):
time: datetime | None = None,
message: str | None = None,
) -> Self:
"""
Convenience constructor for creating an `ADJUST_STOCK` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls(
time=time,
type_=TransactionType.ADJUST_STOCK,
@@ -471,6 +499,11 @@ class Transaction(Base):
time: datetime | None = None,
message: str | None = None,
) -> Self:
"""
Convenience constructor for creating an `ADD_PRODUCT` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls(
time=time,
type_=TransactionType.ADD_PRODUCT,
@@ -491,6 +524,11 @@ class Transaction(Base):
time: datetime | None = None,
message: str | None = None,
) -> Self:
"""
Convenience constructor for creating a `BUY_PRODUCT` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls(
time=time,
type_=TransactionType.BUY_PRODUCT,
@@ -509,6 +547,11 @@ class Transaction(Base):
time: datetime | None = None,
message: str | None = None,
) -> Self:
"""
Convenience constructor for creating a `JOINT` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls(
time=time,
type_=TransactionType.JOINT,
@@ -526,6 +569,11 @@ class Transaction(Base):
time: datetime | None = None,
message: str | None = None,
) -> Self:
"""
Convenience constructor for creating a `JOINT_BUY_PRODUCT` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls(
time=time,
type_=TransactionType.JOINT_BUY_PRODUCT,
@@ -543,6 +591,11 @@ class Transaction(Base):
time: datetime | None = None,
message: str | None = None,
) -> Self:
"""
Convenience constructor for creating a `TRANSFER` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls(
time=time,
type_=TransactionType.TRANSFER,
@@ -561,6 +614,11 @@ class Transaction(Base):
time: datetime | None = None,
message: str | None = None,
) -> Self:
"""
Convenience constructor for creating a `THROW_PRODUCT` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls(
time=time,
type_=TransactionType.THROW_PRODUCT,

View File

@@ -1,14 +1,33 @@
from datetime import datetime
from __future__ import annotations
from sqlalchemy import Integer, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from dibbler.models import Base
if TYPE_CHECKING:
from dibbler.models import LastCacheTransaction, User
# More like user balance cash money flow, amirite?
class UserBalanceCache(Base):
user_id: Mapped[int] = mapped_column(Integer, primary_key=True)
class UserCache(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""internal database id"""
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
user: Mapped[User] = relationship(
lazy="joined",
foreign_keys=[user_id],
)
balance: Mapped[int] = mapped_column(Integer)
timestamp: Mapped[datetime] = mapped_column(DateTime)
last_cache_transaction_id: Mapped[int | None] = mapped_column(
ForeignKey("last_cache_transaction.id"), nullable=True,
)
last_cache_transaction: Mapped[LastCacheTransaction | None] = relationship(
lazy="joined",
foreign_keys=[last_cache_transaction_id],
)

View File

@@ -1,13 +1,19 @@
__all__ = [
"Base",
"LastCacheTransaction",
"Product",
"ProductCache",
"Transaction",
"TransactionType",
"User",
"UserCache",
]
from .Base import Base
from .LastCacheTransaction import LastCacheTransaction
from .Product import Product
from .ProductCache import ProductCache
from .Transaction import Transaction
from .TransactionType import TransactionType
from .User import User
from .UserCache import UserCache

View File

@@ -1,8 +1,13 @@
__all__ = [
# "add_product",
# "add_user",
"add_product",
"adjust_balance",
"adjust_interest",
"adjust_penalty",
"adjust_stock",
"affected_products",
"affected_users",
"create_product",
"create_user",
"current_interest",
"current_penalty",
"joint_buy_product",
@@ -11,27 +16,37 @@ __all__ = [
"product_price",
"product_price_log",
"product_stock",
# "products_owned_by_user",
"search_product",
"search_user",
"throw_product",
"transaction_log",
"transfer",
"update_cache",
"user_balance",
"user_balance_log",
"user_products",
]
# from .add_product import add_product
# from .add_user import add_user
from .add_product import add_product
from .adjust_balance import adjust_balance
from .adjust_interest import adjust_interest
from .adjust_penalty import adjust_penalty
from .adjust_stock import adjust_stock
from .affected_products import affected_products
from .affected_users import affected_users
from .create_product import create_product
from .create_user import create_user
from .current_interest import current_interest
from .current_penalty import current_penalty
from .joint_buy_product import joint_buy_product
from .product_owners import product_owners, product_owners_log
from .product_price import product_price, product_price_log
from .product_stock import product_stock
# from .products_owned_by_user import products_owned_by_user
from .search_product import search_product
from .search_user import search_user
from .throw_product import throw_product
from .transaction_log import transaction_log
from .transfer import transfer
from .update_cache import update_cache
from .user_balance import user_balance, user_balance_log
from .user_products import user_products

View File

@@ -1 +1,51 @@
# TODO: implement me
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def add_product(
sql_session: Session,
user: User,
product: Product,
amount: int,
per_product: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if amount <= 0:
raise ValueError("Amount must be positive.")
if per_product <= 0:
raise ValueError("Per product price must be positive.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
if per_product * product_count < amount:
raise ValueError("Total per product price must be at least equal to amount.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=amount,
per_product=per_product,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -1 +0,0 @@
# TODO: implement me

View File

@@ -0,0 +1,33 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
def adjust_balance(
sql_session: Session,
user: User,
amount: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if amount == 0:
raise ValueError("Amount must be non-zero.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_balance(
user_id=user.id,
amount=amount,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
@@ -10,19 +12,25 @@ def adjust_interest(
sql_session: Session,
user: User,
new_interest: int,
time: datetime | None = None,
message: str | None = None,
) -> None:
) -> Transaction:
if new_interest < 0:
raise ValueError("Interest rate cannot be negative")
if user.id is None:
raise ValueError("User must be persisted in the database.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_interest(
user_id=user.id,
interest_rate_percent=new_interest,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
@@ -12,8 +14,9 @@ def adjust_penalty(
user: User,
new_penalty: int | None = None,
new_penalty_multiplier: int | None = None,
time: datetime | None = None,
message: str | None = None,
) -> None:
) -> Transaction:
if new_penalty is None and new_penalty_multiplier is None:
raise ValueError("At least one of new_penalty or new_penalty_multiplier must be provided")
@@ -30,12 +33,17 @@ def adjust_penalty(
if new_penalty_multiplier is None:
new_penalty_multiplier = existing_penalty_multiplier
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_penalty(
user_id=user.id,
penalty_threshold=new_penalty,
penalty_multiplier_percent=new_penalty_multiplier,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,40 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def adjust_stock(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count == 0:
raise ValueError("Product count must be non-zero.")
# TODO: it should not be possible to reduce stock below zero.
#
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,88 @@
from datetime import datetime
from sqlalchemy import BindParameter, select
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, TransactionType
from dibbler.queries.query_helpers import after_filter, until_filter
def affected_products(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
after_time: BindParameter[datetime] | datetime | None = None,
after_transaction: Transaction | None = None,
after_inclusive: bool = True,
) -> set[Product]:
"""
Get all products where attributes were affected over a given interval.
"""
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if not (after_time is None or after_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(after_time, datetime):
after_time = BindParameter("after_time", value=after_time)
if isinstance(after_transaction, Transaction):
if after_transaction.id is None:
raise ValueError("after_transaction must be persisted in the database.")
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
else:
after_transaction_id = None
if after_time is not None and until_time is not None:
assert isinstance(after_time.value, datetime)
assert isinstance(until_time.value, datetime)
if after_time.value > until_time.value:
raise ValueError("after_time cannot be after until_time.")
if after_transaction is not None and until_transaction is not None:
assert after_transaction.time is not None
assert until_transaction.time is not None
if after_transaction.time > until_transaction.time:
raise ValueError("after_transaction cannot be after until_transaction.")
result = sql_session.scalars(
select(Product)
.distinct()
.join(Transaction, Product.id == Transaction.product_id)
.where(
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.JOINT.as_literal_column(),
TransactionType.THROW_PRODUCT.as_literal_column(),
],
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
after_filter(
after_time=after_time,
after_transaction_id=after_transaction_id,
after_inclusive=after_inclusive,
),
)
.order_by(Transaction.time.desc()),
).all()
return set(result)

View File

@@ -0,0 +1,90 @@
from datetime import datetime
from sqlalchemy import BindParameter, or_, select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType, User
from dibbler.queries.query_helpers import after_filter, until_filter
def affected_users(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
after_time: BindParameter[datetime] | datetime | None = None,
after_transaction: Transaction | None = None,
after_inclusive: bool = True,
) -> set[User]:
"""
Get all users where attributes were affected over a given interval.
"""
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if not (after_time is None or after_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(after_time, datetime):
after_time = BindParameter("after_time", value=after_time)
if isinstance(after_transaction, Transaction):
if after_transaction.id is None:
raise ValueError("after_transaction must be persisted in the database.")
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
else:
after_transaction_id = None
if after_time is not None and until_time is not None:
assert isinstance(after_time.value, datetime)
assert isinstance(until_time.value, datetime)
if after_time.value > until_time.value:
raise ValueError("after_time cannot be after until_time.")
if after_transaction is not None and until_transaction is not None:
assert after_transaction.time is not None
assert until_transaction.time is not None
if after_transaction.time > until_transaction.time:
raise ValueError("after_transaction cannot be after until_transaction.")
result = sql_session.scalars(
select(User)
.distinct()
.join(
Transaction,
or_(User.id == Transaction.user_id, User.id == Transaction.transfer_user_id),
)
.where(
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_BALANCE.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.TRANSFER.as_literal_column(),
],
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
after_filter(
after_time=after_time,
after_transaction_id=after_transaction_id,
after_inclusive=after_inclusive,
),
)
.order_by(Transaction.time.desc()),
).all()
return set(result)

View File

@@ -0,0 +1,38 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def buy_product(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.buy_product(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,25 @@
from sqlalchemy.orm import Session
from dibbler.models import Product
def create_product(
sql_session: Session,
name: str,
barcode: str,
) -> Product:
if not name:
raise ValueError("Name cannot be empty.")
if not barcode:
raise ValueError("Barcode cannot be empty.")
# TODO: check for duplicate names, barcodes
# TODO: add more validation for barcode
product = Product(barcode, name)
sql_session.add(product)
sql_session.commit()
return product

View File

@@ -0,0 +1,21 @@
from sqlalchemy.orm import Session
from dibbler.models import User
def create_user(
sql_session: Session,
name: str,
card: str | None,
rfid: str | None,
) -> User:
if not name:
raise ValueError("Name cannot be empty.")
# TODO: check for duplicate names, cards, rfids
user = User(name=name, card=card, rfid=rfid)
sql_session.add(user)
sql_session.commit()
return user

View File

@@ -1,24 +1,55 @@
from sqlalchemy import select
from datetime import datetime
from sqlalchemy import BindParameter, bindparam, select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT
from dibbler.queries.query_helpers import until_filter
# TODO: add until transaction parameter
# TODO: add until datetime parameter
def current_interest(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
) -> int:
"""
Get the current interest rate percentage as of a given time or transaction.
Returns the interest rate percentage as an integer.
"""
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
def current_interest(sql_session: Session) -> int:
result = sql_session.scalars(
select(Transaction)
.where(Transaction.type_ == TransactionType.ADJUST_INTEREST)
.where(
Transaction.type_ == TransactionType.ADJUST_INTEREST,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.limit(1)
.limit(1),
).one_or_none()
if result is None:
return DEFAULT_INTEREST_RATE_PERCENTAGE
return DEFAULT_INTEREST_RATE_PERCENT
elif result.interest_rate_percent is None:
return DEFAULT_INTEREST_RATE_PERCENTAGE
return DEFAULT_INTEREST_RATE_PERCENT
else:
return result.interest_rate_percent

View File

@@ -1,26 +1,57 @@
from sqlalchemy import select
from datetime import datetime
from sqlalchemy import BindParameter, bindparam, select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import (
DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE,
DEFAULT_PENALTY_MULTIPLIER_PERCENT,
DEFAULT_PENALTY_THRESHOLD,
)
from dibbler.queries.query_helpers import until_filter
# TODO: add until transaction parameter
# TODO: add until datetime parameter
def current_penalty(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
) -> tuple[int, int]:
"""
Get the current penalty settings (threshold and multiplier percentage) as of a given time or transaction.
Returns a tuple of `(penalty_threshold, penalty_multiplier_percentage)`.
"""
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
def current_penalty(sql_session: Session) -> tuple[int, int]:
result = sql_session.scalars(
select(Transaction)
.where(Transaction.type_ == TransactionType.ADJUST_PENALTY)
.where(
Transaction.type_ == TransactionType.ADJUST_PENALTY,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.limit(1)
.limit(1),
).one_or_none()
if result is None:
return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE
return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENT
assert result.penalty_threshold is not None, "Penalty threshold must be set"
assert result.penalty_multiplier_percent is not None, "Penalty multiplier percent must be set"

View File

@@ -17,7 +17,7 @@ def joint_buy_product(
users: list[User],
time: datetime | None = None,
message: str | None = None,
) -> None:
) -> list[Transaction]:
"""
Create buy product transactions for multiple users at once.
"""
@@ -25,15 +25,23 @@ def joint_buy_product(
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if instigator not in users:
raise ValueError("Instigator must be in the list of users buying the product.")
if instigator.id is None:
raise ValueError("Instigator must be persisted in the database.")
if len(users) == 0:
raise ValueError("At least bying one user must be specified.")
if any(user.id is None for user in users):
raise ValueError("All users must be persisted in the database.")
if instigator not in users:
raise ValueError("Instigator must be in the list of users buying the product.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
joint_transaction = Transaction.joint(
user_id=instigator.id,
product_id=product.id,
@@ -44,6 +52,8 @@ def joint_buy_product(
sql_session.add(joint_transaction)
sql_session.flush() # Ensure joint_transaction gets an ID
transactions = [joint_transaction]
for user in users:
buy_transaction = Transaction.joint_buy_product(
user_id=user.id,
@@ -52,5 +62,7 @@ def joint_buy_product(
message=message,
)
sql_session.add(buy_transaction)
transactions.append(buy_transaction)
sql_session.commit()
return transactions

View File

@@ -8,11 +8,11 @@ from sqlalchemy import (
bindparam,
case,
func,
or_,
select,
)
from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO
from dibbler.models import (
Product,
Transaction,
@@ -20,13 +20,22 @@ from dibbler.models import (
User,
)
from dibbler.queries.product_stock import _product_stock_query
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
until_filter,
)
def _product_owners_query(
product_id: BindParameter[int] | int,
use_cache: bool = True,
until: BindParameter[datetime] | datetime | None = None,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
) -> CTE:
"""
The inner query for inferring the owners of a given product.
@@ -38,13 +47,25 @@ def _product_owners_query(
if isinstance(product_id, int):
product_id = bindparam("product_id", value=product_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until)
if until_time is not None and until_transaction is not None:
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = bindparam("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
product_stock = _product_stock_query(
product_id=product_id,
use_cache=use_cache,
until=until,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
# Subset of transactions that we'll want to iterate over.
@@ -57,22 +78,23 @@ def _product_owners_query(
Transaction.user_id,
Transaction.product_count,
)
# TODO: maybe add value constraint on ADJUST_STOCK?
.where(
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
# TransactionType.BUY_PRODUCT,
TransactionType.ADJUST_STOCK.as_literal_column(),
# TransactionType.JOINT,
# TransactionType.THROW_PRODUCT,
]
or_(
Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
and_(
Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
Transaction.product_count > CONST_ZERO,
),
),
Transaction.product_id == product_id,
CONST_TRUE if until is None else Transaction.time <= until,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.subquery()
.subquery(trx_subset_name)
)
initial_element = select(
@@ -117,35 +139,19 @@ def _product_owners_query(
).label("product_count"),
# How many products left to account for
case(
# Someone adds the product -> increase the number of products left to account for
# Someone adds the product -> known owner, decrease the number of products left to account for
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
),
# Someone buys/joins/throws the product -> decrease the number of products left to account for
# (
# trx_subset.c.type_.in_(
# [
# TransactionType.BUY_PRODUCT,
# TransactionType.JOINT,
# TransactionType.THROW_PRODUCT,
# ]
# ),
# recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
# ),
# Someone adjusts the stock ->
# If adjusted upwards -> products owned by nobody, decrease products left to account for
# If adjusted downwards -> products taken away from owners, decrease products left to account for
# Stock got adjusted upwards -> none owner, decrease the number of products left to account for
(
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
and (trx_subset.c.product_count > CONST_ZERO),
and_(
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
trx_subset.c.product_count > CONST_ZERO,
),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
),
# (
# (trx_subset.c.type_ == TransactionType.ADJUST_STOCK)
# and (trx_subset.c.product_count < 0),
# recursive_cte.c.products_left_to_account_for + trx_subset.c.product_count,
# ),
else_=recursive_cte.c.products_left_to_account_for,
).label("products_left_to_account_for"),
)
@@ -153,8 +159,9 @@ def _product_owners_query(
.where(
and_(
trx_subset.c.i == recursive_cte.c.i + CONST_ONE,
# Base case: stop if we've accounted for all products
recursive_cte.c.products_left_to_account_for > CONST_ZERO,
)
),
)
)
@@ -167,13 +174,14 @@ class ProductOwnersLogEntry:
user: User | None
products_left_to_account_for: int
# TODO: add until datetime parameter
def product_owners_log(
sql_session: Session,
product: Product,
use_cache: bool = True,
until: Transaction | None = None,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[ProductOwnersLogEntry]:
"""
Returns a log of the product ownership calculation for the given product.
@@ -184,13 +192,12 @@ def product_owners_log(
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if until is not None and until.id is None:
raise ValueError("'until' transaction must be persisted in the database.")
recursive_cte = _product_owners_query(
product_id=product.id,
use_cache=use_cache,
until=until.time if until else None,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.execute(
@@ -209,13 +216,13 @@ def product_owners_log(
onclause=User.id == recursive_cte.c.user_id,
isouter=True,
)
.order_by(recursive_cte.c.time.desc())
.order_by(recursive_cte.c.time.desc()),
).all()
if result is None:
# If there are no transactions for this product, the query should return an empty list, not None.
raise RuntimeError(
f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id})."
f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id}).",
)
return [
@@ -228,13 +235,13 @@ def product_owners_log(
]
# TODO: add until transaction parameter
def product_owners(
sql_session: Session,
product: Product,
use_cache: bool = True,
until: datetime | None = None,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[User | None]:
"""
Returns an ordered list of users owning the given product.
@@ -248,7 +255,9 @@ def product_owners(
recursive_cte = _product_owners_query(
product_id=product.id,
use_cache=use_cache,
until=until,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
db_result = sql_session.execute(
@@ -258,7 +267,7 @@ def product_owners(
User,
)
.join(User, User.id == recursive_cte.c.user_id, isouter=True)
.order_by(recursive_cte.c.time.desc())
.order_by(recursive_cte.c.time.desc()),
).all()
print(db_result)
@@ -292,7 +301,7 @@ def product_owners(
result.extend([None] * none_count)
# # NOTE: if the last line exeeds the product count, we need to truncate it
# # NOTE: if the last line exceeds the product count, we need to truncate it
# result.extend([user] * min(user_count, products_left_to_account_for))
# redistribute the user counts to a list of users

View File

@@ -6,7 +6,7 @@ from sqlalchemy import (
BindParameter,
ColumnElement,
Integer,
asc,
bindparam,
case,
cast,
func,
@@ -14,52 +14,111 @@ from sqlalchemy import (
)
from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO
from dibbler.models import (
LastCacheTransaction,
Product,
ProductCache,
Transaction,
TransactionType,
)
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
after_filter,
until_filter,
)
def _product_price_query(
product_id: int | ColumnElement[int],
use_cache: bool = True,
until: BindParameter[datetime] | datetime | None = None,
until_including: BindParameter[bool] | bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
):
"""
The inner query for calculating the product price.
"""
if use_cache:
print("WARNING: Using cache for product price query is not implemented yet.")
if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until)
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_including, bool):
until_including = BindParameter("until_including", value=until_including)
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
initial_element = select(
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("price"),
CONST_ZERO.label("product_count"),
)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if use_cache:
initial_element_fields = (
select(
Transaction.time.label("time"),
Transaction.id.label("transaction_id"),
ProductCache.price.label("price"),
ProductCache.stock.label("product_count"),
)
.select_from(ProductCache)
.join(
LastCacheTransaction,
ProductCache.last_cache_transaction_id == LastCacheTransaction.id,
)
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
.where(
ProductCache.product_id == product_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.union(
select(
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("price"),
CONST_ZERO.label("product_count"),
),
)
.order_by(Transaction.time.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
.subquery()
.alias("initial_element_fields")
)
initial_element = select(
CONST_ZERO.label("i"),
initial_element_fields.c.time,
initial_element_fields.c.transaction_id,
initial_element_fields.c.price,
initial_element_fields.c.product_count,
).select_from(initial_element_fields)
else:
initial_element = select(
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("price"),
CONST_ZERO.label("product_count"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=asc(Transaction.time)).label("i"),
func.row_number().over(order_by=Transaction.time.asc()).label("i"),
Transaction.id,
Transaction.time,
Transaction.type_,
@@ -73,18 +132,22 @@ def _product_price_query(
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.JOINT.as_literal_column(),
]
],
),
Transaction.product_id == product_id,
case(
(until_including, Transaction.time <= until),
else_=Transaction.time < until,
)
if until is not None
else CONST_TRUE,
after_filter(
after_time=None,
after_transaction_id=recursive_cte.c.transaction_id,
after_inclusive=False,
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.asc())
.alias("trx_subset")
.subquery(trx_subset_name)
)
recursive_elements = (
@@ -115,7 +178,7 @@ def _product_price_query(
# and other disastrous phenomena.
func.max(recursive_cte.c.product_count, CONST_ZERO)
+ trx_subset.c.product_count
)
),
),
Integer,
),
@@ -170,13 +233,13 @@ class ProductPriceLogEntry:
product_count: int
# TODO: add until datetime parameter
def product_price_log(
sql_session: Session,
product: Product,
use_cache: bool = True,
until: Transaction | None = None,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[ProductPriceLogEntry]:
"""
Calculates the price of a product and returns a log of the price changes.
@@ -185,13 +248,12 @@ def product_price_log(
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if until is not None and until.id is None:
raise ValueError("'until' transaction must be persisted in the database.")
recursive_cte = _product_price_query(
product.id,
use_cache=use_cache,
until=until.time if until else None,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.execute(
@@ -205,13 +267,13 @@ def product_price_log(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.order_by(recursive_cte.c.i.asc())
.order_by(recursive_cte.c.i.asc()),
).all()
if result is None:
# If there are no transactions for this product, the query should return an empty list, not None.
raise RuntimeError(
f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})."
f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id}).",
)
return [
@@ -224,13 +286,13 @@ def product_price_log(
]
# TODO: add until datetime parameter
def product_price(
sql_session: Session,
product: Product,
use_cache: bool = True,
until: Transaction | None = None,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
include_interest: bool = False,
) -> int:
"""
@@ -240,13 +302,22 @@ def product_price(
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if until is not None and until.id is None:
raise ValueError("'until' transaction must be persisted in the database.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
recursive_cte = _product_price_query(
product.id,
use_cache=use_cache,
until=until.time if until else None,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
# TODO: optionally verify subresults:
@@ -257,13 +328,13 @@ def product_price(
select(recursive_cte.c.price)
.order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
.offset(CONST_ZERO),
).one_or_none()
if result is None:
# If there are no transactions for this product, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})."
f"Something went wrong while calculating the price for product {product.name} (ID: {product.id}).",
)
if include_interest:
@@ -272,12 +343,16 @@ def product_price(
select(Transaction.interest_rate_percent)
.where(
Transaction.type_ == TransactionType.ADJUST_INTEREST,
CONST_TRUE if until is None else Transaction.time <= until.time,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc())
.limit(CONST_ONE)
.limit(CONST_ONE),
)
or DEFAULT_INTEREST_RATE_PERCENTAGE
or DEFAULT_INTEREST_RATE_PERCENT
)
result = math.ceil(result * interest_rate / 100)

View File

@@ -1,27 +1,31 @@
from datetime import datetime
from typing import Tuple
from sqlalchemy import (
BindParameter,
Select,
bindparam,
case,
func,
select,
)
from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_TRUE
from dibbler.models import (
Product,
Transaction,
TransactionType,
)
from dibbler.queries.query_helpers import until_filter
def _product_stock_query(
product_id: BindParameter[int] | int,
use_cache: bool = True,
until: BindParameter[datetime] | datetime | None = None,
) -> Select:
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> Select[tuple[int]]:
"""
The inner query for calculating the product stock.
"""
@@ -32,8 +36,18 @@ def _product_stock_query(
if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until)
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
query = select(
func.sum(
@@ -59,8 +73,8 @@ def _product_stock_query(
-Transaction.product_count,
),
else_=0,
)
)
),
).label("stock"),
).where(
Transaction.type_.in_(
[
@@ -69,22 +83,26 @@ def _product_stock_query(
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.JOINT.as_literal_column(),
TransactionType.THROW_PRODUCT.as_literal_column(),
]
],
),
Transaction.product_id == product_id,
Transaction.time <= until if until is not None else CONST_TRUE,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
return query
# TODO: add until transaction parameter
def product_stock(
sql_session: Session,
product: Product,
use_cache: bool = True,
until: datetime | None = None,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> int:
"""
Returns the number of products in stock.
@@ -98,7 +116,9 @@ def product_stock(
query = _product_stock_query(
product_id=product.id,
use_cache=use_cache,
until=until,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.scalars(query).one_or_none()

View File

@@ -0,0 +1,119 @@
from datetime import datetime
from typing import TypeVar
from sqlalchemy import (
BindParameter,
ColumnExpressionArgument,
literal,
select,
)
from sqlalchemy.orm import QueryableAttribute
from dibbler.models import Transaction
T = TypeVar("T")
def const(value: T) -> BindParameter[T]:
"""
Create a constant SQL literal bind parameter.
This is useful to avoid too many `?` bind parameters in SQL queries,
when the input value is known to be safe.
"""
return literal(value, literal_execute=True)
CONST_ZERO: BindParameter[int] = const(0)
"""A constant SQL expression `0`. This will render as a literal `0` in SQL queries."""
CONST_ONE: BindParameter[int] = const(1)
"""A constant SQL expression `1`. This will render as a literal `1` in SQL queries."""
CONST_TRUE: BindParameter[bool] = const(True)
"""A constant SQL expression `TRUE`. This will render as a literal `TRUE` in SQL queries."""
CONST_FALSE: BindParameter[bool] = const(False)
"""A constant SQL expression `FALSE`. This will render as a literal `FALSE` in SQL queries."""
CONST_NONE: BindParameter[None] = const(None)
"""A constant SQL expression `NULL`. This will render as a literal `NULL` in SQL queries."""
def until_filter(
until_time: BindParameter[datetime] | None = None,
until_transaction_id: BindParameter[int] | None = None,
until_inclusive: bool = True,
transaction_time: QueryableAttribute = Transaction.time,
) -> ColumnExpressionArgument[bool]:
"""
Create a filter condition for transactions up to a given time or transaction.
Only one of `until_time` or `until_transaction_id` may be specified.
"""
assert not (until_time is not None and until_transaction_id is not None), (
"Cannot filter by both until_time and until_transaction_id."
)
match (until_time, until_transaction_id, until_inclusive):
case (BindParameter(), None, True):
return transaction_time <= until_time
case (BindParameter(), None, False):
return transaction_time < until_time
case (None, BindParameter(), True):
return (
transaction_time
<= select(Transaction.time)
.where(Transaction.id == until_transaction_id)
.scalar_subquery()
)
case (None, BindParameter(), False):
return (
transaction_time
< select(Transaction.time)
.where(Transaction.id == until_transaction_id)
.scalar_subquery()
)
return CONST_TRUE
def after_filter(
after_time: BindParameter[datetime] | None = None,
after_transaction_id: BindParameter[int] | None = None,
after_inclusive: bool = True,
transaction_time: QueryableAttribute = Transaction.time,
) -> ColumnExpressionArgument[bool]:
"""
Create a filter condition for transactions after a given time or transaction.
Only one of `after_time` or `after_transaction_id` may be specified.
"""
assert not (after_time is not None and after_transaction_id is not None), (
"Cannot filter by both after_time and after_transaction_id."
)
match (after_time, after_transaction_id, after_inclusive):
case (BindParameter(), None, True):
return transaction_time >= after_time
case (BindParameter(), None, False):
return transaction_time > after_time
case (None, BindParameter(), True):
return (
transaction_time
>= select(Transaction.time)
.where(Transaction.id == after_transaction_id)
.scalar_subquery()
)
case (None, BindParameter(), False):
return (
transaction_time
> select(Transaction.time)
.where(Transaction.id == after_transaction_id)
.scalar_subquery()
)
return CONST_TRUE

View File

@@ -9,6 +9,9 @@ def search_product(
sql_session: Session,
find_hidden_products=False,
) -> Product | list[Product]:
if not string:
raise ValueError("Search string cannot be empty.")
exact_match = sql_session.scalars(
select(Product).where(
or_(
@@ -17,8 +20,8 @@ def search_product(
Product.name == string,
literal(True) if find_hidden_products else not_(Product.hidden),
),
)
)
),
),
).first()
if exact_match:
@@ -32,8 +35,8 @@ def search_product(
Product.name.ilike(f"%{string}%"),
literal(True) if find_hidden_products else not_(Product.hidden),
),
)
)
),
),
).all()
return list(product_list)

View File

@@ -8,6 +8,9 @@ def search_user(
string: str,
sql_session: Session,
) -> User | list[User]:
if not string:
raise ValueError("Search string cannot be empty.")
string = string.lower()
exact_match = sql_session.scalars(
@@ -16,8 +19,8 @@ def search_user(
User.name == string,
User.card == string,
User.rfid == string,
)
)
),
),
).first()
if exact_match:
@@ -29,8 +32,8 @@ def search_user(
User.name.ilike(f"%{string}%"),
User.card.ilike(f"%{string}%"),
User.rfid.ilike(f"%{string}%"),
)
)
),
),
).all()
return list(user_list)

View File

@@ -0,0 +1,42 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def throw_product(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
raise NotImplementedError(
"Please don't use this function until relevant calculations have been added to user_balance.",
)
transaction = Transaction.throw_product(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -1,4 +1,6 @@
from sqlalchemy import select
from datetime import datetime
from sqlalchemy import BindParameter, select
from sqlalchemy.orm import Session
from dibbler.models import (
@@ -15,12 +17,12 @@ def transaction_log(
sql_session: Session,
user: User | None = None,
product: Product | None = None,
exclusive_after: bool = False,
after_time=None,
after_transaction_id: int | None = None,
exclusive_before: bool = False,
before_time=None,
before_transaction_id: int | None = None,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
after_time: BindParameter[datetime] | datetime | None = None,
after_transaction: Transaction | None = None,
after_inclusive: bool = True,
transaction_type: list[TransactionType] | None = None,
negate_transaction_type_filter: bool = False,
limit: int | None = None,
@@ -29,51 +31,101 @@ def transaction_log(
Retrieve the transaction log, optionally filtered.
Only one of `user` or `product` may be specified.
Only one of `until_time` or `until_transaction_id` may be specified.
Only one of `after_time` or `after_transaction_id` may be specified.
Only one of `before_time` or `before_transaction_id` may be specified.
The before and after filters are inclusive by default.
The after and after filters are inclusive by default.
"""
if not (user is None or product is None):
raise ValueError("Cannot filter by both user and product.")
if user is not None and user.id is None:
raise ValueError("User must be persisted in the database.")
if isinstance(user, User):
if user.id is None:
raise ValueError("User must be persisted in the database.")
user_id = BindParameter("user_id", value=user.id)
else:
user_id = None
if product is not None and product.id is None:
raise ValueError("Product must be persisted in the database.")
if isinstance(product, Product):
if product.id is None:
raise ValueError("Product must be persisted in the database.")
product_id = BindParameter("product_id", value=product.id)
else:
product_id = None
if not (after_time is None or after_transaction_id is None):
raise ValueError("Cannot filter by both from_time and from_transaction_id.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if not (after_time is None or after_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(after_time, datetime):
after_time = BindParameter("after_time", value=after_time)
if isinstance(after_transaction, Transaction):
if after_transaction.id is None:
raise ValueError("after_transaction must be persisted in the database.")
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
else:
after_transaction_id = None
if after_time is not None and until_time is not None:
assert isinstance(after_time.value, datetime)
assert isinstance(until_time.value, datetime)
if after_time.value > until_time.value:
raise ValueError("after_time cannot be after until_time.")
if after_transaction is not None and until_transaction is not None:
assert after_transaction.time is not None
assert until_transaction.time is not None
if after_transaction.time > until_transaction.time:
raise ValueError("after_transaction cannot be after until_transaction.")
if limit is not None and limit <= 0:
raise ValueError("Limit must be positive.")
query = select(Transaction)
if user is not None:
query = query.where(Transaction.user_id == user.id)
query = query.where(Transaction.user_id == user_id)
if product is not None:
query = query.where(Transaction.product_id == product.id)
query = query.where(Transaction.product_id == product_id)
if after_time is not None:
if exclusive_after:
query = query.where(Transaction.time > after_time)
else:
match (until_time, until_transaction_id, until_inclusive):
case (BindParameter(), None, True):
query = query.where(Transaction.time <= until_time)
case (BindParameter(), None, False):
query = query.where(Transaction.time < until_time)
case (None, BindParameter(), True):
query = query.where(Transaction.id <= until_transaction_id)
case (None, BindParameter(), False):
query = query.where(Transaction.id < until_transaction_id)
case _:
pass
match (after_time, after_transaction_id, after_inclusive):
case (BindParameter(), None, True):
query = query.where(Transaction.time >= after_time)
if after_transaction_id is not None:
if exclusive_after:
query = query.where(Transaction.id > after_transaction_id)
else:
case (BindParameter(), None, False):
query = query.where(Transaction.time > after_time)
case (None, BindParameter(), True):
query = query.where(Transaction.id >= after_transaction_id)
if before_time is not None:
if exclusive_before:
query = query.where(Transaction.time < before_time)
else:
query = query.where(Transaction.time <= before_time)
if before_transaction_id is not None:
if exclusive_before:
query = query.where(Transaction.id < before_transaction_id)
else:
query = query.where(Transaction.id <= before_transaction_id)
case (None, BindParameter(), False):
query = query.where(Transaction.id > after_transaction_id)
case _:
pass
if transaction_type is not None:
if negate_transaction_type_filter:

View File

@@ -0,0 +1,38 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
def transfer(
sql_session: Session,
from_user: User,
to_user: User,
amount: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if from_user.id is None:
raise ValueError("From user must be persisted in the database.")
if to_user.id is None:
raise ValueError("To user must be persisted in the database.")
if amount <= 0:
raise ValueError("Amount must be positive.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.transfer(
user_id=from_user.id,
transfer_user_id=to_user.id,
amount=amount,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction

View File

@@ -0,0 +1,118 @@
from sqlalchemy import insert, select
from sqlalchemy.orm import Session
from dibbler.models import LastCacheTransaction, ProductCache, Transaction, UserCache
from dibbler.queries.affected_products import affected_products
from dibbler.queries.affected_users import affected_users
from dibbler.queries.product_price import product_price
from dibbler.queries.product_stock import product_stock
from dibbler.queries.user_balance import user_balance
def update_cache(
sql_session: Session,
use_cache: bool = True,
) -> None:
"""
Update the cache used for searching products.
If `use_cache` is False, the cache will be rebuilt from scratch.
"""
last_transaction = sql_session.scalars(
select(Transaction).order_by(Transaction.time.desc()).limit(1),
).one_or_none()
print(last_transaction)
if last_transaction is None:
# No transactions exist, nothing to update
return
if use_cache:
last_cache_transaction = sql_session.scalars(
select(LastCacheTransaction)
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
.order_by(Transaction.time.desc())
.limit(1),
).one_or_none()
if last_cache_transaction is not None:
last_cache_transaction = last_cache_transaction.transaction
else:
last_cache_transaction = None
if last_cache_transaction is not None and last_cache_transaction.id == last_transaction.id:
# Cache is already up to date
return
users = affected_users(
sql_session,
after_transaction=last_cache_transaction,
after_inclusive=False,
until_transaction=last_transaction,
)
products = affected_products(
sql_session,
after_transaction=last_cache_transaction,
after_inclusive=False,
until_transaction=last_transaction,
)
user_balances = {}
for user in users:
x = user_balance(
sql_session,
user,
use_cache=use_cache,
until_transaction=last_transaction,
)
user_balances[user.id] = x
product_stocks = {}
product_prices = {}
for product in products:
product_stocks[product.id] = product_stock(
sql_session,
product,
use_cache=use_cache,
until_transaction=last_transaction,
)
product_prices[product.id] = product_price(
sql_session,
product,
use_cache=use_cache,
until_transaction=last_transaction,
)
next_cache_transaction = LastCacheTransaction(transaction_id=last_transaction.id)
sql_session.add(next_cache_transaction)
sql_session.flush()
if not len(users) == 0:
sql_session.execute(
insert(UserCache),
[
{
"user_id": user.id,
"balance": user_balances[user.id],
"last_cache_transaction_id": next_cache_transaction.id,
}
for user in users
],
)
if not len(products) == 0:
sql_session.execute(
insert(ProductCache),
[
{
"product_id": product.id,
"stock": product_stocks[product.id],
"price": product_prices[product.id],
"last_cache_transaction_id": next_cache_transaction.id,
}
for product in products
],
)
sql_session.commit()

View File

@@ -1,12 +1,15 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Tuple
from sqlalchemy import (
CTE,
BindParameter,
Float,
Integer,
Select,
and_,
bindparam,
case,
cast,
column,
@@ -14,28 +17,243 @@ from sqlalchemy import (
or_,
select,
)
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, aliased
from sqlalchemy.sql.elements import KeyedColumnElement
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
from dibbler.models import (
Transaction,
TransactionType,
User,
)
from dibbler.models.Transaction import (
DEFAULT_INTEREST_RATE_PERCENTAGE,
DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE,
DEFAULT_INTEREST_RATE_PERCENT,
DEFAULT_PENALTY_MULTIPLIER_PERCENT,
DEFAULT_PENALTY_THRESHOLD,
)
from dibbler.queries.product_price import _product_price_query
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
const,
until_filter,
)
def _joint_transaction_query(
user_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> Select[tuple[int, int, int]]:
"""
The inner query for getting joint transactions relevant to a user.
This scans for JOINT_BUY_PRODUCT transactions made by the user,
then finds the corresponding JOINT transactions, and counts how many "shares"
of the joint transaction the user has, as well as the total number of shares.
"""
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
# First, select all joint buy product transactions for the given user
# sub_joint_transaction = aliased(Transaction, name="right_trx")
sub_joint_transaction = (
select(Transaction.joint_transaction_id.distinct().label("joint_transaction_id"))
.where(
Transaction.type_ == TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
Transaction.user_id == user_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
transaction_time=Transaction.time,
),
)
.subquery("sub_joint_transaction")
)
# Join those with their main joint transaction
# (just use Transaction)
# Then, count how many users are involved in each joint transaction
joint_transaction_count = aliased(Transaction, name="count_trx")
joint_transaction = (
select(
Transaction.id,
# Shares the user has in the transaction,
func.sum(
case(
(joint_transaction_count.user_id == user_id, CONST_ONE),
else_=CONST_ZERO,
),
).label("user_shares"),
# The total number of shares in the transaction,
func.count(joint_transaction_count.id).label("user_count"),
)
.select_from(sub_joint_transaction)
.join(
Transaction,
onclause=Transaction.id == sub_joint_transaction.c.joint_transaction_id,
)
.join(
joint_transaction_count,
onclause=joint_transaction_count.joint_transaction_id == Transaction.id,
)
.group_by(joint_transaction_count.joint_transaction_id)
)
return joint_transaction
def _non_joint_transaction_query(
user_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> Select[tuple[int, None, None]]:
"""
The inner query for getting non-joint transactions relevant to a user.
"""
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
query = select(
Transaction.id,
CONST_NONE.label("user_shares"),
CONST_NONE.label("user_count"),
).where(
or_(
and_(
Transaction.user_id == user_id,
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_BALANCE.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.TRANSFER.as_literal_column(),
],
),
),
and_(
Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
Transaction.transfer_user_id == user_id,
),
Transaction.type_.in_(
[
TransactionType.THROW_PRODUCT.as_literal_column(),
TransactionType.ADJUST_INTEREST.as_literal_column(),
TransactionType.ADJUST_PENALTY.as_literal_column(),
],
),
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
return query
def _product_cost_expression(
product_count_column: KeyedColumnElement[int],
product_id_column: KeyedColumnElement[int],
interest_rate_percent_column: KeyedColumnElement[int],
user_balance_column: KeyedColumnElement[int],
penalty_threshold_column: KeyedColumnElement[int],
penalty_multiplier_percent_column: KeyedColumnElement[int],
joint_user_shares_column: KeyedColumnElement[int],
joint_user_count_column: KeyedColumnElement[int],
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "product_price_cte",
trx_subset_name: str = "product_price_trx_subset",
):
# TODO: This can get quite expensive real quick, so we should do some caching of the
# product prices somehow.
expression = (
select(
cast(
func.ceil(
# Base price
(
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
)
# Interest
+ (
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
* cast(interest_rate_percent_column - const(100), Float)
/ const(100.0)
)
# Penalty
+ (
(
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
)
* cast(penalty_multiplier_percent_column - const(100), Float)
/ const(100.0)
* cast(user_balance_column < penalty_threshold_column, Integer)
),
),
Integer,
),
)
.select_from(
_product_price_query(
product_id_column,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
cte_name=cte_name,
trx_subset_name=trx_subset_name,
),
)
.order_by(column("i").desc())
.limit(CONST_ONE)
.scalar_subquery()
)
return expression
def _user_balance_query(
user_id: BindParameter[int] | int,
use_cache: bool = True,
until: BindParameter[datetime] | BindParameter[None] | datetime | None = None,
until_including: BindParameter[bool] | bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
) -> CTE:
"""
The inner query for calculating the user's balance.
@@ -47,30 +265,44 @@ def _user_balance_query(
if isinstance(user_id, int):
user_id = BindParameter("user_id", value=user_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until, type_=datetime)
if isinstance(until_including, bool):
until_including = BindParameter("until_including", value=until_including, type_=bool)
initial_element = select(
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("balance"),
const(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
const(DEFAULT_INTEREST_RATE_PERCENT).label("interest_rate_percent"),
const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
const(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"),
const(DEFAULT_PENALTY_MULTIPLIER_PERCENT).label("penalty_multiplier_percent"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
trx_subset_subset = (
_non_joint_transaction_query(
user_id=user_id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
.union_all(
_joint_transaction_query(
user_id=user_id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
),
)
.subquery(f"{trx_subset_name}_subset")
)
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=Transaction.time.asc()).label("i"),
Transaction.amount,
Transaction.id,
Transaction.amount,
Transaction.interest_rate_percent,
Transaction.penalty_multiplier_percent,
Transaction.penalty_threshold,
@@ -79,44 +311,16 @@ def _user_balance_query(
Transaction.time,
Transaction.transfer_user_id,
Transaction.type_,
trx_subset_subset.c.user_shares,
trx_subset_subset.c.user_count,
)
.where(
or_(
and_(
Transaction.user_id == user_id,
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_BALANCE.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.TRANSFER.as_literal_column(),
# TODO: join this with the JOINT transactions, and determine
# how much the current user paid for the product.
TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
]
),
),
and_(
Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
Transaction.transfer_user_id == user_id,
),
Transaction.type_.in_(
[
TransactionType.THROW_PRODUCT.as_literal_column(),
TransactionType.ADJUST_INTEREST.as_literal_column(),
TransactionType.ADJUST_PENALTY.as_literal_column(),
]
),
),
case(
(until_including, Transaction.time <= until),
else_=Transaction.time < until,
)
if until is not None
else CONST_TRUE,
.select_from(trx_subset_subset)
.join(
Transaction,
onclause=Transaction.id == trx_subset_subset.c.id,
)
.order_by(Transaction.time.asc())
.alias("trx_subset")
.subquery(trx_subset_name)
)
recursive_elements = (
@@ -139,49 +343,43 @@ def _user_balance_query(
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.balance
- (
trx_subset.c.product_count
# Price of a single product, accounted for penalties and interest.
* cast(
func.ceil(
# TODO: This can get quite expensive real quick, so we should do some caching of the
# product prices somehow.
# Base price
(
# FIXME: this always returns 0 for some reason...
select(cast(column("price"), Float))
.select_from(
_product_price_query(
trx_subset.c.product_id,
use_cache=use_cache,
until=trx_subset.c.time,
until_including=False,
cte_name="product_price_cte",
)
)
.order_by(column("i").desc())
.limit(CONST_ONE)
).scalar_subquery()
# TODO: should interest be applied before or after the penalty multiplier?
# at the moment of writing, after sound right, but maybe ask someone?
# Interest
* (cast(recursive_cte.c.interest_rate_percent, Float) / const(100))
# TODO: these should be added together, not multiplied, see specification
# Penalty
* case(
(
recursive_cte.c.balance < recursive_cte.c.penalty_threshold,
(
cast(recursive_cte.c.penalty_multiplier_percent, Float)
/ const(100)
),
),
else_=const(1.0),
)
),
Integer,
)
),
- _product_cost_expression(
product_count_column=trx_subset.c.product_count,
product_id_column=trx_subset.c.product_id,
interest_rate_percent_column=recursive_cte.c.interest_rate_percent,
user_balance_column=recursive_cte.c.balance,
penalty_threshold_column=recursive_cte.c.penalty_threshold,
penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent,
joint_user_shares_column=CONST_ONE,
joint_user_count_column=CONST_ONE,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
cte_name=f"{cte_name}_price",
trx_subset_name=f"{trx_subset_name}_price",
).label("product_cost"),
),
# Joint transaction -> balance decreases proportionally
(
trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(),
recursive_cte.c.balance
- _product_cost_expression(
product_count_column=trx_subset.c.product_count,
product_id_column=trx_subset.c.product_id,
interest_rate_percent_column=recursive_cte.c.interest_rate_percent,
user_balance_column=recursive_cte.c.balance,
penalty_threshold_column=recursive_cte.c.penalty_threshold,
penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent,
joint_user_shares_column=trx_subset.c.user_shares,
joint_user_count_column=trx_subset.c.user_count,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
cte_name=f"{cte_name}_joint_price",
trx_subset_name=f"{trx_subset_name}_joint_price",
).label("joint_product_cost"),
),
# Transfers money to self -> balance increases
(
@@ -200,8 +398,7 @@ def _user_balance_query(
recursive_cte.c.balance - trx_subset.c.amount,
),
# Throws a product -> if the user is considered to have bought it, balance increases
# TODO:
# (
# TODO: # (
# trx_subset.c.type_ == TransactionType.THROW_PRODUCT,
# recursive_cte.c.balance + trx_subset.c.amount,
# ),
@@ -255,17 +452,16 @@ class UserBalanceLogEntry:
Returns whether this exact transaction is penalized.
"""
return False
raise NotImplementedError("is_penalized is not implemented yet.")
# return self.transaction.type_ == TransactionType.BUY_PRODUCT and prev?
# TODO: add until datetime parameter
def user_balance_log(
sql_session: Session,
user: User,
use_cache: bool = True,
until: Transaction | None = None,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[UserBalanceLogEntry]:
"""
Returns a log of the user's balance over time, including interest and penalty adjustments.
@@ -276,13 +472,18 @@ def user_balance_log(
if user.id is None:
raise ValueError("User must be persisted in the database.")
if until is not None and until.id is None:
raise ValueError("'until' transaction must be persisted in the database.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
recursive_cte = _user_balance_query(
user.id,
use_cache=use_cache,
until=until.time if until else None,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.execute(
@@ -298,13 +499,13 @@ def user_balance_log(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.order_by(recursive_cte.c.i.asc())
.order_by(recursive_cte.c.i.asc()),
).all()
if result is None:
# If there are no transactions for this user, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id}).",
)
return [
@@ -319,13 +520,13 @@ def user_balance_log(
]
# TODO: add until datetime parameter
def user_balance(
sql_session: Session,
user: User,
use_cache: bool = True,
until: Transaction | None = None,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> int:
"""
Calculates the balance of a user.
@@ -336,23 +537,31 @@ def user_balance(
if user.id is None:
raise ValueError("User must be persisted in the database.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
recursive_cte = _user_balance_query(
user.id,
use_cache=use_cache,
until=until.time if until else None,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
result = sql_session.scalar(
select(recursive_cte.c.balance)
.order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
.offset(CONST_ZERO),
)
if result is None:
# If there are no transactions for this user, the query should return 0, not None.
raise RuntimeError(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id}).",
)
return result

View File

@@ -1,4 +1,11 @@
# This absoulutely needs a cache, else we can't stop recursing until we know all owners for all products...
from datetime import datetime
from sqlalchemy import BindParameter, bindparam
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
# NOTE: This absolutely needs a cache, else we can't stop recursing until we know all owners for all products...
#
# Since we know that the non-owned products will not get renowned by the user by other means,
# we can just check for ownership on the products that have an ADD_PRODUCT transaction for the user.
@@ -8,3 +15,34 @@
# but we still need to check if the user passes out of ownership for the item, without needing to check past
# the cache time. Maybe we also need to store the queue number(s) per user/product combo in the cache? What if
# a user has products multiple places in the queue, interleaved with other users?
def user_products(
sql_session: Session,
user: User,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[tuple[Product, int]]:
"""
Returns the list of products owned by the user, along with how many of each product they own.
"""
if user.id is None:
raise ValueError("User must be persisted in the database.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
raise NotImplementedError("Not implemented yet, needs caching system first.")

View File

@@ -1,79 +1,111 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
import random
import sys
import traceback
from signal import (
SIG_IGN,
SIGQUIT,
SIGTSTP,
)
from signal import (
signal as set_signal_handler,
)
from sqlalchemy.orm import Session
from ..conf import config
from ..lib.helpers import *
from ..menus import *
from ..menus import (
AddProductMenu,
AddStockMenu,
AddUserMenu,
AdjustCreditMenu,
AdjustStockMenu,
BalanceMenu,
BuyMenu,
CleanupStockMenu,
EditProductMenu,
EditUserMenu,
FAQMenu,
LoggedStatisticsMenu,