63 Commits

Author SHA1 Message Date
oysteikt 4c14a2cf65 fixup! WIP
Run tests / run-tests (push) Successful in 1m21s
2025-12-10 15:42:22 +09:00
oysteikt 90da53c26c fixup! .gitea/workflows: init test pipeline 2025-12-10 15:42:21 +09:00
oysteikt e9ce51b97b fixup! WIP
Run tests / run-tests (push) Successful in 1m23s
2025-12-10 15:38:32 +09:00
oysteikt a4a22e6565 fixup! WIP 2025-12-10 14:25:42 +09:00
oysteikt bb7d1a2743 fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Successful in 1m14s
2025-12-10 13:40:29 +09:00
oysteikt e68d7effcd fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Has been cancelled
2025-12-10 13:40:15 +09:00
oysteikt dc668ab113 fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Successful in 1m6s
2025-12-10 13:37:02 +09:00
oysteikt fa7ad3a258 fixup! WIP
Run tests / run-tests (push) Successful in 1m30s
2025-12-10 13:32:54 +09:00
oysteikt 7f4a980eef fixup! WIP
Run tests / run-tests (push) Successful in 1m25s
2025-12-10 11:39:30 +09:00
oysteikt 2207001136 fixup! WIP
Run tests / run-tests (push) Successful in 43s
2025-12-09 21:30:19 +09:00
oysteikt fead6257c7 fixup! WIP 2025-12-09 21:30:15 +09:00
oysteikt 2a9ace4263 fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Successful in 1m8s
2025-12-09 20:39:21 +09:00
oysteikt 60fa6529ee fixup! WIP 2025-12-09 20:39:01 +09:00
oysteikt 2e66a9a4b0 fixup! WIP
Run tests / run-tests (push) Successful in 41s
2025-12-09 18:53:14 +09:00
oysteikt a087d3bede fixup! WIP
Run tests / run-tests (push) Successful in 41s
2025-12-09 18:43:42 +09:00
oysteikt 45bb31aba0 fixup! WIP
Run tests / run-tests (push) Successful in 45s
2025-12-09 18:32:04 +09:00
oysteikt f15c748558 fixup! WIP
Run tests / run-tests (push) Has been cancelled
2025-12-09 18:31:10 +09:00
oysteikt d220342d56 fixup! WIP
Run tests / run-tests (push) Successful in 46s
2025-12-09 17:44:18 +09:00
oysteikt 108e17edb8 fixup! WIP: docs/economy 2025-12-09 17:40:41 +09:00
oysteikt 12028cee22 fixup! WIP: docs/economy
Run tests / run-tests (push) Successful in 46s
2025-12-09 17:03:47 +09:00
oysteikt 1515eb6dff fixup! WIP
Run tests / run-tests (push) Successful in 44s
2025-12-09 17:02:28 +09:00
oysteikt 7199cbf34a WIP: docs/economy 2025-12-09 17:02:22 +09:00
oysteikt 722f0cae93 fixup! WIP 2025-12-09 15:51:54 +09:00
oysteikt 16be0f0fbe fixup! WIP 2025-12-09 15:47:14 +09:00
oysteikt cec91d923c fixup! WIP 2025-12-09 15:30:16 +09:00
oysteikt 0504cc1a1e fixup! WIP
Run tests / run-tests (push) Successful in 46s
2025-12-09 15:17:52 +09:00
oysteikt e7453d0fdd fixup! WIP
Run tests / run-tests (push) Successful in 1m25s
2025-12-09 14:43:45 +09:00
oysteikt c6ecb6fae9 fixup! WIP 2025-12-09 13:25:34 +09:00
oysteikt aaa5a6c556 fixup! WIP
Run tests / run-tests (push) Successful in 46s
2025-12-09 13:00:16 +09:00
oysteikt 6a83a41f28 fixup! WIP
Run tests / run-tests (push) Successful in 46s
2025-12-09 12:55:24 +09:00
oysteikt aa4e8dbee5 fixup! WIP
Run tests / run-tests (push) Successful in 41s
2025-12-09 12:49:38 +09:00
oysteikt f39e649b3d fixup! WIP
Run tests / run-tests (push) Successful in 41s
2025-12-09 11:56:06 +09:00
oysteikt 0a2fc799dd fixup! WIP
Run tests / run-tests (push) Successful in 1m37s
2025-12-09 04:25:05 +09:00
oysteikt 7d498f9bf1 fixup! WIP
Run tests / run-tests (push) Successful in 40s
2025-12-08 21:55:25 +09:00
oysteikt f1b15357f9 fixup! WIP
Run tests / run-tests (push) Successful in 41s
2025-12-08 21:43:27 +09:00
oysteikt de896901bb models/Base: add comment about __repr__ impl 2025-12-08 21:43:26 +09:00
oysteikt 15d1763405 fixup! WIP 2025-12-08 21:43:23 +09:00
oysteikt 683981d9dc fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Successful in 40s
2025-12-08 21:30:07 +09:00
oysteikt 4289d63ac9 fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Successful in 45s
2025-12-08 21:28:01 +09:00
oysteikt ce3e65357b fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Successful in 54s
2025-12-08 21:16:04 +09:00
oysteikt 928ab2a98a fixup! WIP
Run tests / run-tests (push) Successful in 38s
2025-12-08 21:13:49 +09:00
oysteikt 0b59d469dd fixup! WIP
Run tests / run-tests (push) Successful in 56s
2025-12-08 21:04:49 +09:00
oysteikt 24c5a9af38 fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Successful in 53s
2025-12-08 20:33:03 +09:00
oysteikt 21ccf78401 fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Successful in 1m8s
2025-12-08 20:29:10 +09:00
oysteikt d5b481d97a fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Successful in 35s
2025-12-08 20:27:32 +09:00
oysteikt cac1b5be20 fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Failing after 24s
2025-12-08 20:24:47 +09:00
oysteikt ad1fcfe98d fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Failing after 22s
2025-12-08 20:20:35 +09:00
oysteikt cc7b40ab7e fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Failing after 21s
2025-12-08 20:19:09 +09:00
oysteikt d35ffd04cc fixup! .gitea/workflows: init test pipeline
Run tests / run-tests (push) Failing after 20s
2025-12-08 20:17:53 +09:00
oysteikt d39f1f8a92 pyproject.toml: psycopg2 -> psycopg2-binary
Run tests / run-tests (push) Failing after 24s
2025-12-08 20:07:37 +09:00
oysteikt 0e3bed9bf5 uv.lock: update 2025-12-08 20:07:37 +09:00
oysteikt 3a1fc58a68 .gitea/workflows: init test pipeline 2025-12-08 20:07:37 +09:00
oysteikt 1ec7c79378 fixup! WIP 2025-12-08 19:50:22 +09:00
oysteikt bc43d4948c README: add overview of project structure 2025-12-08 19:45:37 +09:00
oysteikt 7e5345c7fb fixup! WIP 2025-12-08 19:45:22 +09:00
oysteikt 50867db928 fixup! WIP 2025-12-08 18:26:49 +09:00
oysteikt 5252cb611f fixup! WIP 2025-12-08 18:26:49 +09:00
oysteikt 5f510ee5d8 fixup! WIP 2025-12-08 18:26:48 +09:00
oysteikt f8829a6c7b fixup! WIP 2025-12-08 18:26:48 +09:00
oysteikt 885e989659 fixup! WIP 2025-12-08 18:26:48 +09:00
oysteikt 5c0b2b5229 WIP 2025-12-08 18:26:48 +09:00
oysteikt 9f2d8229fd .gitignore: add pytest-cov data 2025-12-08 18:26:47 +09:00
oysteikt 8807d7278a {nix,pyproject.toml}: add pytest, pytest-cov 2025-12-08 18:26:46 +09:00
109 changed files with 2717 additions and 6256 deletions
-71
View File
@@ -1,71 +0,0 @@
name: Run benchmarks
on:
workflow_dispatch:
# TODO: make this only workflow_dispatch when merged into main
push:
jobs:
run-tests:
runs-on: debian-latest
steps:
- uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install dependencies
run: uv sync --locked --group test
- name: Run benchmarks
continue-on-error: true
run: |
set -euo pipefail
set -x
PYTEST_ARGS=(
-vv
--benchmark-only
-k test_benchmark
)
uv run -- pytest "${PYTEST_ARGS[@]}"
- name: Upload benchmark JSON report
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v2
with:
source: ./benchmark/*/*.json
quote-source: false
target: ${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/benchmark.json
username: gitea-web
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
host: pages.pvv.ntnu.no
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
- name: Upload histograms
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v2
with:
source: ./benchmark/*.svg
quote-source: false
target: ${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/
username: gitea-web
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
host: pages.pvv.ntnu.no
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
# NOTE: $GITHUB_STEP_SUMMARY when...
- name: Run information
run: |
echo "Benchmark run ID: ${{ github.run_id }}"
echo "Benchmark JSON: https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/benchmark.json"
echo "Histograms: https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-product_owners.svg"
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-product_price.svg"
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-product_stock.svg"
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-transaction_log.svg"
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-user_balance.svg"
- name: Check failure
if: failure()
run: |
echo "Tests failed"
exit 1
+53 -44
View File
@@ -16,57 +16,66 @@ jobs:
run-tests: run-tests:
runs-on: debian-latest runs-on: debian-latest
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v6
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v7 uses: astral-sh/setup-uv@v7
- name: Install dependencies - name: Install dependencies
run: uv sync --locked --group test run: uv sync --locked --group test
- name: Run tests - name: Run tests
continue-on-error: true continue-on-error: true
run: | run: |
set -euo pipefail set -euo pipefail
set -x set -x
PYTEST_ARGS=( PYTEST_ARGS=(
-vv -vv
--cov=dibbler.lib
--cov=dibbler.models
--cov=dibbler.queries
--cov-report=html
--cov-branch
--self-contained-html
--html=./test-report/index.html
)
if [ "$DEBUG_SQL" == "true" ]; then
PYTEST_ARGS+=(
--debug-sql
) )
fi
if [ "$DEBUG_SQL" == "true" ]; then uv run -- pytest "${PYTEST_ARGS[@]}"
PYTEST_ARGS+=(
--debug-sql
)
fi
uv run -- pytest "${PYTEST_ARGS[@]}" - name: Generate badge
run: uv run -- coverage-badge -o htmlcov/badge.svg
- name: Generate badge - name: Upload test report
run: uv run -- coverage-badge -o htmlcov/badge.svg uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1
with:
source: ./test-report/
target: ${{ gitea.ref_name }}/test-report/
username: gitea-web
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
host: pages.pvv.ntnu.no
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
- name: Upload test report - name: Upload coverage report
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1 uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1
with: with:
source: ./test-report/ source: ./htmlcov/
target: ${{ gitea.ref_name }}/test-report/ target: ${{ gitea.ref_name }}/coverage/
username: gitea-web username: gitea-web
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }} ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
host: pages.pvv.ntnu.no host: pages.pvv.ntnu.no
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg" known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
- name: Upload coverage report - name: Check failure
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1 if: failure()
with: run: |
source: ./htmlcov/ echo "Tests failed"
target: ${{ gitea.ref_name }}/coverage/ exit 1
username: gitea-web
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
host: pages.pvv.ntnu.no
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
- name: Check failure
if: failure()
run: |
echo "Tests failed"
exit 1
-6
View File
@@ -8,12 +8,6 @@ test.db
.ruff_cache .ruff_cache
*.qcow2
dibbler/_version.py
.coverage .coverage
.coverage.*
htmlcov htmlcov
test-report test-report
/benchmark
+16 -22
View File
@@ -23,9 +23,8 @@ Installer python, og lag og aktiver et venv. Installer så avhengighetene med `p
Deretter kan du kjøre programmet med Deretter kan du kjøre programmet med
```console ```console
python -m dibbler -c example-config.toml create-db python -m dibbler -c example-config.ini create-db
python -m dibbler -c example-config.toml seed-data python -m dibbler -c example-config.ini loop
python -m dibbler -c example-config.toml loop
``` ```
## Prosjektstruktur ## Prosjektstruktur
@@ -62,30 +61,25 @@ Her ligger enhetstester for prosjektet. Testene bruker `pytest` som testløper.
## Nix ## Nix
> [!NOTE] ### Bygge nytt image
> Vi har skrevet nix-kode for å generere en QEMU-VM med tilnærmet produksjonsoppsett.
> Det kjører ikke nødvendigvis noen VM-er i produksjon, og ihvertfall ikke denne VM-en.
> Den er hovedsakelig laget for enkel interaktiv testing, og for å teste NixOS modulen.
Du kan enklest komme i gang med nix-utvikling ved å kjøre test VM-en: For å bygge et image trenger du en builder som takler å bygge for arkitekturen du skal lage et image for.
```console (Eller be til gudene om at cross compile funker)
nix run .#vm
# Eller hvis du trenger tilgang til terminalen i VM-en også: Flaket exposer en modul som autologger inn med en bruker som automatisk kjører dibbler, og setter opp et minimalistisk miljø.
nix run .#vm-non-kiosk
```
Du kan også bygge pakken manuelt, eller kjøre den direkte: Før du bygger imaget burde du kopiere og endre `example-config.ini` lokalt til å inneholde instillingene dine. **NB: Denne kommer til å ligge i nix storen, ikke si noe her som du ikke vil at moren din skal høre.**
```console Du kan også endre hvilken config-fil som blir brukt direkte i pakken eller i modulen.
nix build .#dibbler
nix run .# -- --config example-config.toml create-db Se eksempelet for hvordan skrot er satt opp i `flake.nix` og `nix/skrott.nix`
nix run .# -- --config example-config.toml seed-data
nix run .# -- --config example-config.toml loop
```
## Produksjonssetting ### Bygge image for skrot
Se https://wiki.pvv.ntnu.no/wiki/Drift/Dibbler Skrot har et image definert i flake.nix:
1. endre `example-config.ini`
2. `nix build .#images.skrot`
3. ???
4. non-profit
+4 -54
View File
@@ -1,56 +1,6 @@
import os # This module is supposed to act as a singleton and be filled
import sys # with config variables by cli.py
import tomllib
from pathlib import Path
from typing import Any
from dibbler.lib.helpers import file_is_submissive_and_readable import configparser
DEFAULT_CONFIG_PATH = Path("/etc/dibbler/dibbler.toml") config = configparser.ConfigParser()
config: dict[str, dict[str, Any]] = {}
def load_config(config_path: Path | None = None) -> None:
global config
if config_path is not None:
with Path(config_path).open("rb") as file:
config = tomllib.load(file)
elif file_is_submissive_and_readable(DEFAULT_CONFIG_PATH):
with DEFAULT_CONFIG_PATH.open("rb") as file:
config = tomllib.load(file)
else:
print(
"Could not read config file, it was neither provided nor readable in default location",
file=sys.stderr,
)
sys.exit(1)
def config_db_string() -> str:
db_type = config["database"]["type"]
if db_type == "sqlite":
path = Path(config["database"]["sqlite"]["path"])
return f"sqlite:///{path.absolute()}"
if db_type == "postgresql":
host = config["database"]["postgresql"]["host"]
port = config["database"]["postgresql"].get("port", 5432)
username = config["database"]["postgresql"].get("username", "dibbler")
dbname = config["database"]["postgresql"].get("dbname", "dibbler")
if "password_file" in config["database"]["postgresql"]:
with Path(config["database"]["postgresql"]["password_file"]).open("r") as f:
password = f.read().strip()
elif "password" in config["database"]["postgresql"]:
password = config["database"]["postgresql"]["password"]
else:
password = ""
if host.startswith("/"):
return f"postgresql+psycopg2://{username}:{password}@/{dbname}?host={host}"
return f"postgresql+psycopg2://{username}:{password}@{host}:{port}/{dbname}"
print(f"Error: unknown database type '{db_type}'")
exit(1)
+19
View File
@@ -0,0 +1,19 @@
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from dibbler.conf import config
database_url: str | None = None
if (url := config.get("database", "url")) is not None:
database_url = url
elif (url_file := config.get("database", "url_file")) is not None:
with Path(url_file).open() as file:
database_url = file.read().strip()
assert database_url is not None, "Database URL must be specified in config"
engine = create_engine(database_url)
Session = sessionmaker(bind=engine)
+59 -59
View File
@@ -1,71 +1,71 @@
# import os import os
# from PIL import ImageFont from PIL import ImageFont
# from barcode.writer import ImageWriter, mm2px from barcode.writer import ImageWriter, mm2px
# from brother_ql.labels import ALL_LABELS from brother_ql.labels import ALL_LABELS
# def px2mm(px, dpi=300): def px2mm(px, dpi=300):
# return (25.4 * px) / dpi return (25.4 * px) / dpi
# class BrotherLabelWriter(ImageWriter): class BrotherLabelWriter(ImageWriter):
# def __init__(self, typ="62", max_height=350, rot=False, text=None): def __init__(self, typ="62", max_height=350, rot=False, text=None):
# super(BrotherLabelWriter, self).__init__() super(BrotherLabelWriter, self).__init__()
# label = next([l for l in ALL_LABELS if l.identifier == typ]) label = next([l for l in ALL_LABELS if l.identifier == typ])
# assert label is not None assert label is not None
# self.rot = rot self.rot = rot
# if self.rot: if self.rot:
# self._h, self._w = label.dots_printable self._h, self._w = label.dots_printable
# if self._w == 0 or self._w > max_height: if self._w == 0 or self._w > max_height:
# self._w = min(max_height, self._h / 2) self._w = min(max_height, self._h / 2)
# else: else:
# self._w, self._h = label.dots_printable self._w, self._h = label.dots_printable
# if self._h == 0 or self._h > max_height: if self._h == 0 or self._h > max_height:
# self._h = min(max_height, self._w / 2) self._h = min(max_height, self._w / 2)
# self._xo = 0.0 self._xo = 0.0
# self._yo = 0.0 self._yo = 0.0
# self._title = text self._title = text
# def _init(self, code): def _init(self, code):
# self.text = None self.text = None
# super(BrotherLabelWriter, self)._init(code) super(BrotherLabelWriter, self)._init(code)
# def calculate_size(self, modules_per_line, number_of_lines, dpi=300): def calculate_size(self, modules_per_line, number_of_lines, dpi=300):
# x, y = super(BrotherLabelWriter, self).calculate_size( x, y = super(BrotherLabelWriter, self).calculate_size(
# modules_per_line, number_of_lines, dpi modules_per_line, number_of_lines, dpi
# ) )
# self._xo = (px2mm(self._w) - px2mm(x)) / 2 self._xo = (px2mm(self._w) - px2mm(x)) / 2
# self._yo = px2mm(self._h) - px2mm(y) self._yo = px2mm(self._h) - px2mm(y)
# assert self._xo >= 0 assert self._xo >= 0
# assert self._yo >= 0 assert self._yo >= 0
# return int(self._w), int(self._h) return int(self._w), int(self._h)
# def _paint_module(self, xpos, ypos, width, color): def _paint_module(self, xpos, ypos, width, color):
# super(BrotherLabelWriter, self)._paint_module( super(BrotherLabelWriter, self)._paint_module(
# xpos + self._xo, ypos + self._yo, width, color xpos + self._xo, ypos + self._yo, width, color
# ) )
# def _paint_text(self, xpos, ypos): def _paint_text(self, xpos, ypos):
# super(BrotherLabelWriter, self)._paint_text(xpos + self._xo, ypos + self._yo) super(BrotherLabelWriter, self)._paint_text(xpos + self._xo, ypos + self._yo)
# def _finish(self): def _finish(self):
# if self._title: if self._title:
# width = self._w + 1 width = self._w + 1
# height = 0 height = 0
# max_h = self._h - mm2px(self._yo, self.dpi) max_h = self._h - mm2px(self._yo, self.dpi)
# fs = int(max_h / 1.2) fs = int(max_h / 1.2)
# font_path = os.path.join( font_path = os.path.join(
# os.path.dirname(os.path.realpath(__file__)), os.path.dirname(os.path.realpath(__file__)),
# "Stranger back in the Night.ttf", "Stranger back in the Night.ttf",
# ) )
# font = ImageFont.truetype(font_path, 10) font = ImageFont.truetype(font_path, 10)
# while width > self._w or height > max_h: while width > self._w or height > max_h:
# font = ImageFont.truetype(font_path, fs) font = ImageFont.truetype(font_path, fs)
# width, height = font.getsize(self._title) width, height = font.getsize(self._title)
# fs -= 1 fs -= 1
# pos = ((self._w - width) // 2, 0 - (height // 8)) pos = ((self._w - width) // 2, 0 - (height // 8))
# self._draw.text(pos, self._title, font=font, fill=self.foreground) self._draw.text(pos, self._title, font=font, fill=self.foreground)
# return self._image return self._image
-108
View File
@@ -1,108 +0,0 @@
import sys
from pathlib import Path
from sqlalchemy import Engine, create_engine, inspect, select
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.clsregistry import _ModuleMarker
from dibbler.lib.helpers import file_is_submissive_and_readable
from dibbler.models import Base
def check_db_health(engine: Engine, verify_table_existence: bool = False) -> None:
dialect_name = getattr(engine.dialect, "name", "").lower()
if "postgres" in dialect_name:
check_postgres_ping(engine)
elif dialect_name == "sqlite":
check_sqlite_file(engine)
if verify_table_existence:
verify_tables_and_columns(engine)
def check_postgres_ping(engine: Engine) -> None:
try:
with engine.connect() as conn:
result = conn.execute(select(1))
scalar = result.scalar()
if scalar != 1 and scalar is not None:
print(
"Unexpected response from Postgres when running 'SELECT 1'",
file=sys.stderr,
)
sys.exit(1)
except (OperationalError, DBAPIError) as exc:
print(f"Failed to connect to Postgres database: {exc}", file=sys.stderr)
sys.exit(1)
def check_sqlite_file(engine: Engine) -> None:
db_path = engine.url.database
# Don't verify in-memory databases or empty paths
if db_path in (None, "", ":memory:"):
return
db_path = db_path.removeprefix("file:").removeprefix("sqlite:")
# Strip query parameters
if "?" in db_path:
db_path = db_path.split("?", 1)[0]
path = Path(db_path)
if not path.exists():
print(f"SQLite database file does not exist: {path}", file=sys.stderr)
sys.exit(1)
if not path.is_file():
print(f"SQLite database path is not a file: {path}", file=sys.stderr)
sys.exit(1)
if not file_is_submissive_and_readable(path):
print(f"SQLite database file is not submissive and readable: {path}", file=sys.stderr)
sys.exit(1)
return
def verify_tables_and_columns(engine: Engine) -> None:
iengine = inspect(engine)
errors = False
tables = iengine.get_table_names()
views = iengine.get_view_names()
tables.extend(views)
for _name, klass in Base.registry._class_registry.items():
if isinstance(klass, _ModuleMarker):
continue
table = klass.__tablename__
if table in tables:
columns = [c["name"] for c in iengine.get_columns(table)]
mapper = inspect(klass)
for column_prop in mapper.attrs:
if isinstance(column_prop, RelationshipProperty):
pass
else:
for column in column_prop.columns:
if not column.key in columns:
print(
f"Model '{klass}' declares column '{column.key}' which does not exist in database {engine}",
file=sys.stderr,
)
errors = True
else:
print(
f"Model '{klass}' declares table '{table}' which does not exist in database {engine}",
file=sys.stderr,
)
errors = True
if errors:
print("Have you remembered to run `dibbler create-db?", file=sys.stderr)
sys.exit(1)
+5 -22
View File
@@ -2,12 +2,9 @@ import os
import pwd import pwd
import signal import signal
import subprocess import subprocess
from collections.abc import Callable
from pathlib import Path
from typing import Any, Literal
def system_user_exists(username: str) -> bool: def system_user_exists(username):
try: try:
pwd.getpwnam(username) pwd.getpwnam(username)
except KeyError: except KeyError:
@@ -18,7 +15,7 @@ def system_user_exists(username: str) -> bool:
return True return True
def guess_data_type(string: str) -> Literal["card", "rfid", "bar_code", "username"] | None: def guess_data_type(string):
if string.startswith("ntnu") and string[4:].isdigit(): if string.startswith("ntnu") and string[4:].isdigit():
return "card" return "card"
if string.isdigit() and len(string) == 10: if string.isdigit() and len(string) == 10:
@@ -32,11 +29,7 @@ def guess_data_type(string: str) -> Literal["card", "rfid", "bar_code", "usernam
return None return None
def argmax( def argmax(d, all=False, value=None):
d: dict[Any, Any],
all_: bool = False,
value: Callable[[Any], Any] | None = None,
) -> Any | list[Any] | None:
maxarg = None maxarg = None
if value is not None: if value is not None:
dd = d dd = d
@@ -46,12 +39,12 @@ def argmax(
for key in list(d.keys()): for key in list(d.keys()):
if maxarg is None or d[key] > d[maxarg]: if maxarg is None or d[key] > d[maxarg]:
maxarg = key maxarg = key
if all_: if all:
return [k for k in list(d.keys()) if d[k] == d[maxarg]] return [k for k in list(d.keys()) if d[k] == d[maxarg]]
return maxarg return maxarg
def less(string: str) -> None: def less(string):
""" """
Run less with string as input; wait until it finishes. Run less with string as input; wait until it finishes.
""" """
@@ -63,13 +56,3 @@ def less(string: str) -> None:
proc = subprocess.Popen("less", env=env, encoding="utf-8", stdin=subprocess.PIPE) proc = subprocess.Popen("less", env=env, encoding="utf-8", stdin=subprocess.PIPE)
proc.communicate(string) proc.communicate(string)
signal.signal(signal.SIGINT, int_handler) signal.signal(signal.SIGINT, int_handler)
def file_is_submissive_and_readable(file: Path) -> bool:
return file.is_file() and any(
[
file.stat().st_mode & 0o400 and file.stat().st_uid == os.getuid(),
file.stat().st_mode & 0o040 and file.stat().st_gid == os.getgid(),
file.stat().st_mode & 0o004,
],
)
+80 -77
View File
@@ -1,95 +1,98 @@
# import barcode import os
# from brother_ql.brother_ql_create import create_label import datetime
# from brother_ql.raster import BrotherQLRaster
# from brother_ql.backends import backend_factory
# from brother_ql.labels import ALL_LABELS
# from PIL import Image, ImageDraw, ImageFont
# from .barcode_helpers import BrotherLabelWriter import barcode
from brother_ql.brother_ql_create import create_label
from brother_ql.raster import BrotherQLRaster
from brother_ql.backends import backend_factory
from brother_ql.labels import ALL_LABELS
from PIL import Image, ImageDraw, ImageFont
from .barcode_helpers import BrotherLabelWriter
# def print_name_label( def print_name_label(
# text, text,
# margin=10, margin=10,
# rotate=False, rotate=False,
# label_type="62", label_type="62",
# printer_type="QL-700", printer_type="QL-700",
# ): ):
# label = next([l for l in ALL_LABELS if l.identifier == label_type]) label = next([l for l in ALL_LABELS if l.identifier == label_type])
# if not rotate: if not rotate:
# width, height = label.dots_printable width, height = label.dots_printable
# else: else:
# height, width = label.dots_printable height, width = label.dots_printable
# font_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ChopinScript.ttf") font_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ChopinScript.ttf")
# fs = 2000 fs = 2000
# tw, th = width, height tw, th = width, height
# if width == 0: if width == 0:
# while th + 2 * margin > height: while th + 2 * margin > height:
# font = ImageFont.truetype(font_path, fs) font = ImageFont.truetype(font_path, fs)
# tw, th = font.getsize(text) tw, th = font.getsize(text)
# fs -= 1 fs -= 1
# width = tw + 2 * margin width = tw + 2 * margin
# elif height == 0: elif height == 0:
# while tw + 2 * margin > width: while tw + 2 * margin > width:
# font = ImageFont.truetype(font_path, fs) font = ImageFont.truetype(font_path, fs)
# tw, th = font.getsize(text) tw, th = font.getsize(text)
# fs -= 1 fs -= 1
# height = th + 2 * margin height = th + 2 * margin
# else: else:
# while tw + 2 * margin > width or th + 2 * margin > height: while tw + 2 * margin > width or th + 2 * margin > height:
# font = ImageFont.truetype(font_path, fs) font = ImageFont.truetype(font_path, fs)
# tw, th = font.getsize(text) tw, th = font.getsize(text)
# fs -= 1 fs -= 1
# xp = (width // 2) - (tw // 2) xp = (width // 2) - (tw // 2)
# yp = (height // 2) - (th // 2) yp = (height // 2) - (th // 2)
# im = Image.new("RGB", (width, height), (255, 255, 255)) im = Image.new("RGB", (width, height), (255, 255, 255))
# dr = ImageDraw.Draw(im) dr = ImageDraw.Draw(im)
# dr.text((xp, yp), text, fill=(0, 0, 0), font=font) dr.text((xp, yp), text, fill=(0, 0, 0), font=font)
# now = datetime.datetime.now() now = datetime.datetime.now()
# date = now.strftime("%Y-%m-%d") date = now.strftime("%Y-%m-%d")
# dr.text((0, 0), date, fill=(0, 0, 0)) dr.text((0, 0), date, fill=(0, 0, 0))
# base_path = os.path.dirname(os.path.realpath(__file__)) base_path = os.path.dirname(os.path.realpath(__file__))
# fn = os.path.join(base_path, "bar_codes", text + ".png") fn = os.path.join(base_path, "bar_codes", text + ".png")
# im.save(fn, "PNG") im.save(fn, "PNG")
# print_image(fn, printer_type, label_type) print_image(fn, printer_type, label_type)
# def print_bar_code( def print_bar_code(
# barcode_value, barcode_value,
# barcode_text, barcode_text,
# barcode_type="ean13", barcode_type="ean13",
# rotate=False, rotate=False,
# printer_type="QL-700", printer_type="QL-700",
# label_type="62", label_type="62",
# ): ):
# bar_coder = barcode.get_barcode_class(barcode_type) bar_coder = barcode.get_barcode_class(barcode_type)
# wr = BrotherLabelWriter(typ=label_type, rot=rotate, text=barcode_text, max_height=1000) wr = BrotherLabelWriter(typ=label_type, rot=rotate, text=barcode_text, max_height=1000)
# test = bar_coder(barcode_value, writer=wr) test = bar_coder(barcode_value, writer=wr)
# base_path = os.path.dirname(os.path.realpath(__file__)) base_path = os.path.dirname(os.path.realpath(__file__))
# fn = test.save(os.path.join(base_path, "bar_codes", barcode_value)) fn = test.save(os.path.join(base_path, "bar_codes", barcode_value))
# print_image(fn, printer_type, label_type) print_image(fn, printer_type, label_type)
# def print_image(fn, printer_type="QL-700", label_type="62"): def print_image(fn, printer_type="QL-700", label_type="62"):
# qlr = BrotherQLRaster(printer_type) qlr = BrotherQLRaster(printer_type)
# qlr.exception_on_warning = True qlr.exception_on_warning = True
# create_label(qlr, fn, label_type, threshold=70, cut=True) create_label(qlr, fn, label_type, threshold=70, cut=True)
# be = backend_factory("pyusb") be = backend_factory("pyusb")
# list_available_devices = be["list_available_devices"] list_available_devices = be["list_available_devices"]
# BrotherQLBackend = be["backend_class"] BrotherQLBackend = be["backend_class"]
# ad = list_available_devices() ad = list_available_devices()
# assert ad assert ad
# string_descr = ad[0]["string_descr"] string_descr = ad[0]["string_descr"]
# printer = BrotherQLBackend(string_descr) printer = BrotherQLBackend(string_descr)
# printer.write(qlr.data) printer.write(qlr.data)
+22
View File
@@ -0,0 +1,22 @@
from typing import TypeVar
from sqlalchemy import BindParameter, literal
T = TypeVar("T")
def const(value: T) -> BindParameter[T]:
"""
Create a constant SQL literal bind parameter.
This is useful to avoid too many `?` bind parameters in SQL queries,
when the input value is known to be safe.
"""
return literal(value, literal_execute=True)
CONST_ZERO: BindParameter[int] = const(0)
CONST_ONE: BindParameter[int] = const(1)
CONST_TRUE: BindParameter[bool] = const(True)
CONST_FALSE: BindParameter[bool] = const(False)
CONST_NONE: BindParameter[None] = const(None)
+25 -11
View File
@@ -1,4 +1,3 @@
from dibbler.lib.render_tree import render_tree
from dibbler.models import Transaction, TransactionType from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import EXPECTED_FIELDS from dibbler.models.Transaction import EXPECTED_FIELDS
@@ -11,19 +10,23 @@ def render_transaction_log(transaction_log: list[Transaction]) -> str:
aggregated_log = _aggregate_joint_transactions(transaction_log) aggregated_log = _aggregate_joint_transactions(transaction_log)
lines = [] lines = []
for transaction in aggregated_log:
for i, transaction in enumerate(aggregated_log):
if isinstance(transaction, list): if isinstance(transaction, list):
inner_lines = [] inner_lines = []
lines.append(_render_transaction(transaction[0])) is_last = i == len(aggregated_log) - 1
for sub_transaction in transaction[1:]: lines.append(_render_transaction(transaction[0], is_last))
line = _render_transaction(sub_transaction) for j, sub_transaction in enumerate(transaction[1:]):
is_last_inner = j == len(transaction) - 2
line = _render_transaction(sub_transaction, is_last_inner)
inner_lines.append(line) inner_lines.append(line)
lines.append(inner_lines) indented_inner_lines = _indent_lines(inner_lines, is_last=is_last)
lines.extend(indented_inner_lines)
else: else:
line = _render_transaction(transaction) is_last = i == len(aggregated_log) - 1
line = _render_transaction(transaction, is_last)
lines.append(line) lines.append(line)
return "\n".join(lines)
return render_tree(lines)
def _aggregate_joint_transactions( def _aggregate_joint_transactions(
@@ -58,7 +61,17 @@ def _aggregate_joint_transactions(
return aggregated return aggregated
def _render_transaction(transaction: Transaction) -> str: def _indent_lines(lines: list[str], is_last: bool = False) -> list[str]:
indented_lines = []
for line in lines:
if is_last:
indented_lines.append(" " + line)
else:
indented_lines.append("" + line)
return indented_lines
def _render_transaction(transaction: Transaction, is_last: bool) -> str:
match transaction.type_: match transaction.type_:
case TransactionType.ADD_PRODUCT: case TransactionType.ADD_PRODUCT:
line = f"ADD_PRODUCT({transaction.id}, {transaction.user.name}" line = f"ADD_PRODUCT({transaction.id}, {transaction.user.name}"
@@ -112,4 +125,5 @@ def _render_transaction(transaction: Transaction) -> str:
line = ( line = (
f"UNKNOWN[{transaction.type_}](id={transaction.id}, user_id={transaction.user_id})" f"UNKNOWN[{transaction.type_}](id={transaction.id}, user_id={transaction.user_id})"
) )
return line
return "└─ " + line if is_last else "├─ " + line
-115
View File
@@ -1,115 +0,0 @@
_TREE_CHARS = {
"normal": {
"vertical": "",
"branch": "├─ ",
"last": "└─ ",
"empty": " ",
},
"ascii": {
"vertical": "| ",
"branch": "|-- ",
"last": "`-- ",
"empty": " ",
},
}
assert set(_TREE_CHARS["normal"].keys()) == set(_TREE_CHARS["ascii"].keys())
assert all(len(v) == 3 for v in _TREE_CHARS["normal"].values())
assert all(len(v) == 4 for v in _TREE_CHARS["ascii"].values())
def render_tree(
tree: list[str | list],
ascii_only: bool = False,
) -> str:
"""
Render a tree structure as a string.
Each item in the `tree` list can be either a string (a leaf node)
or another list (a subtree).
When `ascii_only` is `True`, only ASCII characters are used for drawing the tree.
Example:
```python
tree = [
"root",
[
"child1",
[
"grandchild1",
"grandchild2",
],
"child2",
],
"root2",
]
print(render_tree(tree, ascii_only=False))
```
Output:
```
├─ root
│ ├─ child1
│ │ ├─ grandchild1
│ │ └─ grandchild2
│ └─ child2
└─ root2
```
Example with ASCII only:
```python
print(render_tree(tree, ascii_only=True))
```
Output:
```
|-- root
| |-- child1
| | |-- grandchild1
| | `-- grandchild2
| `-- child2
`-- root2
```
"""
result: list[str] = []
for index, item in enumerate(tree):
is_last = index == len(tree) - 1
item_lines = _render_tree_line(item, is_last, ascii_only)
result.extend(item_lines)
return "\n".join(result)
def _render_tree_line(
item: str | list,
is_last: bool,
ascii_only: bool,
prefix: str = "",
) -> list[str]:
chars = _TREE_CHARS["ascii"] if ascii_only else _TREE_CHARS["normal"]
lines: list[str] = []
if isinstance(item, str):
line_prefix = chars["last"] if is_last else chars["branch"]
item_lines = item.splitlines()
for line_index, line in enumerate(item_lines):
if line_index == 0:
lines.append(f"{prefix}{line_prefix}{line}")
else:
lines.append(f"{prefix}{chars['vertical']}{line}")
elif isinstance(item, list):
new_prefix = prefix + (chars["empty"] if is_last else chars["vertical"])
for sub_index, sub_item in enumerate(item):
sub_is_last = sub_index == len(item) - 1
sub_lines = _render_tree_line(sub_item, sub_is_last, ascii_only, new_prefix)
lines.extend(sub_lines)
else:
raise ValueError("Item must be either a string or a list.")
return lines
+47 -49
View File
@@ -3,20 +3,18 @@
import datetime import datetime
from collections import defaultdict from collections import defaultdict
from pathlib import Path
from sqlalchemy.orm import Session
from ..models import Transaction
from .helpers import * from .helpers import *
from ..models import Transaction
from ..db import Session
def getUser(sql_session: Session) -> str: def getUser():
assert sql_session is not None
while 1: while 1:
string = input("user? ") string = input("user? ")
user = search_user(string, sql_session) session = Session()
sql_session.close() user = search_user(string, session)
session.close()
if not isinstance(user, list): if not isinstance(user, list):
return user.name return user.name
i = 0 i = 0
@@ -39,11 +37,12 @@ def getUser(sql_session: Session) -> str:
return user[n].name return user[n].name
def getProduct(sql_session: Session) -> str: def getProduct():
assert sql_session is not None
while 1: while 1:
string = input("product? ") string = input("product? ")
product = search_product(string, sql_session) session = Session()
product = search_product(string, session)
session.close()
if not isinstance(product, list): if not isinstance(product, list):
return product.name return product.name
i = 0 i = 0
@@ -90,7 +89,7 @@ class Database:
class InputLine: class InputLine:
def __init__(self, u, p, t) -> None: def __init__(self, u, p, t):
self.inputUser = u self.inputUser = u
self.inputProduct = p self.inputProduct = p
self.inputType = t self.inputType = t
@@ -123,17 +122,17 @@ def getInputType():
return int(inp) return int(inp)
def getProducts(products: str) -> list[tuple[str]]: def getProducts(products):
product = [] product = []
split_products = products.partition("¤") products = products.partition("¤")
product.append(products[0]) product.append(products[0])
while products[1] == "¤": while products[1] == "¤":
split_products = split_products[2].partition("¤") products = products[2].partition("¤")
product.append(products[0]) product.append(products[0])
return product return product
def getDateFile(date: str, inp: str) -> datetime.date: def getDateFile(date, inp):
try: try:
year = inp.partition("-") year = inp.partition("-")
month = year[2].partition("-") month = year[2].partition("-")
@@ -177,7 +176,7 @@ def addLineToDatabase(database, inputLine):
if abs(inputLine.price) > 90000: if abs(inputLine.price) > 90000:
return database return database
# fyller inn for varer # fyller inn for varer
if (inputLine.product != "") and ( if (not inputLine.product == "") and (
(inputLine.inputProduct == "") or (inputLine.inputProduct == inputLine.product) (inputLine.inputProduct == "") or (inputLine.inputProduct == inputLine.product)
): ):
database.varePersonAntall[inputLine.product][inputLine.user] = ( database.varePersonAntall[inputLine.product][inputLine.user] = (
@@ -191,7 +190,7 @@ def addLineToDatabase(database, inputLine):
database.vareUkedagAntall[inputLine.product][inputLine.weekday] += 1 database.vareUkedagAntall[inputLine.product][inputLine.weekday] += 1
# fyller inn for personer # fyller inn for personer
if (inputLine.inputUser == "") or (inputLine.inputUser == inputLine.user): if (inputLine.inputUser == "") or (inputLine.inputUser == inputLine.user):
if inputLine.product != "": if not inputLine.product == "":
database.personVareAntall[inputLine.user][inputLine.product] = ( database.personVareAntall[inputLine.user][inputLine.product] = (
database.personVareAntall[inputLine.user].setdefault(inputLine.product, 0) + 1 database.personVareAntall[inputLine.user].setdefault(inputLine.product, 0) + 1
) )
@@ -215,7 +214,7 @@ def addLineToDatabase(database, inputLine):
database.personNegTransactions[inputLine.user] = ( database.personNegTransactions[inputLine.user] = (
database.personNegTransactions.setdefault(inputLine.user, 0) + inputLine.price database.personNegTransactions.setdefault(inputLine.user, 0) + inputLine.price
) )
elif inputLine.inputType != 1: elif not (inputLine.inputType == 1):
database.globalVareAntall[inputLine.product] = ( database.globalVareAntall[inputLine.product] = (
database.globalVareAntall.setdefault(inputLine.product, 0) + 1 database.globalVareAntall.setdefault(inputLine.product, 0) + 1
) )
@@ -226,7 +225,7 @@ def addLineToDatabase(database, inputLine):
# fyller inn for global statistikk # fyller inn for global statistikk
if (inputLine.inputType == 3) or (inputLine.inputType == 4): if (inputLine.inputType == 3) or (inputLine.inputType == 4):
database.pengebeholdning[inputLine.dateNum] += inputLine.price database.pengebeholdning[inputLine.dateNum] += inputLine.price
if inputLine.product != "": if not (inputLine.product == ""):
database.globalPersonAntall[inputLine.user] = ( database.globalPersonAntall[inputLine.user] = (
database.globalPersonAntall.setdefault(inputLine.user, 0) + 1 database.globalPersonAntall.setdefault(inputLine.user, 0) + 1
) )
@@ -239,12 +238,12 @@ def addLineToDatabase(database, inputLine):
return database return database
def buildDatabaseFromDb(inputType, inputProduct, inputUser, sql_session: Session): def buildDatabaseFromDb(inputType, inputProduct, inputUser):
assert sql_session is not None
sdate = input("enter start date (yyyy-mm-dd)? ") sdate = input("enter start date (yyyy-mm-dd)? ")
edate = input("enter end date (yyyy-mm-dd)? ") edate = input("enter end date (yyyy-mm-dd)? ")
print("building database...") print("building database...")
transaction_list = sql_session.query(Transaction).all() session = Session()
transaction_list = session.query(Transaction).all()
inputLine = InputLine(inputUser, inputProduct, inputType) inputLine = InputLine(inputUser, inputProduct, inputType)
startDate = getDateDb(transaction_list[0].time, sdate) startDate = getDateDb(transaction_list[0].time, sdate)
endDate = getDateDb(transaction_list[-1].time, edate) endDate = getDateDb(transaction_list[-1].time, edate)
@@ -274,9 +273,9 @@ def buildDatabaseFromDb(inputType, inputProduct, inputUser, sql_session: Session
inputLine.price = 0 inputLine.price = 0
print("saving as default.dibblerlog...", end=" ") print("saving as default.dibblerlog...", end=" ")
f = Path.open("default.dibblerlog", "w") f = open("default.dibblerlog", "w")
line_format = "%s|%s|%s|%s|%s|%s\n" line_format = "%s|%s|%s|%s|%s|%s\n"
transaction_list = sql_session.query(Transaction).all() transaction_list = session.query(Transaction).all()
for transaction in transaction_list: for transaction in transaction_list:
if transaction.purchase: if transaction.purchase:
products = "¤".join([ent.product.name for ent in transaction.purchase.entries]) products = "¤".join([ent.product.name for ent in transaction.purchase.entries])
@@ -291,7 +290,8 @@ def buildDatabaseFromDb(inputType, inputProduct, inputUser, sql_session: Session
transaction.description, transaction.description,
) )
f.write(line.encode("utf8")) f.write(line.encode("utf8"))
f.close() session.close()
f.close
# bygg database.pengebeholdning # bygg database.pengebeholdning
if (inputType == 3) or (inputType == 4): if (inputType == 3) or (inputType == 4):
for i in range(inputLine.numberOfDays + 1): for i in range(inputLine.numberOfDays + 1):
@@ -311,7 +311,7 @@ def buildDatabaseFromFile(inputFile, inputType, inputProduct, inputUser):
sdate = input("enter start date (yyyy-mm-dd)? ") sdate = input("enter start date (yyyy-mm-dd)? ")
edate = input("enter end date (yyyy-mm-dd)? ") edate = input("enter end date (yyyy-mm-dd)? ")
f = Path.open(inputFile) f = open(inputFile)
try: try:
fileLines = f.readlines() fileLines = f.readlines()
finally: finally:
@@ -329,7 +329,7 @@ def buildDatabaseFromFile(inputFile, inputType, inputProduct, inputUser):
database.globalUkedagForbruk = [0] * 7 database.globalUkedagForbruk = [0] * 7
database.pengebeholdning = [0] * (inputLine.numberOfDays + 1) database.pengebeholdning = [0] * (inputLine.numberOfDays + 1)
for linje in fileLines: for linje in fileLines:
if linje[0] != "#" and linje != "\n": if not (linje[0] == "#") and not (linje == "\n"):
# henter dateNum, products, user, price # henter dateNum, products, user, price
restDel = linje.partition("|") restDel = linje.partition("|")
restDel = restDel[2].partition(" ") restDel = restDel[2].partition(" ")
@@ -359,7 +359,7 @@ def buildDatabaseFromFile(inputFile, inputType, inputProduct, inputUser):
return database, dateLine return database, dateLine
def printTopDict(dictionary: dict[str, Any], n: int, k: bool) -> None: def printTopDict(dictionary, n, k):
i = 0 i = 0
for key in sorted(dictionary, key=dictionary.get, reverse=k): for key in sorted(dictionary, key=dictionary.get, reverse=k):
print(key, ": ", dictionary[key]) print(key, ": ", dictionary[key])
@@ -369,7 +369,7 @@ def printTopDict(dictionary: dict[str, Any], n: int, k: bool) -> None:
break break
def printTopDict2(dictionary, dictionary2, n) -> None: def printTopDict2(dictionary, dictionary2, n):
print("") print("")
print("product : price[kr] ( number )") print("product : price[kr] ( number )")
i = 0 i = 0
@@ -381,7 +381,7 @@ def printTopDict2(dictionary, dictionary2, n) -> None:
break break
def printWeekdays(week, days) -> None: def printWeekdays(week, days):
if week == [] or days == 0: if week == [] or days == 0:
return return
print( print(
@@ -404,10 +404,10 @@ def printWeekdays(week, days) -> None:
print("") print("")
def printBalance(database, user) -> None: def printBalance(database, user):
forbruk = 0 forbruk = 0
if user in database.personVareVerdi: if user in database.personVareVerdi:
forbruk = sum(database.personVareVerdi[user].values()) forbruk = sum([i for i in list(database.personVareVerdi[user].values())])
print("totalt kjøpt for: ", forbruk, end=" ") print("totalt kjøpt for: ", forbruk, end=" ")
if user in database.personNegTransactions: if user in database.personNegTransactions:
print("kr, totalt lagt til: ", -database.personNegTransactions[user], end=" ") print("kr, totalt lagt til: ", -database.personNegTransactions[user], end=" ")
@@ -419,14 +419,14 @@ def printBalance(database, user) -> None:
print("") print("")
def printUser(database, dateLine, user, n) -> None: def printUser(database, dateLine, user, n):
printTopDict2(database.personVareVerdi[user], database.personVareAntall[user], n) printTopDict2(database.personVareVerdi[user], database.personVareAntall[user], n)
print("\nforbruk per ukedag [kr/dag],", end=" ") print("\nforbruk per ukedag [kr/dag],", end=" ")
printWeekdays(database.personUkedagVerdi[user], len(dateLine)) printWeekdays(database.personUkedagVerdi[user], len(dateLine))
printBalance(database, user) printBalance(database, user)
def printProduct(database, dateLine, product, n) -> None: def printProduct(database, dateLine, product, n):
printTopDict(database.varePersonAntall[product], n, 1) printTopDict(database.varePersonAntall[product], n, 1)
print("\nforbruk per ukedag [antall/dag],", end=" ") print("\nforbruk per ukedag [antall/dag],", end=" ")
printWeekdays(database.vareUkedagAntall[product], len(dateLine)) printWeekdays(database.vareUkedagAntall[product], len(dateLine))
@@ -440,7 +440,7 @@ def printProduct(database, dateLine, product, n) -> None:
) )
def printGlobal(database, dateLine, n) -> None: def printGlobal(database, dateLine, n):
print("\nmest lagt til: ") print("\nmest lagt til: ")
printTopDict(database.personNegTransactions, n, 0) printTopDict(database.personNegTransactions, n, 0)
print("\nmest tatt fra:") print("\nmest tatt fra:")
@@ -454,9 +454,9 @@ def printGlobal(database, dateLine, n) -> None:
"Det er solgt varer til en verdi av: ", "Det er solgt varer til en verdi av: ",
sum(database.globalDatoForbruk), sum(database.globalDatoForbruk),
"kr, det er lagt til", "kr, det er lagt til",
-sum(database.personNegTransactions.values()), -sum([i for i in list(database.personNegTransactions.values())]),
"og tatt fra", "og tatt fra",
sum(database.personPosTransactions.values()), sum([i for i in list(database.personPosTransactions.values())]),
end=" ", end=" ",
) )
print( print(
@@ -466,24 +466,23 @@ def printGlobal(database, dateLine, n) -> None:
) )
def alt4menuTextOnly(database, dateLine, sql_session: Session) -> None: def alt4menuTextOnly(database, dateLine):
assert sql_session is not None
n = 10 n = 10
while 1: while 1:
print( print(
"\n1: user-statistics, 2: product-statistics, 3:global-statistics, n: adjust amount of data shown q:quit", "\n1: user-statistics, 2: product-statistics, 3:global-statistics, n: adjust amount of data shown q:quit"
) )
inp = input("") inp = input("")
if inp == "q": if inp == "q":
break break
if inp == "1": elif inp == "1":
try: try:
printUser(database, dateLine, getUser(sql_session), n) printUser(database, dateLine, getUser(), n)
except: except:
print("\n\nSomething is not right, (last date prior to first date?)") print("\n\nSomething is not right, (last date prior to first date?)")
elif inp == "2": elif inp == "2":
try: try:
printProduct(database, dateLine, getProduct(sql_session), n) printProduct(database, dateLine, getProduct(), n)
except: except:
print("\n\nSomething is not right, (last date prior to first date?)") print("\n\nSomething is not right, (last date prior to first date?)")
elif inp == "3": elif inp == "3":
@@ -495,16 +494,15 @@ def alt4menuTextOnly(database, dateLine, sql_session: Session) -> None:
n = int(input("set number to show ")) n = int(input("set number to show "))
def statisticsTextOnly(sql_session: Session) -> None: def statisticsTextOnly():
assert sql_session is not None
inputType = 4 inputType = 4
product = "" product = ""
user = "" user = ""
print("\n0: from file, 1: from database, q:quit") print("\n0: from file, 1: from database, q:quit")
inp = input("") inp = input("")
if inp == "1": if inp == "1":
database, dateLine = buildDatabaseFromDb(inputType, product, user, sql_session) database, dateLine = buildDatabaseFromDb(inputType, product, user)
elif inp == "0" or inp == "": elif inp == "0" or inp == "":
database, dateLine = buildDatabaseFromFile("default.dibblerlog", inputType, product, user) database, dateLine = buildDatabaseFromFile("default.dibblerlog", inputType, product, user)
if inp != "q": if not inp == "q":
alt4menuTextOnly(database, dateLine, sql_session) alt4menuTextOnly(database, dateLine)
+9 -47
View File
@@ -1,12 +1,7 @@
import argparse import argparse
import sys
from pathlib import Path from pathlib import Path
from sqlalchemy import create_engine from dibbler.conf import config
from sqlalchemy.orm import Session
from dibbler.conf import config_db_string, load_config
from dibbler.lib.check_db_health import check_db_health
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@@ -16,20 +11,13 @@ parser.add_argument(
help="Path to the config file", help="Path to the config file",
type=Path, type=Path,
metavar="FILE", metavar="FILE",
required=False, default="config.ini",
)
parser.add_argument(
"-V",
"--version",
help="Show program version",
action="store_true",
default=False,
) )
subparsers = parser.add_subparsers( subparsers = parser.add_subparsers(
title="subcommands", title="subcommands",
dest="subcommand", dest="subcommand",
required=True,
) )
subparsers.add_parser("loop", help="Run the dibbler loop") subparsers.add_parser("loop", help="Run the dibbler loop")
subparsers.add_parser("create-db", help="Create the database") subparsers.add_parser("create-db", help="Create the database")
@@ -38,55 +26,29 @@ subparsers.add_parser("seed-data", help="Fill with mock data")
subparsers.add_parser("transaction-log", help="Print transaction log") subparsers.add_parser("transaction-log", help="Print transaction log")
def main() -> None: def main():
args = parser.parse_args() args = parser.parse_args()
config.read(args.config)
if args.version:
from ._version import commit_id, version
print(f"Dibbler version {version}, commit {commit_id if commit_id else '<unknown>'}")
return
if not args.subcommand:
parser.print_help()
sys.exit(1)
load_config(args.config)
engine = create_engine(config_db_string())
sql_session = Session(
engine,
expire_on_commit=False,
autocommit=False,
autoflush=False,
close_resets_only=True,
)
check_db_health(
engine,
verify_table_existence=args.subcommand != "create-db",
)
if args.subcommand == "loop": if args.subcommand == "loop":
import dibbler.subcommands.loop as loop import dibbler.subcommands.loop as loop
loop.main(sql_session) loop.main()
elif args.subcommand == "create-db": elif args.subcommand == "create-db":
import dibbler.subcommands.makedb as makedb import dibbler.subcommands.makedb as makedb
makedb.main(engine) makedb.main()
elif args.subcommand == "slabbedasker": elif args.subcommand == "slabbedasker":
import dibbler.subcommands.slabbedasker as slabbedasker import dibbler.subcommands.slabbedasker as slabbedasker
slabbedasker.main(sql_session) slabbedasker.main()
elif args.subcommand == "seed-data": elif args.subcommand == "seed-data":
import dibbler.subcommands.seed_test_data as seed_test_data import dibbler.subcommands.seed_test_data as seed_test_data
seed_test_data.main(sql_session) seed_test_data.main()
elif args.subcommand == "transaction-log": elif args.subcommand == "transaction-log":
import dibbler.subcommands.transaction_log as transaction_log import dibbler.subcommands.transaction_log as transaction_log
+8 -8
View File
@@ -26,28 +26,28 @@ __all__ = [
from .addstock import AddStockMenu from .addstock import AddStockMenu
from .buymenu import BuyMenu from .buymenu import BuyMenu
from .editing import ( from .editing import (
AddProductMenu,
AddUserMenu, AddUserMenu,
EditUserMenu,
AddProductMenu,
EditProductMenu,
AdjustStockMenu, AdjustStockMenu,
CleanupStockMenu, CleanupStockMenu,
EditProductMenu,
EditUserMenu,
) )
from .faq import FAQMenu from .faq import FAQMenu
from .helpermenus import Menu from .helpermenus import Menu
from .mainmenu import MainMenu from .mainmenu import MainMenu
from .miscmenus import ( from .miscmenus import (
AdjustCreditMenu,
ProductListMenu,
ProductSearchMenu, ProductSearchMenu,
ShowUserMenu,
TransferMenu, TransferMenu,
AdjustCreditMenu,
UserListMenu, UserListMenu,
ShowUserMenu,
ProductListMenu,
) )
from .printermenu import PrintLabelMenu from .printermenu import PrintLabelMenu
from .stats import ( from .stats import (
BalanceMenu,
LoggedStatisticsMenu,
ProductPopularityMenu, ProductPopularityMenu,
ProductRevenueMenu, ProductRevenueMenu,
BalanceMenu,
LoggedStatisticsMenu,
) )
+12 -21
View File
@@ -1,7 +1,6 @@
from math import ceil from math import ceil
import sqlalchemy import sqlalchemy
from sqlalchemy.orm import Session
from dibbler.models import ( from dibbler.models import (
Product, Product,
@@ -10,13 +9,12 @@ from dibbler.models import (
Transaction, Transaction,
User, User,
) )
from .helpermenus import Menu from .helpermenus import Menu
class AddStockMenu(Menu): class AddStockMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Add stock and adjust credit", sql_session) Menu.__init__(self, "Add stock and adjust credit", uses_db=True)
self.help_text = """ self.help_text = """
Enter what you have bought for PVVVV here, along with your user name and how Enter what you have bought for PVVVV here, along with your user name and how
much money you're due in credits for the purchase when prompted.\n""" much money you're due in credits for the purchase when prompted.\n"""
@@ -25,7 +23,7 @@ much money you're due in credits for the purchase when prompted.\n"""
self.products = {} self.products = {}
self.price = 0 self.price = 0
def _execute(self, **_kwargs) -> bool | None: def _execute(self):
questions = { questions = {
( (
False, False,
@@ -88,10 +86,10 @@ much money you're due in credits for the purchase when prompted.\n"""
self.perform_transaction() self.perform_transaction()
def complete_input(self) -> bool: def complete_input(self):
return self.users is not None and len(self.products) > 0 and self.price > 0 return bool(self.users) and len(self.products) and self.price
def print_info(self) -> None: def print_info(self):
width = 6 + Product.name_length width = 6 + Product.name_length
print() print()
print(width * "-") print(width * "-")
@@ -111,12 +109,7 @@ much money you're due in credits for the purchase when prompted.\n"""
print(f"{self.products[product][0]}".rjust(width - len(product.name))) print(f"{self.products[product][0]}".rjust(width - len(product.name)))
print(width * "-") print(width * "-")
def add_thing_to_pending( def add_thing_to_pending(self, thing, amount, price):
self,
thing: User | Product,
amount: int,
price: int,
) -> None:
if isinstance(thing, User): if isinstance(thing, User):
self.users.append(thing) self.users.append(thing)
elif thing in list(self.products.keys()): elif thing in list(self.products.keys()):
@@ -126,7 +119,7 @@ much money you're due in credits for the purchase when prompted.\n"""
else: else:
self.products[thing] = [amount, price] self.products[thing] = [amount, price]
def perform_transaction(self) -> None: def perform_transaction(self):
print("Did you pay a different price?") print("Did you pay a different price?")
if self.confirm(">", default=False): if self.confirm(">", default=False):
self.price = self.input_int("How much did you pay?", 0, self.price, default=self.price) self.price = self.input_int("How much did you pay?", 0, self.price, default=self.price)
@@ -139,11 +132,10 @@ much money you're due in credits for the purchase when prompted.\n"""
old_price = product.price old_price = product.price
old_hidden = product.hidden old_hidden = product.hidden
product.price = int( product.price = int(
ceil(float(value) / (max(product.stock, 0) + self.products[product][0])), ceil(float(value) / (max(product.stock, 0) + self.products[product][0]))
) )
product.stock = max( product.stock = max(
self.products[product][0], self.products[product][0], product.stock + self.products[product][0]
product.stock + self.products[product][0],
) )
product.hidden = False product.hidden = False
print( print(
@@ -159,14 +151,13 @@ much money you're due in credits for the purchase when prompted.\n"""
PurchaseEntry(purchase, product, -self.products[product][0]) PurchaseEntry(purchase, product, -self.products[product][0])
purchase.perform_soft_purchase(-self.price, round_up=False) purchase.perform_soft_purchase(-self.price, round_up=False)
self.sql_session.add(purchase) self.session.add(purchase)
try: try:
self.sql_session.commit() self.session.commit()
print("Success! Transaction performed:") print("Success! Transaction performed:")
# self.print_info() # self.print_info()
for user in self.users: for user in self.users:
print(f"User {user.name}'s credit is now {user.credit:d}") print(f"User {user.name}'s credit is now {user.credit:d}")
except sqlalchemy.exc.SQLAlchemyError as e: except sqlalchemy.exc.SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not perform transaction: {e}") print(f"Could not perform transaction: {e}")
+46 -69
View File
@@ -1,8 +1,4 @@
from typing import Any
import sqlalchemy import sqlalchemy
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from dibbler.conf import config from dibbler.conf import config
from dibbler.models import ( from dibbler.models import (
@@ -17,11 +13,10 @@ from .helpermenus import Menu
class BuyMenu(Menu): class BuyMenu(Menu):
superfast_mode: bool def __init__(self, session=None):
purchase: Purchase Menu.__init__(self, "Buy", uses_db=True)
if session:
def __init__(self, sql_session: Session) -> None: self.session = session
super().__init__("Buy", sql_session)
self.superfast_mode = False self.superfast_mode = False
self.help_text = """ self.help_text = """
Each purchase may contain one or more products and one or more buyers. Each purchase may contain one or more products and one or more buyers.
@@ -33,7 +28,7 @@ addition, and you can type 'what' at any time to redisplay it.
When finished, write an empty line to confirm the purchase.\n""" When finished, write an empty line to confirm the purchase.\n"""
@staticmethod @staticmethod
def credit_check(user: User) -> bool: def credit_check(user):
""" """
:param user: :param user:
@@ -42,32 +37,28 @@ When finished, write an empty line to confirm the purchase.\n"""
""" """
assert isinstance(user, User) assert isinstance(user, User)
return user.credit > config["limits"]["low_credit_warning_limit"] return user.credit > config.getint("limits", "low_credit_warning_limit")
def low_credit_warning( def low_credit_warning(self, user, timeout=False):
self,
user: User,
timeout: bool = False,
) -> bool:
assert isinstance(user, User) assert isinstance(user, User)
print(r"***********************************************************************") print("***********************************************************************")
print(r"***********************************************************************") print("***********************************************************************")
print(r"") print("")
print(r"$$\ $$\ $$$$$$\ $$$$$$$\ $$\ $$\ $$$$$$\ $$\ $$\ $$$$$$\\") print("$$\ $$\ $$$$$$\ $$$$$$$\ $$\ $$\ $$$$$$\ $$\ $$\ $$$$$$\\")
print(r"$$ | $\ $$ |$$ __$$\ $$ __$$\ $$$\ $$ |\_$$ _|$$$\ $$ |$$ __$$\\") print("$$ | $\ $$ |$$ __$$\ $$ __$$\ $$$\ $$ |\_$$ _|$$$\ $$ |$$ __$$\\")
print(r"$$ |$$$\ $$ |$$ / $$ |$$ | $$ |$$$$\ $$ | $$ | $$$$\ $$ |$$ / \__|") print("$$ |$$$\ $$ |$$ / $$ |$$ | $$ |$$$$\ $$ | $$ | $$$$\ $$ |$$ / \__|")
print(r"$$ $$ $$\$$ |$$$$$$$$ |$$$$$$$ |$$ $$\$$ | $$ | $$ $$\$$ |$$ |$$$$\\") print("$$ $$ $$\$$ |$$$$$$$$ |$$$$$$$ |$$ $$\$$ | $$ | $$ $$\$$ |$$ |$$$$\\")
print(r"$$$$ _$$$$ |$$ __$$ |$$ __$$< $$ \$$$$ | $$ | $$ \$$$$ |$$ |\_$$ |") print("$$$$ _$$$$ |$$ __$$ |$$ __$$< $$ \$$$$ | $$ | $$ \$$$$ |$$ |\_$$ |")
print(r"$$$ / \$$$ |$$ | $$ |$$ | $$ |$$ |\$$$ | $$ | $$ |\$$$ |$$ | $$ |") print("$$$ / \$$$ |$$ | $$ |$$ | $$ |$$ |\$$$ | $$ | $$ |\$$$ |$$ | $$ |")
print(r"$$ / \$$ |$$ | $$ |$$ | $$ |$$ | \$$ |$$$$$$\ $$ | \$$ |\$$$$$$ |") print("$$ / \$$ |$$ | $$ |$$ | $$ |$$ | \$$ |$$$$$$\ $$ | \$$ |\$$$$$$ |")
print(r"\__/ \__|\__| \__|\__| \__|\__| \__|\______|\__| \__| \______/") print("\__/ \__|\__| \__|\__| \__|\__| \__|\______|\__| \__| \______/")
print(r"") print("")
print(r"***********************************************************************") print("***********************************************************************")
print(r"***********************************************************************") print("***********************************************************************")
print(r"") print("")
print( print(
f"USER {user.name} HAS LOWER CREDIT THAN {config['limits']['low_credit_warning_limit']:d}.", f"USER {user.name} HAS LOWER CREDIT THAN {config.getint('limits', 'low_credit_warning_limit'):d}."
) )
print("THIS PURCHASE WILL CHARGE YOUR CREDIT TWICE AS MUCH.") print("THIS PURCHASE WILL CHARGE YOUR CREDIT TWICE AS MUCH.")
print("CONSIDER PUTTING MONEY IN THE BOX TO AVOID THIS.") print("CONSIDER PUTTING MONEY IN THE BOX TO AVOID THIS.")
@@ -77,13 +68,10 @@ When finished, write an empty line to confirm the purchase.\n"""
if timeout: if timeout:
print("THIS PURCHASE WILL AUTOMATICALLY BE PERFORMED IN 3 MINUTES!") print("THIS PURCHASE WILL AUTOMATICALLY BE PERFORMED IN 3 MINUTES!")
return self.confirm(prompt=">", default=True, timeout=180) return self.confirm(prompt=">", default=True, timeout=180)
return self.confirm(prompt=">", default=True) else:
return self.confirm(prompt=">", default=True)
def add_thing_to_purchase( def add_thing_to_purchase(self, thing, amount=1):
self,
thing: User | Product,
amount: int = 1,
) -> bool:
if isinstance(thing, User): if isinstance(thing, User):
if thing.is_anonymous(): if thing.is_anonymous():
print("---------------------------------------------") print("---------------------------------------------")
@@ -92,10 +80,7 @@ When finished, write an empty line to confirm the purchase.\n"""
print("---------------------------------------------") print("---------------------------------------------")
if not self.credit_check(thing): if not self.credit_check(thing):
if self.low_credit_warning( if self.low_credit_warning(user=thing, timeout=self.superfast_mode):
user=thing,
timeout=self.superfast_mode,
):
Transaction(thing, purchase=self.purchase, penalty=2) Transaction(thing, purchase=self.purchase, penalty=2)
else: else:
return False return False
@@ -110,11 +95,7 @@ When finished, write an empty line to confirm the purchase.\n"""
PurchaseEntry(self.purchase, thing, amount) PurchaseEntry(self.purchase, thing, amount)
return True return True
def _execute( def _execute(self, initial_contents=None):
self,
initial_contents: list[tuple[User | Product, int]] | None = None,
**_kwargs,
) -> bool:
self.print_header() self.print_header()
self.purchase = Purchase() self.purchase = Purchase()
self.exit_confirm_msg = None self.exit_confirm_msg = None
@@ -126,7 +107,7 @@ When finished, write an empty line to confirm the purchase.\n"""
for thing, num in initial_contents: for thing, num in initial_contents:
self.add_thing_to_purchase(thing, num) self.add_thing_to_purchase(thing, num)
def is_product(candidate: Any) -> bool: def is_product(candidate):
return isinstance(candidate[0], Product) return isinstance(candidate[0], Product)
if len(initial_contents) > 0 and all(map(is_product, initial_contents)): if len(initial_contents) > 0 and all(map(is_product, initial_contents)):
@@ -148,7 +129,7 @@ When finished, write an empty line to confirm the purchase.\n"""
True, True,
True, True,
): "Enter more products or users, or an empty line to confirm", ): "Enter more products or users, or an empty line to confirm",
}[(len(self.purchase.transactions) > 0, len(self.purchase.entries) > 0)], }[(len(self.purchase.transactions) > 0, len(self.purchase.entries) > 0)]
) )
# Read in a 'thing' (product or user): # Read in a 'thing' (product or user):
@@ -166,16 +147,16 @@ When finished, write an empty line to confirm the purchase.\n"""
if thing is None: if thing is None:
if not self.complete_input(): if not self.complete_input():
if self.confirm( if self.confirm(
"Not enough information entered. Abort purchase?", "Not enough information entered. Abort purchase?", default=True
default=True,
): ):
return False return False
continue continue
break break
# once we get something in the else:
# purchase, we want to protect the # once we get something in the
# user from accidentally killing it # purchase, we want to protect the
self.exit_confirm_msg = "Abort purchase?" # user from accidentally killing it
self.exit_confirm_msg = "Abort purchase?"
# Add the thing to our purchase object: # Add the thing to our purchase object:
if not self.add_thing_to_purchase(thing, amount=num): if not self.add_thing_to_purchase(thing, amount=num):
@@ -186,11 +167,10 @@ When finished, write an empty line to confirm the purchase.\n"""
break break
self.purchase.perform_purchase() self.purchase.perform_purchase()
self.sql_session.add(self.purchase) self.session.add(self.purchase)
try: try:
self.sql_session.commit() self.session.commit()
except SQLAlchemyError as e: except sqlalchemy.exc.SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store purchase: {e}") print(f"Could not store purchase: {e}")
else: else:
print("Purchase stored.") print("Purchase stored.")
@@ -198,9 +178,9 @@ When finished, write an empty line to confirm the purchase.\n"""
for t in self.purchase.transactions: for t in self.purchase.transactions:
if not t.user.is_anonymous(): if not t.user.is_anonymous():
print(f"User {t.user.name}'s credit is now {t.user.credit:d} kr") print(f"User {t.user.name}'s credit is now {t.user.credit:d} kr")
if t.user.credit < config["limits"]["low_credit_warning_limit"]: if t.user.credit < config.getint("limits", "low_credit_warning_limit"):
print( print(
f"USER {t.user.name} HAS LOWER CREDIT THAN {config['limits']['low_credit_warning_limit']:d},", f"USER {t.user.name} HAS LOWER CREDIT THAN {config.getint('limits', 'low_credit_warning_limit'):d},",
"AND SHOULD CONSIDER PUTTING SOME MONEY IN THE BOX.", "AND SHOULD CONSIDER PUTTING SOME MONEY IN THE BOX.",
) )
@@ -209,10 +189,10 @@ When finished, write an empty line to confirm the purchase.\n"""
print("") print("")
return True return True
def complete_input(self) -> bool: def complete_input(self):
return self.purchase.is_complete() return self.purchase.is_complete()
def format_purchase(self) -> str | None: def format_purchase(self):
self.purchase.set_price() self.purchase.set_price()
transactions = self.purchase.transactions transactions = self.purchase.transactions
entries = self.purchase.entries entries = self.purchase.entries
@@ -224,10 +204,7 @@ When finished, write an empty line to confirm the purchase.\n"""
string += "(empty)" string += "(empty)"
else: else:
string += ", ".join( string += ", ".join(
[ [t.user.name + ("*" if not self.credit_check(t.user) else "") for t in transactions]
t.user.name + ("*" if not self.credit_check(t.user) else "")
for t in transactions
],
) )
string += "\n products: " string += "\n products: "
if len(entries) == 0: if len(entries) == 0:
@@ -235,7 +212,7 @@ When finished, write an empty line to confirm the purchase.\n"""
else: else:
string += "\n " string += "\n "
string += "\n ".join( string += "\n ".join(
[f"{e.amount:d}x {e.product.name} ({e.product.price:d} kr)" for e in entries], [f"{e.amount:d}x {e.product.name} ({e.product.price:d} kr)" for e in entries]
) )
if len(transactions) > 1: if len(transactions) > 1:
string += f"\n price per person: {self.purchase.price_per_transaction():d} kr" string += f"\n price per person: {self.purchase.price_per_transaction():d} kr"
@@ -251,7 +228,7 @@ When finished, write an empty line to confirm the purchase.\n"""
return string return string
def print_purchase(self) -> None: def print_purchase(self):
info = self.format_purchase() info = self.format_purchase()
if info is not None: if info is not None:
self.set_context(info) self.set_context(info)
+38 -67
View File
@@ -1,9 +1,6 @@
import sqlalchemy import sqlalchemy
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.orm import Session
from dibbler.models import Product, User
from dibbler.models import User, Product
from .helpermenus import Menu, Selector from .helpermenus import Menu, Selector
__all__ = [ __all__ = [
@@ -17,48 +14,32 @@ __all__ = [
class AddUserMenu(Menu): class AddUserMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Add user", sql_session) Menu.__init__(self, "Add user", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
username = self.input_str( username = self.input_str(
"Username (should be same as PVV username)", "Username (should be same as PVV username)",
regex=User.name_re, regex=User.name_re,
length_range=(1, 10), length_range=(1, 10),
) )
assert username is not None cardnum = self.input_str("Card number (optional)", regex=User.card_re, length_range=(0, 10))
cardnum = cardnum.lower()
cardnum = self.input_str( rfid = self.input_str("RFID (optional)", regex=User.rfid_re, length_range=(0, 10))
"Card number (optional)",
regex=User.card_re,
length_range=(0, 10),
empty_string_is_none=True,
)
if cardnum is not None:
cardnum = cardnum.lower()
rfid = self.input_str(
"RFID (optional)",
regex=User.rfid_re,
length_range=(0, 10),
empty_string_is_none=True,
)
user = User(username, cardnum, rfid) user = User(username, cardnum, rfid)
self.sql_session.add(user) self.session.add(user)
try: try:
self.sql_session.commit() self.session.commit()
print(f"User {username} stored") print(f"User {username} stored")
except IntegrityError as e: except sqlalchemy.exc.IntegrityError as e:
self.sql_session.rollback()
print(f"Could not store user {username}: {e}") print(f"Could not store user {username}: {e}")
self.pause() self.pause()
class EditUserMenu(Menu): class EditUserMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Edit user", sql_session) Menu.__init__(self, "Edit user", uses_db=True)
self.help_text = """ self.help_text = """
The only editable part of a user is its card number and rfid. The only editable part of a user is its card number and rfid.
@@ -66,7 +47,7 @@ First select an existing user, then enter a new card number for that
user, then rfid (write an empty line to remove the card number or rfid). user, then rfid (write an empty line to remove the card number or rfid).
""" """
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
user = self.input_user("User") user = self.input_user("User")
self.printc(f"Editing user {user.name}") self.printc(f"Editing user {user.name}")
@@ -88,50 +69,43 @@ user, then rfid (write an empty line to remove the card number or rfid).
empty_string_is_none=True, empty_string_is_none=True,
) )
try: try:
self.sql_session.commit() self.session.commit()
print(f"User {user.name} stored") print(f"User {user.name} stored")
except SQLAlchemyError as e: except sqlalchemy.exc.SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store user {user.name}: {e}") print(f"Could not store user {user.name}: {e}")
self.pause() self.pause()
class AddProductMenu(Menu): class AddProductMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Add product", sql_session) Menu.__init__(self, "Add product", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
bar_code = self.input_str("Bar code", regex=Product.bar_code_re, length_range=(8, 13)) bar_code = self.input_str("Bar code", regex=Product.bar_code_re, length_range=(8, 13))
assert bar_code is not None
name = self.input_str("Name", regex=Product.name_re, length_range=(1, Product.name_length)) name = self.input_str("Name", regex=Product.name_re, length_range=(1, Product.name_length))
assert name is not None
price = self.input_int("Price", 1, 100000) price = self.input_int("Price", 1, 100000)
product = Product(bar_code, name, price) product = Product(bar_code, name, price)
self.sql_session.add(product) self.session.add(product)
try: try:
self.sql_session.commit() self.session.commit()
print(f"Product {name} stored") print(f"Product {name} stored")
except SQLAlchemyError as e: except sqlalchemy.exc.SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store product {name}: {e}") print(f"Could not store product {name}: {e}")
self.pause() self.pause()
class EditProductMenu(Menu): class EditProductMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Edit product", sql_session) Menu.__init__(self, "Edit product", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
product = self.input_product("Product") product = self.input_product("Product")
self.printc(f"Editing product {product.name}") self.printc(f"Editing product {product.name}")
while True: while True:
selector = Selector( selector = Selector(
f"Do what with {product.name}?", f"Do what with {product.name}?",
sql_session=self.sql_session,
items=[ items=[
("name", "Edit name"), ("name", "Edit name"),
("price", "Edit price"), ("price", "Edit price"),
@@ -161,10 +135,9 @@ class EditProductMenu(Menu):
product.hidden = self.confirm(f"Hidden(currently {product.hidden})", default=False) product.hidden = self.confirm(f"Hidden(currently {product.hidden})", default=False)
elif what == "store": elif what == "store":
try: try:
self.sql_session.commit() self.session.commit()
print(f"Product {product.name} stored") print(f"Product {product.name} stored")
except SQLAlchemyError as e: except sqlalchemy.exc.SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store product {product.name}: {e}") print(f"Could not store product {product.name}: {e}")
self.pause() self.pause()
return return
@@ -176,10 +149,10 @@ class EditProductMenu(Menu):
class AdjustStockMenu(Menu): class AdjustStockMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Adjust stock", sql_session) Menu.__init__(self, "Adjust stock", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
product = self.input_product("Product") product = self.input_product("Product")
@@ -195,11 +168,10 @@ class AdjustStockMenu(Menu):
product.stock += add_stock product.stock += add_stock
try: try:
self.sql_session.commit() self.session.commit()
print("Stock is now stored") print("Stock is now stored")
self.pause() self.pause()
except SQLAlchemyError as e: except sqlalchemy.exc.SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store stock: {e}") print(f"Could not store stock: {e}")
self.pause() self.pause()
return return
@@ -207,13 +179,13 @@ class AdjustStockMenu(Menu):
class CleanupStockMenu(Menu): class CleanupStockMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Stock Cleanup", sql_session) Menu.__init__(self, "Stock Cleanup", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
products = self.sql_session.query(Product).filter(Product.stock != 0).all() products = self.session.query(Product).filter(Product.stock != 0).all()
print("Every product in stock will be printed.") print("Every product in stock will be printed.")
print("Entering no value will keep current stock or set it to 0 if it is negative.") print("Entering no value will keep current stock or set it to 0 if it is negative.")
@@ -227,16 +199,15 @@ class CleanupStockMenu(Menu):
for product in products: for product in products:
oldstock = product.stock oldstock = product.stock
product.stock = self.input_int(product.name, 0, 10000, default=max(0, oldstock)) product.stock = self.input_int(product.name, 0, 10000, default=max(0, oldstock))
self.sql_session.add(product) self.session.add(product)
if oldstock != product.stock: if oldstock != product.stock:
changed_products.append((product, oldstock)) changed_products.append((product, oldstock))
try: try:
self.sql_session.commit() self.session.commit()
print("New stocks are now stored.") print("New stocks are now stored.")
self.pause() self.pause()
except SQLAlchemyError as e: except sqlalchemy.exc.SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store stock: {e}") print(f"Could not store stock: {e}")
self.pause() self.pause()
return return
+70 -87
View File
@@ -1,146 +1,129 @@
from textwrap import dedent # -*- coding: utf-8 -*-
from sqlalchemy.orm import Session from .helpermenus import MessageMenu, Menu
from .helpermenus import Menu, MessageMenu
class FAQMenu(Menu): class FAQMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Frequently Asked Questions", sql_session) Menu.__init__(self, "Frequently Asked Questions")
self.items = [ self.items = [
MessageMenu( MessageMenu(
"What is the meaning with this program?", "What is the meaning with this program?",
dedent(""" """
We want to avoid keeping lots of cash in PVVVV's money box and to We want to avoid keeping lots of cash in PVVVV's money box and to
make it easy to pay for stuff without using money. (Without using make it easy to pay for stuff without using money. (Without using
money each time, that is. You do of course have to pay for the things money each time, that is. You do of course have to pay for the things
you buy eventually). you buy eventually).
Dibbler stores a "credit" amount for each user. When you register a Dibbler stores a "credit" amount for each user. When you register a
purchase in Dibbler, this amount is decreased. To increase your purchase in Dibbler, this amount is decreased. To increase your
credit, purchase products for dibbler, and register them using "Add credit, purchase products for dibbler, and register them using "Add
stock and adjust credit". stock and adjust credit".
Alternatively, add money to the money box and use "Adjust credit" to Alternatively, add money to the money box and use "Adjust credit" to
tell Dibbler about it. tell Dibbler about it.
"""), """,
sql_session,
), ),
MessageMenu( MessageMenu(
"Can I still pay for stuff using cash?", "Can I still pay for stuff using cash?",
dedent(""" """
Please put money in the money box and use "Adjust Credit" so that Please put money in the money box and use "Adjust Credit" so that
dibbler can keep track of credit and purchases. dibbler can keep track of credit and purchases.""",
"""),
sql_session,
),
MessageMenu(
"How do I exit from a submenu/dialog/thing?",
'Type "exit", "q", or ^d.',
sql_session,
), ),
MessageMenu("How do I exit from a submenu/dialog/thing?", 'Type "exit", "q", or ^d.'),
MessageMenu( MessageMenu(
'What does "." mean?', 'What does "." mean?',
dedent(""" """
The "." character, known as "full stop" or "period", is most often The "." character, known as "full stop" or "period", is most often
used to indicate the end of a sentence. used to indicate the end of a sentence.
It is also used by Dibbler to indicate that the program wants you to It is also used by Dibbler to indicate that the program wants you to
read some text before continuing. Whenever some output ends with a read some text before continuing. Whenever some output ends with a
line containing only a period, you should read the lines above and line containing only a period, you should read the lines above and
then press enter to continue. then press enter to continue.
"""), """,
sql_session,
), ),
MessageMenu( MessageMenu(
"Why is the user interface so terribly unintuitive?", "Why is the user interface so terribly unintuitive?",
dedent(""" """
Answer #1: It is not. Answer #1: It is not.
Answer #2: We are trying to compete with PVV's microwave oven in Answer #2: We are trying to compete with PVV's microwave oven in
userfriendliness. userfriendliness.
Answer #3: YOU are unintuitive. Answer #3: YOU are unintuitive.
"""), """,
sql_session,
), ),
MessageMenu( MessageMenu(
"Why is there no help command?", "Why is there no help command?",
'There is. Have you tried typing "help"?', 'There is. Have you tried typing "help"?',
sql_session,
), ),
MessageMenu( MessageMenu(
'Where are the easter eggs? I tried saying "moo", but nothing happened.', 'Where are the easter eggs? I tried saying "moo", but nothing happened.',
'Don\'t say "moo".', 'Don\'t say "moo".',
sql_session,
), ),
MessageMenu( MessageMenu(
"Why does the program speak English when all the users are Norwegians?", "Why does the program speak English when all the users are Norwegians?",
"Godt spørsmål. Det virket sikkert som en god idé der og da.", "Godt spørsmål. Det virket sikkert som en god idé der og da.",
sql_session,
), ),
MessageMenu( MessageMenu(
"Why does the screen have strange colours?", "Why does the screen have strange colours?",
dedent(""" """
Type "c" on the main menu to change the colours of the display, or Type "c" on the main menu to change the colours of the display, or
"cs" if you are a boring person. "cs" if you are a boring person.
"""), """,
sql_session,
), ),
MessageMenu( MessageMenu(
"I found a bug; is there a reward?", "I found a bug; is there a reward?",
dedent(""" """
No. No.
But if you are certain that it is a bug, not a feature, then you But if you are certain that it is a bug, not a feature, then you
should fix it (or better: force someone else to do it). should fix it (or better: force someone else to do it).
Follow this procedure: Follow this procedure:
1. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler 1. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler
2. Fix the bug. 2. Fix the bug.
3. Check that the program still runs (and, preferably, that the bug is 3. Check that the program still runs (and, preferably, that the bug is
in fact fixed). in fact fixed).
4. Commit. 4. Commit.
5. Update the running copy from svn: 5. Update the running copy from svn:
$ su - $ su -
# su -l -s /bin/bash pvvvv # su -l -s /bin/bash pvvvv
$ cd dibbler $ cd dibbler
$ git pull $ git pull
6. Type "restart" in Dibbler to replace the running process by a new 6. Type "restart" in Dibbler to replace the running process by a new
one using the updated files. one using the updated files.
"""), """,
sql_session,
), ),
MessageMenu( MessageMenu(
"My question isn't listed here; what do I do?", "My question isn't listed here; what do I do?",
dedent(""" """
DON'T PANIC. DON'T PANIC.
Follow this procedure: Follow this procedure:
1. Ask someone (or read the source code) and get an answer. 1. Ask someone (or read the source code) and get an answer.
2. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler 2. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler
3. Add your question (with answer) to the FAQ and commit. 3. Add your question (with answer) to the FAQ and commit.
4. Update the running copy from svn: 4. Update the running copy from svn:
$ su - $ su -
# su -l -s /bin/bash pvvvv # su -l -s /bin/bash pvvvv
$ cd dibbler $ cd dibbler
$ git pull $ git pull
5. Type "restart" in Dibbler to replace the running process by a new 5. Type "restart" in Dibbler to replace the running process by a new
one using the updated files. one using the updated files.
"""), """,
sql_session,
), ),
] ]
+165 -277
View File
@@ -1,64 +1,44 @@
from __future__ import annotations # -*- coding: utf-8 -*-
import re import re
import sys import sys
from select import select from select import select
from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar
from dibbler.db import Session
from dibbler.models import User
from dibbler.lib.helpers import ( from dibbler.lib.helpers import (
argmax,
guess_data_type,
search_product,
search_user, search_user,
search_product,
guess_data_type,
argmax,
) )
from dibbler.models import Product, User
if TYPE_CHECKING: exit_commands = ["exit", "abort", "quit", "bye", "eat flaming death", "q"]
from collections.abc import Callable, Iterable help_commands = ["help", "?"]
context_commands = ["what", "??"]
from sqlalchemy.orm import Session local_help_commands = ["help!", "???"]
exit_commands: list[str] = ["exit", "abort", "quit", "bye", "eat flaming death", "q"]
help_commands: list[str] = ["help", "?"]
context_commands: list[str] = ["what", "??"]
local_help_commands: list[str] = ["help!", "???"]
class ExitMenuException(Exception): class ExitMenu(Exception):
pass pass
MenuItemType = TypeVar("MenuItemType", bound="Menu") class Menu(object):
class Menu:
name: str
sql_session: Session
items: list[Menu | tuple[MenuItemType, str] | str]
prompt: str | None
end_prompt: str | None
return_index: bool
exit_msg: str | None
exit_confirm_msg: str | None
exit_disallowed_msg: str | None
help_text: str | None
context: str | None
def __init__( def __init__(
self, self,
name: str, name,
sql_session: Session, items=None,
items: list[Self | tuple[MenuItemType, str] | str] | None = None, prompt=None,
prompt: str | None = None, end_prompt="> ",
end_prompt: str | None = "> ", return_index=True,
return_index: bool = True, exit_msg=None,
exit_msg: str | None = None, exit_confirm_msg=None,
exit_confirm_msg: str | None = None, exit_disallowed_msg=None,
exit_disallowed_msg: str | None = None, help_text=None,
help_text: str | None = None, uses_db=False,
) -> None: ):
self.name: str = name self.name = name
self.sql_session: Session = sql_session
self.items = items if items is not None else [] self.items = items if items is not None else []
self.prompt = prompt self.prompt = prompt
self.end_prompt = end_prompt self.end_prompt = end_prompt
@@ -68,61 +48,54 @@ class Menu:
self.exit_disallowed_msg = exit_disallowed_msg self.exit_disallowed_msg = exit_disallowed_msg
self.help_text = help_text self.help_text = help_text
self.context = None self.context = None
self.uses_db = uses_db
self.session = None
assert name is not None def exit_menu(self):
assert self.sql_session is not None
def exit_menu(self) -> None:
if self.exit_disallowed_msg is not None: if self.exit_disallowed_msg is not None:
print(self.exit_disallowed_msg) print(self.exit_disallowed_msg)
return return
if self.exit_confirm_msg is not None: if self.exit_confirm_msg is not None:
if not self.confirm(self.exit_confirm_msg, default=True): if not self.confirm(self.exit_confirm_msg, default=True):
return return
raise ExitMenuException() raise ExitMenu()
def at_exit(self) -> None: def at_exit(self):
if self.exit_msg: if self.exit_msg:
print(self.exit_msg) print(self.exit_msg)
def set_context( def set_context(self, string, display=True):
self,
string: str | None,
display: bool = True,
) -> None:
self.context = string self.context = string
if self.context is not None and display: if self.context is not None and display:
print(self.context) print(self.context)
def add_to_context(self, string: str) -> None: def add_to_context(self, string):
if self.context is not None: self.context += string
self.context += string
else:
self.context = string
def printc(self, string: str) -> None: def printc(self, string):
print(string) print(string)
if self.context is None: if self.context is None:
self.context = string self.context = string
else: else:
self.context += "\n" + string self.context += "\n" + string
def show_context(self) -> None: def show_context(self):
print(self.header()) print(self.header())
if self.context is not None: if self.context is not None:
print(self.context) print(self.context)
def item_is_submenu(self, i: int) -> bool: def item_is_submenu(self, i):
return isinstance(self.items[i], Menu) return isinstance(self.items[i], Menu)
def item_name(self, i: int) -> str: def item_name(self, i):
if self.item_is_submenu(i): if self.item_is_submenu(i):
return self.items[i].name return self.items[i].name
if isinstance(self.items[i], tuple): elif isinstance(self.items[i], tuple):
return self.items[i][1] return self.items[i][1]
return self.items[i] else:
return self.items[i]
def item_value(self, i: int) -> MenuItemType | int: def item_value(self, i):
if isinstance(self.items[i], tuple): if isinstance(self.items[i], tuple):
return self.items[i][0] return self.items[i][0]
if self.return_index: if self.return_index:
@@ -131,14 +104,14 @@ class Menu:
def input_str( def input_str(
self, self,
prompt: str | None = None, prompt=None,
end_prompt: str | None = None, end_prompt=None,
regex: str | None = None, regex=None,
length_range: tuple[int | None, int | None] = (None, None), length_range=(None, None),
empty_string_is_none: bool = False, empty_string_is_none=False,
timeout: int | None = None, timeout=None,
default: str | None = None, default=None,
) -> str | None: ):
if prompt is None: if prompt is None:
prompt = self.prompt if self.prompt is not None else "" prompt = self.prompt if self.prompt is not None else ""
if default is not None: if default is not None:
@@ -195,7 +168,7 @@ class Menu:
): ):
if length_range[0] and length_range[1]: if length_range[0] and length_range[1]:
print( print(
f"Value must have length in range [{length_range[0]:d}, {length_range[1]:d}]", f"Value must have length in range [{length_range[0]:d}, {length_range[1]:d}]"
) )
elif length_range[0]: elif length_range[0]:
print(f"Value must have length at least {length_range[0]:d}") print(f"Value must have length at least {length_range[0]:d}")
@@ -204,7 +177,7 @@ class Menu:
continue continue
return result return result
def special_input_options(self, result) -> bool: def special_input_options(self, result):
""" """
Handles special, magic input for input_str Handles special, magic input for input_str
@@ -214,7 +187,7 @@ class Menu:
""" """
return False return False
def special_input_choice(self, in_str: str) -> bool: def special_input_choice(self, in_str):
""" """
Handle choices which are not simply menu items. Handle choices which are not simply menu items.
@@ -224,39 +197,33 @@ class Menu:
""" """
return False return False
def input_choice( def input_choice(self, number_of_choices, prompt=None, end_prompt=None):
self,
number_of_choices: int,
prompt: str | None = None,
end_prompt: str | None = None,
) -> int:
while True: while True:
result = self.input_str(prompt, end_prompt) result = self.input_str(prompt, end_prompt)
assert result is not None
if result == "": if result == "":
print("Please enter something") print("Please enter something")
else: else:
if result.isdigit(): if result.isdigit():
choice = int(result) choice = int(result)
if choice == 0 and number_of_choices >= 10: if choice == 0 and 10 <= number_of_choices:
return 10 return 10
if 0 < choice <= number_of_choices: if 0 < choice <= number_of_choices:
return choice return choice
if not self.special_input_choice(result): if not self.special_input_choice(result):
self.invalid_menu_choice(result) self.invalid_menu_choice(result)
def invalid_menu_choice(self, in_str: str) -> None: def invalid_menu_choice(self, in_str):
print("Please enter a valid choice.") print("Please enter a valid choice.")
def input_int( def input_int(
self, self,
prompt: str, prompt=None,
minimum: int | None = None, minimum=None,
maximum: int | None = None, maximum=None,
null_allowed: bool = False, null_allowed=False,
zero_allowed: bool = True, zero_allowed=True,
default: int | None = None, default=None,
) -> int | Literal[False]: ):
if minimum is not None and maximum is not None: if minimum is not None and maximum is not None:
end_prompt = f"({minimum}-{maximum})>" end_prompt = f"({minimum}-{maximum})>"
elif minimum is not None: elif minimum is not None:
@@ -267,11 +234,7 @@ class Menu:
end_prompt = "" end_prompt = ""
while True: while True:
result = self.input_str( result = self.input_str(prompt + end_prompt, default=default)
prompt + end_prompt,
default=str(default) if default is not None else None,
)
assert result is not None
if result == "" and null_allowed: if result == "" and null_allowed:
return False return False
try: try:
@@ -289,115 +252,93 @@ class Menu:
except ValueError: except ValueError:
print("Please enter an integer") print("Please enter an integer")
def input_user( def input_user(self, prompt=None, end_prompt=None):
self,
prompt: str | None = None,
end_prompt: str | None = None,
) -> User:
user = None user = None
while user is None: while user is None:
search_string = self.input_str(prompt, end_prompt) user = self.retrieve_user(self.input_str(prompt, end_prompt))
assert search_string is not None
user = self.retrieve_user(search_string)
return user return user
def retrieve_user(self, search_str: str) -> User | None: def retrieve_user(self, search_str):
return self.search_ui(search_user, search_str, "user") return self.search_ui(search_user, search_str, "user")
def input_product( def input_product(self, prompt=None, end_prompt=None):
self,
prompt: str | None = None,
end_prompt: str | None = None,
) -> Product:
product = None product = None
while product is None: while product is None:
search_string = self.input_str(prompt, end_prompt) product = self.retrieve_product(self.input_str(prompt, end_prompt))
assert search_string is not None
product = self.retrieve_product(search_string)
return product return product
def retrieve_product(self, search_str: str) -> Product | None: def retrieve_product(self, search_str):
return self.search_ui(search_product, search_str, "product") return self.search_ui(search_product, search_str, "product")
def input_thing( def input_thing(
self, self,
prompt: str | None = None, prompt=None,
end_prompt: str | None = None, end_prompt=None,
permitted_things: Iterable[str] = ("user", "product"), permitted_things=("user", "product"),
add_nonexisting: Iterable[str] = (), add_nonexisting=(),
empty_input_permitted: bool = False, empty_input_permitted=False,
find_hidden_products: bool = True, find_hidden_products=True,
) -> User | Product | None: ):
result = None result = None
while result is None: while result is None:
search_str = self.input_str(prompt, end_prompt) search_str = self.input_str(prompt, end_prompt)
assert search_str is not None
if search_str == "" and empty_input_permitted: if search_str == "" and empty_input_permitted:
return None return None
result = self.search_for_thing( result = self.search_for_thing(
search_str, search_str, permitted_things, add_nonexisting, find_hidden_products
permitted_things,
add_nonexisting,
find_hidden_products,
) )
return result return result
def input_multiple( def input_multiple(
self, self,
prompt: str | None = None, prompt=None,
end_prompt: str | None = None, end_prompt=None,
permitted_things: Iterable[str] = ("user", "product"), permitted_things=("user", "product"),
add_nonexisting: Iterable[str] = (), add_nonexisting=(),
empty_input_permitted: bool = False, empty_input_permitted=False,
find_hidden_products: bool = True, find_hidden_products=True,
) -> tuple[User | Product, int] | None: ):
result = None result = None
num = 0 num = 0
while result is None: while result is None:
search_str = self.input_str(prompt, end_prompt) search_str = self.input_str(prompt, end_prompt)
assert search_str is not None
search_lst = search_str.split(" ") search_lst = search_str.split(" ")
if search_str == "" and empty_input_permitted: if search_str == "" and empty_input_permitted:
return None return None
result = self.search_for_thing( else:
search_str, result = self.search_for_thing(
permitted_things, search_str, permitted_things, add_nonexisting, find_hidden_products
add_nonexisting, )
find_hidden_products, num = 1
)
num = 1
if (result is None) and (len(search_lst) > 1): if (result is None) and (len(search_lst) > 1):
print('Interpreting input as "<number> <product>"') print('Interpreting input as "<number> <product>"')
try: try:
num = int(search_lst[0]) num = int(search_lst[0])
result = self.search_for_thing( result = self.search_for_thing(
" ".join(search_lst[1:]), " ".join(search_lst[1:]),
permitted_things, permitted_things,
add_nonexisting, add_nonexisting,
find_hidden_products, find_hidden_products,
) )
# Her kan det legges inn en except ValueError, # Her kan det legges inn en except ValueError,
# men da blir det fort mye plaging av brukeren # men da blir det fort mye plaging av brukeren
except Exception as e: except Exception as e:
print(e) print(e)
return result, num return result, num
def search_for_thing( def search_for_thing(
self, self,
search_str: str, search_str,
permitted_things: Iterable[str] = ("user", "product"), permitted_things=("user", "product"),
add_non_existing: Iterable[str] = (), add_non_existing=(),
find_hidden_products: bool = True, find_hidden_products=True,
) -> User | Product | None: ):
search_fun = { search_fun = {"user": search_user, "product": search_product}
"user": search_user,
"product": search_product,
}
results = {} results = {}
result_values = {} result_values = {}
for thing in permitted_things: for thing in permitted_things:
results[thing] = search_fun[thing](search_str, self.sql_session, find_hidden_products) results[thing] = search_fun[thing](search_str, self.session, find_hidden_products)
result_values[thing] = self.search_result_value(results[thing]) result_values[thing] = self.search_result_value(results[thing])
selected_thing = argmax(result_values) selected_thing = argmax(result_values)
if not results[selected_thing]: if not results[selected_thing]:
@@ -412,14 +353,10 @@ class Menu:
return self.search_add(search_str) return self.search_add(search_str)
# print('No match found for "%s".' % search_str) # print('No match found for "%s".' % search_str)
return None return None
return self.search_ui2( return self.search_ui2(search_str, results[selected_thing], selected_thing)
search_str,
results[selected_thing],
selected_thing,
)
@staticmethod @staticmethod
def search_result_value(result) -> Literal[0, 1, 2, 3]: def search_result_value(result):
if result is None: if result is None:
return 0 return 0
if not isinstance(result, list): if not isinstance(result, list):
@@ -430,19 +367,18 @@ class Menu:
return 2 return 2
return 1 return 1
def search_add(self, string: str) -> User | None: def search_add(self, string):
type_guess = guess_data_type(string) type_guess = guess_data_type(string)
if type_guess == "username": if type_guess == "username":
print(f'"{string}" looks like a username, but no such user exists.') print(f'"{string}" looks like a username, but no such user exists.')
if self.confirm(f"Create user {string}?"): if self.confirm(f"Create user {string}?"):
user = User(string, None) user = User(string, None)
self.sql_session.add(user) self.session.add(user)
return user return user
return None return None
if type_guess == "card": if type_guess == "card":
selector = Selector( selector = Selector(
f'"{string}" looks like a card number, but no user with that card number exists.', f'"{string}" looks like a card number, but no user with that card number exists.',
self.sql_session,
[ [
("create", f"Create user with card number {string}"), ("create", f"Create user with card number {string}"),
("set", f"Set card number of an existing user to {string}"), ("set", f"Set card number of an existing user to {string}"),
@@ -451,14 +387,12 @@ class Menu:
selection = selector.execute() selection = selector.execute()
if selection == "create": if selection == "create":
username = self.input_str( username = self.input_str(
prompt="Username for new user (should be same as PVV username)", "Username for new user (should be same as PVV username)",
end_prompt=None, User.name_re,
regex=User.name_re, (1, 10),
length_range=(1, 10),
) )
assert username is not None
user = User(username, string) user = User(username, string)
self.sql_session.add(user) self.session.add(user)
return user return user
if selection == "set": if selection == "set":
user = self.input_user("User to set card number for") user = self.input_user("User to set card number for")
@@ -471,21 +405,11 @@ class Menu:
print(f'"{string}" looks like the bar code for a product, but no such product exists.') print(f'"{string}" looks like the bar code for a product, but no such product exists.')
return None return None
def search_ui( def search_ui(self, search_fun, search_str, thing):
self, result = search_fun(search_str, self.session)
search_fun: Callable[[str, Session], list[Any] | Any],
search_str: str,
thing: str,
) -> Any:
result = search_fun(search_str, self.sql_session)
return self.search_ui2(search_str, result, thing) return self.search_ui2(search_str, result, thing)
def search_ui2( def search_ui2(self, search_str, result, thing):
self,
search_str: str,
result: list[Any] | Any,
thing: str,
) -> Any:
if not isinstance(result, list): if not isinstance(result, list):
return result return result
if len(result) == 0: if len(result) == 0:
@@ -505,41 +429,25 @@ class Menu:
else: else:
select_header = f'{len(result):d} {thing}s matching "{search_str}"' select_header = f'{len(result):d} {thing}s matching "{search_str}"'
select_items = result select_items = result
selector = Selector( selector = Selector(select_header, items=select_items, return_index=False)
select_header,
self.sql_session,
items=select_items,
return_index=False,
)
return selector.execute() return selector.execute()
def confirm( @staticmethod
self, def confirm(prompt, end_prompt=None, default=None, timeout=None):
prompt: str, return ConfirmMenu(prompt, end_prompt=None, default=default, timeout=timeout).execute()
end_prompt: str | None = None,
default: bool | None = None,
timeout: int | None = None,
) -> bool:
return ConfirmMenu(
self.sql_session,
prompt,
end_prompt=None,
default=default,
timeout=timeout,
).execute()
def header(self) -> str: def header(self):
return f"[{self.name}]" return f"[{self.name}]"
def print_header(self) -> None: def print_header(self):
print("") print("")
print(self.header()) print(self.header())
def pause(self) -> None: def pause(self):
self.input_str(".", end_prompt="") self.input_str(".", end_prompt="")
@staticmethod @staticmethod
def general_help() -> None: def general_help():
print( print(
""" """
DIBBLER HELP DIBBLER HELP
@@ -562,10 +470,10 @@ class Menu:
of money PVVVV owes the user. This value decreases with the of money PVVVV owes the user. This value decreases with the
appropriate amount when you register a purchase, and you may increase appropriate amount when you register a purchase, and you may increase
it by putting money in the box and using the "Adjust credit" menu. it by putting money in the box and using the "Adjust credit" menu.
""", """
) )
def local_help(self) -> None: def local_help(self):
if self.help_text is None: if self.help_text is None:
print("no help here") print("no help here")
else: else:
@@ -573,15 +481,21 @@ class Menu:
print(f"Help for {self.header()}:") print(f"Help for {self.header()}:")
print(self.help_text) print(self.help_text)
def execute(self, **_kwargs) -> MenuItemType | int | None: def execute(self, **kwargs):
self.set_context(None) self.set_context(None)
try: try:
return self._execute(**_kwargs) if self.uses_db and not self.session:
except ExitMenuException: self.session = Session()
return self._execute(**kwargs)
except ExitMenu:
self.at_exit() self.at_exit()
return None return None
finally:
if self.session is not None:
self.session.close()
self.session = None
def _execute(self, **_kwargs) -> MenuItemType | int | None: def _execute(self, **kwargs):
while True: while True:
self.print_header() self.print_header()
self.set_context(None) self.set_context(None)
@@ -600,21 +514,12 @@ class Menu:
class MessageMenu(Menu): class MessageMenu(Menu):
message: str def __init__(self, name, message, pause_after_message=True):
pause_after_message: bool Menu.__init__(self, name)
def __init__(
self,
name: str,
message: str,
sql_session: Session,
pause_after_message: bool = True,
) -> None:
super().__init__(name, sql_session)
self.message = message.strip() self.message = message.strip()
self.pause_after_message = pause_after_message self.pause_after_message = pause_after_message
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
print("") print("")
print(self.message) print(self.message)
@@ -623,17 +528,10 @@ class MessageMenu(Menu):
class ConfirmMenu(Menu): class ConfirmMenu(Menu):
def __init__( def __init__(self, prompt="confirm? ", end_prompt=": ", default=None, timeout=0):
self, Menu.__init__(
sql_session: Session, self,
prompt: str = "confirm? ",
end_prompt: str | None = ": ",
default: bool | None = None,
timeout: int | None = 0,
) -> None:
super().__init__(
"question", "question",
sql_session,
prompt=prompt, prompt=prompt,
end_prompt=end_prompt, end_prompt=end_prompt,
exit_disallowed_msg="Please answer yes or no", exit_disallowed_msg="Please answer yes or no",
@@ -641,55 +539,45 @@ class ConfirmMenu(Menu):
self.default = default self.default = default
self.timeout = timeout self.timeout = timeout
def _execute(self, **_kwargs) -> bool: def _execute(self):
options = {True: "[y]/n", False: "y/[n]", None: "y/n"}[self.default] options = {True: "[y]/n", False: "y/[n]", None: "y/n"}[self.default]
while True: while True:
result = self.input_str( result = self.input_str(
f"{self.prompt} ({options})", f"{self.prompt} ({options})", end_prompt=": ", timeout=self.timeout
end_prompt=": ",
timeout=self.timeout,
) )
result = result.lower().strip() result = result.lower().strip()
if result in ["y", "yes"]: if result in ["y", "yes"]:
return True return True
if result in ["n", "no"]: elif result in ["n", "no"]:
return False return False
if self.default is not None and result == "": elif self.default is not None and result == "":
return self.default return self.default
print("Please answer yes or no") else:
print("Please answer yes or no")
class Selector(Menu): class Selector(Menu):
def __init__( def __init__(
self, self,
name: str, name,
sql_session: Session, items=None,
items: list[Self | tuple[MenuItemType, str] | str] | None = None, prompt="select",
prompt: str | None = "select", return_index=True,
return_index: bool = True, exit_msg=None,
exit_msg: str | None = None, exit_confirm_msg=None,
exit_confirm_msg: str | None = None, help_text=None,
help_text: str | None = None, ):
) -> None:
if items is None: if items is None:
items = [] items = []
super().__init__( Menu.__init__(self, name, items, prompt, return_index=return_index, exit_msg=exit_msg)
name,
sql_session,
items,
prompt,
return_index=return_index,
exit_msg=exit_msg,
help_text=help_text,
)
def header(self) -> str: def header(self):
return self.name return self.name
def print_header(self) -> None: def print_header(self):
print(self.header()) print(self.header())
def local_help(self) -> None: def local_help(self):
if self.help_text is None: if self.help_text is None:
print("This is a selection menu. Enter one of the listed numbers, or") print("This is a selection menu. Enter one of the listed numbers, or")
print("'exit' to go out and do something else.") print("'exit' to go out and do something else.")
+20 -16
View File
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
import os import os
import random import random
import sys import sys
from sqlalchemy.orm import Session from dibbler.db import Session
from .buymenu import BuyMenu from .buymenu import BuyMenu
from .faq import FAQMenu from .faq import FAQMenu
@@ -12,17 +13,14 @@ faq_commands = ["faq"]
restart_commands = ["restart"] restart_commands = ["restart"]
def restart() -> None: def restart():
# Does not work if the script is not executable, or if it was # Does not work if the script is not executable, or if it was
# started by searching $PATH. # started by searching $PATH.
os.execv(sys.argv[0], sys.argv) os.execv(sys.argv[0], sys.argv)
class MainMenu(Menu): class MainMenu(Menu):
def __init__(self, sql_session: Session, **_kwargs) -> None: def special_input_choice(self, in_str):
super().__init__("Dibbler main menu", sql_session, **_kwargs)
def special_input_choice(self, in_str: str) -> bool:
mv = in_str.split() mv = in_str.split()
if len(mv) == 2 and mv[0].isdigit(): if len(mv) == 2 and mv[0].isdigit():
num = int(mv[0]) num = int(mv[0])
@@ -30,7 +28,7 @@ class MainMenu(Menu):
else: else:
num = 1 num = 1
item_name = in_str item_name = in_str
buy_menu = BuyMenu(self.sql_session) buy_menu = BuyMenu(Session())
thing = buy_menu.search_for_thing(item_name, find_hidden_products=False) thing = buy_menu.search_for_thing(item_name, find_hidden_products=False)
if thing: if thing:
buy_menu.execute(initial_contents=[(thing, num)]) buy_menu.execute(initial_contents=[(thing, num)])
@@ -38,26 +36,32 @@ class MainMenu(Menu):
return True return True
return False return False
def special_input_options(self, result: str) -> bool: def special_input_options(self, result):
if result in faq_commands: if result in faq_commands:
FAQMenu(self.sql_session).execute() FAQMenu().execute()
return True return True
if result in restart_commands: if result in restart_commands:
if self.confirm("Restart Dibbler?"): if self.confirm("Restart Dibbler?"):
restart() restart()
pass pass
return True return True
if result == "c": elif result == "c":
print(f"\033[{random.randint(40, 49)};{random.randint(30, 37)};5m") os.system(
print("\033[2J") 'echo -e "\033['
+ str(random.randint(40, 49))
+ ";"
+ str(random.randint(30, 37))
+ ';5m"'
)
os.system("clear")
self.show_context() self.show_context()
return True return True
if result == "cs": elif result == "cs":
print("\033[0m") os.system('echo -e "\033[0m"')
print("\033[2J") os.system("clear")
self.show_context() self.show_context()
return True return True
return False return False
def invalid_menu_choice(self, in_str: str) -> None: def invalid_menu_choice(self, in_str):
print(self.show_context()) print(self.show_context())
+46 -49
View File
@@ -1,19 +1,17 @@
import sqlalchemy import sqlalchemy
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from dibbler.conf import config from dibbler.conf import config
from dibbler.models import Transaction, Product, User
from dibbler.lib.helpers import less from dibbler.lib.helpers import less
from dibbler.models import Product, Transaction, User
from .helpermenus import Menu, Selector from .helpermenus import Menu, Selector
class TransferMenu(Menu): class TransferMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Transfer credit between users", sql_session) Menu.__init__(self, "Transfer credit between users", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
amount = self.input_int("Transfer amount", 1, 100000) amount = self.input_int("Transfer amount", 1, 100000)
self.set_context(f"Transferring {amount:d} kr", display=False) self.set_context(f"Transferring {amount:d} kr", display=False)
@@ -28,25 +26,24 @@ class TransferMenu(Menu):
t2 = Transaction(user2, -amount, f'transfer from {user1.name} "{comment}"') t2 = Transaction(user2, -amount, f'transfer from {user1.name} "{comment}"')
t1.perform_transaction() t1.perform_transaction()
t2.perform_transaction() t2.perform_transaction()
self.sql_session.add(t1) self.session.add(t1)
self.sql_session.add(t2) self.session.add(t2)
try: try:
self.sql_session.commit() self.session.commit()
print(f"Transferred {amount:d} kr from {user1} to {user2}") print(f"Transferred {amount:d} kr from {user1} to {user2}")
print(f"User {user1}'s credit is now {user1.credit:d} kr") print(f"User {user1}'s credit is now {user1.credit:d} kr")
print(f"User {user2}'s credit is now {user2.credit:d} kr") print(f"User {user2}'s credit is now {user2.credit:d} kr")
print(f"Comment: {comment}") print(f"Comment: {comment}")
except SQLAlchemyError as e: except sqlalchemy.exc.SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not perform transfer: {e}") print(f"Could not perform transfer: {e}")
# self.pause() # self.pause()
class ShowUserMenu(Menu): class ShowUserMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Show user", sql_session) Menu.__init__(self, "Show user", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
user = self.input_user("User name, card number or RFID") user = self.input_user("User name, card number or RFID")
print(f"User name: {user.name}") print(f"User name: {user.name}")
@@ -55,12 +52,11 @@ class ShowUserMenu(Menu):
print(f"Credit: {user.credit} kr") print(f"Credit: {user.credit} kr")
selector = Selector( selector = Selector(
f"What do you want to know about {user.name}?", f"What do you want to know about {user.name}?",
self.sql_session,
items=[ items=[
( (
"transactions", "transactions",
"Recent transactions (List of last " "Recent transactions (List of last "
+ str(config["limits"]["user_recent_transaction_limit"]) + str(config.getint("limits", "user_recent_transaction_limit"))
+ ")", + ")",
), ),
("products", f"Which products {user.name} has bought, and how many"), ("products", f"Which products {user.name} has bought, and how many"),
@@ -69,7 +65,7 @@ class ShowUserMenu(Menu):
) )
what = selector.execute() what = selector.execute()
if what == "transactions": if what == "transactions":
self.print_transactions(user, config["limits"]["user_recent_transaction_limit"]) self.print_transactions(user, config.getint("limits", "user_recent_transaction_limit"))
elif what == "products": elif what == "products":
self.print_purchased_products(user) self.print_purchased_products(user)
elif what == "transactions-all": elif what == "transactions-all":
@@ -78,7 +74,7 @@ class ShowUserMenu(Menu):
print("What what?") print("What what?")
@staticmethod @staticmethod
def print_transactions(user: User, limit: int | None = None) -> None: def print_transactions(user, limit=None):
num_trans = len(user.transactions) num_trans = len(user.transactions)
if limit is None: if limit is None:
limit = num_trans limit = num_trans
@@ -91,7 +87,10 @@ class ShowUserMenu(Menu):
if t.purchase: if t.purchase:
products = [] products = []
for entry in t.purchase.entries: for entry in t.purchase.entries:
amount = f"{abs(entry.amount)}x " if abs(entry.amount) != 1 else "" if abs(entry.amount) != 1:
amount = f"{abs(entry.amount)}x "
else:
amount = ""
product = f"{amount}{entry.product.name}" product = f"{amount}{entry.product.name}"
products.append(product) products.append(product)
string += "purchase (" string += "purchase ("
@@ -99,13 +98,13 @@ class ShowUserMenu(Menu):
string += ")" string += ")"
if t.penalty > 1: if t.penalty > 1:
string += f" * {t.penalty:d}x penalty applied" string += f" * {t.penalty:d}x penalty applied"
elif t.description is not None: else:
string += t.description string += t.description
string += "\n" string += "\n"
less(string) less(string)
@staticmethod @staticmethod
def print_purchased_products(user: User) -> None: def print_purchased_products(user):
products = [] products = []
for ref in user.products: for ref in user.products:
product = ref.product product = ref.product
@@ -124,13 +123,13 @@ class ShowUserMenu(Menu):
class UserListMenu(Menu): class UserListMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("User list", sql_session) Menu.__init__(self, "User list", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
user_list = self.sql_session.query(User).all() user_list = self.session.query(User).all()
total_credit = self.sql_session.query(sqlalchemy.func.sum(User.credit)).first()[0] total_credit = self.session.query(sqlalchemy.func.sum(User.credit)).first()[0]
line_format = "%-12s | %6s\n" line_format = "%-12s | %6s\n"
hline = "---------------------\n" hline = "---------------------\n"
@@ -145,10 +144,10 @@ class UserListMenu(Menu):
class AdjustCreditMenu(Menu): class AdjustCreditMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Adjust credit", sql_session) Menu.__init__(self, "Adjust credit", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
user = self.input_user("User") user = self.input_user("User")
print(f"User {user.name}'s credit is {user.credit:d} kr") print(f"User {user.name}'s credit is {user.credit:d} kr")
@@ -165,25 +164,24 @@ class AdjustCreditMenu(Menu):
description = "manually adjusted credit" description = "manually adjusted credit"
transaction = Transaction(user, -amount, description) transaction = Transaction(user, -amount, description)
transaction.perform_transaction() transaction.perform_transaction()
self.sql_session.add(transaction) self.session.add(transaction)
try: try:
self.sql_session.commit() self.session.commit()
print(f"User {user.name}'s credit is now {user.credit:d} kr") print(f"User {user.name}'s credit is now {user.credit:d} kr")
except SQLAlchemyError as e: except sqlalchemy.exc.SQLAlchemyError as e:
self.sql_session.rollback()
print(f"Could not store transaction: {e}") print(f"Could not store transaction: {e}")
# self.pause() # self.pause()
class ProductListMenu(Menu): class ProductListMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Product list", sql_session) Menu.__init__(self, "Product list", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
text = "" text = ""
product_list = ( product_list = (
self.sql_session.query(Product) self.session.query(Product)
.filter(Product.hidden.is_(False)) .filter(Product.hidden.is_(False))
.order_by(Product.stock.desc()) .order_by(Product.stock.desc())
) )
@@ -206,22 +204,21 @@ class ProductListMenu(Menu):
class ProductSearchMenu(Menu): class ProductSearchMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Product search", sql_session) Menu.__init__(self, "Product search", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
self.set_context("Enter (part of) product name or bar code") self.set_context("Enter (part of) product name or bar code")
product = self.input_product() product = self.input_product()
print( print(
", ".join( "Result: %s, price: %d kr, bar code: %s, stock: %d, hidden: %s"
[ % (
f"Result: {product.name}", product.name,
f"price: {product.price} kr", product.price,
f"bar code: {product.bar_code}", product.bar_code,
f"stock: {product.stock}", product.stock,
f"hidden: {'Y' if product.hidden else 'N'}", ("Y" if product.hidden else "N"),
], )
),
) )
# self.pause() # self.pause()
+31 -32
View File
@@ -1,46 +1,45 @@
from sqlalchemy.orm import Session import re
from dibbler.conf import config
from dibbler.models import Product, User
from dibbler.lib.printer_helpers import print_bar_code, print_name_label
# from dibbler.lib.printer_helpers import print_bar_code, print_name_label
from .helpermenus import Menu from .helpermenus import Menu
class PrintLabelMenu(Menu): class PrintLabelMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Print a label", sql_session) Menu.__init__(self, "Print a label", uses_db=True)
self.help_text = """ self.help_text = """
Prints out a product bar code on the printer Prints out a product bar code on the printer
Put it up somewhere in the vicinity. Put it up somewhere in the vicinity.
""" """
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
print("Printer menu is under renovation, please be patient") thing = self.input_thing("Product/User")
return if isinstance(thing, Product):
if re.match(r"^[0-9]{13}$", thing.bar_code):
# thing = self.input_thing("Product/User") bar_type = "ean13"
elif re.match(r"^[0-9]{8}$", thing.bar_code):
# if isinstance(thing, Product): bar_type = "ean8"
# if re.match(r"^[0-9]{13}$", thing.bar_code): else:
# bar_type = "ean13" bar_type = "code39"
# elif re.match(r"^[0-9]{8}$", thing.bar_code): print_bar_code(
# bar_type = "ean8" thing.bar_code,
# else: thing.name,
# bar_type = "code39" barcode_type=bar_type,
# print_bar_code( rotate=config.getboolean("printer", "rotate"),
# thing.bar_code, printer_type="QL-700",
# thing.name, label_type=config.get("printer", "label_type"),
# barcode_type=bar_type, )
# rotate=config["printer"]["rotate"], elif isinstance(thing, User):
# printer_type="QL-700", print_name_label(
# label_type=config.get("printer", "label_type"), text=thing.name,
# ) label_type=config.get("printer", "label_type"),
# elif isinstance(thing, User): rotate=config.getboolean("printer", "rotate"),
# print_name_label( printer_type="QL-700",
# text=thing.name, )
# label_type=config["printer"]["label_type"],
# rotate=config["printer"]["rotate"],
# printer_type="QL-700",
# )
+23 -28
View File
@@ -1,9 +1,8 @@
from sqlalchemy import desc, func from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from dibbler.lib.helpers import less from dibbler.lib.helpers import less
from dibbler.models import PurchaseEntry, Product, User
from dibbler.lib.statistikkHelpers import statisticsTextOnly from dibbler.lib.statistikkHelpers import statisticsTextOnly
from dibbler.models import Product, PurchaseEntry, User
from .helpermenus import Menu from .helpermenus import Menu
@@ -16,14 +15,14 @@ __all__ = [
class ProductPopularityMenu(Menu): class ProductPopularityMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Products by popularity", sql_session) Menu.__init__(self, "Products by popularity", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
text = "" text = ""
sub = ( sub = (
self.sql_session.query( self.session.query(
PurchaseEntry.product_id, PurchaseEntry.product_id,
func.sum(PurchaseEntry.amount).label("purchase_count"), func.sum(PurchaseEntry.amount).label("purchase_count"),
) )
@@ -32,8 +31,8 @@ class ProductPopularityMenu(Menu):
.subquery() .subquery()
) )
product_list = ( product_list = (
self.sql_session.query(Product, sub.c.purchase_count) self.session.query(Product, sub.c.purchase_count)
.outerjoin(sub, Product.product_id == sub.c.product_id) .outerjoin((sub, Product.product_id == sub.c.product_id))
.order_by(desc(sub.c.purchase_count)) .order_by(desc(sub.c.purchase_count))
.filter(sub.c.purchase_count is not None) .filter(sub.c.purchase_count is not None)
.all() .all()
@@ -49,14 +48,14 @@ class ProductPopularityMenu(Menu):
class ProductRevenueMenu(Menu): class ProductRevenueMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Products by revenue", sql_session) Menu.__init__(self, "Products by revenue", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
text = "" text = ""
sub = ( sub = (
self.sql_session.query( self.session.query(
PurchaseEntry.product_id, PurchaseEntry.product_id,
func.sum(PurchaseEntry.amount).label("purchase_count"), func.sum(PurchaseEntry.amount).label("purchase_count"),
) )
@@ -65,8 +64,8 @@ class ProductRevenueMenu(Menu):
.subquery() .subquery()
) )
product_list = ( product_list = (
self.sql_session.query(Product, sub.c.purchase_count) self.session.query(Product, sub.c.purchase_count)
.outerjoin(sub, Product.product_id == sub.c.product_id) .outerjoin((sub, Product.product_id == sub.c.product_id))
.order_by(desc(sub.c.purchase_count * Product.price)) .order_by(desc(sub.c.purchase_count * Product.price))
.filter(sub.c.purchase_count is not None) .filter(sub.c.purchase_count is not None)
.all() .all()
@@ -87,26 +86,22 @@ class ProductRevenueMenu(Menu):
class BalanceMenu(Menu): class BalanceMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Total balance of PVVVV", sql_session) Menu.__init__(self, "Total balance of PVVVV", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
self.print_header() self.print_header()
text = "" text = ""
total_value = 0 total_value = 0
product_list = self.sql_session.query(Product).filter(Product.stock > 0).all() product_list = self.session.query(Product).filter(Product.stock > 0).all()
for p in product_list: for p in product_list:
total_value += p.stock * p.price total_value += p.stock * p.price
total_positive_credit = ( total_positive_credit = (
self.sql_session.query(func.coalesce(func.sum(User.credit), 0)) self.session.query(func.sum(User.credit)).filter(User.credit > 0).first()[0]
.filter(User.credit > 0)
.first()[0]
) )
total_negative_credit = ( total_negative_credit = (
self.sql_session.query(func.coalesce(func.sum(User.credit), 0)) self.session.query(func.sum(User.credit)).filter(User.credit < 0).first()[0]
.filter(User.credit < 0)
.first()[0]
) )
total_credit = total_positive_credit + total_negative_credit total_credit = total_positive_credit + total_negative_credit
@@ -124,8 +119,8 @@ class BalanceMenu(Menu):
class LoggedStatisticsMenu(Menu): class LoggedStatisticsMenu(Menu):
def __init__(self, sql_session: Session) -> None: def __init__(self):
super().__init__("Statistics from log", sql_session) Menu.__init__(self, "Statistics from log", uses_db=True)
def _execute(self, **_kwargs) -> None: def _execute(self):
statisticsTextOnly(self.sql_session) statisticsTextOnly()
+7 -10
View File
@@ -17,19 +17,16 @@ def _pascal_case_to_snake_case(name: str) -> str:
class Base(DeclarativeBase): class Base(DeclarativeBase):
metadata = MetaData( metadata = MetaData(
naming_convention={ naming_convention={
"ix": "ix__%(table_name)s__%(column_0_label)s", "ix": "ix_%(table_name)s_%(column_0_label)s",
"uq": "uq__%(table_name)s__%(column_0_name)s", "uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck__%(table_name)s__%(constraint_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk__%(table_name)s__%(column_0_name)s_%(referred_table_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk__%(table_name)s", "pk": "pk_%(table_name)s",
}, }
) )
@declared_attr.directive @declared_attr.directive
def __tablename__(cls) -> str: def __tablename__(cls) -> str:
if hasattr(cls, "__table_name__"):
assert isinstance(cls.__table_name__, str)
return cls.__table_name__
return _pascal_case_to_snake_case(cls.__name__) return _pascal_case_to_snake_case(cls.__name__)
# NOTE: This is the default implementation of __repr__ for all tables, # NOTE: This is the default implementation of __repr__ for all tables,
@@ -49,7 +46,7 @@ class Base(DeclarativeBase):
isinstance(v, InstrumentedList), isinstance(v, InstrumentedList),
isinstance(v, InstrumentedSet), isinstance(v, InstrumentedSet),
isinstance(v, InstrumentedDict), isinstance(v, InstrumentedDict),
], ]
) )
) )
return f"<{self.__class__.__name__}({columns})>" return f"<{self.__class__.__name__}({columns})>"
-27
View File
@@ -1,27 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from dibbler.models import Base
if TYPE_CHECKING:
from dibbler.models import Transaction
class LastCacheTransaction(Base):
"""Tracks the last transaction that affected various caches."""
id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""Internal database ID"""
transaction_id: Mapped[int | None] = mapped_column(ForeignKey("trx.id"), index=True)
"""The ID of the last transaction that affected the cache(s)."""
transaction: Mapped[Transaction | None] = relationship(
lazy="joined",
foreign_keys=[transaction_id],
)
"""The last transaction that affected the cache(s)."""
+1 -2
View File
@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Self from typing import Self
from sqlalchemy import ( from sqlalchemy import (
Boolean, Boolean,
@@ -19,7 +19,6 @@ class Product(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True) id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""Internal database ID""" """Internal database ID"""
# TODO: add more validation for barcode
bar_code: Mapped[str] = mapped_column(String(13), unique=True) bar_code: Mapped[str] = mapped_column(String(13), unique=True)
""" """
The bar code of the product. The bar code of the product.
+7 -24
View File
@@ -1,33 +1,16 @@
from __future__ import annotations from datetime import datetime
from typing import TYPE_CHECKING from sqlalchemy import Integer, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from dibbler.models import Base from dibbler.models import Base
if TYPE_CHECKING:
from dibbler.models import LastCacheTransaction, Product
class ProductCache(Base): class ProductCache(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True) product_id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""Internal database ID"""
product_id: Mapped[int] = mapped_column(ForeignKey("product.id"))
product: Mapped[Product] = relationship(
lazy="joined",
foreign_keys=[product_id],
)
price: Mapped[int] = mapped_column(Integer) price: Mapped[int] = mapped_column(Integer)
stock: Mapped[int] = mapped_column(Integer) price_timestamp: Mapped[datetime] = mapped_column(DateTime)
last_cache_transaction_id: Mapped[int | None] = mapped_column( stock: Mapped[int] = mapped_column(Integer)
ForeignKey("last_cache_transaction.id"), nullable=True, stock_timestamp: Mapped[datetime] = mapped_column(DateTime)
)
last_cache_transaction: Mapped[LastCacheTransaction | None] = relationship(
lazy="joined",
foreign_keys=[last_cache_transaction_id],
)
+16 -74
View File
@@ -33,10 +33,11 @@ if TYPE_CHECKING:
from .Product import Product from .Product import Product
from .User import User from .User import User
# TODO: rename to *_PERCENT
# NOTE: these only matter when there are no adjustments made in the database. # NOTE: these only matter when there are no adjustments made in the database.
DEFAULT_INTEREST_RATE_PERCENT = 100 DEFAULT_INTEREST_RATE_PERCENTAGE = 100
DEFAULT_PENALTY_THRESHOLD = -100 DEFAULT_PENALTY_THRESHOLD = -100
DEFAULT_PENALTY_MULTIPLIER_PERCENT = 200 DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE = 200
_DYNAMIC_FIELDS: set[str] = { _DYNAMIC_FIELDS: set[str] = {
"amount", "amount",
@@ -87,7 +88,6 @@ def _transaction_type_field_constraints(
class Transaction(Base): class Transaction(Base):
__tablename__ = "trx"
__table_args__ = ( __table_args__ = (
*[ *[
_transaction_type_field_constraints(transaction_type, expected_fields) _transaction_type_field_constraints(transaction_type, expected_fields)
@@ -131,12 +131,12 @@ class Transaction(Base):
), ),
name="trx_joint_transaction_id_not_self", name="trx_joint_transaction_id_not_self",
), ),
# Speed up product stock calculation # Speed up product count calculation
Index("ix__transaction__product_id_type_time", "product_id", "type", "time"), Index("product_user_time", "product_id", "user_id", "time"),
# Speed up product owner calculation # Speed up product owner calculation
Index("ix__transaction__user_id_product_time", "user_id", "product_id", "time"), Index("user_product_time", "user_id", "product_id", "time"),
# Speed up user transaction list / credit calculation # Speed up user transaction list / credit calculation
Index("ix__transaction__user_id_time", "user_id", "time"), Index("user_time", "user_id", "time"),
) )
id: Mapped[int] = mapped_column(Integer, primary_key=True) id: Mapped[int] = mapped_column(Integer, primary_key=True)
@@ -146,7 +146,7 @@ class Transaction(Base):
Not used for anything else than identifying the transaction in the database. Not used for anything else than identifying the transaction in the database.
""" """
time: Mapped[datetime] = mapped_column(DateTime, index=True) time: Mapped[datetime] = mapped_column(DateTime)
""" """
The time when the transaction took place. The time when the transaction took place.
@@ -162,7 +162,7 @@ class Transaction(Base):
This is not used for any calculations, but can be useful for debugging. This is not used for any calculations, but can be useful for debugging.
""" """
type_: Mapped[TransactionType] = mapped_column(TransactionTypeSQL, name="type", index=True) type_: Mapped[TransactionType] = mapped_column(TransactionTypeSQL, name="type")
""" """
Which type of transaction this is. Which type of transaction this is.
@@ -189,7 +189,7 @@ class Transaction(Base):
that the user paid in the store would be stored in the `amount` field. that the user paid in the store would be stored in the `amount` field.
""" """
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), index=True) user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
"""The user who performs the transaction. See `user` for more details.""" """The user who performs the transaction. See `user` for more details."""
user: Mapped[User] = relationship( user: Mapped[User] = relationship(
lazy="joined", lazy="joined",
@@ -207,10 +207,7 @@ class Transaction(Base):
In the case of `JOINT` transactions, this is the user who initiated the joint transaction. In the case of `JOINT` transactions, this is the user who initiated the joint transaction.
""" """
joint_transaction_id: Mapped[int | None] = mapped_column( joint_transaction_id: Mapped[int | None] = mapped_column(ForeignKey("transaction.id"))
ForeignKey("trx.id"),
index=True,
)
""" """
An optional ID to group multiple transactions together as part of a joint transaction. An optional ID to group multiple transactions together as part of a joint transaction.
@@ -226,7 +223,7 @@ class Transaction(Base):
""" """
# Receiving user when moving credit from one user to another # Receiving user when moving credit from one user to another
transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"), index=True) transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"))
"""The user who receives money in a `TRANSFER` transaction.""" """The user who receives money in a `TRANSFER` transaction."""
transfer_user: Mapped[User | None] = relationship( transfer_user: Mapped[User | None] = relationship(
lazy="joined", lazy="joined",
@@ -235,7 +232,7 @@ class Transaction(Base):
"""The user who receives money in a `TRANSFER` transaction.""" """The user who receives money in a `TRANSFER` transaction."""
# The product that is either being added or bought # The product that is either being added or bought
product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id"), index=True) product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id"))
"""The product being added or bought.""" """The product being added or bought."""
product: Mapped[Product | None] = relationship(lazy="joined") product: Mapped[Product | None] = relationship(lazy="joined")
"""The product being added or bought.""" """The product being added or bought."""
@@ -333,6 +330,7 @@ class Transaction(Base):
Validates the transaction's fields based on its type. Validates the transaction's fields based on its type.
Raises `ValueError` if the transaction is invalid. Raises `ValueError` if the transaction is invalid.
""" """
# TODO: do we allow free products?
if self.amount == 0: if self.amount == 0:
raise ValueError("Amount must not be zero.") raise ValueError("Amount must not be zero.")
@@ -354,7 +352,7 @@ class Transaction(Base):
and self.amount > self.per_product * self.product_count and self.amount > self.per_product * self.product_count
): ):
raise ValueError( raise ValueError(
"The real amount of the transaction must be less than the total value of the products.", "The real amount of the transaction must be less than the total value of the products."
) )
# TODO: improve printing further # TODO: improve printing further
@@ -385,7 +383,7 @@ class Transaction(Base):
isinstance(v, InstrumentedSet), isinstance(v, InstrumentedSet),
isinstance(v, InstrumentedDict), isinstance(v, InstrumentedDict),
*[k in (_DYNAMIC_FIELDS - EXPECTED_FIELDS[self.type_])], *[k in (_DYNAMIC_FIELDS - EXPECTED_FIELDS[self.type_])],
], ]
) )
) )
return f"{self.type_.upper()}({columns})" return f"{self.type_.upper()}({columns})"
@@ -402,11 +400,6 @@ class Transaction(Base):
time: datetime | None = None, time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Self: ) -> Self:
"""
Convenience constructor for creating an `ADJUST_BALANCE` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls( return cls(
time=time, time=time,
type_=TransactionType.ADJUST_BALANCE, type_=TransactionType.ADJUST_BALANCE,
@@ -423,14 +416,6 @@ class Transaction(Base):
time: datetime | None = None, time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Self: ) -> Self:
"""
Convenience constructor for creating an `ADJUST_INTEREST` transaction.
Note that the `interest_rate_percent` is absolute, not relative to the previous interest rate.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls( return cls(
time=time, time=time,
type_=TransactionType.ADJUST_INTEREST, type_=TransactionType.ADJUST_INTEREST,
@@ -448,14 +433,6 @@ class Transaction(Base):
time: datetime | None = None, time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Self: ) -> Self:
"""
Convenience constructor for creating an `ADJUST_PENALTY` transaction.
Note that both `penalty_multiplier_percent` and `penalty_threshold` are absolute,
not relative to the previous settings.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls( return cls(
time=time, time=time,
type_=TransactionType.ADJUST_PENALTY, type_=TransactionType.ADJUST_PENALTY,
@@ -474,11 +451,6 @@ class Transaction(Base):
time: datetime | None = None, time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Self: ) -> Self:
"""
Convenience constructor for creating an `ADJUST_STOCK` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls( return cls(
time=time, time=time,
type_=TransactionType.ADJUST_STOCK, type_=TransactionType.ADJUST_STOCK,
@@ -499,11 +471,6 @@ class Transaction(Base):
time: datetime | None = None, time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Self: ) -> Self:
"""
Convenience constructor for creating an `ADD_PRODUCT` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls( return cls(
time=time, time=time,
type_=TransactionType.ADD_PRODUCT, type_=TransactionType.ADD_PRODUCT,
@@ -524,11 +491,6 @@ class Transaction(Base):
time: datetime | None = None, time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Self: ) -> Self:
"""
Convenience constructor for creating a `BUY_PRODUCT` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls( return cls(
time=time, time=time,
type_=TransactionType.BUY_PRODUCT, type_=TransactionType.BUY_PRODUCT,
@@ -547,11 +509,6 @@ class Transaction(Base):
time: datetime | None = None, time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Self: ) -> Self:
"""
Convenience constructor for creating a `JOINT` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls( return cls(
time=time, time=time,
type_=TransactionType.JOINT, type_=TransactionType.JOINT,
@@ -569,11 +526,6 @@ class Transaction(Base):
time: datetime | None = None, time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Self: ) -> Self:
"""
Convenience constructor for creating a `JOINT_BUY_PRODUCT` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls( return cls(
time=time, time=time,
type_=TransactionType.JOINT_BUY_PRODUCT, type_=TransactionType.JOINT_BUY_PRODUCT,
@@ -591,11 +543,6 @@ class Transaction(Base):
time: datetime | None = None, time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Self: ) -> Self:
"""
Convenience constructor for creating a `TRANSFER` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls( return cls(
time=time, time=time,
type_=TransactionType.TRANSFER, type_=TransactionType.TRANSFER,
@@ -614,11 +561,6 @@ class Transaction(Base):
time: datetime | None = None, time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Self: ) -> Self:
"""
Convenience constructor for creating a `THROW_PRODUCT` transaction.
Should NOT be used directly in the application code; use the various queries instead.
"""
return cls( return cls(
time=time, time=time,
type_=TransactionType.THROW_PRODUCT, type_=TransactionType.THROW_PRODUCT,
+6 -25
View File
@@ -1,33 +1,14 @@
from __future__ import annotations from datetime import datetime
from typing import TYPE_CHECKING from sqlalchemy import Integer, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from dibbler.models import Base from dibbler.models import Base
if TYPE_CHECKING:
from dibbler.models import LastCacheTransaction, User
# More like user balance cash money flow, amirite? # More like user balance cash money flow, amirite?
class UserCache(Base): class UserBalanceCache(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True) user_id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""internal database id"""
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
user: Mapped[User] = relationship(
lazy="joined",
foreign_keys=[user_id],
)
balance: Mapped[int] = mapped_column(Integer) balance: Mapped[int] = mapped_column(Integer)
timestamp: Mapped[datetime] = mapped_column(DateTime)
last_cache_transaction_id: Mapped[int | None] = mapped_column(
ForeignKey("last_cache_transaction.id"), nullable=True,
)
last_cache_transaction: Mapped[LastCacheTransaction | None] = relationship(
lazy="joined",
foreign_keys=[last_cache_transaction_id],
)
-6
View File
@@ -1,19 +1,13 @@
__all__ = [ __all__ = [
"Base", "Base",
"LastCacheTransaction",
"Product", "Product",
"ProductCache",
"Transaction", "Transaction",
"TransactionType", "TransactionType",
"User", "User",
"UserCache",
] ]
from .Base import Base from .Base import Base
from .LastCacheTransaction import LastCacheTransaction
from .Product import Product from .Product import Product
from .ProductCache import ProductCache
from .Transaction import Transaction from .Transaction import Transaction
from .TransactionType import TransactionType from .TransactionType import TransactionType
from .User import User from .User import User
from .UserCache import UserCache
+7 -22
View File
@@ -1,13 +1,8 @@
__all__ = [ __all__ = [
"add_product", # "add_product",
"adjust_balance", # "add_user",
"adjust_interest", "adjust_interest",
"adjust_penalty", "adjust_penalty",
"adjust_stock",
"affected_products",
"affected_users",
"create_product",
"create_user",
"current_interest", "current_interest",
"current_penalty", "current_penalty",
"joint_buy_product", "joint_buy_product",
@@ -16,37 +11,27 @@ __all__ = [
"product_price", "product_price",
"product_price_log", "product_price_log",
"product_stock", "product_stock",
# "products_owned_by_user",
"search_product", "search_product",
"search_user", "search_user",
"throw_product",
"transaction_log", "transaction_log",
"transfer",
"update_cache",
"user_balance", "user_balance",
"user_balance_log", "user_balance_log",
"user_products",
] ]
from .add_product import add_product # from .add_product import add_product
from .adjust_balance import adjust_balance # from .add_user import add_user
from .adjust_interest import adjust_interest from .adjust_interest import adjust_interest
from .adjust_penalty import adjust_penalty from .adjust_penalty import adjust_penalty
from .adjust_stock import adjust_stock
from .affected_products import affected_products
from .affected_users import affected_users
from .create_product import create_product
from .create_user import create_user
from .current_interest import current_interest from .current_interest import current_interest
from .current_penalty import current_penalty from .current_penalty import current_penalty
from .joint_buy_product import joint_buy_product from .joint_buy_product import joint_buy_product
from .product_owners import product_owners, product_owners_log from .product_owners import product_owners, product_owners_log
from .product_price import product_price, product_price_log from .product_price import product_price, product_price_log
from .product_stock import product_stock from .product_stock import product_stock
# from .products_owned_by_user import products_owned_by_user
from .search_product import search_product from .search_product import search_product
from .search_user import search_user from .search_user import search_user
from .throw_product import throw_product
from .transaction_log import transaction_log from .transaction_log import transaction_log
from .transfer import transfer
from .update_cache import update_cache
from .user_balance import user_balance, user_balance_log from .user_balance import user_balance, user_balance_log
from .user_products import user_products
+1 -51
View File
@@ -1,51 +1 @@
from datetime import datetime # TODO: implement me
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def add_product(
sql_session: Session,
user: User,
product: Product,
amount: int,
per_product: int,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if amount <= 0:
raise ValueError("Amount must be positive.")
if per_product <= 0:
raise ValueError("Per product price must be positive.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
if per_product * product_count < amount:
raise ValueError("Total per product price must be at least equal to amount.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.add_product(
user_id=user.id,
product_id=product.id,
amount=amount,
per_product=per_product,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction
+1
View File
@@ -0,0 +1 @@
# TODO: implement me
-33
View File
@@ -1,33 +0,0 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
def adjust_balance(
sql_session: Session,
user: User,
amount: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if amount == 0:
raise ValueError("Amount must be non-zero.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_balance(
user_id=user.id,
amount=amount,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction
+1 -9
View File
@@ -1,5 +1,3 @@
from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Transaction, User from dibbler.models import Transaction, User
@@ -12,25 +10,19 @@ def adjust_interest(
sql_session: Session, sql_session: Session,
user: User, user: User,
new_interest: int, new_interest: int,
time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Transaction: ) -> None:
if new_interest < 0: if new_interest < 0:
raise ValueError("Interest rate cannot be negative") raise ValueError("Interest rate cannot be negative")
if user.id is None: if user.id is None:
raise ValueError("User must be persisted in the database.") raise ValueError("User must be persisted in the database.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_interest( transaction = Transaction.adjust_interest(
user_id=user.id, user_id=user.id,
interest_rate_percent=new_interest, interest_rate_percent=new_interest,
time=time,
message=message, message=message,
) )
sql_session.add(transaction) sql_session.add(transaction)
sql_session.commit() sql_session.commit()
return transaction
+1 -9
View File
@@ -1,5 +1,3 @@
from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Transaction, User from dibbler.models import Transaction, User
@@ -14,9 +12,8 @@ def adjust_penalty(
user: User, user: User,
new_penalty: int | None = None, new_penalty: int | None = None,
new_penalty_multiplier: int | None = None, new_penalty_multiplier: int | None = None,
time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> Transaction: ) -> None:
if new_penalty is None and new_penalty_multiplier is None: if new_penalty is None and new_penalty_multiplier is None:
raise ValueError("At least one of new_penalty or new_penalty_multiplier must be provided") raise ValueError("At least one of new_penalty or new_penalty_multiplier must be provided")
@@ -33,17 +30,12 @@ def adjust_penalty(
if new_penalty_multiplier is None: if new_penalty_multiplier is None:
new_penalty_multiplier = existing_penalty_multiplier new_penalty_multiplier = existing_penalty_multiplier
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_penalty( transaction = Transaction.adjust_penalty(
user_id=user.id, user_id=user.id,
penalty_threshold=new_penalty, penalty_threshold=new_penalty,
penalty_multiplier_percent=new_penalty_multiplier, penalty_multiplier_percent=new_penalty_multiplier,
time=time,
message=message, message=message,
) )
sql_session.add(transaction) sql_session.add(transaction)
sql_session.commit() sql_session.commit()
return transaction
-40
View File
@@ -1,40 +0,0 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def adjust_stock(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count == 0:
raise ValueError("Product count must be non-zero.")
# TODO: it should not be possible to reduce stock below zero.
#
# TODO: verify time is not behind last transaction's time
transaction = Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction
-88
View File
@@ -1,88 +0,0 @@
from datetime import datetime
from sqlalchemy import BindParameter, select
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, TransactionType
from dibbler.queries.query_helpers import after_filter, until_filter
def affected_products(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
after_time: BindParameter[datetime] | datetime | None = None,
after_transaction: Transaction | None = None,
after_inclusive: bool = True,
) -> set[Product]:
"""
Get all products where attributes were affected over a given interval.
"""
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if not (after_time is None or after_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(after_time, datetime):
after_time = BindParameter("after_time", value=after_time)
if isinstance(after_transaction, Transaction):
if after_transaction.id is None:
raise ValueError("after_transaction must be persisted in the database.")
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
else:
after_transaction_id = None
if after_time is not None and until_time is not None:
assert isinstance(after_time.value, datetime)
assert isinstance(until_time.value, datetime)
if after_time.value > until_time.value:
raise ValueError("after_time cannot be after until_time.")
if after_transaction is not None and until_transaction is not None:
assert after_transaction.time is not None
assert until_transaction.time is not None
if after_transaction.time > until_transaction.time:
raise ValueError("after_transaction cannot be after until_transaction.")
result = sql_session.scalars(
select(Product)
.distinct()
.join(Transaction, Product.id == Transaction.product_id)
.where(
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.JOINT.as_literal_column(),
TransactionType.THROW_PRODUCT.as_literal_column(),
],
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
after_filter(
after_time=after_time,
after_transaction_id=after_transaction_id,
after_inclusive=after_inclusive,
),
)
.order_by(Transaction.time.desc()),
).all()
return set(result)
-90
View File
@@ -1,90 +0,0 @@
from datetime import datetime
from sqlalchemy import BindParameter, or_, select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType, User
from dibbler.queries.query_helpers import after_filter, until_filter
def affected_users(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
after_time: BindParameter[datetime] | datetime | None = None,
after_transaction: Transaction | None = None,
after_inclusive: bool = True,
) -> set[User]:
"""
Get all users where attributes were affected over a given interval.
"""
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if not (after_time is None or after_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(after_time, datetime):
after_time = BindParameter("after_time", value=after_time)
if isinstance(after_transaction, Transaction):
if after_transaction.id is None:
raise ValueError("after_transaction must be persisted in the database.")
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
else:
after_transaction_id = None
if after_time is not None and until_time is not None:
assert isinstance(after_time.value, datetime)
assert isinstance(until_time.value, datetime)
if after_time.value > until_time.value:
raise ValueError("after_time cannot be after until_time.")
if after_transaction is not None and until_transaction is not None:
assert after_transaction.time is not None
assert until_transaction.time is not None
if after_transaction.time > until_transaction.time:
raise ValueError("after_transaction cannot be after until_transaction.")
result = sql_session.scalars(
select(User)
.distinct()
.join(
Transaction,
or_(User.id == Transaction.user_id, User.id == Transaction.transfer_user_id),
)
.where(
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_BALANCE.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.TRANSFER.as_literal_column(),
],
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
after_filter(
after_time=after_time,
after_transaction_id=after_transaction_id,
after_inclusive=after_inclusive,
),
)
.order_by(Transaction.time.desc()),
).all()
return set(result)
-38
View File
@@ -1,38 +0,0 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def buy_product(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.buy_product(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction
-25
View File
@@ -1,25 +0,0 @@
from sqlalchemy.orm import Session
from dibbler.models import Product
def create_product(
sql_session: Session,
name: str,
barcode: str,
) -> Product:
if not name:
raise ValueError("Name cannot be empty.")
if not barcode:
raise ValueError("Barcode cannot be empty.")
# TODO: check for duplicate names, barcodes
# TODO: add more validation for barcode
product = Product(barcode, name)
sql_session.add(product)
sql_session.commit()
return product
-21
View File
@@ -1,21 +0,0 @@
from sqlalchemy.orm import Session
from dibbler.models import User
def create_user(
sql_session: Session,
name: str,
card: str | None,
rfid: str | None,
) -> User:
if not name:
raise ValueError("Name cannot be empty.")
# TODO: check for duplicate names, cards, rfids
user = User(name=name, card=card, rfid=rfid)
sql_session.add(user)
sql_session.commit()
return user
+9 -40
View File
@@ -1,55 +1,24 @@
from datetime import datetime from sqlalchemy import select
from sqlalchemy import BindParameter, bindparam, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
from dibbler.queries.query_helpers import until_filter
def current_interest( # TODO: add until transaction parameter
sql_session: Session, # TODO: add until datetime parameter
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
) -> int:
"""
Get the current interest rate percentage as of a given time or transaction.
Returns the interest rate percentage as an integer.
"""
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
def current_interest(sql_session: Session) -> int:
result = sql_session.scalars( result = sql_session.scalars(
select(Transaction) select(Transaction)
.where( .where(Transaction.type_ == TransactionType.ADJUST_INTEREST)
Transaction.type_ == TransactionType.ADJUST_INTEREST,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc()) .order_by(Transaction.time.desc())
.limit(1), .limit(1)
).one_or_none() ).one_or_none()
if result is None: if result is None:
return DEFAULT_INTEREST_RATE_PERCENT return DEFAULT_INTEREST_RATE_PERCENTAGE
elif result.interest_rate_percent is None: elif result.interest_rate_percent is None:
return DEFAULT_INTEREST_RATE_PERCENT return DEFAULT_INTEREST_RATE_PERCENTAGE
else: else:
return result.interest_rate_percent return result.interest_rate_percent
+8 -39
View File
@@ -1,57 +1,26 @@
from datetime import datetime from sqlalchemy import select
from sqlalchemy import BindParameter, bindparam, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType from dibbler.models import Transaction, TransactionType
from dibbler.models.Transaction import ( from dibbler.models.Transaction import (
DEFAULT_PENALTY_MULTIPLIER_PERCENT, DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE,
DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_THRESHOLD,
) )
from dibbler.queries.query_helpers import until_filter
def current_penalty( # TODO: add until transaction parameter
sql_session: Session, # TODO: add until datetime parameter
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
) -> tuple[int, int]:
"""
Get the current penalty settings (threshold and multiplier percentage) as of a given time or transaction.
Returns a tuple of `(penalty_threshold, penalty_multiplier_percentage)`.
"""
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
def current_penalty(sql_session: Session) -> tuple[int, int]:
result = sql_session.scalars( result = sql_session.scalars(
select(Transaction) select(Transaction)
.where( .where(Transaction.type_ == TransactionType.ADJUST_PENALTY)
Transaction.type_ == TransactionType.ADJUST_PENALTY,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.order_by(Transaction.time.desc()) .order_by(Transaction.time.desc())
.limit(1), .limit(1)
).one_or_none() ).one_or_none()
if result is None: if result is None:
return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENT return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE
assert result.penalty_threshold is not None, "Penalty threshold must be set" assert result.penalty_threshold is not None, "Penalty threshold must be set"
assert result.penalty_multiplier_percent is not None, "Penalty multiplier percent must be set" assert result.penalty_multiplier_percent is not None, "Penalty multiplier percent must be set"
+3 -15
View File
@@ -17,7 +17,7 @@ def joint_buy_product(
users: list[User], users: list[User],
time: datetime | None = None, time: datetime | None = None,
message: str | None = None, message: str | None = None,
) -> list[Transaction]: ) -> None:
""" """
Create buy product transactions for multiple users at once. Create buy product transactions for multiple users at once.
""" """
@@ -25,23 +25,15 @@ def joint_buy_product(
if product.id is None: if product.id is None:
raise ValueError("Product must be persisted in the database.") raise ValueError("Product must be persisted in the database.")
if instigator.id is None: if instigator not in users:
raise ValueError("Instigator must be persisted in the database.") raise ValueError("Instigator must be in the list of users buying the product.")
if len(users) == 0:
raise ValueError("At least bying one user must be specified.")
if any(user.id is None for user in users): if any(user.id is None for user in users):
raise ValueError("All users must be persisted in the database.") raise ValueError("All users must be persisted in the database.")
if instigator not in users:
raise ValueError("Instigator must be in the list of users buying the product.")
if product_count <= 0: if product_count <= 0:
raise ValueError("Product count must be positive.") raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
joint_transaction = Transaction.joint( joint_transaction = Transaction.joint(
user_id=instigator.id, user_id=instigator.id,
product_id=product.id, product_id=product.id,
@@ -52,8 +44,6 @@ def joint_buy_product(
sql_session.add(joint_transaction) sql_session.add(joint_transaction)
sql_session.flush() # Ensure joint_transaction gets an ID sql_session.flush() # Ensure joint_transaction gets an ID
transactions = [joint_transaction]
for user in users: for user in users:
buy_transaction = Transaction.joint_buy_product( buy_transaction = Transaction.joint_buy_product(
user_id=user.id, user_id=user.id,
@@ -62,7 +52,5 @@ def joint_buy_product(
message=message, message=message,
) )
sql_session.add(buy_transaction) sql_session.add(buy_transaction)
transactions.append(buy_transaction)
sql_session.commit() sql_session.commit()
return transactions
+53 -62
View File
@@ -8,11 +8,11 @@ from sqlalchemy import (
bindparam, bindparam,
case, case,
func, func,
or_,
select, select,
) )
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO
from dibbler.models import ( from dibbler.models import (
Product, Product,
Transaction, Transaction,
@@ -20,22 +20,13 @@ from dibbler.models import (
User, User,
) )
from dibbler.queries.product_stock import _product_stock_query from dibbler.queries.product_stock import _product_stock_query
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
until_filter,
)
def _product_owners_query( def _product_owners_query(
product_id: BindParameter[int] | int, product_id: BindParameter[int] | int,
use_cache: bool = True, use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None, until: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "rec_cte", cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
) -> CTE: ) -> CTE:
""" """
The inner query for inferring the owners of a given product. The inner query for inferring the owners of a given product.
@@ -47,25 +38,13 @@ def _product_owners_query(
if isinstance(product_id, int): if isinstance(product_id, int):
product_id = bindparam("product_id", value=product_id) product_id = bindparam("product_id", value=product_id)
if until_time is not None and until_transaction is not None: if isinstance(until, datetime):
raise ValueError("Cannot filter by both until_time and until_transaction.") until = BindParameter("until", value=until)
if isinstance(until_time, datetime):
until_time = bindparam("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
product_stock = _product_stock_query( product_stock = _product_stock_query(
product_id=product_id, product_id=product_id,
use_cache=use_cache, use_cache=use_cache,
until_time=until_time, until=until,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
) )
# Subset of transactions that we'll want to iterate over. # Subset of transactions that we'll want to iterate over.
@@ -78,23 +57,22 @@ def _product_owners_query(
Transaction.user_id, Transaction.user_id,
Transaction.product_count, Transaction.product_count,
) )
# TODO: maybe add value constraint on ADJUST_STOCK?
.where( .where(
or_( Transaction.type_.in_(
Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), [
and_( TransactionType.ADD_PRODUCT.as_literal_column(),
Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), # TransactionType.BUY_PRODUCT,
Transaction.product_count > CONST_ZERO, TransactionType.ADJUST_STOCK.as_literal_column(),
), # TransactionType.JOINT,
# TransactionType.THROW_PRODUCT,
]
), ),
Transaction.product_id == product_id, Transaction.product_id == product_id,
until_filter( CONST_TRUE if until is None else Transaction.time <= until,
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
) )
.order_by(Transaction.time.desc()) .order_by(Transaction.time.desc())
.subquery(trx_subset_name) .subquery()
) )
initial_element = select( initial_element = select(
@@ -139,19 +117,35 @@ def _product_owners_query(
).label("product_count"), ).label("product_count"),
# How many products left to account for # How many products left to account for
case( case(
# Someone adds the product -> known owner, decrease the number of products left to account for # Someone adds the product -> increase the number of products left to account for
( (
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
), ),
# Stock got adjusted upwards -> none owner, decrease the number of products left to account for # Someone buys/joins/throws the product -> decrease the number of products left to account for
# (
# trx_subset.c.type_.in_(
# [
# TransactionType.BUY_PRODUCT,
# TransactionType.JOINT,
# TransactionType.THROW_PRODUCT,
# ]
# ),
# recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
# ),
# Someone adjusts the stock ->
# If adjusted upwards -> products owned by nobody, decrease products left to account for
# If adjusted downwards -> products taken away from owners, decrease products left to account for
( (
and_( (trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), and (trx_subset.c.product_count > CONST_ZERO),
trx_subset.c.product_count > CONST_ZERO,
),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
), ),
# (
# (trx_subset.c.type_ == TransactionType.ADJUST_STOCK)
# and (trx_subset.c.product_count < 0),
# recursive_cte.c.products_left_to_account_for + trx_subset.c.product_count,
# ),
else_=recursive_cte.c.products_left_to_account_for, else_=recursive_cte.c.products_left_to_account_for,
).label("products_left_to_account_for"), ).label("products_left_to_account_for"),
) )
@@ -159,9 +153,8 @@ def _product_owners_query(
.where( .where(
and_( and_(
trx_subset.c.i == recursive_cte.c.i + CONST_ONE, trx_subset.c.i == recursive_cte.c.i + CONST_ONE,
# Base case: stop if we've accounted for all products
recursive_cte.c.products_left_to_account_for > CONST_ZERO, recursive_cte.c.products_left_to_account_for > CONST_ZERO,
), )
) )
) )
@@ -174,14 +167,13 @@ class ProductOwnersLogEntry:
user: User | None user: User | None
products_left_to_account_for: int products_left_to_account_for: int
# TODO: add until datetime parameter
def product_owners_log( def product_owners_log(
sql_session: Session, sql_session: Session,
product: Product, product: Product,
use_cache: bool = True, use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None, until: Transaction | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[ProductOwnersLogEntry]: ) -> list[ProductOwnersLogEntry]:
""" """
Returns a log of the product ownership calculation for the given product. Returns a log of the product ownership calculation for the given product.
@@ -192,12 +184,13 @@ def product_owners_log(
if product.id is None: if product.id is None:
raise ValueError("Product must be persisted in the database.") raise ValueError("Product must be persisted in the database.")
if until is not None and until.id is None:
raise ValueError("'until' transaction must be persisted in the database.")
recursive_cte = _product_owners_query( recursive_cte = _product_owners_query(
product_id=product.id, product_id=product.id,
use_cache=use_cache, use_cache=use_cache,
until_time=until_time, until=until.time if until else None,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
) )
result = sql_session.execute( result = sql_session.execute(
@@ -216,13 +209,13 @@ def product_owners_log(
onclause=User.id == recursive_cte.c.user_id, onclause=User.id == recursive_cte.c.user_id,
isouter=True, isouter=True,
) )
.order_by(recursive_cte.c.time.desc()), .order_by(recursive_cte.c.time.desc())
).all() ).all()
if result is None: if result is None:
# If there are no transactions for this product, the query should return an empty list, not None. # If there are no transactions for this product, the query should return an empty list, not None.
raise RuntimeError( raise RuntimeError(
f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id}).", f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id})."
) )
return [ return [
@@ -235,13 +228,13 @@ def product_owners_log(
] ]
# TODO: add until transaction parameter
def product_owners( def product_owners(
sql_session: Session, sql_session: Session,
product: Product, product: Product,
use_cache: bool = True, use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None, until: datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[User | None]: ) -> list[User | None]:
""" """
Returns an ordered list of users owning the given product. Returns an ordered list of users owning the given product.
@@ -255,9 +248,7 @@ def product_owners(
recursive_cte = _product_owners_query( recursive_cte = _product_owners_query(
product_id=product.id, product_id=product.id,
use_cache=use_cache, use_cache=use_cache,
until_time=until_time, until=until,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
) )
db_result = sql_session.execute( db_result = sql_session.execute(
@@ -267,7 +258,7 @@ def product_owners(
User, User,
) )
.join(User, User.id == recursive_cte.c.user_id, isouter=True) .join(User, User.id == recursive_cte.c.user_id, isouter=True)
.order_by(recursive_cte.c.time.desc()), .order_by(recursive_cte.c.time.desc())
).all() ).all()
print(db_result) print(db_result)
@@ -301,7 +292,7 @@ def product_owners(
result.extend([None] * none_count) result.extend([None] * none_count)
# # NOTE: if the last line exceeds the product count, we need to truncate it # # NOTE: if the last line exeeds the product count, we need to truncate it
# result.extend([user] * min(user_count, products_left_to_account_for)) # result.extend([user] * min(user_count, products_left_to_account_for))
# redistribute the user counts to a list of users # redistribute the user counts to a list of users
+49 -124
View File
@@ -6,7 +6,7 @@ from sqlalchemy import (
BindParameter, BindParameter,
ColumnElement, ColumnElement,
Integer, Integer,
bindparam, asc,
case, case,
cast, cast,
func, func,
@@ -14,111 +14,52 @@ from sqlalchemy import (
) )
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO
from dibbler.models import ( from dibbler.models import (
LastCacheTransaction,
Product, Product,
ProductCache,
Transaction, Transaction,
TransactionType, TransactionType,
) )
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
after_filter,
until_filter,
)
def _product_price_query( def _product_price_query(
product_id: int | ColumnElement[int], product_id: int | ColumnElement[int],
use_cache: bool = True, use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None, until: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None, until_including: BindParameter[bool] | bool = True,
until_inclusive: bool = True,
cte_name: str = "rec_cte", cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
): ):
""" """
The inner query for calculating the product price. The inner query for calculating the product price.
""" """
if use_cache:
print("WARNING: Using cache for product price query is not implemented yet.")
if isinstance(product_id, int): if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id) product_id = BindParameter("product_id", value=product_id)
if not (until_time is None or until_transaction is None): if isinstance(until, datetime):
raise ValueError("Cannot filter by both until_time and until_transaction.") until = BindParameter("until", value=until)
if isinstance(until_time, datetime): if isinstance(until_including, bool):
until_time = BindParameter("until_time", value=until_time) until_including = BindParameter("until_including", value=until_including)
if isinstance(until_transaction, Transaction): initial_element = select(
if until_transaction.id is None: CONST_ZERO.label("i"),
raise ValueError("until_transaction must be persisted in the database.") CONST_ZERO.label("time"),
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) CONST_NONE.label("transaction_id"),
else: CONST_ZERO.label("price"),
until_transaction_id = None CONST_ZERO.label("product_count"),
)
if use_cache:
initial_element_fields = (
select(
Transaction.time.label("time"),
Transaction.id.label("transaction_id"),
ProductCache.price.label("price"),
ProductCache.stock.label("product_count"),
)
.select_from(ProductCache)
.join(
LastCacheTransaction,
ProductCache.last_cache_transaction_id == LastCacheTransaction.id,
)
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
.where(
ProductCache.product_id == product_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
.union(
select(
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("price"),
CONST_ZERO.label("product_count"),
),
)
.order_by(Transaction.time.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
.subquery()
.alias("initial_element_fields")
)
initial_element = select(
CONST_ZERO.label("i"),
initial_element_fields.c.time,
initial_element_fields.c.transaction_id,
initial_element_fields.c.price,
initial_element_fields.c.product_count,
).select_from(initial_element_fields)
else:
initial_element = select(
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("price"),
CONST_ZERO.label("product_count"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True) recursive_cte = initial_element.cte(name=cte_name, recursive=True)
# Subset of transactions that we'll want to iterate over. # Subset of transactions that we'll want to iterate over.
trx_subset = ( trx_subset = (
select( select(
func.row_number().over(order_by=Transaction.time.asc()).label("i"), func.row_number().over(order_by=asc(Transaction.time)).label("i"),
Transaction.id, Transaction.id,
Transaction.time, Transaction.time,
Transaction.type_, Transaction.type_,
@@ -132,22 +73,18 @@ def _product_price_query(
TransactionType.ADD_PRODUCT.as_literal_column(), TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK.as_literal_column(), TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.JOINT.as_literal_column(), TransactionType.JOINT.as_literal_column(),
], ]
), ),
Transaction.product_id == product_id, Transaction.product_id == product_id,
after_filter( case(
after_time=None, (until_including, Transaction.time <= until),
after_transaction_id=recursive_cte.c.transaction_id, else_=Transaction.time < until,
after_inclusive=False, )
), if until is not None
until_filter( else CONST_TRUE,
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
) )
.order_by(Transaction.time.asc()) .order_by(Transaction.time.asc())
.subquery(trx_subset_name) .alias("trx_subset")
) )
recursive_elements = ( recursive_elements = (
@@ -178,7 +115,7 @@ def _product_price_query(
# and other disastrous phenomena. # and other disastrous phenomena.
func.max(recursive_cte.c.product_count, CONST_ZERO) func.max(recursive_cte.c.product_count, CONST_ZERO)
+ trx_subset.c.product_count + trx_subset.c.product_count
), )
), ),
Integer, Integer,
), ),
@@ -233,13 +170,13 @@ class ProductPriceLogEntry:
product_count: int product_count: int
# TODO: add until datetime parameter
def product_price_log( def product_price_log(
sql_session: Session, sql_session: Session,
product: Product, product: Product,
use_cache: bool = True, use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None, until: Transaction | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[ProductPriceLogEntry]: ) -> list[ProductPriceLogEntry]:
""" """
Calculates the price of a product and returns a log of the price changes. Calculates the price of a product and returns a log of the price changes.
@@ -248,12 +185,13 @@ def product_price_log(
if product.id is None: if product.id is None:
raise ValueError("Product must be persisted in the database.") raise ValueError("Product must be persisted in the database.")
if until is not None and until.id is None:
raise ValueError("'until' transaction must be persisted in the database.")
recursive_cte = _product_price_query( recursive_cte = _product_price_query(
product.id, product.id,
use_cache=use_cache, use_cache=use_cache,
until_time=until_time, until=until.time if until else None,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
) )
result = sql_session.execute( result = sql_session.execute(
@@ -267,13 +205,13 @@ def product_price_log(
Transaction, Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id, onclause=Transaction.id == recursive_cte.c.transaction_id,
) )
.order_by(recursive_cte.c.i.asc()), .order_by(recursive_cte.c.i.asc())
).all() ).all()
if result is None: if result is None:
# If there are no transactions for this product, the query should return an empty list, not None. # If there are no transactions for this product, the query should return an empty list, not None.
raise RuntimeError( raise RuntimeError(
f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id}).", f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})."
) )
return [ return [
@@ -286,13 +224,13 @@ def product_price_log(
] ]
# TODO: add until datetime parameter
def product_price( def product_price(
sql_session: Session, sql_session: Session,
product: Product, product: Product,
use_cache: bool = True, use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None, until: Transaction | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
include_interest: bool = False, include_interest: bool = False,
) -> int: ) -> int:
""" """
@@ -302,22 +240,13 @@ def product_price(
if product.id is None: if product.id is None:
raise ValueError("Product must be persisted in the database.") raise ValueError("Product must be persisted in the database.")
if isinstance(until_time, datetime): if until is not None and until.id is None:
until_time = BindParameter("until_time", value=until_time) raise ValueError("'until' transaction must be persisted in the database.")
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
recursive_cte = _product_price_query( recursive_cte = _product_price_query(
product.id, product.id,
use_cache=use_cache, use_cache=use_cache,
until_time=until_time, until=until.time if until else None,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
) )
# TODO: optionally verify subresults: # TODO: optionally verify subresults:
@@ -328,13 +257,13 @@ def product_price(
select(recursive_cte.c.price) select(recursive_cte.c.price)
.order_by(recursive_cte.c.i.desc()) .order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE) .limit(CONST_ONE)
.offset(CONST_ZERO), .offset(CONST_ZERO)
).one_or_none() ).one_or_none()
if result is None: if result is None:
# If there are no transactions for this product, the query should return 0, not None. # If there are no transactions for this product, the query should return 0, not None.
raise RuntimeError( raise RuntimeError(
f"Something went wrong while calculating the price for product {product.name} (ID: {product.id}).", f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})."
) )
if include_interest: if include_interest:
@@ -343,16 +272,12 @@ def product_price(
select(Transaction.interest_rate_percent) select(Transaction.interest_rate_percent)
.where( .where(
Transaction.type_ == TransactionType.ADJUST_INTEREST, Transaction.type_ == TransactionType.ADJUST_INTEREST,
until_filter( CONST_TRUE if until is None else Transaction.time <= until.time,
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
) )
.order_by(Transaction.time.desc()) .order_by(Transaction.time.desc())
.limit(CONST_ONE), .limit(CONST_ONE)
) )
or DEFAULT_INTEREST_RATE_PERCENT or DEFAULT_INTEREST_RATE_PERCENTAGE
) )
result = math.ceil(result * interest_rate / 100) result = math.ceil(result * interest_rate / 100)
+13 -33
View File
@@ -1,31 +1,27 @@
from datetime import datetime from datetime import datetime
from typing import Tuple
from sqlalchemy import ( from sqlalchemy import (
BindParameter, BindParameter,
Select, Select,
bindparam,
case, case,
func, func,
select, select,
) )
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_TRUE
from dibbler.models import ( from dibbler.models import (
Product, Product,
Transaction, Transaction,
TransactionType, TransactionType,
) )
from dibbler.queries.query_helpers import until_filter
def _product_stock_query( def _product_stock_query(
product_id: BindParameter[int] | int, product_id: BindParameter[int] | int,
use_cache: bool = True, use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None, until: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None, ) -> Select:
until_inclusive: bool = True,
) -> Select[tuple[int]]:
""" """
The inner query for calculating the product stock. The inner query for calculating the product stock.
""" """
@@ -36,18 +32,8 @@ def _product_stock_query(
if isinstance(product_id, int): if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id) product_id = BindParameter("product_id", value=product_id)
if not (until_time is None or until_transaction is None): if isinstance(until, datetime):
raise ValueError("Cannot filter by both until_time and until_transaction.") until = BindParameter("until", value=until)
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
query = select( query = select(
func.sum( func.sum(
@@ -73,8 +59,8 @@ def _product_stock_query(
-Transaction.product_count, -Transaction.product_count,
), ),
else_=0, else_=0,
), )
).label("stock"), )
).where( ).where(
Transaction.type_.in_( Transaction.type_.in_(
[ [
@@ -83,26 +69,22 @@ def _product_stock_query(
TransactionType.BUY_PRODUCT.as_literal_column(), TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.JOINT.as_literal_column(), TransactionType.JOINT.as_literal_column(),
TransactionType.THROW_PRODUCT.as_literal_column(), TransactionType.THROW_PRODUCT.as_literal_column(),
], ]
), ),
Transaction.product_id == product_id, Transaction.product_id == product_id,
until_filter( Transaction.time <= until if until is not None else CONST_TRUE,
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
) )
return query return query
# TODO: add until transaction parameter
def product_stock( def product_stock(
sql_session: Session, sql_session: Session,
product: Product, product: Product,
use_cache: bool = True, use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None, until: datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> int: ) -> int:
""" """
Returns the number of products in stock. Returns the number of products in stock.
@@ -116,9 +98,7 @@ def product_stock(
query = _product_stock_query( query = _product_stock_query(
product_id=product.id, product_id=product.id,
use_cache=use_cache, use_cache=use_cache,
until_time=until_time, until=until,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
) )
result = sql_session.scalars(query).one_or_none() result = sql_session.scalars(query).one_or_none()
-119
View File
@@ -1,119 +0,0 @@
from datetime import datetime
from typing import TypeVar
from sqlalchemy import (
BindParameter,
ColumnExpressionArgument,
literal,
select,
)
from sqlalchemy.orm import QueryableAttribute
from dibbler.models import Transaction
T = TypeVar("T")
def const(value: T) -> BindParameter[T]:
"""
Create a constant SQL literal bind parameter.
This is useful to avoid too many `?` bind parameters in SQL queries,
when the input value is known to be safe.
"""
return literal(value, literal_execute=True)
CONST_ZERO: BindParameter[int] = const(0)
"""A constant SQL expression `0`. This will render as a literal `0` in SQL queries."""
CONST_ONE: BindParameter[int] = const(1)
"""A constant SQL expression `1`. This will render as a literal `1` in SQL queries."""
CONST_TRUE: BindParameter[bool] = const(True)
"""A constant SQL expression `TRUE`. This will render as a literal `TRUE` in SQL queries."""
CONST_FALSE: BindParameter[bool] = const(False)
"""A constant SQL expression `FALSE`. This will render as a literal `FALSE` in SQL queries."""
CONST_NONE: BindParameter[None] = const(None)
"""A constant SQL expression `NULL`. This will render as a literal `NULL` in SQL queries."""
def until_filter(
until_time: BindParameter[datetime] | None = None,
until_transaction_id: BindParameter[int] | None = None,
until_inclusive: bool = True,
transaction_time: QueryableAttribute = Transaction.time,
) -> ColumnExpressionArgument[bool]:
"""
Create a filter condition for transactions up to a given time or transaction.
Only one of `until_time` or `until_transaction_id` may be specified.
"""
assert not (until_time is not None and until_transaction_id is not None), (
"Cannot filter by both until_time and until_transaction_id."
)
match (until_time, until_transaction_id, until_inclusive):
case (BindParameter(), None, True):
return transaction_time <= until_time
case (BindParameter(), None, False):
return transaction_time < until_time
case (None, BindParameter(), True):
return (
transaction_time
<= select(Transaction.time)
.where(Transaction.id == until_transaction_id)
.scalar_subquery()
)
case (None, BindParameter(), False):
return (
transaction_time
< select(Transaction.time)
.where(Transaction.id == until_transaction_id)
.scalar_subquery()
)
return CONST_TRUE
def after_filter(
after_time: BindParameter[datetime] | None = None,
after_transaction_id: BindParameter[int] | None = None,
after_inclusive: bool = True,
transaction_time: QueryableAttribute = Transaction.time,
) -> ColumnExpressionArgument[bool]:
"""
Create a filter condition for transactions after a given time or transaction.
Only one of `after_time` or `after_transaction_id` may be specified.
"""
assert not (after_time is not None and after_transaction_id is not None), (
"Cannot filter by both after_time and after_transaction_id."
)
match (after_time, after_transaction_id, after_inclusive):
case (BindParameter(), None, True):
return transaction_time >= after_time
case (BindParameter(), None, False):
return transaction_time > after_time
case (None, BindParameter(), True):
return (
transaction_time
>= select(Transaction.time)
.where(Transaction.id == after_transaction_id)
.scalar_subquery()
)
case (None, BindParameter(), False):
return (
transaction_time
> select(Transaction.time)
.where(Transaction.id == after_transaction_id)
.scalar_subquery()
)
return CONST_TRUE
+4 -7
View File
@@ -9,9 +9,6 @@ def search_product(
sql_session: Session, sql_session: Session,
find_hidden_products=False, find_hidden_products=False,
) -> Product | list[Product]: ) -> Product | list[Product]:
if not string:
raise ValueError("Search string cannot be empty.")
exact_match = sql_session.scalars( exact_match = sql_session.scalars(
select(Product).where( select(Product).where(
or_( or_(
@@ -20,8 +17,8 @@ def search_product(
Product.name == string, Product.name == string,
literal(True) if find_hidden_products else not_(Product.hidden), literal(True) if find_hidden_products else not_(Product.hidden),
), ),
), )
), )
).first() ).first()
if exact_match: if exact_match:
@@ -35,8 +32,8 @@ def search_product(
Product.name.ilike(f"%{string}%"), Product.name.ilike(f"%{string}%"),
literal(True) if find_hidden_products else not_(Product.hidden), literal(True) if find_hidden_products else not_(Product.hidden),
), ),
), )
), )
).all() ).all()
return list(product_list) return list(product_list)
+4 -7
View File
@@ -8,9 +8,6 @@ def search_user(
string: str, string: str,
sql_session: Session, sql_session: Session,
) -> User | list[User]: ) -> User | list[User]:
if not string:
raise ValueError("Search string cannot be empty.")
string = string.lower() string = string.lower()
exact_match = sql_session.scalars( exact_match = sql_session.scalars(
@@ -19,8 +16,8 @@ def search_user(
User.name == string, User.name == string,
User.card == string, User.card == string,
User.rfid == string, User.rfid == string,
), )
), )
).first() ).first()
if exact_match: if exact_match:
@@ -32,8 +29,8 @@ def search_user(
User.name.ilike(f"%{string}%"), User.name.ilike(f"%{string}%"),
User.card.ilike(f"%{string}%"), User.card.ilike(f"%{string}%"),
User.rfid.ilike(f"%{string}%"), User.rfid.ilike(f"%{string}%"),
), )
), )
).all() ).all()
return list(user_list) return list(user_list)
-42
View File
@@ -1,42 +0,0 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
def throw_product(
sql_session: Session,
user: User,
product: Product,
product_count: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if user.id is None:
raise ValueError("User must be persisted in the database.")
if product.id is None:
raise ValueError("Product must be persisted in the database.")
if product_count <= 0:
raise ValueError("Product count must be positive.")
# TODO: verify time is not behind last transaction's time
raise NotImplementedError(
"Please don't use this function until relevant calculations have been added to user_balance.",
)
transaction = Transaction.throw_product(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction
+36 -88
View File
@@ -1,6 +1,4 @@
from datetime import datetime from sqlalchemy import select
from sqlalchemy import BindParameter, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import ( from dibbler.models import (
@@ -17,12 +15,12 @@ def transaction_log(
sql_session: Session, sql_session: Session,
user: User | None = None, user: User | None = None,
product: Product | None = None, product: Product | None = None,
until_time: BindParameter[datetime] | datetime | None = None, exclusive_after: bool = False,
until_transaction: Transaction | None = None, after_time=None,
until_inclusive: bool = True, after_transaction_id: int | None = None,
after_time: BindParameter[datetime] | datetime | None = None, exclusive_before: bool = False,
after_transaction: Transaction | None = None, before_time=None,
after_inclusive: bool = True, before_transaction_id: int | None = None,
transaction_type: list[TransactionType] | None = None, transaction_type: list[TransactionType] | None = None,
negate_transaction_type_filter: bool = False, negate_transaction_type_filter: bool = False,
limit: int | None = None, limit: int | None = None,
@@ -31,101 +29,51 @@ def transaction_log(
Retrieve the transaction log, optionally filtered. Retrieve the transaction log, optionally filtered.
Only one of `user` or `product` may be specified. Only one of `user` or `product` may be specified.
Only one of `until_time` or `until_transaction_id` may be specified.
Only one of `after_time` or `after_transaction_id` may be specified. Only one of `after_time` or `after_transaction_id` may be specified.
Only one of `before_time` or `before_transaction_id` may be specified.
The after and after filters are inclusive by default. The before and after filters are inclusive by default.
""" """
if not (user is None or product is None): if not (user is None or product is None):
raise ValueError("Cannot filter by both user and product.") raise ValueError("Cannot filter by both user and product.")
if isinstance(user, User): if user is not None and user.id is None:
if user.id is None: raise ValueError("User must be persisted in the database.")
raise ValueError("User must be persisted in the database.")
user_id = BindParameter("user_id", value=user.id)
else:
user_id = None
if isinstance(product, Product): if product is not None and product.id is None:
if product.id is None: raise ValueError("Product must be persisted in the database.")
raise ValueError("Product must be persisted in the database.")
product_id = BindParameter("product_id", value=product.id)
else:
product_id = None
if not (until_time is None or until_transaction is None): if not (after_time is None or after_transaction_id is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.") raise ValueError("Cannot filter by both from_time and from_transaction_id.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if not (after_time is None or after_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(after_time, datetime):
after_time = BindParameter("after_time", value=after_time)
if isinstance(after_transaction, Transaction):
if after_transaction.id is None:
raise ValueError("after_transaction must be persisted in the database.")
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
else:
after_transaction_id = None
if after_time is not None and until_time is not None:
assert isinstance(after_time.value, datetime)
assert isinstance(until_time.value, datetime)
if after_time.value > until_time.value:
raise ValueError("after_time cannot be after until_time.")
if after_transaction is not None and until_transaction is not None:
assert after_transaction.time is not None
assert until_transaction.time is not None
if after_transaction.time > until_transaction.time:
raise ValueError("after_transaction cannot be after until_transaction.")
if limit is not None and limit <= 0:
raise ValueError("Limit must be positive.")
query = select(Transaction) query = select(Transaction)
if user is not None: if user is not None:
query = query.where(Transaction.user_id == user_id) query = query.where(Transaction.user_id == user.id)
if product is not None: if product is not None:
query = query.where(Transaction.product_id == product_id) query = query.where(Transaction.product_id == product.id)
match (until_time, until_transaction_id, until_inclusive): if after_time is not None:
case (BindParameter(), None, True): if exclusive_after:
query = query.where(Transaction.time <= until_time)
case (BindParameter(), None, False):
query = query.where(Transaction.time < until_time)
case (None, BindParameter(), True):
query = query.where(Transaction.id <= until_transaction_id)
case (None, BindParameter(), False):
query = query.where(Transaction.id < until_transaction_id)
case _:
pass
match (after_time, after_transaction_id, after_inclusive):
case (BindParameter(), None, True):
query = query.where(Transaction.time >= after_time)
case (BindParameter(), None, False):
query = query.where(Transaction.time > after_time) query = query.where(Transaction.time > after_time)
case (None, BindParameter(), True): else:
query = query.where(Transaction.id >= after_transaction_id) query = query.where(Transaction.time >= after_time)
case (None, BindParameter(), False): if after_transaction_id is not None:
if exclusive_after:
query = query.where(Transaction.id > after_transaction_id) query = query.where(Transaction.id > after_transaction_id)
case _: else:
pass query = query.where(Transaction.id >= after_transaction_id)
if before_time is not None:
if exclusive_before:
query = query.where(Transaction.time < before_time)
else:
query = query.where(Transaction.time <= before_time)
if before_transaction_id is not None:
if exclusive_before:
query = query.where(Transaction.id < before_transaction_id)
else:
query = query.where(Transaction.id <= before_transaction_id)
if transaction_type is not None: if transaction_type is not None:
if negate_transaction_type_filter: if negate_transaction_type_filter:
-38
View File
@@ -1,38 +0,0 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
def transfer(
sql_session: Session,
from_user: User,
to_user: User,
amount: int,
time: datetime | None = None,
message: str | None = None,
) -> Transaction:
if from_user.id is None:
raise ValueError("From user must be persisted in the database.")
if to_user.id is None:
raise ValueError("To user must be persisted in the database.")
if amount <= 0:
raise ValueError("Amount must be positive.")
# TODO: verify time is not behind last transaction's time
transaction = Transaction.transfer(
user_id=from_user.id,
transfer_user_id=to_user.id,
amount=amount,
time=time,
message=message,
)
sql_session.add(transaction)
sql_session.commit()
return transaction
-118
View File
@@ -1,118 +0,0 @@
from sqlalchemy import insert, select
from sqlalchemy.orm import Session
from dibbler.models import LastCacheTransaction, ProductCache, Transaction, UserCache
from dibbler.queries.affected_products import affected_products
from dibbler.queries.affected_users import affected_users
from dibbler.queries.product_price import product_price
from dibbler.queries.product_stock import product_stock
from dibbler.queries.user_balance import user_balance
def update_cache(
sql_session: Session,
use_cache: bool = True,
) -> None:
"""
Update the cache used for searching products.
If `use_cache` is False, the cache will be rebuilt from scratch.
"""
last_transaction = sql_session.scalars(
select(Transaction).order_by(Transaction.time.desc()).limit(1),
).one_or_none()
print(last_transaction)
if last_transaction is None:
# No transactions exist, nothing to update
return
if use_cache:
last_cache_transaction = sql_session.scalars(
select(LastCacheTransaction)
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
.order_by(Transaction.time.desc())
.limit(1),
).one_or_none()
if last_cache_transaction is not None:
last_cache_transaction = last_cache_transaction.transaction
else:
last_cache_transaction = None
if last_cache_transaction is not None and last_cache_transaction.id == last_transaction.id:
# Cache is already up to date
return
users = affected_users(
sql_session,
after_transaction=last_cache_transaction,
after_inclusive=False,
until_transaction=last_transaction,
)
products = affected_products(
sql_session,
after_transaction=last_cache_transaction,
after_inclusive=False,
until_transaction=last_transaction,
)
user_balances = {}
for user in users:
x = user_balance(
sql_session,
user,
use_cache=use_cache,
until_transaction=last_transaction,
)
user_balances[user.id] = x
product_stocks = {}
product_prices = {}
for product in products:
product_stocks[product.id] = product_stock(
sql_session,
product,
use_cache=use_cache,
until_transaction=last_transaction,
)
product_prices[product.id] = product_price(
sql_session,
product,
use_cache=use_cache,
until_transaction=last_transaction,
)
next_cache_transaction = LastCacheTransaction(transaction_id=last_transaction.id)
sql_session.add(next_cache_transaction)
sql_session.flush()
if not len(users) == 0:
sql_session.execute(
insert(UserCache),
[
{
"user_id": user.id,
"balance": user_balances[user.id],
"last_cache_transaction_id": next_cache_transaction.id,
}
for user in users
],
)
if not len(products) == 0:
sql_session.execute(
insert(ProductCache),
[
{
"product_id": product.id,
"stock": product_stocks[product.id],
"price": product_prices[product.id],
"last_cache_transaction_id": next_cache_transaction.id,
}
for product in products
],
)
sql_session.commit()
+111 -320
View File
@@ -1,15 +1,12 @@
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Tuple
from sqlalchemy import ( from sqlalchemy import (
CTE, CTE,
BindParameter, BindParameter,
Float, Float,
Integer, Integer,
Select,
and_, and_,
bindparam,
case, case,
cast, cast,
column, column,
@@ -17,243 +14,28 @@ from sqlalchemy import (
or_, or_,
select, select,
) )
from sqlalchemy.orm import Session, aliased from sqlalchemy.orm import Session
from sqlalchemy.sql.elements import KeyedColumnElement
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
from dibbler.models import ( from dibbler.models import (
Transaction, Transaction,
TransactionType, TransactionType,
User, User,
) )
from dibbler.models.Transaction import ( from dibbler.models.Transaction import (
DEFAULT_INTEREST_RATE_PERCENT, DEFAULT_INTEREST_RATE_PERCENTAGE,
DEFAULT_PENALTY_MULTIPLIER_PERCENT, DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE,
DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_THRESHOLD,
) )
from dibbler.queries.product_price import _product_price_query from dibbler.queries.product_price import _product_price_query
from dibbler.queries.query_helpers import (
CONST_NONE,
CONST_ONE,
CONST_ZERO,
const,
until_filter,
)
def _joint_transaction_query(
user_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> Select[tuple[int, int, int]]:
"""
The inner query for getting joint transactions relevant to a user.
This scans for JOINT_BUY_PRODUCT transactions made by the user,
then finds the corresponding JOINT transactions, and counts how many "shares"
of the joint transaction the user has, as well as the total number of shares.
"""
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
# First, select all joint buy product transactions for the given user
# sub_joint_transaction = aliased(Transaction, name="right_trx")
sub_joint_transaction = (
select(Transaction.joint_transaction_id.distinct().label("joint_transaction_id"))
.where(
Transaction.type_ == TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
Transaction.user_id == user_id,
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
transaction_time=Transaction.time,
),
)
.subquery("sub_joint_transaction")
)
# Join those with their main joint transaction
# (just use Transaction)
# Then, count how many users are involved in each joint transaction
joint_transaction_count = aliased(Transaction, name="count_trx")
joint_transaction = (
select(
Transaction.id,
# Shares the user has in the transaction,
func.sum(
case(
(joint_transaction_count.user_id == user_id, CONST_ONE),
else_=CONST_ZERO,
),
).label("user_shares"),
# The total number of shares in the transaction,
func.count(joint_transaction_count.id).label("user_count"),
)
.select_from(sub_joint_transaction)
.join(
Transaction,
onclause=Transaction.id == sub_joint_transaction.c.joint_transaction_id,
)
.join(
joint_transaction_count,
onclause=joint_transaction_count.joint_transaction_id == Transaction.id,
)
.group_by(joint_transaction_count.joint_transaction_id)
)
return joint_transaction
def _non_joint_transaction_query(
user_id: BindParameter[int] | int,
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> Select[tuple[int, None, None]]:
"""
The inner query for getting non-joint transactions relevant to a user.
"""
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
query = select(
Transaction.id,
CONST_NONE.label("user_shares"),
CONST_NONE.label("user_count"),
).where(
or_(
and_(
Transaction.user_id == user_id,
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_BALANCE.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.TRANSFER.as_literal_column(),
],
),
),
and_(
Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
Transaction.transfer_user_id == user_id,
),
Transaction.type_.in_(
[
TransactionType.THROW_PRODUCT.as_literal_column(),
TransactionType.ADJUST_INTEREST.as_literal_column(),
TransactionType.ADJUST_PENALTY.as_literal_column(),
],
),
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
)
return query
def _product_cost_expression(
product_count_column: KeyedColumnElement[int],
product_id_column: KeyedColumnElement[int],
interest_rate_percent_column: KeyedColumnElement[int],
user_balance_column: KeyedColumnElement[int],
penalty_threshold_column: KeyedColumnElement[int],
penalty_multiplier_percent_column: KeyedColumnElement[int],
joint_user_shares_column: KeyedColumnElement[int],
joint_user_count_column: KeyedColumnElement[int],
use_cache: bool = True,
until_time: BindParameter[datetime] | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
cte_name: str = "product_price_cte",
trx_subset_name: str = "product_price_trx_subset",
):
# TODO: This can get quite expensive real quick, so we should do some caching of the
# product prices somehow.
expression = (
select(
cast(
func.ceil(
# Base price
(
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
)
# Interest
+ (
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
* cast(interest_rate_percent_column - const(100), Float)
/ const(100.0)
)
# Penalty
+ (
(
cast(
column("price") * product_count_column * joint_user_shares_column,
Float,
)
/ joint_user_count_column
)
* cast(penalty_multiplier_percent_column - const(100), Float)
/ const(100.0)
* cast(user_balance_column < penalty_threshold_column, Integer)
),
),
Integer,
),
)
.select_from(
_product_price_query(
product_id_column,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
cte_name=cte_name,
trx_subset_name=trx_subset_name,
),
)
.order_by(column("i").desc())
.limit(CONST_ONE)
.scalar_subquery()
)
return expression
def _user_balance_query( def _user_balance_query(
user_id: BindParameter[int] | int, user_id: BindParameter[int] | int,
use_cache: bool = True, use_cache: bool = True,
until_time: BindParameter[datetime] | None = None, until: BindParameter[datetime] | BindParameter[None] | datetime | None = None,
until_transaction: Transaction | None = None, until_including: BindParameter[bool] | bool = True,
until_inclusive: bool = True,
cte_name: str = "rec_cte", cte_name: str = "rec_cte",
trx_subset_name: str = "trx_subset",
) -> CTE: ) -> CTE:
""" """
The inner query for calculating the user's balance. The inner query for calculating the user's balance.
@@ -265,44 +47,30 @@ def _user_balance_query(
if isinstance(user_id, int): if isinstance(user_id, int):
user_id = BindParameter("user_id", value=user_id) user_id = BindParameter("user_id", value=user_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until, type_=datetime)
if isinstance(until_including, bool):
until_including = BindParameter("until_including", value=until_including, type_=bool)
initial_element = select( initial_element = select(
CONST_ZERO.label("i"), CONST_ZERO.label("i"),
CONST_ZERO.label("time"), CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"), CONST_NONE.label("transaction_id"),
CONST_ZERO.label("balance"), CONST_ZERO.label("balance"),
const(DEFAULT_INTEREST_RATE_PERCENT).label("interest_rate_percent"), const(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"), const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
const(DEFAULT_PENALTY_MULTIPLIER_PERCENT).label("penalty_multiplier_percent"), const(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"),
) )
recursive_cte = initial_element.cte(name=cte_name, recursive=True) recursive_cte = initial_element.cte(name=cte_name, recursive=True)
trx_subset_subset = (
_non_joint_transaction_query(
user_id=user_id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
)
.union_all(
_joint_transaction_query(
user_id=user_id,
use_cache=use_cache,
until_time=until_time,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
),
)
.subquery(f"{trx_subset_name}_subset")
)
# Subset of transactions that we'll want to iterate over. # Subset of transactions that we'll want to iterate over.
trx_subset = ( trx_subset = (
select( select(
func.row_number().over(order_by=Transaction.time.asc()).label("i"), func.row_number().over(order_by=Transaction.time.asc()).label("i"),
Transaction.id,
Transaction.amount, Transaction.amount,
Transaction.id,
Transaction.interest_rate_percent, Transaction.interest_rate_percent,
Transaction.penalty_multiplier_percent, Transaction.penalty_multiplier_percent,
Transaction.penalty_threshold, Transaction.penalty_threshold,
@@ -311,16 +79,44 @@ def _user_balance_query(
Transaction.time, Transaction.time,
Transaction.transfer_user_id, Transaction.transfer_user_id,
Transaction.type_, Transaction.type_,
trx_subset_subset.c.user_shares,
trx_subset_subset.c.user_count,
) )
.select_from(trx_subset_subset) .where(
.join( or_(
Transaction, and_(
onclause=Transaction.id == trx_subset_subset.c.id, Transaction.user_id == user_id,
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_BALANCE.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.TRANSFER.as_literal_column(),
# TODO: join this with the JOINT transactions, and determine
# how much the current user paid for the product.
TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
]
),
),
and_(
Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
Transaction.transfer_user_id == user_id,
),
Transaction.type_.in_(
[
TransactionType.THROW_PRODUCT.as_literal_column(),
TransactionType.ADJUST_INTEREST.as_literal_column(),
TransactionType.ADJUST_PENALTY.as_literal_column(),
]
),
),
case(
(until_including, Transaction.time <= until),
else_=Transaction.time < until,
)
if until is not None
else CONST_TRUE,
) )
.order_by(Transaction.time.asc()) .order_by(Transaction.time.asc())
.subquery(trx_subset_name) .alias("trx_subset")
) )
recursive_elements = ( recursive_elements = (
@@ -343,43 +139,49 @@ def _user_balance_query(
( (
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.balance recursive_cte.c.balance
- _product_cost_expression( - (
product_count_column=trx_subset.c.product_count, trx_subset.c.product_count
product_id_column=trx_subset.c.product_id, # Price of a single product, accounted for penalties and interest.
interest_rate_percent_column=recursive_cte.c.interest_rate_percent, * cast(
user_balance_column=recursive_cte.c.balance, func.ceil(
penalty_threshold_column=recursive_cte.c.penalty_threshold, # TODO: This can get quite expensive real quick, so we should do some caching of the
penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent, # product prices somehow.
joint_user_shares_column=CONST_ONE, # Base price
joint_user_count_column=CONST_ONE, (
use_cache=use_cache, # FIXME: this always returns 0 for some reason...
until_time=until_time, select(cast(column("price"), Float))
until_transaction=until_transaction, .select_from(
until_inclusive=until_inclusive, _product_price_query(
cte_name=f"{cte_name}_price", trx_subset.c.product_id,
trx_subset_name=f"{trx_subset_name}_price", use_cache=use_cache,
).label("product_cost"), until=trx_subset.c.time,
), until_including=False,
# Joint transaction -> balance decreases proportionally cte_name="product_price_cte",
( )
trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(), )
recursive_cte.c.balance .order_by(column("i").desc())
- _product_cost_expression( .limit(CONST_ONE)
product_count_column=trx_subset.c.product_count, ).scalar_subquery()
product_id_column=trx_subset.c.product_id, # TODO: should interest be applied before or after the penalty multiplier?
interest_rate_percent_column=recursive_cte.c.interest_rate_percent, # at the moment of writing, after sound right, but maybe ask someone?
user_balance_column=recursive_cte.c.balance, # Interest
penalty_threshold_column=recursive_cte.c.penalty_threshold, * (cast(recursive_cte.c.interest_rate_percent, Float) / const(100))
penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent, # TODO: these should be added together, not multiplied, see specification
joint_user_shares_column=trx_subset.c.user_shares, # Penalty
joint_user_count_column=trx_subset.c.user_count, * case(
use_cache=use_cache, (
until_time=until_time, recursive_cte.c.balance < recursive_cte.c.penalty_threshold,
until_transaction=until_transaction, (
until_inclusive=until_inclusive, cast(recursive_cte.c.penalty_multiplier_percent, Float)
cte_name=f"{cte_name}_joint_price", / const(100)
trx_subset_name=f"{trx_subset_name}_joint_price", ),
).label("joint_product_cost"), ),
else_=const(1.0),
)
),
Integer,
)
),
), ),
# Transfers money to self -> balance increases # Transfers money to self -> balance increases
( (
@@ -398,7 +200,8 @@ def _user_balance_query(
recursive_cte.c.balance - trx_subset.c.amount, recursive_cte.c.balance - trx_subset.c.amount,
), ),
# Throws a product -> if the user is considered to have bought it, balance increases # Throws a product -> if the user is considered to have bought it, balance increases
# TODO: # ( # TODO:
# (
# trx_subset.c.type_ == TransactionType.THROW_PRODUCT, # trx_subset.c.type_ == TransactionType.THROW_PRODUCT,
# recursive_cte.c.balance + trx_subset.c.amount, # recursive_cte.c.balance + trx_subset.c.amount,
# ), # ),
@@ -452,16 +255,17 @@ class UserBalanceLogEntry:
Returns whether this exact transaction is penalized. Returns whether this exact transaction is penalized.
""" """
raise NotImplementedError("is_penalized is not implemented yet.") return False
# return self.transaction.type_ == TransactionType.BUY_PRODUCT and prev?
# TODO: add until datetime parameter
def user_balance_log( def user_balance_log(
sql_session: Session, sql_session: Session,
user: User, user: User,
use_cache: bool = True, use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None, until: Transaction | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[UserBalanceLogEntry]: ) -> list[UserBalanceLogEntry]:
""" """
Returns a log of the user's balance over time, including interest and penalty adjustments. Returns a log of the user's balance over time, including interest and penalty adjustments.
@@ -472,18 +276,13 @@ def user_balance_log(
if user.id is None: if user.id is None:
raise ValueError("User must be persisted in the database.") raise ValueError("User must be persisted in the database.")
if not (until_time is None or until_transaction is None): if until is not None and until.id is None:
raise ValueError("Cannot filter by both until_time and until_transaction.") raise ValueError("'until' transaction must be persisted in the database.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
recursive_cte = _user_balance_query( recursive_cte = _user_balance_query(
user.id, user.id,
use_cache=use_cache, use_cache=use_cache,
until_time=until_time, until=until.time if until else None,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
) )
result = sql_session.execute( result = sql_session.execute(
@@ -499,13 +298,13 @@ def user_balance_log(
Transaction, Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id, onclause=Transaction.id == recursive_cte.c.transaction_id,
) )
.order_by(recursive_cte.c.i.asc()), .order_by(recursive_cte.c.i.asc())
).all() ).all()
if result is None: if result is None:
# If there are no transactions for this user, the query should return 0, not None. # If there are no transactions for this user, the query should return 0, not None.
raise RuntimeError( raise RuntimeError(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id}).", f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
) )
return [ return [
@@ -520,13 +319,13 @@ def user_balance_log(
] ]
# TODO: add until datetime parameter
def user_balance( def user_balance(
sql_session: Session, sql_session: Session,
user: User, user: User,
use_cache: bool = True, use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None, until: Transaction | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> int: ) -> int:
""" """
Calculates the balance of a user. Calculates the balance of a user.
@@ -537,31 +336,23 @@ def user_balance(
if user.id is None: if user.id is None:
raise ValueError("User must be persisted in the database.") raise ValueError("User must be persisted in the database.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
recursive_cte = _user_balance_query( recursive_cte = _user_balance_query(
user.id, user.id,
use_cache=use_cache, use_cache=use_cache,
until_time=until_time, until=until.time if until else None,
until_transaction=until_transaction,
until_inclusive=until_inclusive,
) )
result = sql_session.scalar( result = sql_session.scalar(
select(recursive_cte.c.balance) select(recursive_cte.c.balance)
.order_by(recursive_cte.c.i.desc()) .order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE) .limit(CONST_ONE)
.offset(CONST_ZERO), .offset(CONST_ZERO)
) )
if result is None: if result is None:
# If there are no transactions for this user, the query should return 0, not None. # If there are no transactions for this user, the query should return 0, not None.
raise RuntimeError( raise RuntimeError(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id}).", f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
) )
return result return result
+1 -39
View File
@@ -1,11 +1,4 @@
from datetime import datetime # This absoulutely needs a cache, else we can't stop recursing until we know all owners for all products...
from sqlalchemy import BindParameter, bindparam
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
# NOTE: This absolutely needs a cache, else we can't stop recursing until we know all owners for all products...
# #
# Since we know that the non-owned products will not get renowned by the user by other means, # Since we know that the non-owned products will not get renowned by the user by other means,
# we can just check for ownership on the products that have an ADD_PRODUCT transaction for the user. # we can just check for ownership on the products that have an ADD_PRODUCT transaction for the user.
@@ -15,34 +8,3 @@ from dibbler.models import Product, Transaction, User
# but we still need to check if the user passes out of ownership for the item, without needing to check past # but we still need to check if the user passes out of ownership for the item, without needing to check past
# the cache time. Maybe we also need to store the queue number(s) per user/product combo in the cache? What if # the cache time. Maybe we also need to store the queue number(s) per user/product combo in the cache? What if
# a user has products multiple places in the queue, interleaved with other users? # a user has products multiple places in the queue, interleaved with other users?
def user_products(
sql_session: Session,
user: User,
use_cache: bool = True,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: Transaction | None = None,
until_inclusive: bool = True,
) -> list[tuple[Product, int]]:
"""
Returns the list of products owned by the user, along with how many of each product they own.
"""
if user.id is None:
raise ValueError("User must be persisted in the database.")
if not (until_time is None or until_transaction is None):
raise ValueError("Cannot filter by both until_time and until_transaction.")
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
raise NotImplementedError("Not implemented yet, needs caching system first.")
+38 -70
View File
@@ -1,111 +1,79 @@
#!/usr/bin/python #!/usr/bin/python
# -*- coding: utf-8 -*-
import random import random
import sys import sys
import traceback import traceback
from signal import (
SIG_IGN,
SIGQUIT,
SIGTSTP,
)
from signal import (
signal as set_signal_handler,
)
from sqlalchemy.orm import Session
from ..conf import config from ..conf import config
from ..menus import ( from ..lib.helpers import *
AddProductMenu, from ..menus import *
AddStockMenu,
AddUserMenu,
AdjustCreditMenu,
AdjustStockMenu,
BalanceMenu,
BuyMenu,
CleanupStockMenu,
EditProductMenu,
EditUserMenu,
FAQMenu,
LoggedStatisticsMenu,
MainMenu,
Menu,
PrintLabelMenu,
ProductListMenu,
ProductPopularityMenu,
ProductRevenueMenu,
ProductSearchMenu,
ShowUserMenu,
TransferMenu,
UserListMenu,
)
random.seed() random.seed()
def main(sql_session: Session) -> None: def main():
if not config["general"]["stop_allowed"]: if not config.getboolean("general", "stop_allowed"):
set_signal_handler(SIGQUIT, SIG_IGN) signal.signal(signal.SIGQUIT, signal.SIG_IGN)
if not config["general"]["stop_allowed"]: if not config.getboolean("general", "stop_allowed"):
set_signal_handler(SIGTSTP, SIG_IGN) signal.signal(signal.SIGTSTP, signal.SIG_IGN)
main_menu = MainMenu( main = MainMenu(
sql_session, "Dibbler main menu",
items=[ items=[
BuyMenu(sql_session), BuyMenu(),
ProductListMenu(sql_session), ProductListMenu(),
ShowUserMenu(sql_session), ShowUserMenu(),
UserListMenu(sql_session), UserListMenu(),
AdjustCreditMenu(sql_session), AdjustCreditMenu(),
TransferMenu(sql_session), TransferMenu(),
AddStockMenu(sql_session), AddStockMenu(),
Menu( Menu(
"Add/edit", "Add/edit",
sql_session,
items=[ items=[
AddUserMenu(sql_session), AddUserMenu(),
EditUserMenu(sql_session), EditUserMenu(),
AddProductMenu(sql_session), AddProductMenu(),
EditProductMenu(sql_session), EditProductMenu(),
AdjustStockMenu(sql_session), AdjustStockMenu(),
CleanupStockMenu(sql_session), CleanupStockMenu(),
], ],
), ),
ProductSearchMenu(sql_session), ProductSearchMenu(),
Menu( Menu(
"Statistics", "Statistics",
sql_session,
items=[ items=[
ProductPopularityMenu(sql_session), ProductPopularityMenu(),
ProductRevenueMenu(sql_session), ProductRevenueMenu(),
BalanceMenu(sql_session), BalanceMenu(),
LoggedStatisticsMenu(sql_session), LoggedStatisticsMenu(),
], ],
), ),
FAQMenu(sql_session), FAQMenu(),
PrintLabelMenu(sql_session), PrintLabelMenu(),
], ],
exit_msg="happy happy joy joy", exit_msg="happy happy joy joy",
exit_confirm_msg="Really quit Dibbler?", exit_confirm_msg="Really quit Dibbler?",
) )
if not config["general"]["quit_allowed"]: if not config.getboolean("general", "quit_allowed"):
main_menu.exit_disallowed_msg = ( main.exit_disallowed_msg = "You can check out any time you like, but you can never leave."
"You can check out any time you like, but you can never leave."
)
while True: while True:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
main_menu.execute() main.execute()
except KeyboardInterrupt: except KeyboardInterrupt:
print("") print("")
print("Interrupted.") print("Interrupted.")
except: except:
print("Something went wrong.") print("Something went wrong.")
print(f"{sys.exc_info()[0]}: {sys.exc_info()[1]}") print(f"{sys.exc_info()[0]}: {sys.exc_info()[1]}")
if config["general"]["show_tracebacks"]: if config.getboolean("general", "show_tracebacks"):
traceback.print_tb(sys.exc_info()[2]) traceback.print_tb(sys.exc_info()[2])
else: else:
break break
print("Restarting main menu.") print("Restarting main menu.")
main_menu.sql_session.reset()
if __name__ == "__main__":
main()
+6 -4
View File
@@ -1,9 +1,11 @@
#!/usr/bin/python #!/usr/bin/python
from sqlalchemy.engine import Engine
from dibbler.models import Base from dibbler.models import Base
from dibbler.db import engine
def main(engine: Engine) -> None: def main():
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
if __name__ == "__main__":
main()
+1
View File
@@ -0,0 +1 @@
# TODO: implement me
+9 -7
View File
@@ -1,27 +1,29 @@
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from sqlalchemy.orm import Session from dibbler.db import Session
from dibbler.models import Product, Transaction, User from dibbler.models import Product, Transaction, User
from dibbler.queries import joint_buy_product from dibbler.queries import joint_buy_product
JSON_FILE = Path(__file__).parent.parent.parent / "mock_data.json" JSON_FILE = Path(__file__).parent.parent.parent / "mock_data.json"
def clear_db(sql_session: Session) -> None: # TODO: integrate this as a part of create-db, either asking interactively
# TODO: integrate this as a part of create-db, either asking interactively # whether to seed test data, or by using command line arguments for
# whether to seed test data, or by using command line arguments for # automatating the answer.
# automatating the answer.
def clear_db(sql_session):
sql_session.query(Product).delete() sql_session.query(Product).delete()
sql_session.query(User).delete() sql_session.query(User).delete()
sql_session.commit() sql_session.commit()
def main(sql_session: Session) -> None: def main():
# TODO: There is some leftover json data in the mock_data.json file. # TODO: There is some leftover json data in the mock_data.json file.
# It should be dealt with before merging this PR, either by removing # It should be dealt with before merging this PR, either by removing
# it or using it here. # it or using it here.
sql_session = Session()
clear_db(sql_session) clear_db(sql_session)
# Add users # Add users
+9 -4
View File
@@ -1,13 +1,18 @@
#!/usr/bin/python #!/usr/bin/python
from sqlalchemy.orm import Session from dibbler.db import Session
from dibbler.models import User from dibbler.models import User
def main(sql_session: Session) -> None: def main():
# Start an SQL session
session = Session()
# Let's find all users with a negative credit # Let's find all users with a negative credit
slabbedasker = sql_session.query(User).filter(User.credit < 0).all() slabbedasker = session.query(User).filter(User.credit < 0).all()
for slubbert in slabbedasker: for slubbert in slabbedasker:
print(f"{slubbert.name}, {slubbert.credit}") print(f"{slubbert.name}, {slubbert.credit}")
if __name__ == "__main__":
main()
+199 -199
View File
@@ -1,231 +1,231 @@
# #! /usr/bin/env python #! /usr/bin/env python
# # TODO: fixme # TODO: fixme
# # -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# import matplotlib.dates as mdates import matplotlib.dates as mdates
# from dibbler.lib.statistikkHelpers import * from dibbler.lib.statistikkHelpers import *
# def getInputType(): def getInputType():
# inp = 0 inp = 0
# while not (inp == "1" or inp == "2" or inp == "3" or inp == "4"): while not (inp == "1" or inp == "2" or inp == "3" or inp == "4"):
# print("type 1 for user-statistics") print("type 1 for user-statistics")
# print("type 2 for product-statistics") print("type 2 for product-statistics")
# print("type 3 for global-statistics") print("type 3 for global-statistics")
# print("type 4 to enter loop-mode") print("type 4 to enter loop-mode")
# inp = input("") inp = input("")
# return int(inp) return int(inp)
# def getDateFile(date, n): def getDateFile(date, n):
# try: try:
# if n == 0: if n == 0:
# inp = input("start date? (yyyy-mm-dd) ") inp = input("start date? (yyyy-mm-dd) ")
# elif n == -1: elif n == -1:
# inp = input("end date? (yyyy-mm-dd) ") inp = input("end date? (yyyy-mm-dd) ")
# year = inp.partition("-") year = inp.partition("-")
# month = year[2].partition("-") month = year[2].partition("-")
# return datetime.date(int(year[0]), int(month[0]), int(month[2])) return datetime.date(int(year[0]), int(month[0]), int(month[2]))
# except: except:
# print("invalid date, setting start start date") print("invalid date, setting start start date")
# if n == 0: if n == 0:
# print("to date found on first line") print("to date found on first line")
# elif n == -1: elif n == -1:
# print("to date found on last line") print("to date found on last line")
# print(date) print(date)
# return datetime.date( return datetime.date(
# int(date.partition("-")[0]), int(date.partition("-")[0]),
# int(date.partition("-")[2].partition("-")[0]), int(date.partition("-")[2].partition("-")[0]),
# int(date.partition("-")[2].partition("-")[2]), int(date.partition("-")[2].partition("-")[2]),
# ) )
# def dateToDateNumFile(date, startDate): def dateToDateNumFile(date, startDate):
# year = date.partition("-") year = date.partition("-")
# month = year[2].partition("-") month = year[2].partition("-")
# day = datetime.date(int(year[0]), int(month[0]), int(month[2])) day = datetime.date(int(year[0]), int(month[0]), int(month[2]))
# deltaDays = day - startDate deltaDays = day - startDate
# return int(deltaDays.days), day.weekday() return int(deltaDays.days), day.weekday()
# def getProducts(products): def getProducts(products):
# product = [] product = []
# products = products.partition("¤") products = products.partition("¤")
# product.append(products[0]) product.append(products[0])
# while products[1] == "¤": while products[1] == "¤":
# products = products[2].partition("¤") products = products[2].partition("¤")
# product.append(products[0]) product.append(products[0])
# return product return product
# def piePlot(dictionary, n): def piePlot(dictionary, n):
# keys = [] keys = []
# values = [] values = []
# i = 0 i = 0
# for key in sorted(dictionary, key=dictionary.get, reverse=True): for key in sorted(dictionary, key=dictionary.get, reverse=True):
# values.append(dictionary[key]) values.append(dictionary[key])
# if i < n: if i < n:
# keys.append(key) keys.append(key)
# i += 1 i += 1
# else: else:
# keys.append("") keys.append("")
# plt.pie(values, labels=keys) plt.pie(values, labels=keys)
# def datePlot(array, dateLine): def datePlot(array, dateLine):
# if not array == []: if not array == []:
# plt.bar(dateLine, array) plt.bar(dateLine, array)
# plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%b")) plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%b"))
# def dayPlot(array, days): def dayPlot(array, days):
# if not array == []: if not array == []:
# for i in range(7): for i in range(7):
# array[i] = array[i] * 7.0 / days array[i] = array[i] * 7.0 / days
# plt.bar(list(range(7)), array) plt.bar(list(range(7)), array)
# plt.xticks( plt.xticks(
# list(range(7)), list(range(7)),
# [ [
# " mon", " mon",
# " tue", " tue",
# " wed", " wed",
# " thu", " thu",
# " fri", " fri",
# " sat", " sat",
# " sun", " sun",
# ], ],
# ) )
# def graphPlot(array, dateLine): def graphPlot(array, dateLine):
# if not array == []: if not array == []:
# plt.plot(dateLine, array) plt.plot(dateLine, array)
# plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%b")) plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%b"))
# def plotUser(database, dateLine, user, n): def plotUser(database, dateLine, user, n):
# printUser(database, dateLine, user, n) printUser(database, dateLine, user, n)
# plt.subplot(221) plt.subplot(221)
# piePlot(database.personVareAntall[user], n) piePlot(database.personVareAntall[user], n)
# plt.xlabel("antall varer kjøpt gjengitt i antall") plt.xlabel("antall varer kjøpt gjengitt i antall")
# plt.subplot(222) plt.subplot(222)
# datePlot(database.personDatoVerdi[user], dateLine) datePlot(database.personDatoVerdi[user], dateLine)
# plt.xlabel("penger brukt over dato") plt.xlabel("penger brukt over dato")
# plt.subplot(223) plt.subplot(223)
# piePlot(database.personVareVerdi[user], n) piePlot(database.personVareVerdi[user], n)
# plt.xlabel("antall varer kjøpt gjengitt i verdi") plt.xlabel("antall varer kjøpt gjengitt i verdi")
# plt.subplot(224) plt.subplot(224)
# dayPlot(database.personUkedagVerdi[user], len(dateLine)) dayPlot(database.personUkedagVerdi[user], len(dateLine))
# plt.xlabel("forbruk over ukedager") plt.xlabel("forbruk over ukedager")
# plt.show() plt.show()
# def plotProduct(database, dateLine, product, n): def plotProduct(database, dateLine, product, n):
# printProduct(database, dateLine, product, n) printProduct(database, dateLine, product, n)
# plt.subplot(221) plt.subplot(221)
# piePlot(database.varePersonAntall[product], n) piePlot(database.varePersonAntall[product], n)
# plt.xlabel("personer som har handler produktet") plt.xlabel("personer som har handler produktet")
# plt.subplot(222) plt.subplot(222)
# datePlot(database.vareDatoAntall[product], dateLine) datePlot(database.vareDatoAntall[product], dateLine)
# plt.xlabel("antall produkter handlet per dag") plt.xlabel("antall produkter handlet per dag")
# # plt.subplot(223) # plt.subplot(223)
# plt.subplot(224) plt.subplot(224)
# dayPlot(database.vareUkedagAntall[product], len(dateLine)) dayPlot(database.vareUkedagAntall[product], len(dateLine))
# plt.xlabel("antall over ukedager") plt.xlabel("antall over ukedager")
# plt.show() plt.show()
# def plotGlobal(database, dateLine, n): def plotGlobal(database, dateLine, n):
# printGlobal(database, dateLine, n) printGlobal(database, dateLine, n)
# plt.subplot(231) plt.subplot(231)
# piePlot(database.globalVareVerdi, n) piePlot(database.globalVareVerdi, n)
# plt.xlabel("varer kjøpt gjengitt som verdi") plt.xlabel("varer kjøpt gjengitt som verdi")
# plt.subplot(232) plt.subplot(232)
# datePlot(database.globalDatoForbruk, dateLine) datePlot(database.globalDatoForbruk, dateLine)
# plt.xlabel("forbruk over dato") plt.xlabel("forbruk over dato")
# plt.subplot(233) plt.subplot(233)
# graphPlot(database.pengebeholdning, dateLine) graphPlot(database.pengebeholdning, dateLine)
# plt.xlabel("pengebeholdning over tid (negativ verdi utgjør samlet kreditt)") plt.xlabel("pengebeholdning over tid (negativ verdi utgjør samlet kreditt)")
# plt.subplot(234) plt.subplot(234)
# piePlot(database.globalPersonForbruk, n) piePlot(database.globalPersonForbruk, n)
# plt.xlabel("penger brukt av personer") plt.xlabel("penger brukt av personer")
# plt.subplot(235) plt.subplot(235)
# dayPlot(database.globalUkedagForbruk, len(dateLine)) dayPlot(database.globalUkedagForbruk, len(dateLine))
# plt.xlabel("forbruk over ukedager") plt.xlabel("forbruk over ukedager")
# plt.show() plt.show()
# def alt4menu(database, dateLine, useDatabase): def alt4menu(database, dateLine, useDatabase):
# n = 10 n = 10
# while 1: while 1:
# print( print(
# "\n1: user-statistics, 2: product-statistics, 3:global-statistics, n: adjust amount of data shown q:quit" "\n1: user-statistics, 2: product-statistics, 3:global-statistics, n: adjust amount of data shown q:quit"
# ) )
# try: try:
# inp = input("") inp = input("")
# except: except:
# continue continue
# if inp == "q": if inp == "q":
# break break
# elif inp == "1": elif inp == "1":
# if i == "0": if i == "0":
# user = input("input full username: ") user = input("input full username: ")
# else: else:
# user = getUser() user = getUser()
# plotUser(database, dateLine, user, n) plotUser(database, dateLine, user, n)
# elif inp == "2": elif inp == "2":
# if i == "0": if i == "0":
# product = input("input full product name: ") product = input("input full product name: ")
# else: else:
# product = getProduct() product = getProduct()
# plotProduct(database, dateLine, product, n) plotProduct(database, dateLine, product, n)
# elif inp == "3": elif inp == "3":
# plotGlobal(database, dateLine, n) plotGlobal(database, dateLine, n)
# elif inp == "n": elif inp == "n":
# try: try:
# n = int(input("set number to show ")) n = int(input("set number to show "))
# except: except:
# pass pass
# def main(): def main():
# inputType = getInputType() inputType = getInputType()
# i = input("0:fil, 1:database \n? ") i = input("0:fil, 1:database \n? ")
# if inputType == 1: if inputType == 1:
# if i == "0": if i == "0":
# user = input("input full username: ") user = input("input full username: ")
# else: else:
# user = getUser() user = getUser()
# product = "" product = ""
# elif inputType == 2: elif inputType == 2:
# if i == "0": if i == "0":
# product = input("input full product name: ") product = input("input full product name: ")
# else: else:
# product = getProduct() product = getProduct()
# user = "" user = ""
# else: else:
# product = "" product = ""
# user = "" user = ""
# if i == "0": if i == "0":
# inputFile = input("logfil? ") inputFile = input("logfil? ")
# if inputFile == "": if inputFile == "":
# inputFile = "default.dibblerlog" inputFile = "default.dibblerlog"
# database, dateLine = buildDatabaseFromFile(inputFile, inputType, product, user) database, dateLine = buildDatabaseFromFile(inputFile, inputType, product, user)
# else: else:
# database, dateLine = buildDatabaseFromDb(inputType, product, user) database, dateLine = buildDatabaseFromDb(inputType, product, user)
# if inputType == 1: if inputType == 1:
# plotUser(database, dateLine, user, 10) plotUser(database, dateLine, user, 10)
# if inputType == 2: if inputType == 2:
# plotProduct(database, dateLine, product, 10) plotProduct(database, dateLine, product, 10)
# if inputType == 3: if inputType == 3:
# plotGlobal(database, dateLine, 10) plotGlobal(database, dateLine, 10)
# if inputType == 4: if inputType == 4:
# alt4menu(database, dateLine, i) alt4menu(database, dateLine, i)
# if __name__ == "__main__": if __name__ == "__main__":
# main() main()
+1 -1
View File
@@ -1,6 +1,6 @@
from dibbler.db import Session from dibbler.db import Session
from dibbler.lib.render_transaction_log import render_transaction_log
from dibbler.queries import transaction_log from dibbler.queries import transaction_log
from dibbler.lib.render_transaction_log import render_transaction_log
def main() -> None: def main() -> None:
+28 -83
View File
@@ -1,4 +1,4 @@
# Dibbler economy spec v1 # Economics
This document provides an overview of how dibbler counts and calculates its running event log. This document provides an overview of how dibbler counts and calculates its running event log.
@@ -8,26 +8,20 @@ It is a sort of semi-formal specification for how dibbler's economy is intended
- All calculations involving money are done in whole numbers (integers). There are no fractional krs. - All calculations involving money are done in whole numbers (integers). There are no fractional krs.
- All rounding is done by rounding up to the nearest integer, in favor of the system economy - not the users. - All rounding is done by rounding up to the nearest integer, in favor of the system economy - not the users.
- All rounding is done as late as possible in calculations, to avoid rounding errors accumulating.
- The system allows negative stock counts, but acts a bit weirdly and potentially unfairly when that happens. - The system allows negative stock counts, but acts a bit weirdly and potentially unfairly when that happens.
The system should generally warn you about this, and recommend recounting the stock whenever it happens. The system should generally warn you about this, and recommend recounting the stock whenever it happens.
- Throughout the document, the penalty multiplier and interest rate are expressed as percentages in int (e.g. `penalty_multiplier = 150` means the prices should be multiplied by `1.5`, and `interest_rate = 120` means the prices should be multiplied by `1.2`).
## Adding products - product stock and product price ## Adding products - product stock and product price
This section covers what happens to the stock count and price of a product when a user adds more of that product to the system. This section covers what happens to the stock count and price of a product when a user adds more of that product to the system.
### Calculating the total value of products added
When a user adds a product, the resulting product price is averaged over the new products and the existing products. However, the new product price will become an integer. To avoid the economy going downwards, we round up the price after doing the averaging - i.e. in favor of the system, not the users.
### When the product count is `0` before adding. ### When the product count is `0` before adding.
When the product count is `0`, adding more of that product sets the product count to the amount added, and the product price will be set to the price of all products added divided by the number of products added, rounded up to the nearest integer. When the product count is `0`, adding more of that product sets the product count to the amount added, and the product price will be set to the price of all products added divided by the number of products added, rounded up to the nearest integer.
```python ```python
new_product_count: int = products_added new_product_count = products_added
new_product_price: int = math.ceil(total_value_of_products_added / products_added) new_product_price = math.ceil(total_value_of_products_added / products_added)
``` ```
### When the product count is greater than `0` before adding. ### When the product count is greater than `0` before adding.
@@ -35,8 +29,8 @@ new_product_price: int = math.ceil(total_value_of_products_added / products_adde
When the product count is greater than `0`, adding more of that product increases the product count by the amount added, and the product price will be recalculated as the total value of all existing products plus the total value of all newly added products, divided by the new total product count, rounded up to the nearest integer. When the product count is greater than `0`, adding more of that product increases the product count by the amount added, and the product price will be recalculated as the total value of all existing products plus the total value of all newly added products, divided by the new total product count, rounded up to the nearest integer.
```python ```python
new_product_count: int = product_count + products_added new_product_count = product_count + products_added
new_product_price: int = math.ceil((product_price * product_count + total_value_of_new_products_added) / new_product_count) new_product_price = math.ceil((total_value_of_existing_products + total_value_of_products_added) / new_product_count)
``` ```
### When the product count is less than `0` before adding. ### When the product count is less than `0` before adding.
@@ -50,11 +44,11 @@ When the product count is less than `0`, adding more of that product increases t
> [!WARN] > [!WARN]
> Note that this means that if you add products to a negative stock and the stock is still negative, > Note that this means that if you add products to a negative stock and the stock is still negative,
> the product price will be completely recalculated the next time someone adds the same product. > the product price will be completely recalculated the next time someone adds the same product.
> There will also be a noticeable effect if the stock goes from negative to positive. > There will also be a noticable effect if the stock goes from negative to positive.
```python ```python
new_product_count: int = product_count + products_added new_product_count = product_count + products_added
new_product_price: int = math.ceil(((product_price * max(product_count, 0)) + (total_value_of_new_products_added)) / new_product_count) new_product_price = math.ceil(((product_price * math.max(product_count, 0)) + (total_value_of_products_added)) / new_product_count)
``` ```
### A note about adding `0` items ### A note about adding `0` items
@@ -69,7 +63,7 @@ If a user attempts to add `0` items of a product, the system will not change the
When the product count is positive and a user buys an amount less than or equal to the current stock count, the product stock count will be decreased by the amount bought. When the product count is positive and a user buys an amount less than or equal to the current stock count, the product stock count will be decreased by the amount bought.
```python ```python
new_product_count: int = product_count - products_bought new_product_count = product_count - products_bought
``` ```
### When the product count is positive or `0` and you buy more than there are in stock ### When the product count is positive or `0` and you buy more than there are in stock
@@ -77,7 +71,7 @@ new_product_count: int = product_count - products_bought
When the product count is positive and a user buys an amount greater than the current stock count, the product stock count will be decreased by the amount bought, resulting in a negative stock count. When the product count is positive and a user buys an amount greater than the current stock count, the product stock count will be decreased by the amount bought, resulting in a negative stock count.
```python ```python
new_product_count: int = product_count - products_bought new_product_count = product_count - products_bought
``` ```
> [!NOTE] > [!NOTE]
@@ -88,7 +82,7 @@ new_product_count: int = product_count - products_bought
When the product count is negative, buying more of that product will further decrease the product stock count. When the product count is negative, buying more of that product will further decrease the product stock count.
```python ```python
new_product_count: int = product_count - products_bought new_product_count = product_count - products_bought
``` ```
> [!NOTE] > [!NOTE]
@@ -109,10 +103,10 @@ If a user attempts to buy `0` items of a product, the system will not change the
We have had some issues with the economy going in the negative, most likely due to users throwing away products gone bad. When the economy goes negative, we end up in a situation where users have money but there aren't really any products to buy, because the users don't have the incentive to add products back into the system to gain more balance. We have had some issues with the economy going in the negative, most likely due to users throwing away products gone bad. When the economy goes negative, we end up in a situation where users have money but there aren't really any products to buy, because the users don't have the incentive to add products back into the system to gain more balance.
To readjust the economy over time, there is an interest rate that will increase the amount you pay for each product by a certain percentage (the interest rate). This percentage can be adjusted by administrators when they see that the economy needs fixing. By default, the interest rate is set to `100%` (i.e., you don't pay anything extra). To readjust the economy over time, there is an interest rate that will increase the amount you pay for each product by a certain percentage (the interest rate). This percentage can be adjusted by administrators when they see that the economy needs fixing. By default, the interest rate is set to `0%`.
> [!NOTE] > [!NOTE]
> You can not go below `100%` interest rate. > You can not go below `0%` interest rate.
### What is penalty, and why do we need it ### What is penalty, and why do we need it
@@ -133,9 +127,7 @@ You gain balance equal to the total value of the products you add.
Note that this might be separate from the per-product cost of the products after you add them, due to rounding and price recalculation. Note that this might be separate from the per-product cost of the products after you add them, due to rounding and price recalculation.
```python ```python
new_user_balance: int = user_balance + total_value_of_products_added new_user_balance = user_balance + total_value_of_products_added
assert total_value_of_new_products_added >= product_price * products_added
``` ```
### When your existing balance is below the penalty threshold ### When your existing balance is below the penalty threshold
@@ -150,7 +142,7 @@ This case is the same as above.
You pay the normal product price for the products you buy, plus any interest. You pay the normal product price for the products you buy, plus any interest.
```python ```python
new_user_balance: int = user_balance - math.ceil(products_bought * product_price * (interest_rate / 100)) new_user_balance = user_balance - (products_bought * product_price * (1 + interest_rate))
``` ```
Note that the system performs a transaction for every product kind, so if you buy multiple different products in one go, the rounding is done per product kind. Note that the system performs a transaction for every product kind, so if you buy multiple different products in one go, the rounding is done per product kind.
@@ -162,67 +154,34 @@ You pay the penalized product price for the products you buy, plus any interest.
The interest and penalty are calculated separately before they are added together, *not* multiplied together. The interest and penalty are calculated separately before they are added together, *not* multiplied together.
```python ```python
base_cost: float = product_price * products_bought penalty = ((product_price * penalty_multiplier) - product_price)
penalty: float = (base_cost * (penalty_multiplier / 100)) - base_cost interest = (product_price * interest_rate)
interest: float = (base_cost * (interest_rate / 100)) - base_cost new_user_balance = user_balance - (products_bought * (product_price + penalty + interest))
new_user_balance: int = user_balance - math.ceil(base_cost + penalty + interest)
``` ```
### When your balance is above the penalty threshold before buying, but the purchase pushes you below the threshold ### When your balance is above the penalty threshold before buying, but the purchase pushes you below the threshold
When your balance is above the penalty threshold before buying, but the purchase pushes you below the threshold, the system not apply any penalty for the purchase. The entire purchase is done at the normal product price plus any interest. TODO:
```python ```python
new_user_balance: int = user_balance - math.ceil(products_bought * product_price * (interest_rate / 100))
``` ```
> [!NOTE]
> In the case where you are performing multiple transactions at once, the system should try its best to order the purchases in a way that minimizes the amount of penalties you need to pay.
### Joint purchases, when all users are above the penalty threshold and stays above the threshold ### Joint purchases, when all users are above the penalty threshold and stays above the threshold
When making joint purchases (multiple users buying products together), and all users are above the penalty threshold before and after the purchase, the total cost (including interest) will be split equally between all users. The price will be rounded up for each user after splitting the bill. TODO: how does rounding work here, does one user pay more than the other?
```python TODO: ordering the purchases in favor of the user.
total_cost: float = product_price * products_bought * (interest_rate / 100)
cost_per_user: float = total_cost / number_of_users
new_user_balance = user_balance - math.ceil(cost_per_user)
```
### Joint purchases where a user appears more than one time When performing joint purchases (multiple users
When a user appears more than once in a joint purchase (e.g. two people buying together, but one of them is buying twice as much as the other), the system will the amount of times a user appears in the purchase as a multiplier for the base price. You can think of it as if the user is having shares in the joint purchase.
```python
base_cost_for_user: float = product_price * products_bought * user_shares / total_user_shares
added_interest: float = base_cost_for_user * ((interest_rate - 100) / 100)
new_user_balance: int = user_balance - math.ceil(base_cost_for_user + added_interest)
```
### Joint purchases when one or more users are below the penalty threshold ### Joint purchases when one or more users are below the penalty threshold
The cost for each user will be calculated as usual, but for the users who are below the penalty threshold, the penalty will also be calculated and added to this user's cost. The penalty is calculated based on the share of the total purchase that this user is responsible for. TODO
```python
base_cost_for_user: float = product_price * products_bought * user_shares / total_user_shares
added_interest: float = base_cost_for_user * ((interest_rate - 100) / 100)
penalty: float = base_cost_for_user * ((penalty_multiplier - 100) / 100)
new_user_balance: int = user_balance - math.ceil(base_cost_for_user + added_interest + penalty)
```
### Joint purchases when one or more users will end up below the penalty threshold after the purchase ### Joint purchases when one or more users will end up below the penalty threshold after the purchase
Just as the single-user case, if a user who is part of a joint purchase is above the penalty threshold before the purchase, but will end up below the threshold after the purchase, no penalty will be applied to that user for this purchase. The entire cost (including interest) will be split equally between all users. TODO
```python
base_cost_for_user: float = product_price * products_bought * user_shares / total_user_shares
added_interest: float = base_cost_for_user * ((interest_rate - 100) / 100)
new_user_balance: int = user_balance - math.ceil(base_cost_for_user + added_interest)
```
> [!NOTE]
> In the case where you (and others) are performing multiple transactions at once, the system should try its best to order the purchases in a way that minimizes the amount of penalties you need to pay.
## Who owns a product ## Who owns a product
@@ -238,26 +197,12 @@ Upon throwing away products (not manual adjustment), the system will pull money
## Other actions ## Other actions
### Transfers Transfers
You can transfer money from one user to another. The amount transferred will be deducted from the sender's balance and added to the receiver's balance without any interest or penalty applied. Note about self-transfers
```python Balance adjustments
new_sender_balance: int = sender_balance - amount_transferred
new_receiver_balance: int = receiver_balance + amount_transferred
```
> [!NOTE]
> Transfers from one user to itself are not allowed.
### Balance adjustments
You can manually adjust a user's balance. This action will not have any multipliers of any kind applied, and will simply add or subtract the specified amount from the user's balance.
```python
new_user_balance: int = user_balance + adjustment_amount
```
## Updating the economy specification ## Updating the economy specification
All transactions in the database are tagged with the economy specification version they were created under. If you are to update this document with changes to how the economy works, and change the software accordingly, you will want to keep the old logic around and bump the version number. This way, the old event log is still valid, and will be aggregated using the old logic, while new transactions will user the logic applicable to the version they were created under. Keep old logic, database rows tagged with spec version.
+19
View File
@@ -0,0 +1,19 @@
[general]
quit_allowed = true
stop_allowed = false
show_tracebacks = true
input_encoding = 'utf8'
[database]
# url = "postgresql://robertem@127.0.0.1/pvvvv"
url = sqlite:///test.db
[limits]
low_credit_warning_limit = -100
user_recent_transaction_limit = 100
# See https://pypi.org/project/brother_ql/ for label types
# Set rotate to False for endless labels
[printer]
label_type = "62"
label_rotate = false
-35
View File
@@ -1,35 +0,0 @@
[general]
quit_allowed = true
stop_allowed = false
show_tracebacks = true
input_encoding = 'utf8'
[database]
type = 'sqlite'
[database.sqlite]
path = 'test.db'
[database.postgresql]
host = 'localhost'
# host = '/run/postgresql'
port = 5432
username = 'dibbler'
dbname = 'dibbler'
# You can either specify a path to a file containing the password,
# or just specify the password directly
# password = 'superhemlig'
# password_file = '/var/lib/dibbler/db-password'
[limits]
low_credit_warning_limit = -100
user_recent_transaction_limit = 100
# See https://pypi.org/project/brother_ql/ for label types
# Set rotate to False for endless labels
[printer]
label_type = '62'
label_rotate = false
Generated
+33
View File
@@ -1,5 +1,22 @@
{ {
"nodes": { "nodes": {
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"id": "flake-utils",
"type": "indirect"
}
},
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1764950072, "lastModified": 1764950072,
@@ -18,8 +35,24 @@
}, },
"root": { "root": {
"inputs": { "inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs" "nixpkgs": "nixpkgs"
} }
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
} }
}, },
"root": "root", "root": "root",
+27 -35
View File
@@ -3,7 +3,7 @@
inputs.nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; inputs.nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
outputs = { self, nixpkgs }: let outputs = { self, nixpkgs, flake-utils }: let
inherit (nixpkgs) lib; inherit (nixpkgs) lib;
systems = [ systems = [
@@ -17,57 +17,49 @@
pkgs = nixpkgs.legacyPackages.${system}; pkgs = nixpkgs.legacyPackages.${system};
in f system pkgs); in f system pkgs);
in { in {
apps = let packages = forAllSystems (system: pkgs: {
mkApp = program: description: { default = self.packages.${system}.dibbler;
type = "app"; dibbler = pkgs.callPackage ./nix/dibbler.nix {
program = toString program; python3Packages = pkgs.python312Packages;
meta = {
inherit description;
};
}; };
mkVm = name: mkApp "${self.nixosConfigurations.${name}.config.system.build.vm}/bin/run-nixos-vm"; skrot = self.nixosConfigurations.skrot.config.system.build.sdImage;
in forAllSystems (system: pkgs: {
default = self.apps.${system}.dibbler;
dibbler = let
app = pkgs.writeShellApplication {
name = "dibbler-with-default-config";
runtimeInputs = [ self.packages.${system}.dibbler ];
text = ''
dibbler -c ${./example-config.toml} "$@"
'';
};
in mkApp (lib.getExe app) "Run the dibbler cli with its default config against an SQLite database";
vm = mkVm "vm" "Start a NixOS VM with dibbler installed in kiosk-mode";
vm-non-kiosk = mkVm "vm-non-kiosk" "Start a NixOS VM with dibbler installed in nonkiosk-mode";
}); });
nixosModules.default = import ./nix/module.nix; apps = forAllSystems (system: pkgs: {
default = self.apps.${system}.dibbler;
nixosConfigurations = { dibbler = flake-utils.lib.mkApp {
vm = import ./nix/nixos-configurations/vm.nix { inherit self nixpkgs; }; drv = self.packages.${system}.dibbler;
vm-non-kiosk = import ./nix/nixos-configurations/vm-non-kiosk.nix { inherit self nixpkgs; }; };
}; });
overlays = { overlays = {
default = self.overlays.dibbler; default = self.overlays.dibbler;
dibbler = final: prev: { dibbler = final: prev: {
inherit (self.packages.${prev.stdenv.hostPlatform.system}) dibbler; inherit (self.packages.${prev.system}) dibbler;
}; };
}; };
devShells = forAllSystems (system: pkgs: { devShells = forAllSystems (system: pkgs: {
default = self.devShells.${system}.dibbler; default = self.devShells.${system}.dibbler;
dibbler = pkgs.callPackage ./nix/shell.nix { dibbler = pkgs.callPackage ./nix/shell.nix {
python3 = pkgs.python313; python3 = pkgs.python312;
}; };
}); });
packages = forAllSystems (system: pkgs: { # Note: using the module requires that you have applied the overlay first
default = self.packages.${system}.dibbler; nixosModules.default = import ./nix/module.nix;
dibbler = pkgs.callPackage ./nix/package.nix {
python3Packages = pkgs.python313Packages; nixosConfigurations.skrot = nixpkgs.lib.nixosSystem (rec {
inherit (self) sourceInfo; system = "aarch64-linux";
pkgs = import nixpkgs {
inherit system;
overlays = [ self.overlays.dibbler ];
}; };
modules = [
(nixpkgs + "/nixos/modules/installer/sd-card/sd-image-aarch64.nix")
self.nixosModules.default
./nix/skrott.nix
];
}); });
}; };
} }
+32
View File
@@ -0,0 +1,32 @@
{ lib
, python3Packages
, fetchFromGitHub
}:
python3Packages.buildPythonApplication {
pname = "dibbler";
version = "unstable";
src = lib.cleanSource ../.;
format = "pyproject";
# brother-ql is breaky breaky
# https://github.com/NixOS/nixpkgs/issues/285234
dontCheckRuntimeDeps = true;
pythonImportsCheck = [];
doCheck = true;
nativeCheckInputs = with python3Packages; [
pytest
pytestCheckHook
];
nativeBuildInputs = with python3Packages; [ setuptools ];
propagatedBuildInputs = with python3Packages; [
brother-ql
matplotlib
psycopg2-binary
python-barcode
sqlalchemy
];
}
+50 -159
View File
@@ -1,52 +1,13 @@
{ config, pkgs, lib, ... }: let { config, pkgs, lib, ... }: let
cfg = config.services.dibbler; cfg = config.services.dibbler;
format = pkgs.formats.toml { }; format = pkgs.formats.ini { };
in { in {
options.services.dibbler = { options.services.dibbler = {
enable = lib.mkEnableOption "dibbler, the little kiosk computer"; enable = lib.mkEnableOption "dibbler, the little kiosk computer";
package = lib.mkPackageOption pkgs "dibbler" { }; package = lib.mkPackageOption pkgs "dibbler" { };
screenPackage = lib.mkPackageOption pkgs "screen" { };
createLocalDatabase = lib.mkEnableOption "" // {
description = ''
Whether to set up a local postgres database automatically.
::: {.note}
You must set up postgres manually before enabling this option.
:::
'';
};
kioskMode = lib.mkEnableOption "" // {
description = ''
Whether to let dibbler take over the entire machine.
This will restrict the machine to a single TTY and make the program unquittable.
You can still get access to PTYs via SSH and similar, if enabled.
'';
};
limitScreenHeight = lib.mkOption {
type = with lib.types; nullOr ints.unsigned;
default = null;
example = 42;
description = ''
If set, limits the height of the screen dibbler uses to the given number of lines.
'';
};
limitScreenWidth = lib.mkOption {
type = with lib.types; nullOr ints.unsigned;
default = null;
example = 80;
description = ''
If set, limits the width of the screen dibbler uses to the given number of columns.
'';
};
settings = lib.mkOption { settings = lib.mkOption {
description = "Configuration for dibbler"; description = "Configuration for dibbler";
default = { }; default = { };
@@ -56,131 +17,61 @@ in {
}; };
}; };
config = lib.mkIf cfg.enable (lib.mkMerge [ config = let
{ screen = "${pkgs.screen}/bin/screen";
services.dibbler.settings = lib.pipe ../example-config.toml [ in lib.mkIf cfg.enable {
builtins.readFile services.dibbler.settings = lib.pipe ../example-config.ini [
builtins.fromTOML builtins.readFile
(lib.mapAttrsRecursive (_: lib.mkDefault)) builtins.fromTOML
]; (lib.mapAttrsRecursive (_: lib.mkDefault))
} ];
{
environment.systemPackages = [ cfg.package ];
environment.etc."dibbler/dibbler.toml".source = format.generate "dibbler.toml" cfg.settings; boot = {
consoleLogLevel = 0;
enableContainers = false;
loader.grub.enable = false;
};
users = { users = {
users.dibbler = { groups.dibbler = { };
group = "dibbler"; users.dibbler = {
isNormalUser = true; group = "dibbler";
};
groups.dibbler = { };
};
services.dibbler.settings.database = lib.mkIf cfg.createLocalDatabase {
type = "postgresql";
postgresql.host = "/run/postgresql";
};
services.postgresql = lib.mkIf cfg.createLocalDatabase {
ensureDatabases = [ "dibbler" ];
ensureUsers = [{
name = "dibbler";
ensureDBOwnership = true;
ensureClauses.login = true;
}];
};
systemd.services.dibbler-setup-database = lib.mkIf cfg.createLocalDatabase {
description = "Dibbler database setup";
wantedBy = [ "default.target" ];
after = [ "postgresql.service" ];
unitConfig = {
ConditionPathExists = "!/var/lib/dibbler/.db-setup-done";
};
serviceConfig = {
Type = "oneshot";
ExecStart = "${lib.getExe cfg.package} --config /etc/dibbler/dibbler.toml create-db";
ExecStartPost = "${lib.getExe' pkgs.coreutils "touch"} /var/lib/dibbler/.db-setup-done";
StateDirectory = "dibbler";
User = "dibbler";
Group = "dibbler";
};
};
}
(lib.mkIf cfg.kioskMode {
boot.kernelParams = [
"console=tty1"
];
users.users.dibbler = {
extraGroups = [ "lp" ]; extraGroups = [ "lp" ];
shell = (pkgs.writeShellScriptBin "login-shell" "${lib.getExe' cfg.screenPackage "screen"} -x dibbler") // { isNormalUser = true;
shellPath = "/bin/login-shell"; shell = (pkgs.writeShellScriptBin "login-shell" "${screen} -x dibbler") // {shellPath = "/bin/login-shell";};
};
}; };
};
services.dibbler.settings.general = { systemd.services.screen-daemon = {
quit_allowed = false; description = "Dibbler service screen";
stop_allowed = false; wantedBy = [ "default.target" ];
serviceConfig = {
ExecStartPre = "-${screen} -X -S dibbler kill";
ExecStart = let
config = format.generate "dibbler-config.ini" cfg.settings;
in "${screen} -dmS dibbler -O -l ${cfg.package}/bin/dibbler --config ${config} loop";
ExecStartPost = "${screen} -X -S dibbler width 42 80";
User = "dibbler";
Group = "dibbler";
Type = "forking";
RemainAfterExit = false;
Restart = "always";
RestartSec = "5s";
SuccessExitStatus = 1;
}; };
};
systemd.services.dibbler-screen-session = { # https://github.com/NixOS/nixpkgs/issues/84105
description = "Dibbler Screen Session"; boot.kernelParams = [
wantedBy = [ "console=ttyUSB0,9600"
"default.target" "console=tty1"
]; ];
after = if cfg.createLocalDatabase then [ systemd.services."serial-getty@ttyUSB0" = {
"postgresql.service" enable = true;
"dibbler-setup-database.service" wantedBy = [ "getty.target" ]; # to start at boot
] else [ serviceConfig.Restart = "always"; # restart when session is closed
"network.target" };
];
serviceConfig = {
Type = "forking";
RemainAfterExit = false;
Restart = "always";
RestartSec = "5s";
SuccessExitStatus = 1;
User = "dibbler"; services.getty.autologinUser = lib.mkForce "dibbler";
Group = "dibbler"; };
ExecStartPre = "-${lib.getExe' cfg.screenPackage "screen"} -X -S dibbler kill";
ExecStart = let
screenArgs = lib.escapeShellArgs [
# -dm creates the screen in detached mode without accessing it
"-dm"
# Session name
"-S"
"dibbler"
# Set optimal output mode instead of VT100 emulation
"-O"
# Enable login mode, updates utmp entries
"-l"
];
dibblerArgs = lib.cli.toCommandLineShellGNU { } {
config = "/etc/dibbler/dibbler.toml";
};
in "${lib.getExe' cfg.screenPackage "screen"} ${screenArgs} ${lib.getExe cfg.package} ${dibblerArgs} loop";
ExecStartPost =
lib.optionals (cfg.limitScreenWidth != null) [
"${lib.getExe' cfg.screenPackage "screen"} -X -S dibbler width ${toString cfg.limitScreenWidth}"
]
++ lib.optionals (cfg.limitScreenHeight != null) [
"${lib.getExe' cfg.screenPackage "screen"} -X -S dibbler height ${toString cfg.limitScreenHeight}"
];
};
};
services.getty.autologinUser = "dibbler";
})
]);
} }
-54
View File
@@ -1,54 +0,0 @@
{ self, nixpkgs, ... }:
nixpkgs.lib.nixosSystem {
system = "x86_64-linux";
pkgs = import nixpkgs {
system = "x86_64-linux";
overlays = [
self.overlays.dibbler
];
};
modules = [
"${nixpkgs}/nixos/modules/virtualisation/qemu-vm.nix"
"${nixpkgs}/nixos/tests/common/user-account.nix"
self.nixosModules.default
({ config, ... }: {
system.stateVersion = config.system.nixos.release;
virtualisation.graphics = false;
users.motd = ''
=================================
Welcome to the dibbler non-kiosk vm!
Try running:
${config.services.dibbler.package.meta.mainProgram} loop
Password for dibbler is 'dibbler'
To exit, press Ctrl+A, then X
=================================
'';
users.users.dibbler = {
isNormalUser = true;
password = "dibbler";
extraGroups = [ "wheel" ];
};
services.getty.autologinUser = "dibbler";
programs.vim = {
enable = true;
defaultEditor = true;
};
services.postgresql.enable = true;
services.dibbler = {
enable = true;
createLocalDatabase = true;
};
})
];
}
-29
View File
@@ -1,29 +0,0 @@
{ self, nixpkgs, ... }:
nixpkgs.lib.nixosSystem {
system = "x86_64-linux";
pkgs = import nixpkgs {
system = "x86_64-linux";
overlays = [
self.overlays.default
];
};
modules = [
"${nixpkgs}/nixos/modules/virtualisation/qemu-vm.nix"
"${nixpkgs}/nixos/tests/common/user-account.nix"
self.nixosModules.default
({ config, ... }: {
system.stateVersion = config.system.nixos.release;
virtualisation.graphics = false;
services.postgresql.enable = true;
services.dibbler = {
enable = true;
createLocalDatabase = true;
kioskMode = true;
};
})
];
}
-61
View File
@@ -1,61 +0,0 @@
{ lib
, sourceInfo
, python3Packages
, makeWrapper
, less
}:
let
pyproject = builtins.fromTOML (builtins.readFile ../pyproject.toml);
in
python3Packages.buildPythonApplication {
pname = pyproject.project.name;
version = "0.1";
src = lib.cleanSource ../.;
format = "pyproject";
# brother-ql is breaky breaky
# https://github.com/NixOS/nixpkgs/issues/285234
# dontCheckRuntimeDeps = true;
env.SETUPTOOLS_SCM_PRETEND_METADATA = (x: "{${x}}") (lib.concatStringsSep ", " [
"node=\"${sourceInfo.rev or (lib.substring 0 64 sourceInfo.dirtyRev)}\""
"node_date=${lib.substring 0 4 sourceInfo.lastModifiedDate}-${lib.substring 4 2 sourceInfo.lastModifiedDate}-${lib.substring 6 2 sourceInfo.lastModifiedDate}"
"dirty=${if sourceInfo ? dirtyRev then "true" else "false"}"
]);
nativeBuildInputs = with python3Packages; [
makeWrapper
setuptools
setuptools-scm
];
propagatedBuildInputs = with python3Packages; [
# brother-ql
# matplotlib
psycopg2-binary
# python-barcode
sqlalchemy
];
pythonImportsCheck = [];
doCheck = true;
nativeCheckInputs = with python3Packages; [
pytest
pytestCheckHook
sqlparse
pytest-html
pytest-cov
pytest-benchmark
];
postInstall = ''
wrapProgram $out/bin/dibbler \
--prefix PATH : "${lib.makeBinPath [ less ]}"
'';
meta = {
description = "The little kiosk that could";
mainProgram = "dibbler";
};
}
+3 -5
View File
@@ -10,18 +10,16 @@ mkShell {
ruff ruff
uv uv
(python3.withPackages (ps: with ps; [ (python3.withPackages (ps: with ps; [
# brother-ql brother-ql
# matplotlib matplotlib
psycopg2 psycopg2
# python-barcode python-barcode
sqlalchemy sqlalchemy
sqlparse sqlparse
pytest pytest
pytest-cov pytest-cov
pytest-html pytest-html
pytest-benchmark
pygal
])) ]))
]; ];
} }
+27
View File
@@ -0,0 +1,27 @@
{...}: {
system.stateVersion = "25.05";
services.dibbler.enable = true;
networking = {
hostName = "skrot";
domain = "pvv.ntnu.no";
nameservers = [ "129.241.0.200" "129.241.0.201" ];
defaultGateway = "129.241.210.129";
interfaces.eth0 = {
useDHCP = false;
ipv4.addresses = [{
address = "129.241.210.235";
prefixLength = 25;
}];
};
};
# services.resolved.enable = true;
# systemd.network.enable = true;
# systemd.network.networks."30-network" = {
# matchConfig.Name = "*";
# DHCP = "no";
# address = [ "129.241.210.235/25" ];
# gateway = [ "129.241.210.129" ];
# };
}
+8 -66
View File
@@ -1,16 +1,10 @@
[build-system] [build-system]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
requires = [
"setuptools",
"setuptools-scm",
]
[project] [project]
name = "dibbler" name = "dibbler"
dynamic = ["version"] authors = []
authors = [
{ name = "Programvareverkstedet", email = "projects@pvv.ntnu.no" }
]
description = "EDB-system for PVV" description = "EDB-system for PVV"
readme = "README.md" readme = "README.md"
requires-python = ">=3.11" requires-python = ">=3.11"
@@ -19,12 +13,12 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"SQLAlchemy >= 2.0, <2.1", "SQLAlchemy >= 2.0, <2.1",
# "brother-ql", "brother-ql",
# "matplotlib", "matplotlib",
"psycopg2-binary >= 2.8, <2.10", "psycopg2-binary >= 2.8, <2.10",
# "python-barcode", "python-barcode",
] ]
scripts.dibbler = "dibbler.main:main" dynamic = ["version"]
[dependency-groups] [dependency-groups]
test = [ test = [
@@ -33,68 +27,16 @@ test = [
"coverage-badge>=1.1.2", "coverage-badge>=1.1.2",
"pytest-html>=4.1.1", "pytest-html>=4.1.1",
"sqlparse>=0.5.4", "sqlparse>=0.5.4",
"pytest-benchmark[histogram]>=5.2.3",
] ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
include = ["dibbler*"] include = ["dibbler*"]
[tool.setuptools_scm] [project.scripts]
version_file = "dibbler/_version.py" dibbler = "dibbler.main:main"
[tool.black] [tool.black]
line-length = 100 line-length = 100
[tool.ruff] [tool.ruff]
line-length = 100 line-length = 100
[tool.ruff.lint]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"COM", # flake8-commas
"ANN",
# "E", # pycodestyle
# "F", # Pyflakes
"FA", # flake8-future-annotations
"I", # isort
"S", # flake8-bandit
"ICN", # flake8-import-conventions
"ISC", # flake8-implicit-str-concat
# "N", # pep8-naming
"PTH", # flake8-use-pathlib
# "RET", # flake8-return
# "SIM", # flake8-simplify
"TC", # flake8-type-checking
"UP", # pyupgrade
"YTT", # flake8-2020
]
ignore = [
"E501", # line too long
"S101", # assert detected
"S311", # non-cryptographic random generator
]
[tool.ruff.lint.flake8-annotations]
suppress-dummy-args = true
ignore-fully-untyped = true
[tool.pytest.ini_options]
addopts = [
"--cov=dibbler.lib",
"--cov=dibbler.models",
"--cov=dibbler.queries",
"--cov-report=html",
"--cov-branch",
"--self-contained-html",
"--html=./test-report/index.html",
"--benchmark-skip",
"--benchmark-autosave",
"--benchmark-save=default",
"--benchmark-verbose",
"--benchmark-storage=benchmark",
"--benchmark-histogram=benchmark/histogram",
]
-11
View File
@@ -1,11 +0,0 @@
TRANSACTION_GENERATOR_EXCEPTION_LIMIT = 15
"""
The random transaction generator uses a set seed to generate transactions.
However, not all transactions are valid in all contexts. We catch illegal
generated transactions with a try/except, and retry until we generate a valid
one. However, if we exceed this limit, something is likely wrong with the generator
instead, due to the unlikely high number of exceptions.
"""
BENCHMARK_ITERATIONS = 5
BENCHMARK_ROUNDS = 3
-306
View File
@@ -1,306 +0,0 @@
import random
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, TransactionType, User
from dibbler.queries import joint_buy_product
from tests.benchmark.benchmark_settings import TRANSACTION_GENERATOR_EXCEPTION_LIMIT
def insert_users_and_products(
sql_session: Session,
user_count: int = 10,
product_count: int = 10,
) -> tuple[list[User], list[Product]]:
users = []
for i in range(user_count):
user = User(f"User{i + 1}")
sql_session.add(user)
users.append(user)
sql_session.commit()
products = []
for i in range(product_count):
barcode = str(1000000000000 + i)
product = Product(barcode, f"Product{i + 1}")
sql_session.add(product)
products.append(product)
sql_session.commit()
return users, products
def generate_random_transactions(
sql_session: Session,
n: int,
seed: int = 42,
transaction_type_filter: list[TransactionType] | None = None,
distribution: dict[TransactionType, float] | None = None,
cache_every_n: int | None = None,
) -> list[Transaction]:
random.seed(seed)
if transaction_type_filter is None:
transaction_type_filter = list(TransactionType)
if TransactionType.JOINT_BUY_PRODUCT in transaction_type_filter:
transaction_type_filter.remove(TransactionType.JOINT_BUY_PRODUCT)
# TODO: implement me
if TransactionType.THROW_PRODUCT in transaction_type_filter:
transaction_type_filter.remove(TransactionType.THROW_PRODUCT)
if distribution is None:
distribution = {t: 1 / len(transaction_type_filter) for t in transaction_type_filter}
transaction_types = list(distribution.keys())
weights = list(distribution.values())
transactions: list[Transaction] = []
last_time = datetime(2023, 1, 1, 0, 0, 0)
for _ in range(n):
transaction_type = random.choices(transaction_types, weights=weights, k=1)[0]
generator = RANDOM_GENERATORS[transaction_type]
transaction_or_transactions = generator(sql_session, last_time)
if isinstance(transaction_or_transactions, list):
transactions.extend(transaction_or_transactions)
last_time = max(t.time for t in transaction_or_transactions)
else:
transactions.append(transaction_or_transactions)
last_time = transaction_or_transactions.time
return transactions
def random_add_product_transaction(sql_session: Session, last_time: datetime) -> Transaction:
i = 0
while True:
i += 1
user = random.choice(sql_session.query(User).all())
product = random.choice(sql_session.query(Product).all())
product_count = random.randint(1, 10)
product_price = random.randint(15, 45)
amount = product_count * product_price + random.randint(-7, 0)
new_datetime = last_time + timedelta(minutes=random.randint(1, 60))
try:
transaction = Transaction.add_product(
amount,
user.id,
product.id,
product_price,
product_count,
time=new_datetime,
)
except Exception:
if i > TRANSACTION_GENERATOR_EXCEPTION_LIMIT:
raise RuntimeError(
"Too many failed attempts to create a valid transaction, consider changing the seed",
)
continue
return transaction
def random_adjust_balance_transaction(sql_session: Session, last_time: datetime) -> Transaction:
i = 0
while True:
i += 1
user = random.choice(sql_session.query(User).all())
amount = random.randint(-50, 100)
if amount == 0:
amount = 1
new_datetime = last_time + timedelta(minutes=random.randint(1, 60))
try:
transaction = Transaction.adjust_balance(
amount,
user.id,
time=new_datetime,
)
except Exception:
if i > TRANSACTION_GENERATOR_EXCEPTION_LIMIT:
raise RuntimeError(
"Too many failed attempts to create a valid transaction, consider changing the seed",
)
continue
return transaction
def random_adjust_interest_transaction(sql_session: Session, last_time: datetime) -> Transaction:
i = 0
while True:
i += 1
user = random.choice(sql_session.query(User).all())
amount = random.randint(100, 105)
new_datetime = last_time + timedelta(minutes=random.randint(1, 60))
try:
transaction = Transaction.adjust_interest(
amount,
user.id,
time=new_datetime,
)
except Exception:
if i > TRANSACTION_GENERATOR_EXCEPTION_LIMIT:
raise RuntimeError(
"Too many failed attempts to create a valid transaction, consider changing the seed",
)
continue
return transaction
def random_adjust_penalty_transaction(sql_session: Session, last_time: datetime) -> Transaction:
i = 0
while True:
i += 1
user = random.choice(sql_session.query(User).all())
penalty_multiplier_percent = random.randint(100, 200)
penalty_threshold = random.randint(-150, -50)
new_datetime = last_time + timedelta(minutes=random.randint(1, 60))
try:
transaction = Transaction.adjust_penalty(
penalty_multiplier_percent,
penalty_threshold,
user.id,
time=new_datetime,
)
except Exception:
if i > TRANSACTION_GENERATOR_EXCEPTION_LIMIT:
raise RuntimeError(
"Too many failed attempts to create a valid transaction, consider changing the seed",
)
continue
return transaction
def random_adjust_stock_transaction(sql_session: Session, last_time: datetime) -> Transaction:
i = 0
while True:
i += 1
user = random.choice(sql_session.query(User).all())
product = random.choice(sql_session.query(Product).all())
stock_change = random.randint(-5, 6)
if stock_change == 0:
stock_change = 1
new_datetime = last_time + timedelta(minutes=random.randint(1, 60))
try:
transaction = Transaction.adjust_stock(
user_id=user.id,
product_id=product.id,
product_count=stock_change,
time=new_datetime,
)
except Exception:
if i > TRANSACTION_GENERATOR_EXCEPTION_LIMIT:
raise RuntimeError(
"Too many failed attempts to create a valid transaction, consider changing the seed",
)
continue
return transaction
def random_buy_product_transaction(sql_session: Session, last_time: datetime) -> Transaction:
i = 0
while True:
i += 1
user = random.choice(sql_session.query(User).all())
product = random.choice(sql_session.query(Product).all())
product_count = random.randint(1, 5)
new_datetime = last_time + timedelta(minutes=random.randint(1, 60))
try:
transaction = Transaction.buy_product(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=new_datetime,
)
except Exception:
if i > TRANSACTION_GENERATOR_EXCEPTION_LIMIT:
raise RuntimeError(
"Too many failed attempts to create a valid transaction, consider changing the seed",
)
continue
return transaction
def random_joint_transaction(sql_session: Session, last_time: datetime) -> list[Transaction]:
i = 0
while True:
i += 1
user_count = random.randint(2, 4)
users = random.sample(sql_session.query(User).all(), k=user_count)
product = random.choice(sql_session.query(Product).all())
product_count = random.randint(1, 5)
new_datetime = last_time + timedelta(minutes=random.randint(1, 60))
try:
transactions = joint_buy_product(
sql_session,
product=product,
product_count=product_count,
instigator=users[0],
users=users,
time=new_datetime,
)
except Exception:
if i > TRANSACTION_GENERATOR_EXCEPTION_LIMIT:
raise RuntimeError(
"Too many failed attempts to create a valid transaction, consider changing the seed",
)
continue
return transactions
def random_transfer_transaction(sql_session: Session, last_time: datetime) -> Transaction:
i = 0
while True:
i += 1
sender, receiver = random.sample(sql_session.query(User).all(), k=2)
amount = random.randint(1, 50)
new_datetime = last_time + timedelta(minutes=random.randint(1, 60))
try:
transaction = Transaction.transfer(
amount,
sender.id,
receiver.id,
time=new_datetime,
)
except Exception:
if i > TRANSACTION_GENERATOR_EXCEPTION_LIMIT:
raise RuntimeError(
"Too many failed attempts to create a valid transaction, consider changing the seed",
)
continue
return transaction
def random_throw_product_transaction(sql_session: Session, last_time: datetime) -> Transaction:
i = 0
while True:
i += 1
user = random.choice(sql_session.query(User).all())
product = random.choice(sql_session.query(Product).all())
product_count = random.randint(1, 5)
new_datetime = last_time + timedelta(minutes=random.randint(1, 60))
try:
transaction = Transaction.throw_product(
user_id=user.id,
product_id=product.id,
product_count=product_count,
time=new_datetime,
)
except Exception:
if i > TRANSACTION_GENERATOR_EXCEPTION_LIMIT:
raise RuntimeError(
"Too many failed attempts to create a valid transaction, consider changing the seed",
)
continue
return transaction
RANDOM_GENERATORS = {
TransactionType.ADD_PRODUCT: random_add_product_transaction,
TransactionType.ADJUST_BALANCE: random_adjust_balance_transaction,
TransactionType.ADJUST_INTEREST: random_adjust_interest_transaction,
TransactionType.ADJUST_PENALTY: random_adjust_penalty_transaction,
TransactionType.ADJUST_STOCK: random_adjust_stock_transaction,
TransactionType.BUY_PRODUCT: random_buy_product_transaction,
TransactionType.JOINT: random_joint_transaction,
TransactionType.TRANSFER: random_transfer_transaction,
TransactionType.THROW_PRODUCT: random_throw_product_transaction,
}
@@ -1,54 +0,0 @@
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, TransactionType
from dibbler.queries import product_owners
from tests.benchmark.benchmark_settings import BENCHMARK_ITERATIONS, BENCHMARK_ROUNDS
from tests.benchmark.helpers import generate_random_transactions, insert_users_and_products
@pytest.mark.benchmark(group="product_owners")
@pytest.mark.parametrize(
"transaction_count",
[
100,
500,
1000,
2000,
5000,
10000,
],
)
def test_benchmark_product_owners(
benchmark,
sql_session: Session,
transaction_count: int,
) -> None:
_users, products = insert_users_and_products(sql_session)
transactions = generate_random_transactions(
sql_session,
transaction_count,
transaction_type_filter=[
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_STOCK,
TransactionType.BUY_PRODUCT,
TransactionType.JOINT,
TransactionType.THROW_PRODUCT,
],
)
sql_session.add_all(transactions)
sql_session.commit()
benchmark.pedantic(
query_all_product_owners,
args=(sql_session, products),
iterations=BENCHMARK_ITERATIONS,
rounds=BENCHMARK_ROUNDS,
)
def query_all_product_owners(sql_session: Session, products: list[Product]) -> None:
for product in products:
product_owners(sql_session, product, use_cache=False)
@@ -1,92 +0,0 @@
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, TransactionType
from dibbler.queries import product_price, update_cache
from tests.benchmark.benchmark_settings import BENCHMARK_ITERATIONS, BENCHMARK_ROUNDS
from tests.benchmark.helpers import generate_random_transactions, insert_users_and_products
# @pytest.mark.benchmark(group="product_price")
# @pytest.mark.parametrize(
# "transaction_count",
# [
# 100,
# 500,
# 1000,
# 2000,
# 5000,
# 10000,
# ],
# )
# def test_benchmark_product_price(benchmark, sql_session: Session, transaction_count):
# _users, products = insert_users_and_products(sql_session)
# transactions = generate_random_transactions(
# sql_session,
# transaction_count,
# transaction_type_filter=[
# TransactionType.ADD_PRODUCT,
# TransactionType.ADJUST_STOCK,
# TransactionType.BUY_PRODUCT,
# TransactionType.JOINT,
# TransactionType.THROW_PRODUCT,
# ],
# )
# sql_session.add_all(transactions)
# sql_session.commit()
# benchmark.pedantic(
# query_all_products_price,
# args=(sql_session, products),
# iterations=BENCHMARK_ITERATIONS,
# rounds=BENCHMARK_ROUNDS,
# )
@pytest.mark.benchmark(group="product_price")
@pytest.mark.parametrize(
"transaction_count",
[
1000,
2000,
5000,
10000,
],
)
def test_benchmark_product_price_cache_every_500(
benchmark,
sql_session: Session,
transaction_count: int,
) -> None:
users, _products = insert_users_and_products(sql_session)
transactions = generate_random_transactions(
sql_session,
transaction_count,
transaction_type_filter=[
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_STOCK,
TransactionType.BUY_PRODUCT,
TransactionType.JOINT,
TransactionType.THROW_PRODUCT,
],
)
for i in range(0, len(transactions), 500):
update_cache(sql_session)
sql_session.add_all(transactions[i : i + 500])
sql_session.commit()
benchmark.pedantic(
query_all_products_price,
args=(sql_session, users),
iterations=BENCHMARK_ITERATIONS,
rounds=BENCHMARK_ROUNDS,
)
def query_all_products_price(sql_session: Session, products: list[Product]) -> None:
for product in products:
product_price(sql_session, product, use_cache=False)
@@ -1,54 +0,0 @@
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, TransactionType
from dibbler.queries import product_stock
from tests.benchmark.benchmark_settings import BENCHMARK_ITERATIONS, BENCHMARK_ROUNDS
from tests.benchmark.helpers import generate_random_transactions, insert_users_and_products
@pytest.mark.benchmark(group="product_stock")
@pytest.mark.parametrize(
"transaction_count",
[
100,
500,
1000,
2000,
5000,
10000,
],
)
def test_benchmark_product_stock(
benchmark,
sql_session: Session,
transaction_count: int,
) -> None:
_users, products = insert_users_and_products(sql_session)
transactions = generate_random_transactions(
sql_session,
transaction_count,
transaction_type_filter=[
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_STOCK,
TransactionType.BUY_PRODUCT,
TransactionType.JOINT,
TransactionType.THROW_PRODUCT,
],
)
sql_session.add_all(transactions)
sql_session.commit()
benchmark.pedantic(
query_all_products_stock,
args=(sql_session, products),
iterations=BENCHMARK_ITERATIONS,
rounds=BENCHMARK_ROUNDS,
)
def query_all_products_stock(sql_session: Session, products: list[Product]) -> None:
for product in products:
product_stock(sql_session, product, use_cache=False)
@@ -1,54 +0,0 @@
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, User
from dibbler.queries import transaction_log
from tests.benchmark.benchmark_settings import BENCHMARK_ITERATIONS, BENCHMARK_ROUNDS
from tests.benchmark.helpers import generate_random_transactions, insert_users_and_products
@pytest.mark.benchmark(group="transaction_log")
@pytest.mark.parametrize(
"transaction_count",
[
100,
500,
1000,
2000,
5000,
10000,
],
)
def test_benchmark_transaction_log(
benchmark,
sql_session: Session,
transaction_count: int,
) -> None:
users, products = insert_users_and_products(sql_session)
transactions = generate_random_transactions(
sql_session,
transaction_count,
)
sql_session.add_all(transactions)
sql_session.commit()
benchmark.pedantic(
query_transaction_log,
args=(
sql_session,
products,
users,
),
iterations=BENCHMARK_ITERATIONS,
rounds=BENCHMARK_ROUNDS,
)
def query_transaction_log(sql_session: Session, products: list[Product], users: list[User]) -> None:
for user in users:
transaction_log(sql_session, user=user)
for product in products:
transaction_log(sql_session, product=product)
@@ -1,81 +0,0 @@
import pytest
from sqlalchemy.orm import Session
from dibbler.models import User
from dibbler.queries import update_cache, user_balance
from tests.benchmark.benchmark_settings import BENCHMARK_ITERATIONS, BENCHMARK_ROUNDS
from tests.benchmark.helpers import generate_random_transactions, insert_users_and_products
@pytest.mark.benchmark(group="user_balance")
@pytest.mark.parametrize(
"transaction_count",
[
100,
500,
1000,
1500,
2000,
],
)
def test_benchmark_user_balance(
benchmark,
sql_session: Session,
transaction_count: int,
) -> None:
users, _products = insert_users_and_products(sql_session)
transactions = generate_random_transactions(
sql_session,
transaction_count,
)
sql_session.add_all(transactions)
sql_session.commit()
benchmark.pedantic(
query_all_users_balance,
args=(sql_session, users),
iterations=BENCHMARK_ITERATIONS,
rounds=BENCHMARK_ROUNDS,
)
@pytest.mark.benchmark(group="user_balance")
@pytest.mark.parametrize(
"transaction_count",
[
1000,
1500,
2000,
],
)
def test_benchmark_user_balance_cache_every_500(
benchmark,
sql_session: Session,
transaction_count: int,
) -> None:
users, _products = insert_users_and_products(sql_session)
transactions = generate_random_transactions(
sql_session,
transaction_count,
)
for i in range(0, len(transactions), 500):
update_cache(sql_session)
sql_session.add_all(transactions[i : i + 500])
sql_session.commit()
benchmark.pedantic(
query_all_users_balance,
args=(sql_session, users),
iterations=BENCHMARK_ITERATIONS,
rounds=BENCHMARK_ROUNDS,
)
def query_all_users_balance(sql_session: Session, users: list[User]) -> None:
for user in users:
user_balance(sql_session, user, use_cache=False)
+1 -3
View File
@@ -3,8 +3,7 @@ import logging
import pytest import pytest
import sqlparse import sqlparse
from sqlalchemy import create_engine, event from sqlalchemy import create_engine, event
from sqlalchemy.exc import OperationalError
# from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Base from dibbler.models import Base
@@ -67,7 +66,6 @@ def sql_session(request):
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
with Session(engine) as sql_session: with Session(engine) as sql_session:
yield sql_session yield sql_session
sql_session.close()
# FIXME: Declaring this hook seems to have a side effect where the database does not # FIXME: Declaring this hook seems to have a side effect where the database does not
-28
View File
@@ -1,28 +0,0 @@
from datetime import datetime, timedelta
from dibbler.models import Transaction
def assign_times(
transactions: list[Transaction],
start_time: datetime = datetime(2024, 1, 1, 0, 0, 0),
delta: timedelta = timedelta(minutes=1),
) -> None:
"""Assigns datetimes to a list of transactions starting from start_time and incrementing by delta."""
current_time = start_time
for transaction in transactions:
transaction.time = current_time
current_time += delta
def assert_id_order_similar_to_time_order(transactions: list[Transaction]) -> None:
"""Asserts that the order of transaction IDs is similar to the order of their timestamps."""
sorted_by_time = sorted(transactions, key=lambda t: t.time)
sorted_by_id = sorted(transactions, key=lambda t: t.id)
for t1, t2 in zip(sorted_by_time, sorted_by_id, strict=False):
assert t1.id == t2.id or t1.time == t2.time, (
f"Transaction ID order does not match time order:\n"
f"ID {t1.id} at time {t1.time}\n"
f"ID {t2.id} at time {t2.time}"
)
+2 -2
View File
@@ -12,7 +12,7 @@ def insert_test_data(sql_session: Session) -> Product:
return product return product
def test_product_no_duplicate_barcodes(sql_session: Session) -> None: def test_product_no_duplicate_barcodes(sql_session: Session):
product = insert_test_data(sql_session) product = insert_test_data(sql_session)
duplicate_product = Product(product.bar_code, "Hehe >:)") duplicate_product = Product(product.bar_code, "Hehe >:)")
@@ -22,7 +22,7 @@ def test_product_no_duplicate_barcodes(sql_session: Session) -> None:
sql_session.commit() sql_session.commit()
def test_product_no_duplicate_names(sql_session: Session) -> None: def test_product_no_duplicate_names(sql_session: Session):
product = insert_test_data(sql_session) product = insert_test_data(sql_session)
duplicate_product = Product("1918238911928", product.name) duplicate_product = Product("1918238911928", product.name)
+18
View File
@@ -123,6 +123,24 @@ def test_transaction_buy_product_more_than_stock(sql_session: Session) -> None:
assert product_stock(sql_session, product) == 1 - 10 assert product_stock(sql_session, product) == 1 - 10
def test_transaction_buy_product_dont_allow_no_add_product_transactions(
sql_session: Session,
) -> None:
user, product = insert_test_data(sql_session)
transaction = Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
product_count=1,
user_id=user.id,
product_id=product.id,
)
sql_session.add(transaction)
with pytest.raises(ValueError):
sql_session.commit()
def test_transaction_add_product_deny_amount_over_per_product_times_product_count( def test_transaction_add_product_deny_amount_over_per_product_times_product_count(
sql_session: Session, sql_session: Session,
) -> None: ) -> None:
+4 -2
View File
@@ -1,8 +1,10 @@
from datetime import datetime
import pytest import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import User from dibbler.models import Product, Transaction, User
def insert_test_data(sql_session: Session) -> User: def insert_test_data(sql_session: Session) -> User:
@@ -13,7 +15,7 @@ def insert_test_data(sql_session: Session) -> User:
return user return user
def test_ensure_no_duplicate_user_names(sql_session: Session) -> None: def test_ensure_no_duplicate_user_names(sql_session: Session):
user = insert_test_data(sql_session) user = insert_test_data(sql_session)
user2 = User(user.name) user2 = User(user.name)
View File
+1 -14
View File
@@ -1,4 +1,4 @@
from datetime import datetime, timedelta from datetime import datetime
import pytest import pytest
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -15,18 +15,6 @@ def insert_test_data(sql_session: Session) -> User:
return user return user
def test_adjust_interest_uninitialized_user(sql_session: Session) -> None:
user = User("Uninitialized User")
with pytest.raises(ValueError, match="User must be persisted in the database."):
adjust_interest(
sql_session,
user=user,
new_interest=4,
message="Attempting to adjust interest for uninitialized user",
)
def test_adjust_interest_no_history(sql_session: Session) -> None: def test_adjust_interest_no_history(sql_session: Session) -> None:
user = insert_test_data(sql_session) user = insert_test_data(sql_session)
@@ -65,7 +53,6 @@ def test_adjust_interest_existing_history(sql_session: Session) -> None:
user=user, user=user,
new_interest=2, new_interest=2,
message="Adjusting interest rate", message="Adjusting interest rate",
time=transactions[-1].time + timedelta(days=1),
) )
sql_session.commit() sql_session.commit()
+3 -29
View File
@@ -1,11 +1,11 @@
from datetime import datetime, timedelta from datetime import datetime
import pytest import pytest
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Transaction, User from dibbler.models import Transaction, User
from dibbler.models.Transaction import ( from dibbler.models.Transaction import (
DEFAULT_PENALTY_MULTIPLIER_PERCENT, DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE,
DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_THRESHOLD,
) )
from dibbler.queries import adjust_penalty, current_penalty from dibbler.queries import adjust_penalty, current_penalty
@@ -19,30 +19,6 @@ def insert_test_data(sql_session: Session) -> User:
return user return user
def test_adjust_penalty_empty_not_allowed(sql_session: Session) -> None:
user = insert_test_data(sql_session)
with pytest.raises(ValueError):
adjust_penalty(
sql_session,
user=user,
message="No penalty or multiplier provided",
)
def test_adjust_penalty_uninitialized_user(sql_session: Session) -> None:
user = User("Uninitialized User")
with pytest.raises(ValueError):
adjust_penalty(
sql_session,
user=user,
new_penalty=-100,
new_penalty_multiplier=110,
message="Attempting to adjust penalty for uninitialized user",
)
def test_adjust_penalty_no_history(sql_session: Session) -> None: def test_adjust_penalty_no_history(sql_session: Session) -> None:
user = insert_test_data(sql_session) user = insert_test_data(sql_session)
@@ -57,7 +33,7 @@ def test_adjust_penalty_no_history(sql_session: Session) -> None:
(penalty, multiplier) = current_penalty(sql_session) (penalty, multiplier) = current_penalty(sql_session)
assert penalty == -200 assert penalty == -200
assert multiplier == DEFAULT_PENALTY_MULTIPLIER_PERCENT assert multiplier == DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE
def test_adjust_penalty_multiplier_no_history(sql_session: Session) -> None: def test_adjust_penalty_multiplier_no_history(sql_session: Session) -> None:
@@ -124,7 +100,6 @@ def test_adjust_penalty_existing_history(sql_session: Session) -> None:
user=user, user=user,
new_penalty=-250, new_penalty=-250,
message="Adjusting penalty threshold", message="Adjusting penalty threshold",
time=transactions[-1].time + timedelta(days=1),
) )
sql_session.commit() sql_session.commit()
@@ -155,7 +130,6 @@ def test_adjust_penalty_multiplier_existing_history(sql_session: Session) -> Non
user=user, user=user,
new_penalty_multiplier=130, new_penalty_multiplier=130,
message="Adjusting penalty multiplier", message="Adjusting penalty multiplier",
time=transactions[-1].time + timedelta(days=1),
) )
sql_session.commit() sql_session.commit()
(_, multiplier) = current_penalty(sql_session) (_, multiplier) = current_penalty(sql_session)
-74
View File
@@ -1,74 +0,0 @@
from datetime import datetime, timedelta
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries import affected_products
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
def insert_test_data(sql_session: Session) -> tuple[User, list[Product]]:
user = User("Test User")
products = []
for i in range(10):
product = Product(f"12345678901{i:02d}", f"Test Product {i}")
products.append(product)
sql_session.add(user)
sql_session.add_all(products)
sql_session.commit()
return user, products
def test_affected_products_no_history(sql_session: Session) -> None:
insert_test_data(sql_session)
result = affected_products(sql_session)
assert result == set()
def test_affected_products_basic_history(sql_session: Session) -> None:
user, products = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=10,
per_product=10,
user_id=user.id,
product_id=products[i].id,
product_count=1,
)
for i in range(5)
] + [
Transaction.buy_product(
user_id=user.id,
product_id=products[i].id,
product_count=1,
)
for i in range(3)
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
result = affected_products(sql_session)
expected_products = {products[i] for i in range(5)}
assert result == expected_products
# def test_affected_products_after(sql_session: Session) -> None:
# def test_affected_products_until(sql_session: Session) -> None:
# def test_affected_products_after_until(sql_session: Session) -> None:
# def test_affected_products_after_inclusive(sql_session: Session) -> None:
# def test_affected_products_until_inclusive(sql_session: Session) -> None:
# def test_affected_products_after_until_inclusive(sql_session: Session) -> None:
-74
View File
@@ -1,74 +0,0 @@
from datetime import datetime, timedelta
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries import affected_users
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
def insert_test_data(sql_session: Session) -> tuple[list[User], Product]:
users = []
for i in range(10):
user = User(f"Test User {i + 1}")
users.append(user)
product = Product("1234567890123", "Test Product")
sql_session.add_all(users)
sql_session.add(product)
sql_session.commit()
return users, product
def test_affected_users_no_history(sql_session: Session) -> None:
insert_test_data(sql_session)
result = affected_users(sql_session)
assert result == set()
def test_affected_users_basic_history(sql_session: Session) -> None:
users, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=10,
per_product=10,
user_id=users[i].id,
product_id=product.id,
product_count=1,
)
for i in range(5)
] + [
Transaction.buy_product(
user_id=users[i].id,
product_id=product.id,
product_count=1,
)
for i in range(3)
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
result = affected_users(sql_session)
expected_users = {users[i] for i in range(5)}
assert result == expected_users
# def test_affected_users_after(sql_session: Session) -> None:
# def test_affected_users_until(sql_session: Session) -> None:
# def test_affected_users_after_until(sql_session: Session) -> None:
# def test_affected_users_after_inclusive(sql_session: Session) -> None:
# def test_affected_users_until_inclusive(sql_session: Session) -> None:
# def test_affected_users_after_until_inclusive(sql_session: Session) -> None:
+6 -7
View File
@@ -1,13 +1,14 @@
from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
from dibbler.models import Transaction, User from dibbler.models import Transaction, User
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT
from dibbler.queries import current_interest from dibbler.queries import current_interest
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
def test_current_interest_no_history(sql_session: Session) -> None: def test_current_interest_no_history(sql_session: Session) -> None:
assert current_interest(sql_session) == DEFAULT_INTEREST_RATE_PERCENT assert current_interest(sql_session) == DEFAULT_INTEREST_RATE_PERCENTAGE
def test_current_interest_with_history(sql_session: Session) -> None: def test_current_interest_with_history(sql_session: Session) -> None:
@@ -17,20 +18,18 @@ def test_current_interest_with_history(sql_session: Session) -> None:
transactions = [ transactions = [
Transaction.adjust_interest( Transaction.adjust_interest(
time=datetime(2023, 10, 1, 10, 0, 0),
interest_rate_percent=5, interest_rate_percent=5,
user_id=user.id, user_id=user.id,
), ),
Transaction.adjust_interest( Transaction.adjust_interest(
time=datetime(2023, 11, 1, 10, 0, 0),
interest_rate_percent=7, interest_rate_percent=7,
user_id=user.id, user_id=user.id,
), ),
] ]
assign_times(transactions)
sql_session.add_all(transactions) sql_session.add_all(transactions)
sql_session.commit() sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
assert current_interest(sql_session) == 7 assert current_interest(sql_session) == 7
+6 -8
View File
@@ -1,18 +1,19 @@
from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Transaction, User from dibbler.models import Transaction, User
from dibbler.models.Transaction import ( from dibbler.models.Transaction import (
DEFAULT_PENALTY_MULTIPLIER_PERCENT, DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE,
DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_THRESHOLD,
) )
from dibbler.queries import current_penalty from dibbler.queries import current_penalty
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
def test_current_penalty_no_history(sql_session: Session) -> None: def test_current_penalty_no_history(sql_session: Session) -> None:
assert current_penalty(sql_session) == ( assert current_penalty(sql_session) == (
DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_THRESHOLD,
DEFAULT_PENALTY_MULTIPLIER_PERCENT, DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE,
) )
@@ -23,22 +24,19 @@ def test_current_penalty_with_history(sql_session: Session) -> None:
transactions = [ transactions = [
Transaction.adjust_penalty( Transaction.adjust_penalty(
time=datetime(2023, 10, 1, 10, 0, 0),
penalty_threshold=-200, penalty_threshold=-200,
penalty_multiplier_percent=150, penalty_multiplier_percent=150,
user_id=user.id, user_id=user.id,
), ),
Transaction.adjust_penalty( Transaction.adjust_penalty(
time=datetime(2023, 10, 2, 10, 0, 0),
penalty_threshold=-300, penalty_threshold=-300,
penalty_multiplier_percent=200, penalty_multiplier_percent=200,
user_id=user.id, user_id=user.id,
), ),
] ]
assign_times(transactions)
sql_session.add_all(transactions) sql_session.add_all(transactions)
sql_session.commit() sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
assert current_penalty(sql_session) == (-300, 200) assert current_penalty(sql_session) == (-300, 200)
+12 -37
View File
@@ -1,4 +1,4 @@
from datetime import datetime, timedelta from datetime import datetime
import pytest import pytest
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -24,7 +24,7 @@ def insert_test_data(sql_session: Session) -> tuple[User, User, User, Product]:
per_product=10, per_product=10,
product_count=3, product_count=3,
time=datetime(2024, 1, 1, 10, 0, 0), time=datetime(2024, 1, 1, 10, 0, 0),
), )
] ]
sql_session.add_all(transactions) sql_session.add_all(transactions)
@@ -33,12 +33,12 @@ def insert_test_data(sql_session: Session) -> tuple[User, User, User, Product]:
return user1, user2, user3, product return user1, user2, user3, product
def test_joint_buy_product_uninitialized_product(sql_session: Session) -> None: def test_joint_buy_product_missing_product(sql_session: Session) -> None:
user = User("Test User 1") user = User("Test User 1")
sql_session.add(user) sql_session.add(user)
sql_session.commit() sql_session.commit()
product = Product("1234567890123", "Uninitialized Product") product = Product("1234567890123", "Test Product")
with pytest.raises(ValueError): with pytest.raises(ValueError):
joint_buy_product( joint_buy_product(
@@ -50,42 +50,18 @@ def test_joint_buy_product_uninitialized_product(sql_session: Session) -> None:
) )
def test_joint_buy_product_no_users(sql_session: Session) -> None: def test_joint_buy_product_missing_user(sql_session: Session) -> None:
user, _, _, product = insert_test_data(sql_session) user = User("Test User 1")
product = Product("1234567890123", "Test Product")
sql_session.add(product)
sql_session.commit()
with pytest.raises(ValueError): with pytest.raises(ValueError):
joint_buy_product( joint_buy_product(
sql_session, sql_session,
instigator=user, instigator=user,
users=[], users=[user],
product=product,
product_count=1,
)
def test_joint_buy_product_uninitialized_instigator(sql_session: Session) -> None:
user, user2, _, product = insert_test_data(sql_session)
uninitialized_user = User("Uninitialized User")
with pytest.raises(ValueError):
joint_buy_product(
sql_session,
instigator=uninitialized_user,
users=[user, user2],
product=product,
product_count=1,
)
def test_joint_buy_product_uninitialized_user_in_list(sql_session: Session) -> None:
user, _, _, product = insert_test_data(sql_session)
uninitialized_user = User("Uninitialized User")
with pytest.raises(ValueError):
joint_buy_product(
sql_session,
instigator=user,
users=[user, uninitialized_user],
product=product, product=product,
product_count=1, product_count=1,
) )
@@ -158,7 +134,7 @@ def test_joint_buy_product_out_of_stock(sql_session: Session) -> None:
product_id=product.id, product_id=product.id,
product_count=3, product_count=3,
time=datetime(2024, 1, 2, 10, 0, 0), time=datetime(2024, 1, 2, 10, 0, 0),
), )
] ]
sql_session.add_all(transactions) sql_session.add_all(transactions)
@@ -170,7 +146,6 @@ def test_joint_buy_product_out_of_stock(sql_session: Session) -> None:
users=[user, user2, user3], users=[user, user2, user3],
product=product, product=product,
product_count=10, product_count=10,
time=transactions[-1].time + timedelta(days=1),
) )

Some files were not shown because too many files have changed in this diff Show More