Compare commits
52 Commits
event-sour
...
event-sour
| Author | SHA1 | Date | |
|---|---|---|---|
|
a265fb921c
|
|||
|
aed85b4a06
|
|||
|
3fc7d78c1c
|
|||
|
38e9066300
|
|||
|
a9070fc680
|
|||
|
2ac7d26bcd
|
|||
|
c85a11eb89
|
|||
|
57f7d25cdf
|
|||
|
2a05bd7a58
|
|||
|
00afede3d9
|
|||
|
19ee9bebc2
|
|||
|
acb31992f8
|
|||
|
fb0f24cb67
|
|||
|
3d555ca9d1
|
|||
|
af5710d663
|
|||
|
4d88409e97
|
|||
|
72cd066414
|
|||
|
b1bb1e556b
|
|||
|
70b04c0c45
|
|||
|
7bea5b0b96
|
|||
|
3123b8b474
|
|||
|
9091adedad
|
|||
|
94955cb706
|
|||
|
3b6cd1d354
|
|||
|
c2ee66c394
|
|||
|
b5b2706085
|
|||
|
bf9cea7dfc
|
|||
|
cf945143ba
|
|||
|
e84b43e2a0
|
|||
|
17fc23ba97
|
|||
|
45179a9c43
|
|||
|
dfaa818f46
|
|||
|
ec43f67e58
|
|||
|
1b09a904cb
|
|||
|
8e84669d9b
|
|||
|
1d01e1b2cb
|
|||
|
019f419b12
|
|||
|
3bab62b3ac
|
|||
|
e771fb0240
|
|||
|
2331e53795
|
|||
|
2ae651a1fa
|
|||
|
76f07841be
|
|||
|
ecaec99212
|
|||
|
cb385097dc
|
|||
|
b86962ef0e
|
|||
|
9c0bd54be6
|
|||
|
919d7a5afe
|
|||
|
ddca959ad6
|
|||
|
1733843b77
|
|||
|
4ed68ff05c
|
|||
|
78161a96be
|
|||
|
f4b5e1d6d4
|
71
.gitea/workflows/benchmark.yaml
Normal file
71
.gitea/workflows/benchmark.yaml
Normal file
@@ -0,0 +1,71 @@
|
||||
name: Run benchmarks
|
||||
on:
|
||||
workflow_dispatch:
|
||||
# TODO: make this only workflow_dispatch when merged into main
|
||||
push:
|
||||
|
||||
jobs:
|
||||
run-tests:
|
||||
runs-on: debian-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --locked --group test
|
||||
|
||||
- name: Run benchmarks
|
||||
continue-on-error: true
|
||||
run: |
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
PYTEST_ARGS=(
|
||||
-vv
|
||||
|
||||
--benchmark-only
|
||||
-k test_benchmark
|
||||
)
|
||||
|
||||
uv run -- pytest "${PYTEST_ARGS[@]}"
|
||||
|
||||
- name: Upload benchmark JSON report
|
||||
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v2
|
||||
with:
|
||||
source: ./benchmark/*/*.json
|
||||
quote-source: false
|
||||
target: ${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/benchmark.json
|
||||
username: gitea-web
|
||||
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
|
||||
host: pages.pvv.ntnu.no
|
||||
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
|
||||
|
||||
- name: Upload histograms
|
||||
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v2
|
||||
with:
|
||||
source: ./benchmark/*.svg
|
||||
quote-source: false
|
||||
target: ${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/
|
||||
username: gitea-web
|
||||
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
|
||||
host: pages.pvv.ntnu.no
|
||||
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
|
||||
|
||||
# NOTE: $GITHUB_STEP_SUMMARY when...
|
||||
- name: Run information
|
||||
run: |
|
||||
echo "Benchmark run ID: ${{ github.run_id }}"
|
||||
echo "Benchmark JSON: https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/benchmark.json"
|
||||
echo "Histograms: https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-product_owners.svg"
|
||||
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-product_price.svg"
|
||||
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-product_stock.svg"
|
||||
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-transaction_log.svg"
|
||||
echo " https://pages.pvv.ntnu.no/${{ gitea.repository }}/${{ gitea.ref_name }}/benchmark/${{ github.run_id }}/histogram-user_balance.svg"
|
||||
|
||||
- name: Check failure
|
||||
if: failure()
|
||||
run: |
|
||||
echo "Tests failed"
|
||||
exit 1
|
||||
@@ -16,66 +16,57 @@ jobs:
|
||||
run-tests:
|
||||
runs-on: debian-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --locked --group test
|
||||
- name: Install dependencies
|
||||
run: uv sync --locked --group test
|
||||
|
||||
- name: Run tests
|
||||
continue-on-error: true
|
||||
run: |
|
||||
set -euo pipefail
|
||||
set -x
|
||||
- name: Run tests
|
||||
continue-on-error: true
|
||||
run: |
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
PYTEST_ARGS=(
|
||||
-vv
|
||||
|
||||
--cov=dibbler.lib
|
||||
--cov=dibbler.models
|
||||
--cov=dibbler.queries
|
||||
--cov-report=html
|
||||
--cov-branch
|
||||
|
||||
--self-contained-html
|
||||
--html=./test-report/index.html
|
||||
)
|
||||
|
||||
if [ "$DEBUG_SQL" == "true" ]; then
|
||||
PYTEST_ARGS+=(
|
||||
--debug-sql
|
||||
PYTEST_ARGS=(
|
||||
-vv
|
||||
)
|
||||
fi
|
||||
|
||||
uv run -- pytest "${PYTEST_ARGS[@]}"
|
||||
if [ "$DEBUG_SQL" == "true" ]; then
|
||||
PYTEST_ARGS+=(
|
||||
--debug-sql
|
||||
)
|
||||
fi
|
||||
|
||||
- name: Generate badge
|
||||
run: uv run -- coverage-badge -o htmlcov/badge.svg
|
||||
uv run -- pytest "${PYTEST_ARGS[@]}"
|
||||
|
||||
- name: Upload test report
|
||||
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1
|
||||
with:
|
||||
source: ./test-report/
|
||||
target: ${{ gitea.ref_name }}/test-report/
|
||||
username: gitea-web
|
||||
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
|
||||
host: pages.pvv.ntnu.no
|
||||
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
|
||||
- name: Generate badge
|
||||
run: uv run -- coverage-badge -o htmlcov/badge.svg
|
||||
|
||||
- name: Upload coverage report
|
||||
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1
|
||||
with:
|
||||
source: ./htmlcov/
|
||||
target: ${{ gitea.ref_name }}/coverage/
|
||||
username: gitea-web
|
||||
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
|
||||
host: pages.pvv.ntnu.no
|
||||
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
|
||||
- name: Upload test report
|
||||
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1
|
||||
with:
|
||||
source: ./test-report/
|
||||
target: ${{ gitea.ref_name }}/test-report/
|
||||
username: gitea-web
|
||||
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
|
||||
host: pages.pvv.ntnu.no
|
||||
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
|
||||
|
||||
- name: Check failure
|
||||
if: failure()
|
||||
run: |
|
||||
echo "Tests failed"
|
||||
exit 1
|
||||
- name: Upload coverage report
|
||||
uses: https://git.pvv.ntnu.no/Projects/rsync-action@v1
|
||||
with:
|
||||
source: ./htmlcov/
|
||||
target: ${{ gitea.ref_name }}/coverage/
|
||||
username: gitea-web
|
||||
ssh-key: ${{ secrets.WEB_SYNC_SSH_KEY }}
|
||||
host: pages.pvv.ntnu.no
|
||||
known-hosts: "pages.pvv.ntnu.no ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH2QjfFB+city1SYqltkVqWACfo1j37k+oQQfj13mtgg"
|
||||
|
||||
- name: Check failure
|
||||
if: failure()
|
||||
run: |
|
||||
echo "Tests failed"
|
||||
exit 1
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -8,6 +8,12 @@ test.db
|
||||
|
||||
.ruff_cache
|
||||
|
||||
*.qcow2
|
||||
|
||||
dibbler/_version.py
|
||||
|
||||
.coverage
|
||||
.coverage.*
|
||||
htmlcov
|
||||
test-report
|
||||
/benchmark
|
||||
|
||||
38
README.md
38
README.md
@@ -23,8 +23,9 @@ Installer python, og lag og aktiver et venv. Installer så avhengighetene med `p
|
||||
Deretter kan du kjøre programmet med
|
||||
|
||||
```console
|
||||
python -m dibbler -c example-config.ini create-db
|
||||
python -m dibbler -c example-config.ini loop
|
||||
python -m dibbler -c example-config.toml create-db
|
||||
python -m dibbler -c example-config.toml seed-data
|
||||
python -m dibbler -c example-config.toml loop
|
||||
```
|
||||
|
||||
## Prosjektstruktur
|
||||
@@ -61,25 +62,30 @@ Her ligger enhetstester for prosjektet. Testene bruker `pytest` som testløper.
|
||||
|
||||
## Nix
|
||||
|
||||
### Bygge nytt image
|
||||
> [!NOTE]
|
||||
> Vi har skrevet nix-kode for å generere en QEMU-VM med tilnærmet produksjonsoppsett.
|
||||
> Det kjører ikke nødvendigvis noen VM-er i produksjon, og ihvertfall ikke denne VM-en.
|
||||
> Den er hovedsakelig laget for enkel interaktiv testing, og for å teste NixOS modulen.
|
||||
|
||||
For å bygge et image trenger du en builder som takler å bygge for arkitekturen du skal lage et image for.
|
||||
Du kan enklest komme i gang med nix-utvikling ved å kjøre test VM-en:
|
||||
|
||||
(Eller be til gudene om at cross compile funker)
|
||||
```console
|
||||
nix run .#vm
|
||||
|
||||
Flaket exposer en modul som autologger inn med en bruker som automatisk kjører dibbler, og setter opp et minimalistisk miljø.
|
||||
# Eller hvis du trenger tilgang til terminalen i VM-en også:
|
||||
nix run .#vm-non-kiosk
|
||||
```
|
||||
|
||||
Før du bygger imaget burde du kopiere og endre `example-config.ini` lokalt til å inneholde instillingene dine. **NB: Denne kommer til å ligge i nix storen, ikke si noe her som du ikke vil at moren din skal høre.**
|
||||
Du kan også bygge pakken manuelt, eller kjøre den direkte:
|
||||
|
||||
Du kan også endre hvilken config-fil som blir brukt direkte i pakken eller i modulen.
|
||||
```console
|
||||
nix build .#dibbler
|
||||
|
||||
Se eksempelet for hvordan skrot er satt opp i `flake.nix` og `nix/skrott.nix`
|
||||
nix run .# -- --config example-config.toml create-db
|
||||
nix run .# -- --config example-config.toml seed-data
|
||||
nix run .# -- --config example-config.toml loop
|
||||
```
|
||||
|
||||
### Bygge image for skrot
|
||||
## Produksjonssetting
|
||||
|
||||
Skrot har et image definert i flake.nix:
|
||||
|
||||
1. endre `example-config.ini`
|
||||
2. `nix build .#images.skrot`
|
||||
3. ???
|
||||
4. non-profit
|
||||
Se https://wiki.pvv.ntnu.no/wiki/Drift/Dibbler
|
||||
|
||||
@@ -1,6 +1,56 @@
|
||||
# This module is supposed to act as a singleton and be filled
|
||||
# with config variables by cli.py
|
||||
import os
|
||||
import sys
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import configparser
|
||||
from dibbler.lib.helpers import file_is_submissive_and_readable
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
DEFAULT_CONFIG_PATH = Path("/etc/dibbler/dibbler.toml")
|
||||
|
||||
|
||||
config: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def load_config(config_path: Path | None = None) -> None:
|
||||
global config
|
||||
if config_path is not None:
|
||||
with Path(config_path).open("rb") as file:
|
||||
config = tomllib.load(file)
|
||||
elif file_is_submissive_and_readable(DEFAULT_CONFIG_PATH):
|
||||
with DEFAULT_CONFIG_PATH.open("rb") as file:
|
||||
config = tomllib.load(file)
|
||||
else:
|
||||
print(
|
||||
"Could not read config file, it was neither provided nor readable in default location",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def config_db_string() -> str:
|
||||
db_type = config["database"]["type"]
|
||||
|
||||
if db_type == "sqlite":
|
||||
path = Path(config["database"]["sqlite"]["path"])
|
||||
return f"sqlite:///{path.absolute()}"
|
||||
|
||||
if db_type == "postgresql":
|
||||
host = config["database"]["postgresql"]["host"]
|
||||
port = config["database"]["postgresql"].get("port", 5432)
|
||||
username = config["database"]["postgresql"].get("username", "dibbler")
|
||||
dbname = config["database"]["postgresql"].get("dbname", "dibbler")
|
||||
|
||||
if "password_file" in config["database"]["postgresql"]:
|
||||
with Path(config["database"]["postgresql"]["password_file"]).open("r") as f:
|
||||
password = f.read().strip()
|
||||
elif "password" in config["database"]["postgresql"]:
|
||||
password = config["database"]["postgresql"]["password"]
|
||||
else:
|
||||
password = ""
|
||||
|
||||
if host.startswith("/"):
|
||||
return f"postgresql+psycopg2://{username}:{password}@/{dbname}?host={host}"
|
||||
return f"postgresql+psycopg2://{username}:{password}@{host}:{port}/{dbname}"
|
||||
print(f"Error: unknown database type '{db_type}'")
|
||||
exit(1)
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from dibbler.conf import config
|
||||
|
||||
database_url: str | None = None
|
||||
|
||||
if (url := config.get("database", "url")) is not None:
|
||||
database_url = url
|
||||
elif (url_file := config.get("database", "url_file")) is not None:
|
||||
with Path(url_file).open() as file:
|
||||
database_url = file.read().strip()
|
||||
|
||||
assert database_url is not None, "Database URL must be specified in config"
|
||||
|
||||
engine = create_engine(database_url)
|
||||
Session = sessionmaker(bind=engine)
|
||||
@@ -1,71 +1,71 @@
|
||||
import os
|
||||
# import os
|
||||
|
||||
from PIL import ImageFont
|
||||
from barcode.writer import ImageWriter, mm2px
|
||||
from brother_ql.labels import ALL_LABELS
|
||||
# from PIL import ImageFont
|
||||
# from barcode.writer import ImageWriter, mm2px
|
||||
# from brother_ql.labels import ALL_LABELS
|
||||
|
||||
|
||||
def px2mm(px, dpi=300):
|
||||
return (25.4 * px) / dpi
|
||||
# def px2mm(px, dpi=300):
|
||||
# return (25.4 * px) / dpi
|
||||
|
||||
|
||||
class BrotherLabelWriter(ImageWriter):
|
||||
def __init__(self, typ="62", max_height=350, rot=False, text=None):
|
||||
super(BrotherLabelWriter, self).__init__()
|
||||
label = next([l for l in ALL_LABELS if l.identifier == typ])
|
||||
assert label is not None
|
||||
self.rot = rot
|
||||
if self.rot:
|
||||
self._h, self._w = label.dots_printable
|
||||
if self._w == 0 or self._w > max_height:
|
||||
self._w = min(max_height, self._h / 2)
|
||||
else:
|
||||
self._w, self._h = label.dots_printable
|
||||
if self._h == 0 or self._h > max_height:
|
||||
self._h = min(max_height, self._w / 2)
|
||||
self._xo = 0.0
|
||||
self._yo = 0.0
|
||||
self._title = text
|
||||
# class BrotherLabelWriter(ImageWriter):
|
||||
# def __init__(self, typ="62", max_height=350, rot=False, text=None):
|
||||
# super(BrotherLabelWriter, self).__init__()
|
||||
# label = next([l for l in ALL_LABELS if l.identifier == typ])
|
||||
# assert label is not None
|
||||
# self.rot = rot
|
||||
# if self.rot:
|
||||
# self._h, self._w = label.dots_printable
|
||||
# if self._w == 0 or self._w > max_height:
|
||||
# self._w = min(max_height, self._h / 2)
|
||||
# else:
|
||||
# self._w, self._h = label.dots_printable
|
||||
# if self._h == 0 or self._h > max_height:
|
||||
# self._h = min(max_height, self._w / 2)
|
||||
# self._xo = 0.0
|
||||
# self._yo = 0.0
|
||||
# self._title = text
|
||||
|
||||
def _init(self, code):
|
||||
self.text = None
|
||||
super(BrotherLabelWriter, self)._init(code)
|
||||
# def _init(self, code):
|
||||
# self.text = None
|
||||
# super(BrotherLabelWriter, self)._init(code)
|
||||
|
||||
def calculate_size(self, modules_per_line, number_of_lines, dpi=300):
|
||||
x, y = super(BrotherLabelWriter, self).calculate_size(
|
||||
modules_per_line, number_of_lines, dpi
|
||||
)
|
||||
# def calculate_size(self, modules_per_line, number_of_lines, dpi=300):
|
||||
# x, y = super(BrotherLabelWriter, self).calculate_size(
|
||||
# modules_per_line, number_of_lines, dpi
|
||||
# )
|
||||
|
||||
self._xo = (px2mm(self._w) - px2mm(x)) / 2
|
||||
self._yo = px2mm(self._h) - px2mm(y)
|
||||
assert self._xo >= 0
|
||||
assert self._yo >= 0
|
||||
# self._xo = (px2mm(self._w) - px2mm(x)) / 2
|
||||
# self._yo = px2mm(self._h) - px2mm(y)
|
||||
# assert self._xo >= 0
|
||||
# assert self._yo >= 0
|
||||
|
||||
return int(self._w), int(self._h)
|
||||
# return int(self._w), int(self._h)
|
||||
|
||||
def _paint_module(self, xpos, ypos, width, color):
|
||||
super(BrotherLabelWriter, self)._paint_module(
|
||||
xpos + self._xo, ypos + self._yo, width, color
|
||||
)
|
||||
# def _paint_module(self, xpos, ypos, width, color):
|
||||
# super(BrotherLabelWriter, self)._paint_module(
|
||||
# xpos + self._xo, ypos + self._yo, width, color
|
||||
# )
|
||||
|
||||
def _paint_text(self, xpos, ypos):
|
||||
super(BrotherLabelWriter, self)._paint_text(xpos + self._xo, ypos + self._yo)
|
||||
# def _paint_text(self, xpos, ypos):
|
||||
# super(BrotherLabelWriter, self)._paint_text(xpos + self._xo, ypos + self._yo)
|
||||
|
||||
def _finish(self):
|
||||
if self._title:
|
||||
width = self._w + 1
|
||||
height = 0
|
||||
max_h = self._h - mm2px(self._yo, self.dpi)
|
||||
fs = int(max_h / 1.2)
|
||||
font_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"Stranger back in the Night.ttf",
|
||||
)
|
||||
font = ImageFont.truetype(font_path, 10)
|
||||
while width > self._w or height > max_h:
|
||||
font = ImageFont.truetype(font_path, fs)
|
||||
width, height = font.getsize(self._title)
|
||||
fs -= 1
|
||||
pos = ((self._w - width) // 2, 0 - (height // 8))
|
||||
self._draw.text(pos, self._title, font=font, fill=self.foreground)
|
||||
return self._image
|
||||
# def _finish(self):
|
||||
# if self._title:
|
||||
# width = self._w + 1
|
||||
# height = 0
|
||||
# max_h = self._h - mm2px(self._yo, self.dpi)
|
||||
# fs = int(max_h / 1.2)
|
||||
# font_path = os.path.join(
|
||||
# os.path.dirname(os.path.realpath(__file__)),
|
||||
# "Stranger back in the Night.ttf",
|
||||
# )
|
||||
# font = ImageFont.truetype(font_path, 10)
|
||||
# while width > self._w or height > max_h:
|
||||
# font = ImageFont.truetype(font_path, fs)
|
||||
# width, height = font.getsize(self._title)
|
||||
# fs -= 1
|
||||
# pos = ((self._w - width) // 2, 0 - (height // 8))
|
||||
# self._draw.text(pos, self._title, font=font, fill=self.foreground)
|
||||
# return self._image
|
||||
|
||||
108
dibbler/lib/check_db_health.py
Normal file
108
dibbler/lib/check_db_health.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import Engine, create_engine, inspect, select
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError
|
||||
from sqlalchemy.orm import RelationshipProperty
|
||||
from sqlalchemy.orm.clsregistry import _ModuleMarker
|
||||
|
||||
from dibbler.lib.helpers import file_is_submissive_and_readable
|
||||
from dibbler.models import Base
|
||||
|
||||
|
||||
def check_db_health(engine: Engine, verify_table_existence: bool = False) -> None:
|
||||
dialect_name = getattr(engine.dialect, "name", "").lower()
|
||||
|
||||
if "postgres" in dialect_name:
|
||||
check_postgres_ping(engine)
|
||||
|
||||
elif dialect_name == "sqlite":
|
||||
check_sqlite_file(engine)
|
||||
|
||||
if verify_table_existence:
|
||||
verify_tables_and_columns(engine)
|
||||
|
||||
|
||||
def check_postgres_ping(engine: Engine) -> None:
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(select(1))
|
||||
scalar = result.scalar()
|
||||
if scalar != 1 and scalar is not None:
|
||||
print(
|
||||
"Unexpected response from Postgres when running 'SELECT 1'",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
except (OperationalError, DBAPIError) as exc:
|
||||
print(f"Failed to connect to Postgres database: {exc}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_sqlite_file(engine: Engine) -> None:
|
||||
db_path = engine.url.database
|
||||
|
||||
# Don't verify in-memory databases or empty paths
|
||||
if db_path in (None, "", ":memory:"):
|
||||
return
|
||||
|
||||
db_path = db_path.removeprefix("file:").removeprefix("sqlite:")
|
||||
|
||||
# Strip query parameters
|
||||
if "?" in db_path:
|
||||
db_path = db_path.split("?", 1)[0]
|
||||
|
||||
path = Path(db_path)
|
||||
|
||||
if not path.exists():
|
||||
print(f"SQLite database file does not exist: {path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not path.is_file():
|
||||
print(f"SQLite database path is not a file: {path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not file_is_submissive_and_readable(path):
|
||||
print(f"SQLite database file is not submissive and readable: {path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def verify_tables_and_columns(engine: Engine) -> None:
|
||||
iengine = inspect(engine)
|
||||
errors = False
|
||||
tables = iengine.get_table_names()
|
||||
views = iengine.get_view_names()
|
||||
tables.extend(views)
|
||||
|
||||
for _name, klass in Base.registry._class_registry.items():
|
||||
if isinstance(klass, _ModuleMarker):
|
||||
continue
|
||||
|
||||
table = klass.__tablename__
|
||||
if table in tables:
|
||||
columns = [c["name"] for c in iengine.get_columns(table)]
|
||||
mapper = inspect(klass)
|
||||
|
||||
for column_prop in mapper.attrs:
|
||||
if isinstance(column_prop, RelationshipProperty):
|
||||
pass
|
||||
else:
|
||||
for column in column_prop.columns:
|
||||
if not column.key in columns:
|
||||
print(
|
||||
f"Model '{klass}' declares column '{column.key}' which does not exist in database {engine}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
errors = True
|
||||
else:
|
||||
print(
|
||||
f"Model '{klass}' declares table '{table}' which does not exist in database {engine}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
errors = True
|
||||
|
||||
if errors:
|
||||
print("Have you remembered to run `dibbler create-db?", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
@@ -2,9 +2,12 @@ import os
|
||||
import pwd
|
||||
import signal
|
||||
import subprocess
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
def system_user_exists(username):
|
||||
def system_user_exists(username: str) -> bool:
|
||||
try:
|
||||
pwd.getpwnam(username)
|
||||
except KeyError:
|
||||
@@ -15,7 +18,7 @@ def system_user_exists(username):
|
||||
return True
|
||||
|
||||
|
||||
def guess_data_type(string):
|
||||
def guess_data_type(string: str) -> Literal["card", "rfid", "bar_code", "username"] | None:
|
||||
if string.startswith("ntnu") and string[4:].isdigit():
|
||||
return "card"
|
||||
if string.isdigit() and len(string) == 10:
|
||||
@@ -29,7 +32,11 @@ def guess_data_type(string):
|
||||
return None
|
||||
|
||||
|
||||
def argmax(d, all=False, value=None):
|
||||
def argmax(
|
||||
d: dict[Any, Any],
|
||||
all_: bool = False,
|
||||
value: Callable[[Any], Any] | None = None,
|
||||
) -> Any | list[Any] | None:
|
||||
maxarg = None
|
||||
if value is not None:
|
||||
dd = d
|
||||
@@ -39,12 +46,12 @@ def argmax(d, all=False, value=None):
|
||||
for key in list(d.keys()):
|
||||
if maxarg is None or d[key] > d[maxarg]:
|
||||
maxarg = key
|
||||
if all:
|
||||
if all_:
|
||||
return [k for k in list(d.keys()) if d[k] == d[maxarg]]
|
||||
return maxarg
|
||||
|
||||
|
||||
def less(string):
|
||||
def less(string: str) -> None:
|
||||
"""
|
||||
Run less with string as input; wait until it finishes.
|
||||
"""
|
||||
@@ -56,3 +63,13 @@ def less(string):
|
||||
proc = subprocess.Popen("less", env=env, encoding="utf-8", stdin=subprocess.PIPE)
|
||||
proc.communicate(string)
|
||||
signal.signal(signal.SIGINT, int_handler)
|
||||
|
||||
|
||||
def file_is_submissive_and_readable(file: Path) -> bool:
|
||||
return file.is_file() and any(
|
||||
[
|
||||
file.stat().st_mode & 0o400 and file.stat().st_uid == os.getuid(),
|
||||
file.stat().st_mode & 0o040 and file.stat().st_gid == os.getgid(),
|
||||
file.stat().st_mode & 0o004,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,98 +1,95 @@
|
||||
import os
|
||||
import datetime
|
||||
# import barcode
|
||||
# from brother_ql.brother_ql_create import create_label
|
||||
# from brother_ql.raster import BrotherQLRaster
|
||||
# from brother_ql.backends import backend_factory
|
||||
# from brother_ql.labels import ALL_LABELS
|
||||
# from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import barcode
|
||||
from brother_ql.brother_ql_create import create_label
|
||||
from brother_ql.raster import BrotherQLRaster
|
||||
from brother_ql.backends import backend_factory
|
||||
from brother_ql.labels import ALL_LABELS
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from .barcode_helpers import BrotherLabelWriter
|
||||
# from .barcode_helpers import BrotherLabelWriter
|
||||
|
||||
|
||||
def print_name_label(
|
||||
text,
|
||||
margin=10,
|
||||
rotate=False,
|
||||
label_type="62",
|
||||
printer_type="QL-700",
|
||||
):
|
||||
label = next([l for l in ALL_LABELS if l.identifier == label_type])
|
||||
if not rotate:
|
||||
width, height = label.dots_printable
|
||||
else:
|
||||
height, width = label.dots_printable
|
||||
# def print_name_label(
|
||||
# text,
|
||||
# margin=10,
|
||||
# rotate=False,
|
||||
# label_type="62",
|
||||
# printer_type="QL-700",
|
||||
# ):
|
||||
# label = next([l for l in ALL_LABELS if l.identifier == label_type])
|
||||
# if not rotate:
|
||||
# width, height = label.dots_printable
|
||||
# else:
|
||||
# height, width = label.dots_printable
|
||||
|
||||
font_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ChopinScript.ttf")
|
||||
fs = 2000
|
||||
tw, th = width, height
|
||||
if width == 0:
|
||||
while th + 2 * margin > height:
|
||||
font = ImageFont.truetype(font_path, fs)
|
||||
tw, th = font.getsize(text)
|
||||
fs -= 1
|
||||
width = tw + 2 * margin
|
||||
elif height == 0:
|
||||
while tw + 2 * margin > width:
|
||||
font = ImageFont.truetype(font_path, fs)
|
||||
tw, th = font.getsize(text)
|
||||
fs -= 1
|
||||
height = th + 2 * margin
|
||||
else:
|
||||
while tw + 2 * margin > width or th + 2 * margin > height:
|
||||
font = ImageFont.truetype(font_path, fs)
|
||||
tw, th = font.getsize(text)
|
||||
fs -= 1
|
||||
# font_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ChopinScript.ttf")
|
||||
# fs = 2000
|
||||
# tw, th = width, height
|
||||
# if width == 0:
|
||||
# while th + 2 * margin > height:
|
||||
# font = ImageFont.truetype(font_path, fs)
|
||||
# tw, th = font.getsize(text)
|
||||
# fs -= 1
|
||||
# width = tw + 2 * margin
|
||||
# elif height == 0:
|
||||
# while tw + 2 * margin > width:
|
||||
# font = ImageFont.truetype(font_path, fs)
|
||||
# tw, th = font.getsize(text)
|
||||
# fs -= 1
|
||||
# height = th + 2 * margin
|
||||
# else:
|
||||
# while tw + 2 * margin > width or th + 2 * margin > height:
|
||||
# font = ImageFont.truetype(font_path, fs)
|
||||
# tw, th = font.getsize(text)
|
||||
# fs -= 1
|
||||
|
||||
xp = (width // 2) - (tw // 2)
|
||||
yp = (height // 2) - (th // 2)
|
||||
# xp = (width // 2) - (tw // 2)
|
||||
# yp = (height // 2) - (th // 2)
|
||||
|
||||
im = Image.new("RGB", (width, height), (255, 255, 255))
|
||||
dr = ImageDraw.Draw(im)
|
||||
# im = Image.new("RGB", (width, height), (255, 255, 255))
|
||||
# dr = ImageDraw.Draw(im)
|
||||
|
||||
dr.text((xp, yp), text, fill=(0, 0, 0), font=font)
|
||||
now = datetime.datetime.now()
|
||||
date = now.strftime("%Y-%m-%d")
|
||||
dr.text((0, 0), date, fill=(0, 0, 0))
|
||||
# dr.text((xp, yp), text, fill=(0, 0, 0), font=font)
|
||||
# now = datetime.datetime.now()
|
||||
# date = now.strftime("%Y-%m-%d")
|
||||
# dr.text((0, 0), date, fill=(0, 0, 0))
|
||||
|
||||
base_path = os.path.dirname(os.path.realpath(__file__))
|
||||
fn = os.path.join(base_path, "bar_codes", text + ".png")
|
||||
# base_path = os.path.dirname(os.path.realpath(__file__))
|
||||
# fn = os.path.join(base_path, "bar_codes", text + ".png")
|
||||
|
||||
im.save(fn, "PNG")
|
||||
print_image(fn, printer_type, label_type)
|
||||
# im.save(fn, "PNG")
|
||||
# print_image(fn, printer_type, label_type)
|
||||
|
||||
|
||||
def print_bar_code(
|
||||
barcode_value,
|
||||
barcode_text,
|
||||
barcode_type="ean13",
|
||||
rotate=False,
|
||||
printer_type="QL-700",
|
||||
label_type="62",
|
||||
):
|
||||
bar_coder = barcode.get_barcode_class(barcode_type)
|
||||
wr = BrotherLabelWriter(typ=label_type, rot=rotate, text=barcode_text, max_height=1000)
|
||||
# def print_bar_code(
|
||||
# barcode_value,
|
||||
# barcode_text,
|
||||
# barcode_type="ean13",
|
||||
# rotate=False,
|
||||
# printer_type="QL-700",
|
||||
# label_type="62",
|
||||
# ):
|
||||
# bar_coder = barcode.get_barcode_class(barcode_type)
|
||||
# wr = BrotherLabelWriter(typ=label_type, rot=rotate, text=barcode_text, max_height=1000)
|
||||
|
||||
test = bar_coder(barcode_value, writer=wr)
|
||||
base_path = os.path.dirname(os.path.realpath(__file__))
|
||||
fn = test.save(os.path.join(base_path, "bar_codes", barcode_value))
|
||||
print_image(fn, printer_type, label_type)
|
||||
# test = bar_coder(barcode_value, writer=wr)
|
||||
# base_path = os.path.dirname(os.path.realpath(__file__))
|
||||
# fn = test.save(os.path.join(base_path, "bar_codes", barcode_value))
|
||||
# print_image(fn, printer_type, label_type)
|
||||
|
||||
|
||||
def print_image(fn, printer_type="QL-700", label_type="62"):
|
||||
qlr = BrotherQLRaster(printer_type)
|
||||
qlr.exception_on_warning = True
|
||||
create_label(qlr, fn, label_type, threshold=70, cut=True)
|
||||
# def print_image(fn, printer_type="QL-700", label_type="62"):
|
||||
# qlr = BrotherQLRaster(printer_type)
|
||||
# qlr.exception_on_warning = True
|
||||
# create_label(qlr, fn, label_type, threshold=70, cut=True)
|
||||
|
||||
be = backend_factory("pyusb")
|
||||
list_available_devices = be["list_available_devices"]
|
||||
BrotherQLBackend = be["backend_class"]
|
||||
# be = backend_factory("pyusb")
|
||||
# list_available_devices = be["list_available_devices"]
|
||||
# BrotherQLBackend = be["backend_class"]
|
||||
|
||||
ad = list_available_devices()
|
||||
assert ad
|
||||
string_descr = ad[0]["string_descr"]
|
||||
# ad = list_available_devices()
|
||||
# assert ad
|
||||
# string_descr = ad[0]["string_descr"]
|
||||
|
||||
printer = BrotherQLBackend(string_descr)
|
||||
# printer = BrotherQLBackend(string_descr)
|
||||
|
||||
printer.write(qlr.data)
|
||||
# printer.write(qlr.data)
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
from typing import TypeVar
|
||||
from sqlalchemy import BindParameter, literal
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def const(value: T) -> BindParameter[T]:
|
||||
"""
|
||||
Create a constant SQL literal bind parameter.
|
||||
|
||||
This is useful to avoid too many `?` bind parameters in SQL queries,
|
||||
when the input value is known to be safe.
|
||||
"""
|
||||
|
||||
return literal(value, literal_execute=True)
|
||||
|
||||
|
||||
CONST_ZERO: BindParameter[int] = const(0)
|
||||
CONST_ONE: BindParameter[int] = const(1)
|
||||
CONST_TRUE: BindParameter[bool] = const(True)
|
||||
CONST_FALSE: BindParameter[bool] = const(False)
|
||||
CONST_NONE: BindParameter[None] = const(None)
|
||||
@@ -1,3 +1,4 @@
|
||||
from dibbler.lib.render_tree import render_tree
|
||||
from dibbler.models import Transaction, TransactionType
|
||||
from dibbler.models.Transaction import EXPECTED_FIELDS
|
||||
|
||||
@@ -10,23 +11,19 @@ def render_transaction_log(transaction_log: list[Transaction]) -> str:
|
||||
aggregated_log = _aggregate_joint_transactions(transaction_log)
|
||||
|
||||
lines = []
|
||||
|
||||
for i, transaction in enumerate(aggregated_log):
|
||||
for transaction in aggregated_log:
|
||||
if isinstance(transaction, list):
|
||||
inner_lines = []
|
||||
is_last = i == len(aggregated_log) - 1
|
||||
lines.append(_render_transaction(transaction[0], is_last))
|
||||
for j, sub_transaction in enumerate(transaction[1:]):
|
||||
is_last_inner = j == len(transaction) - 2
|
||||
line = _render_transaction(sub_transaction, is_last_inner)
|
||||
lines.append(_render_transaction(transaction[0]))
|
||||
for sub_transaction in transaction[1:]:
|
||||
line = _render_transaction(sub_transaction)
|
||||
inner_lines.append(line)
|
||||
indented_inner_lines = _indent_lines(inner_lines, is_last=is_last)
|
||||
lines.extend(indented_inner_lines)
|
||||
lines.append(inner_lines)
|
||||
else:
|
||||
is_last = i == len(aggregated_log) - 1
|
||||
line = _render_transaction(transaction, is_last)
|
||||
line = _render_transaction(transaction)
|
||||
lines.append(line)
|
||||
return "\n".join(lines)
|
||||
|
||||
return render_tree(lines)
|
||||
|
||||
|
||||
def _aggregate_joint_transactions(
|
||||
@@ -61,17 +58,7 @@ def _aggregate_joint_transactions(
|
||||
return aggregated
|
||||
|
||||
|
||||
def _indent_lines(lines: list[str], is_last: bool = False) -> list[str]:
|
||||
indented_lines = []
|
||||
for line in lines:
|
||||
if is_last:
|
||||
indented_lines.append(" " + line)
|
||||
else:
|
||||
indented_lines.append("│ " + line)
|
||||
return indented_lines
|
||||
|
||||
|
||||
def _render_transaction(transaction: Transaction, is_last: bool) -> str:
|
||||
def _render_transaction(transaction: Transaction) -> str:
|
||||
match transaction.type_:
|
||||
case TransactionType.ADD_PRODUCT:
|
||||
line = f"ADD_PRODUCT({transaction.id}, {transaction.user.name}"
|
||||
@@ -125,5 +112,4 @@ def _render_transaction(transaction: Transaction, is_last: bool) -> str:
|
||||
line = (
|
||||
f"UNKNOWN[{transaction.type_}](id={transaction.id}, user_id={transaction.user_id})"
|
||||
)
|
||||
|
||||
return "└─ " + line if is_last else "├─ " + line
|
||||
return line
|
||||
|
||||
115
dibbler/lib/render_tree.py
Normal file
115
dibbler/lib/render_tree.py
Normal file
@@ -0,0 +1,115 @@
|
||||
_TREE_CHARS = {
|
||||
"normal": {
|
||||
"vertical": "│ ",
|
||||
"branch": "├─ ",
|
||||
"last": "└─ ",
|
||||
"empty": " ",
|
||||
},
|
||||
"ascii": {
|
||||
"vertical": "| ",
|
||||
"branch": "|-- ",
|
||||
"last": "`-- ",
|
||||
"empty": " ",
|
||||
},
|
||||
}
|
||||
|
||||
assert set(_TREE_CHARS["normal"].keys()) == set(_TREE_CHARS["ascii"].keys())
|
||||
assert all(len(v) == 3 for v in _TREE_CHARS["normal"].values())
|
||||
assert all(len(v) == 4 for v in _TREE_CHARS["ascii"].values())
|
||||
|
||||
|
||||
def render_tree(
|
||||
tree: list[str | list],
|
||||
ascii_only: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Render a tree structure as a string.
|
||||
|
||||
Each item in the `tree` list can be either a string (a leaf node)
|
||||
or another list (a subtree).
|
||||
|
||||
When `ascii_only` is `True`, only ASCII characters are used for drawing the tree.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
tree = [
|
||||
"root",
|
||||
[
|
||||
"child1",
|
||||
[
|
||||
"grandchild1",
|
||||
"grandchild2",
|
||||
],
|
||||
"child2",
|
||||
],
|
||||
"root2",
|
||||
]
|
||||
print(render_tree(tree, ascii_only=False))
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
├─ root
|
||||
│ ├─ child1
|
||||
│ │ ├─ grandchild1
|
||||
│ │ └─ grandchild2
|
||||
│ └─ child2
|
||||
└─ root2
|
||||
```
|
||||
|
||||
Example with ASCII only:
|
||||
|
||||
```python
|
||||
print(render_tree(tree, ascii_only=True))
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
|-- root
|
||||
| |-- child1
|
||||
| | |-- grandchild1
|
||||
| | `-- grandchild2
|
||||
| `-- child2
|
||||
`-- root2
|
||||
```
|
||||
"""
|
||||
|
||||
result: list[str] = []
|
||||
for index, item in enumerate(tree):
|
||||
is_last = index == len(tree) - 1
|
||||
item_lines = _render_tree_line(item, is_last, ascii_only)
|
||||
result.extend(item_lines)
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def _render_tree_line(
|
||||
item: str | list,
|
||||
is_last: bool,
|
||||
ascii_only: bool,
|
||||
prefix: str = "",
|
||||
) -> list[str]:
|
||||
chars = _TREE_CHARS["ascii"] if ascii_only else _TREE_CHARS["normal"]
|
||||
lines: list[str] = []
|
||||
|
||||
if isinstance(item, str):
|
||||
line_prefix = chars["last"] if is_last else chars["branch"]
|
||||
item_lines = item.splitlines()
|
||||
for line_index, line in enumerate(item_lines):
|
||||
if line_index == 0:
|
||||
lines.append(f"{prefix}{line_prefix}{line}")
|
||||
else:
|
||||
lines.append(f"{prefix}{chars['vertical']}{line}")
|
||||
|
||||
elif isinstance(item, list):
|
||||
new_prefix = prefix + (chars["empty"] if is_last else chars["vertical"])
|
||||
for sub_index, sub_item in enumerate(item):
|
||||
sub_is_last = sub_index == len(item) - 1
|
||||
sub_lines = _render_tree_line(sub_item, sub_is_last, ascii_only, new_prefix)
|
||||
lines.extend(sub_lines)
|
||||
else:
|
||||
raise ValueError("Item must be either a string or a list.")
|
||||
|
||||
return lines
|
||||
@@ -3,18 +3,20 @@
|
||||
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .helpers import *
|
||||
from ..models import Transaction
|
||||
from ..db import Session
|
||||
from .helpers import *
|
||||
|
||||
|
||||
def getUser():
|
||||
def getUser(sql_session: Session) -> str:
|
||||
assert sql_session is not None
|
||||
while 1:
|
||||
string = input("user? ")
|
||||
session = Session()
|
||||
user = search_user(string, session)
|
||||
session.close()
|
||||
user = search_user(string, sql_session)
|
||||
sql_session.close()
|
||||
if not isinstance(user, list):
|
||||
return user.name
|
||||
i = 0
|
||||
@@ -37,12 +39,11 @@ def getUser():
|
||||
return user[n].name
|
||||
|
||||
|
||||
def getProduct():
|
||||
def getProduct(sql_session: Session) -> str:
|
||||
assert sql_session is not None
|
||||
while 1:
|
||||
string = input("product? ")
|
||||
session = Session()
|
||||
product = search_product(string, session)
|
||||
session.close()
|
||||
product = search_product(string, sql_session)
|
||||
if not isinstance(product, list):
|
||||
return product.name
|
||||
i = 0
|
||||
@@ -89,7 +90,7 @@ class Database:
|
||||
|
||||
|
||||
class InputLine:
|
||||
def __init__(self, u, p, t):
|
||||
def __init__(self, u, p, t) -> None:
|
||||
self.inputUser = u
|
||||
self.inputProduct = p
|
||||
self.inputType = t
|
||||
@@ -122,17 +123,17 @@ def getInputType():
|
||||
return int(inp)
|
||||
|
||||
|
||||
def getProducts(products):
|
||||
def getProducts(products: str) -> list[tuple[str]]:
|
||||
product = []
|
||||
products = products.partition("¤")
|
||||
split_products = products.partition("¤")
|
||||
product.append(products[0])
|
||||
while products[1] == "¤":
|
||||
products = products[2].partition("¤")
|
||||
split_products = split_products[2].partition("¤")
|
||||
product.append(products[0])
|
||||
return product
|
||||
|
||||
|
||||
def getDateFile(date, inp):
|
||||
def getDateFile(date: str, inp: str) -> datetime.date:
|
||||
try:
|
||||
year = inp.partition("-")
|
||||
month = year[2].partition("-")
|
||||
@@ -176,7 +177,7 @@ def addLineToDatabase(database, inputLine):
|
||||
if abs(inputLine.price) > 90000:
|
||||
return database
|
||||
# fyller inn for varer
|
||||
if (not inputLine.product == "") and (
|
||||
if (inputLine.product != "") and (
|
||||
(inputLine.inputProduct == "") or (inputLine.inputProduct == inputLine.product)
|
||||
):
|
||||
database.varePersonAntall[inputLine.product][inputLine.user] = (
|
||||
@@ -190,7 +191,7 @@ def addLineToDatabase(database, inputLine):
|
||||
database.vareUkedagAntall[inputLine.product][inputLine.weekday] += 1
|
||||
# fyller inn for personer
|
||||
if (inputLine.inputUser == "") or (inputLine.inputUser == inputLine.user):
|
||||
if not inputLine.product == "":
|
||||
if inputLine.product != "":
|
||||
database.personVareAntall[inputLine.user][inputLine.product] = (
|
||||
database.personVareAntall[inputLine.user].setdefault(inputLine.product, 0) + 1
|
||||
)
|
||||
@@ -214,7 +215,7 @@ def addLineToDatabase(database, inputLine):
|
||||
database.personNegTransactions[inputLine.user] = (
|
||||
database.personNegTransactions.setdefault(inputLine.user, 0) + inputLine.price
|
||||
)
|
||||
elif not (inputLine.inputType == 1):
|
||||
elif inputLine.inputType != 1:
|
||||
database.globalVareAntall[inputLine.product] = (
|
||||
database.globalVareAntall.setdefault(inputLine.product, 0) + 1
|
||||
)
|
||||
@@ -225,7 +226,7 @@ def addLineToDatabase(database, inputLine):
|
||||
# fyller inn for global statistikk
|
||||
if (inputLine.inputType == 3) or (inputLine.inputType == 4):
|
||||
database.pengebeholdning[inputLine.dateNum] += inputLine.price
|
||||
if not (inputLine.product == ""):
|
||||
if inputLine.product != "":
|
||||
database.globalPersonAntall[inputLine.user] = (
|
||||
database.globalPersonAntall.setdefault(inputLine.user, 0) + 1
|
||||
)
|
||||
@@ -238,12 +239,12 @@ def addLineToDatabase(database, inputLine):
|
||||
return database
|
||||
|
||||
|
||||
def buildDatabaseFromDb(inputType, inputProduct, inputUser):
|
||||
def buildDatabaseFromDb(inputType, inputProduct, inputUser, sql_session: Session):
|
||||
assert sql_session is not None
|
||||
sdate = input("enter start date (yyyy-mm-dd)? ")
|
||||
edate = input("enter end date (yyyy-mm-dd)? ")
|
||||
print("building database...")
|
||||
session = Session()
|
||||
transaction_list = session.query(Transaction).all()
|
||||
transaction_list = sql_session.query(Transaction).all()
|
||||
inputLine = InputLine(inputUser, inputProduct, inputType)
|
||||
startDate = getDateDb(transaction_list[0].time, sdate)
|
||||
endDate = getDateDb(transaction_list[-1].time, edate)
|
||||
@@ -273,9 +274,9 @@ def buildDatabaseFromDb(inputType, inputProduct, inputUser):
|
||||
inputLine.price = 0
|
||||
|
||||
print("saving as default.dibblerlog...", end=" ")
|
||||
f = open("default.dibblerlog", "w")
|
||||
f = Path.open("default.dibblerlog", "w")
|
||||
line_format = "%s|%s|%s|%s|%s|%s\n"
|
||||
transaction_list = session.query(Transaction).all()
|
||||
transaction_list = sql_session.query(Transaction).all()
|
||||
for transaction in transaction_list:
|
||||
if transaction.purchase:
|
||||
products = "¤".join([ent.product.name for ent in transaction.purchase.entries])
|
||||
@@ -290,8 +291,7 @@ def buildDatabaseFromDb(inputType, inputProduct, inputUser):
|
||||
transaction.description,
|
||||
)
|
||||
f.write(line.encode("utf8"))
|
||||
session.close()
|
||||
f.close
|
||||
f.close()
|
||||
# bygg database.pengebeholdning
|
||||
if (inputType == 3) or (inputType == 4):
|
||||
for i in range(inputLine.numberOfDays + 1):
|
||||
@@ -311,7 +311,7 @@ def buildDatabaseFromFile(inputFile, inputType, inputProduct, inputUser):
|
||||
sdate = input("enter start date (yyyy-mm-dd)? ")
|
||||
edate = input("enter end date (yyyy-mm-dd)? ")
|
||||
|
||||
f = open(inputFile)
|
||||
f = Path.open(inputFile)
|
||||
try:
|
||||
fileLines = f.readlines()
|
||||
finally:
|
||||
@@ -329,7 +329,7 @@ def buildDatabaseFromFile(inputFile, inputType, inputProduct, inputUser):
|
||||
database.globalUkedagForbruk = [0] * 7
|
||||
database.pengebeholdning = [0] * (inputLine.numberOfDays + 1)
|
||||
for linje in fileLines:
|
||||
if not (linje[0] == "#") and not (linje == "\n"):
|
||||
if linje[0] != "#" and linje != "\n":
|
||||
# henter dateNum, products, user, price
|
||||
restDel = linje.partition("|")
|
||||
restDel = restDel[2].partition(" ")
|
||||
@@ -359,7 +359,7 @@ def buildDatabaseFromFile(inputFile, inputType, inputProduct, inputUser):
|
||||
return database, dateLine
|
||||
|
||||
|
||||
def printTopDict(dictionary, n, k):
|
||||
def printTopDict(dictionary: dict[str, Any], n: int, k: bool) -> None:
|
||||
i = 0
|
||||
for key in sorted(dictionary, key=dictionary.get, reverse=k):
|
||||
print(key, ": ", dictionary[key])
|
||||
@@ -369,7 +369,7 @@ def printTopDict(dictionary, n, k):
|
||||
break
|
||||
|
||||
|
||||
def printTopDict2(dictionary, dictionary2, n):
|
||||
def printTopDict2(dictionary, dictionary2, n) -> None:
|
||||
print("")
|
||||
print("product : price[kr] ( number )")
|
||||
i = 0
|
||||
@@ -381,7 +381,7 @@ def printTopDict2(dictionary, dictionary2, n):
|
||||
break
|
||||
|
||||
|
||||
def printWeekdays(week, days):
|
||||
def printWeekdays(week, days) -> None:
|
||||
if week == [] or days == 0:
|
||||
return
|
||||
print(
|
||||
@@ -404,10 +404,10 @@ def printWeekdays(week, days):
|
||||
print("")
|
||||
|
||||
|
||||
def printBalance(database, user):
|
||||
def printBalance(database, user) -> None:
|
||||
forbruk = 0
|
||||
if user in database.personVareVerdi:
|
||||
forbruk = sum([i for i in list(database.personVareVerdi[user].values())])
|
||||
forbruk = sum(database.personVareVerdi[user].values())
|
||||
print("totalt kjøpt for: ", forbruk, end=" ")
|
||||
if user in database.personNegTransactions:
|
||||
print("kr, totalt lagt til: ", -database.personNegTransactions[user], end=" ")
|
||||
@@ -419,14 +419,14 @@ def printBalance(database, user):
|
||||
print("")
|
||||
|
||||
|
||||
def printUser(database, dateLine, user, n):
|
||||
def printUser(database, dateLine, user, n) -> None:
|
||||
printTopDict2(database.personVareVerdi[user], database.personVareAntall[user], n)
|
||||
print("\nforbruk per ukedag [kr/dag],", end=" ")
|
||||
printWeekdays(database.personUkedagVerdi[user], len(dateLine))
|
||||
printBalance(database, user)
|
||||
|
||||
|
||||
def printProduct(database, dateLine, product, n):
|
||||
def printProduct(database, dateLine, product, n) -> None:
|
||||
printTopDict(database.varePersonAntall[product], n, 1)
|
||||
print("\nforbruk per ukedag [antall/dag],", end=" ")
|
||||
printWeekdays(database.vareUkedagAntall[product], len(dateLine))
|
||||
@@ -440,7 +440,7 @@ def printProduct(database, dateLine, product, n):
|
||||
)
|
||||
|
||||
|
||||
def printGlobal(database, dateLine, n):
|
||||
def printGlobal(database, dateLine, n) -> None:
|
||||
print("\nmest lagt til: ")
|
||||
printTopDict(database.personNegTransactions, n, 0)
|
||||
print("\nmest tatt fra:")
|
||||
@@ -454,9 +454,9 @@ def printGlobal(database, dateLine, n):
|
||||
"Det er solgt varer til en verdi av: ",
|
||||
sum(database.globalDatoForbruk),
|
||||
"kr, det er lagt til",
|
||||
-sum([i for i in list(database.personNegTransactions.values())]),
|
||||
-sum(database.personNegTransactions.values()),
|
||||
"og tatt fra",
|
||||
sum([i for i in list(database.personPosTransactions.values())]),
|
||||
sum(database.personPosTransactions.values()),
|
||||
end=" ",
|
||||
)
|
||||
print(
|
||||
@@ -466,23 +466,24 @@ def printGlobal(database, dateLine, n):
|
||||
)
|
||||
|
||||
|
||||
def alt4menuTextOnly(database, dateLine):
|
||||
def alt4menuTextOnly(database, dateLine, sql_session: Session) -> None:
|
||||
assert sql_session is not None
|
||||
n = 10
|
||||
while 1:
|
||||
print(
|
||||
"\n1: user-statistics, 2: product-statistics, 3:global-statistics, n: adjust amount of data shown q:quit"
|
||||
"\n1: user-statistics, 2: product-statistics, 3:global-statistics, n: adjust amount of data shown q:quit",
|
||||
)
|
||||
inp = input("")
|
||||
if inp == "q":
|
||||
break
|
||||
elif inp == "1":
|
||||
if inp == "1":
|
||||
try:
|
||||
printUser(database, dateLine, getUser(), n)
|
||||
printUser(database, dateLine, getUser(sql_session), n)
|
||||
except:
|
||||
print("\n\nSomething is not right, (last date prior to first date?)")
|
||||
elif inp == "2":
|
||||
try:
|
||||
printProduct(database, dateLine, getProduct(), n)
|
||||
printProduct(database, dateLine, getProduct(sql_session), n)
|
||||
except:
|
||||
print("\n\nSomething is not right, (last date prior to first date?)")
|
||||
elif inp == "3":
|
||||
@@ -494,15 +495,16 @@ def alt4menuTextOnly(database, dateLine):
|
||||
n = int(input("set number to show "))
|
||||
|
||||
|
||||
def statisticsTextOnly():
|
||||
def statisticsTextOnly(sql_session: Session) -> None:
|
||||
assert sql_session is not None
|
||||
inputType = 4
|
||||
product = ""
|
||||
user = ""
|
||||
print("\n0: from file, 1: from database, q:quit")
|
||||
inp = input("")
|
||||
if inp == "1":
|
||||
database, dateLine = buildDatabaseFromDb(inputType, product, user)
|
||||
database, dateLine = buildDatabaseFromDb(inputType, product, user, sql_session)
|
||||
elif inp == "0" or inp == "":
|
||||
database, dateLine = buildDatabaseFromFile("default.dibblerlog", inputType, product, user)
|
||||
if not inp == "q":
|
||||
alt4menuTextOnly(database, dateLine)
|
||||
if inp != "q":
|
||||
alt4menuTextOnly(database, dateLine, sql_session)
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from dibbler.conf import config
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.conf import config_db_string, load_config
|
||||
from dibbler.lib.check_db_health import check_db_health
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -11,13 +16,20 @@ parser.add_argument(
|
||||
help="Path to the config file",
|
||||
type=Path,
|
||||
metavar="FILE",
|
||||
default="config.ini",
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-V",
|
||||
"--version",
|
||||
help="Show program version",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(
|
||||
title="subcommands",
|
||||
dest="subcommand",
|
||||
required=True,
|
||||
)
|
||||
subparsers.add_parser("loop", help="Run the dibbler loop")
|
||||
subparsers.add_parser("create-db", help="Create the database")
|
||||
@@ -26,29 +38,55 @@ subparsers.add_parser("seed-data", help="Fill with mock data")
|
||||
subparsers.add_parser("transaction-log", help="Print transaction log")
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
args = parser.parse_args()
|
||||
config.read(args.config)
|
||||
|
||||
if args.version:
|
||||
from ._version import commit_id, version
|
||||
|
||||
print(f"Dibbler version {version}, commit {commit_id if commit_id else '<unknown>'}")
|
||||
return
|
||||
|
||||
if not args.subcommand:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
load_config(args.config)
|
||||
|
||||
engine = create_engine(config_db_string())
|
||||
|
||||
sql_session = Session(
|
||||
engine,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
close_resets_only=True,
|
||||
)
|
||||
|
||||
check_db_health(
|
||||
engine,
|
||||
verify_table_existence=args.subcommand != "create-db",
|
||||
)
|
||||
|
||||
if args.subcommand == "loop":
|
||||
import dibbler.subcommands.loop as loop
|
||||
|
||||
loop.main()
|
||||
loop.main(sql_session)
|
||||
|
||||
elif args.subcommand == "create-db":
|
||||
import dibbler.subcommands.makedb as makedb
|
||||
|
||||
makedb.main()
|
||||
makedb.main(engine)
|
||||
|
||||
elif args.subcommand == "slabbedasker":
|
||||
import dibbler.subcommands.slabbedasker as slabbedasker
|
||||
|
||||
slabbedasker.main()
|
||||
slabbedasker.main(sql_session)
|
||||
|
||||
elif args.subcommand == "seed-data":
|
||||
import dibbler.subcommands.seed_test_data as seed_test_data
|
||||
|
||||
seed_test_data.main()
|
||||
seed_test_data.main(sql_session)
|
||||
|
||||
elif args.subcommand == "transaction-log":
|
||||
import dibbler.subcommands.transaction_log as transaction_log
|
||||
|
||||
@@ -26,28 +26,28 @@ __all__ = [
|
||||
from .addstock import AddStockMenu
|
||||
from .buymenu import BuyMenu
|
||||
from .editing import (
|
||||
AddUserMenu,
|
||||
EditUserMenu,
|
||||
AddProductMenu,
|
||||
EditProductMenu,
|
||||
AddUserMenu,
|
||||
AdjustStockMenu,
|
||||
CleanupStockMenu,
|
||||
EditProductMenu,
|
||||
EditUserMenu,
|
||||
)
|
||||
from .faq import FAQMenu
|
||||
from .helpermenus import Menu
|
||||
from .mainmenu import MainMenu
|
||||
from .miscmenus import (
|
||||
ProductSearchMenu,
|
||||
TransferMenu,
|
||||
AdjustCreditMenu,
|
||||
UserListMenu,
|
||||
ShowUserMenu,
|
||||
ProductListMenu,
|
||||
ProductSearchMenu,
|
||||
ShowUserMenu,
|
||||
TransferMenu,
|
||||
UserListMenu,
|
||||
)
|
||||
from .printermenu import PrintLabelMenu
|
||||
from .stats import (
|
||||
ProductPopularityMenu,
|
||||
ProductRevenueMenu,
|
||||
BalanceMenu,
|
||||
LoggedStatisticsMenu,
|
||||
ProductPopularityMenu,
|
||||
ProductRevenueMenu,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from math import ceil
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import (
|
||||
Product,
|
||||
@@ -9,12 +10,13 @@ from dibbler.models import (
|
||||
Transaction,
|
||||
User,
|
||||
)
|
||||
|
||||
from .helpermenus import Menu
|
||||
|
||||
|
||||
class AddStockMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Add stock and adjust credit", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Add stock and adjust credit", sql_session)
|
||||
self.help_text = """
|
||||
Enter what you have bought for PVVVV here, along with your user name and how
|
||||
much money you're due in credits for the purchase when prompted.\n"""
|
||||
@@ -23,7 +25,7 @@ much money you're due in credits for the purchase when prompted.\n"""
|
||||
self.products = {}
|
||||
self.price = 0
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> bool | None:
|
||||
questions = {
|
||||
(
|
||||
False,
|
||||
@@ -86,10 +88,10 @@ much money you're due in credits for the purchase when prompted.\n"""
|
||||
|
||||
self.perform_transaction()
|
||||
|
||||
def complete_input(self):
|
||||
return bool(self.users) and len(self.products) and self.price
|
||||
def complete_input(self) -> bool:
|
||||
return self.users is not None and len(self.products) > 0 and self.price > 0
|
||||
|
||||
def print_info(self):
|
||||
def print_info(self) -> None:
|
||||
width = 6 + Product.name_length
|
||||
print()
|
||||
print(width * "-")
|
||||
@@ -109,7 +111,12 @@ much money you're due in credits for the purchase when prompted.\n"""
|
||||
print(f"{self.products[product][0]}".rjust(width - len(product.name)))
|
||||
print(width * "-")
|
||||
|
||||
def add_thing_to_pending(self, thing, amount, price):
|
||||
def add_thing_to_pending(
|
||||
self,
|
||||
thing: User | Product,
|
||||
amount: int,
|
||||
price: int,
|
||||
) -> None:
|
||||
if isinstance(thing, User):
|
||||
self.users.append(thing)
|
||||
elif thing in list(self.products.keys()):
|
||||
@@ -119,7 +126,7 @@ much money you're due in credits for the purchase when prompted.\n"""
|
||||
else:
|
||||
self.products[thing] = [amount, price]
|
||||
|
||||
def perform_transaction(self):
|
||||
def perform_transaction(self) -> None:
|
||||
print("Did you pay a different price?")
|
||||
if self.confirm(">", default=False):
|
||||
self.price = self.input_int("How much did you pay?", 0, self.price, default=self.price)
|
||||
@@ -132,10 +139,11 @@ much money you're due in credits for the purchase when prompted.\n"""
|
||||
old_price = product.price
|
||||
old_hidden = product.hidden
|
||||
product.price = int(
|
||||
ceil(float(value) / (max(product.stock, 0) + self.products[product][0]))
|
||||
ceil(float(value) / (max(product.stock, 0) + self.products[product][0])),
|
||||
)
|
||||
product.stock = max(
|
||||
self.products[product][0], product.stock + self.products[product][0]
|
||||
self.products[product][0],
|
||||
product.stock + self.products[product][0],
|
||||
)
|
||||
product.hidden = False
|
||||
print(
|
||||
@@ -151,13 +159,14 @@ much money you're due in credits for the purchase when prompted.\n"""
|
||||
PurchaseEntry(purchase, product, -self.products[product][0])
|
||||
|
||||
purchase.perform_soft_purchase(-self.price, round_up=False)
|
||||
self.session.add(purchase)
|
||||
self.sql_session.add(purchase)
|
||||
|
||||
try:
|
||||
self.session.commit()
|
||||
self.sql_session.commit()
|
||||
print("Success! Transaction performed:")
|
||||
# self.print_info()
|
||||
for user in self.users:
|
||||
print(f"User {user.name}'s credit is now {user.credit:d}")
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Could not perform transaction: {e}")
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.conf import config
|
||||
from dibbler.models import (
|
||||
@@ -13,10 +17,11 @@ from .helpermenus import Menu
|
||||
|
||||
|
||||
class BuyMenu(Menu):
|
||||
def __init__(self, session=None):
|
||||
Menu.__init__(self, "Buy", uses_db=True)
|
||||
if session:
|
||||
self.session = session
|
||||
superfast_mode: bool
|
||||
purchase: Purchase
|
||||
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Buy", sql_session)
|
||||
self.superfast_mode = False
|
||||
self.help_text = """
|
||||
Each purchase may contain one or more products and one or more buyers.
|
||||
@@ -28,7 +33,7 @@ addition, and you can type 'what' at any time to redisplay it.
|
||||
When finished, write an empty line to confirm the purchase.\n"""
|
||||
|
||||
@staticmethod
|
||||
def credit_check(user):
|
||||
def credit_check(user: User) -> bool:
|
||||
"""
|
||||
|
||||
:param user:
|
||||
@@ -37,28 +42,32 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
"""
|
||||
assert isinstance(user, User)
|
||||
|
||||
return user.credit > config.getint("limits", "low_credit_warning_limit")
|
||||
return user.credit > config["limits"]["low_credit_warning_limit"]
|
||||
|
||||
def low_credit_warning(self, user, timeout=False):
|
||||
def low_credit_warning(
|
||||
self,
|
||||
user: User,
|
||||
timeout: bool = False,
|
||||
) -> bool:
|
||||
assert isinstance(user, User)
|
||||
|
||||
print("***********************************************************************")
|
||||
print("***********************************************************************")
|
||||
print("")
|
||||
print("$$\ $$\ $$$$$$\ $$$$$$$\ $$\ $$\ $$$$$$\ $$\ $$\ $$$$$$\\")
|
||||
print("$$ | $\ $$ |$$ __$$\ $$ __$$\ $$$\ $$ |\_$$ _|$$$\ $$ |$$ __$$\\")
|
||||
print("$$ |$$$\ $$ |$$ / $$ |$$ | $$ |$$$$\ $$ | $$ | $$$$\ $$ |$$ / \__|")
|
||||
print("$$ $$ $$\$$ |$$$$$$$$ |$$$$$$$ |$$ $$\$$ | $$ | $$ $$\$$ |$$ |$$$$\\")
|
||||
print("$$$$ _$$$$ |$$ __$$ |$$ __$$< $$ \$$$$ | $$ | $$ \$$$$ |$$ |\_$$ |")
|
||||
print("$$$ / \$$$ |$$ | $$ |$$ | $$ |$$ |\$$$ | $$ | $$ |\$$$ |$$ | $$ |")
|
||||
print("$$ / \$$ |$$ | $$ |$$ | $$ |$$ | \$$ |$$$$$$\ $$ | \$$ |\$$$$$$ |")
|
||||
print("\__/ \__|\__| \__|\__| \__|\__| \__|\______|\__| \__| \______/")
|
||||
print("")
|
||||
print("***********************************************************************")
|
||||
print("***********************************************************************")
|
||||
print("")
|
||||
print(r"***********************************************************************")
|
||||
print(r"***********************************************************************")
|
||||
print(r"")
|
||||
print(r"$$\ $$\ $$$$$$\ $$$$$$$\ $$\ $$\ $$$$$$\ $$\ $$\ $$$$$$\\")
|
||||
print(r"$$ | $\ $$ |$$ __$$\ $$ __$$\ $$$\ $$ |\_$$ _|$$$\ $$ |$$ __$$\\")
|
||||
print(r"$$ |$$$\ $$ |$$ / $$ |$$ | $$ |$$$$\ $$ | $$ | $$$$\ $$ |$$ / \__|")
|
||||
print(r"$$ $$ $$\$$ |$$$$$$$$ |$$$$$$$ |$$ $$\$$ | $$ | $$ $$\$$ |$$ |$$$$\\")
|
||||
print(r"$$$$ _$$$$ |$$ __$$ |$$ __$$< $$ \$$$$ | $$ | $$ \$$$$ |$$ |\_$$ |")
|
||||
print(r"$$$ / \$$$ |$$ | $$ |$$ | $$ |$$ |\$$$ | $$ | $$ |\$$$ |$$ | $$ |")
|
||||
print(r"$$ / \$$ |$$ | $$ |$$ | $$ |$$ | \$$ |$$$$$$\ $$ | \$$ |\$$$$$$ |")
|
||||
print(r"\__/ \__|\__| \__|\__| \__|\__| \__|\______|\__| \__| \______/")
|
||||
print(r"")
|
||||
print(r"***********************************************************************")
|
||||
print(r"***********************************************************************")
|
||||
print(r"")
|
||||
print(
|
||||
f"USER {user.name} HAS LOWER CREDIT THAN {config.getint('limits', 'low_credit_warning_limit'):d}."
|
||||
f"USER {user.name} HAS LOWER CREDIT THAN {config['limits']['low_credit_warning_limit']:d}.",
|
||||
)
|
||||
print("THIS PURCHASE WILL CHARGE YOUR CREDIT TWICE AS MUCH.")
|
||||
print("CONSIDER PUTTING MONEY IN THE BOX TO AVOID THIS.")
|
||||
@@ -68,10 +77,13 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
if timeout:
|
||||
print("THIS PURCHASE WILL AUTOMATICALLY BE PERFORMED IN 3 MINUTES!")
|
||||
return self.confirm(prompt=">", default=True, timeout=180)
|
||||
else:
|
||||
return self.confirm(prompt=">", default=True)
|
||||
return self.confirm(prompt=">", default=True)
|
||||
|
||||
def add_thing_to_purchase(self, thing, amount=1):
|
||||
def add_thing_to_purchase(
|
||||
self,
|
||||
thing: User | Product,
|
||||
amount: int = 1,
|
||||
) -> bool:
|
||||
if isinstance(thing, User):
|
||||
if thing.is_anonymous():
|
||||
print("---------------------------------------------")
|
||||
@@ -80,7 +92,10 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
print("---------------------------------------------")
|
||||
|
||||
if not self.credit_check(thing):
|
||||
if self.low_credit_warning(user=thing, timeout=self.superfast_mode):
|
||||
if self.low_credit_warning(
|
||||
user=thing,
|
||||
timeout=self.superfast_mode,
|
||||
):
|
||||
Transaction(thing, purchase=self.purchase, penalty=2)
|
||||
else:
|
||||
return False
|
||||
@@ -95,7 +110,11 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
PurchaseEntry(self.purchase, thing, amount)
|
||||
return True
|
||||
|
||||
def _execute(self, initial_contents=None):
|
||||
def _execute(
|
||||
self,
|
||||
initial_contents: list[tuple[User | Product, int]] | None = None,
|
||||
**_kwargs,
|
||||
) -> bool:
|
||||
self.print_header()
|
||||
self.purchase = Purchase()
|
||||
self.exit_confirm_msg = None
|
||||
@@ -107,7 +126,7 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
for thing, num in initial_contents:
|
||||
self.add_thing_to_purchase(thing, num)
|
||||
|
||||
def is_product(candidate):
|
||||
def is_product(candidate: Any) -> bool:
|
||||
return isinstance(candidate[0], Product)
|
||||
|
||||
if len(initial_contents) > 0 and all(map(is_product, initial_contents)):
|
||||
@@ -129,7 +148,7 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
True,
|
||||
True,
|
||||
): "Enter more products or users, or an empty line to confirm",
|
||||
}[(len(self.purchase.transactions) > 0, len(self.purchase.entries) > 0)]
|
||||
}[(len(self.purchase.transactions) > 0, len(self.purchase.entries) > 0)],
|
||||
)
|
||||
|
||||
# Read in a 'thing' (product or user):
|
||||
@@ -147,16 +166,16 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
if thing is None:
|
||||
if not self.complete_input():
|
||||
if self.confirm(
|
||||
"Not enough information entered. Abort purchase?", default=True
|
||||
"Not enough information entered. Abort purchase?",
|
||||
default=True,
|
||||
):
|
||||
return False
|
||||
continue
|
||||
break
|
||||
else:
|
||||
# once we get something in the
|
||||
# purchase, we want to protect the
|
||||
# user from accidentally killing it
|
||||
self.exit_confirm_msg = "Abort purchase?"
|
||||
# once we get something in the
|
||||
# purchase, we want to protect the
|
||||
# user from accidentally killing it
|
||||
self.exit_confirm_msg = "Abort purchase?"
|
||||
|
||||
# Add the thing to our purchase object:
|
||||
if not self.add_thing_to_purchase(thing, amount=num):
|
||||
@@ -167,10 +186,11 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
break
|
||||
|
||||
self.purchase.perform_purchase()
|
||||
self.session.add(self.purchase)
|
||||
self.sql_session.add(self.purchase)
|
||||
try:
|
||||
self.session.commit()
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
self.sql_session.commit()
|
||||
except SQLAlchemyError as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Could not store purchase: {e}")
|
||||
else:
|
||||
print("Purchase stored.")
|
||||
@@ -178,9 +198,9 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
for t in self.purchase.transactions:
|
||||
if not t.user.is_anonymous():
|
||||
print(f"User {t.user.name}'s credit is now {t.user.credit:d} kr")
|
||||
if t.user.credit < config.getint("limits", "low_credit_warning_limit"):
|
||||
if t.user.credit < config["limits"]["low_credit_warning_limit"]:
|
||||
print(
|
||||
f"USER {t.user.name} HAS LOWER CREDIT THAN {config.getint('limits', 'low_credit_warning_limit'):d},",
|
||||
f"USER {t.user.name} HAS LOWER CREDIT THAN {config['limits']['low_credit_warning_limit']:d},",
|
||||
"AND SHOULD CONSIDER PUTTING SOME MONEY IN THE BOX.",
|
||||
)
|
||||
|
||||
@@ -189,10 +209,10 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
print("")
|
||||
return True
|
||||
|
||||
def complete_input(self):
|
||||
def complete_input(self) -> bool:
|
||||
return self.purchase.is_complete()
|
||||
|
||||
def format_purchase(self):
|
||||
def format_purchase(self) -> str | None:
|
||||
self.purchase.set_price()
|
||||
transactions = self.purchase.transactions
|
||||
entries = self.purchase.entries
|
||||
@@ -204,7 +224,10 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
string += "(empty)"
|
||||
else:
|
||||
string += ", ".join(
|
||||
[t.user.name + ("*" if not self.credit_check(t.user) else "") for t in transactions]
|
||||
[
|
||||
t.user.name + ("*" if not self.credit_check(t.user) else "")
|
||||
for t in transactions
|
||||
],
|
||||
)
|
||||
string += "\n products: "
|
||||
if len(entries) == 0:
|
||||
@@ -212,7 +235,7 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
else:
|
||||
string += "\n "
|
||||
string += "\n ".join(
|
||||
[f"{e.amount:d}x {e.product.name} ({e.product.price:d} kr)" for e in entries]
|
||||
[f"{e.amount:d}x {e.product.name} ({e.product.price:d} kr)" for e in entries],
|
||||
)
|
||||
if len(transactions) > 1:
|
||||
string += f"\n price per person: {self.purchase.price_per_transaction():d} kr"
|
||||
@@ -228,7 +251,7 @@ When finished, write an empty line to confirm the purchase.\n"""
|
||||
|
||||
return string
|
||||
|
||||
def print_purchase(self):
|
||||
def print_purchase(self) -> None:
|
||||
info = self.format_purchase()
|
||||
if info is not None:
|
||||
self.set_context(info)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import sqlalchemy
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, User
|
||||
|
||||
from dibbler.models import User, Product
|
||||
from .helpermenus import Menu, Selector
|
||||
|
||||
__all__ = [
|
||||
@@ -14,32 +17,48 @@ __all__ = [
|
||||
|
||||
|
||||
class AddUserMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Add user", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Add user", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
username = self.input_str(
|
||||
"Username (should be same as PVV username)",
|
||||
regex=User.name_re,
|
||||
length_range=(1, 10),
|
||||
)
|
||||
cardnum = self.input_str("Card number (optional)", regex=User.card_re, length_range=(0, 10))
|
||||
cardnum = cardnum.lower()
|
||||
rfid = self.input_str("RFID (optional)", regex=User.rfid_re, length_range=(0, 10))
|
||||
assert username is not None
|
||||
|
||||
cardnum = self.input_str(
|
||||
"Card number (optional)",
|
||||
regex=User.card_re,
|
||||
length_range=(0, 10),
|
||||
empty_string_is_none=True,
|
||||
)
|
||||
if cardnum is not None:
|
||||
cardnum = cardnum.lower()
|
||||
|
||||
rfid = self.input_str(
|
||||
"RFID (optional)",
|
||||
regex=User.rfid_re,
|
||||
length_range=(0, 10),
|
||||
empty_string_is_none=True,
|
||||
)
|
||||
|
||||
user = User(username, cardnum, rfid)
|
||||
self.session.add(user)
|
||||
self.sql_session.add(user)
|
||||
try:
|
||||
self.session.commit()
|
||||
self.sql_session.commit()
|
||||
print(f"User {username} stored")
|
||||
except sqlalchemy.exc.IntegrityError as e:
|
||||
except IntegrityError as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Could not store user {username}: {e}")
|
||||
self.pause()
|
||||
|
||||
|
||||
class EditUserMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Edit user", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Edit user", sql_session)
|
||||
self.help_text = """
|
||||
The only editable part of a user is its card number and rfid.
|
||||
|
||||
@@ -47,7 +66,7 @@ First select an existing user, then enter a new card number for that
|
||||
user, then rfid (write an empty line to remove the card number or rfid).
|
||||
"""
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
user = self.input_user("User")
|
||||
self.printc(f"Editing user {user.name}")
|
||||
@@ -69,43 +88,50 @@ user, then rfid (write an empty line to remove the card number or rfid).
|
||||
empty_string_is_none=True,
|
||||
)
|
||||
try:
|
||||
self.session.commit()
|
||||
self.sql_session.commit()
|
||||
print(f"User {user.name} stored")
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
except SQLAlchemyError as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Could not store user {user.name}: {e}")
|
||||
self.pause()
|
||||
|
||||
|
||||
class AddProductMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Add product", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Add product", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
bar_code = self.input_str("Bar code", regex=Product.bar_code_re, length_range=(8, 13))
|
||||
assert bar_code is not None
|
||||
|
||||
name = self.input_str("Name", regex=Product.name_re, length_range=(1, Product.name_length))
|
||||
assert name is not None
|
||||
|
||||
price = self.input_int("Price", 1, 100000)
|
||||
product = Product(bar_code, name, price)
|
||||
self.session.add(product)
|
||||
self.sql_session.add(product)
|
||||
try:
|
||||
self.session.commit()
|
||||
self.sql_session.commit()
|
||||
print(f"Product {name} stored")
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
except SQLAlchemyError as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Could not store product {name}: {e}")
|
||||
self.pause()
|
||||
|
||||
|
||||
class EditProductMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Edit product", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Edit product", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
product = self.input_product("Product")
|
||||
self.printc(f"Editing product {product.name}")
|
||||
while True:
|
||||
selector = Selector(
|
||||
f"Do what with {product.name}?",
|
||||
sql_session=self.sql_session,
|
||||
items=[
|
||||
("name", "Edit name"),
|
||||
("price", "Edit price"),
|
||||
@@ -135,9 +161,10 @@ class EditProductMenu(Menu):
|
||||
product.hidden = self.confirm(f"Hidden(currently {product.hidden})", default=False)
|
||||
elif what == "store":
|
||||
try:
|
||||
self.session.commit()
|
||||
self.sql_session.commit()
|
||||
print(f"Product {product.name} stored")
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
except SQLAlchemyError as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Could not store product {product.name}: {e}")
|
||||
self.pause()
|
||||
return
|
||||
@@ -149,10 +176,10 @@ class EditProductMenu(Menu):
|
||||
|
||||
|
||||
class AdjustStockMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Adjust stock", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Adjust stock", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
product = self.input_product("Product")
|
||||
|
||||
@@ -168,10 +195,11 @@ class AdjustStockMenu(Menu):
|
||||
product.stock += add_stock
|
||||
|
||||
try:
|
||||
self.session.commit()
|
||||
self.sql_session.commit()
|
||||
print("Stock is now stored")
|
||||
self.pause()
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
except SQLAlchemyError as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Could not store stock: {e}")
|
||||
self.pause()
|
||||
return
|
||||
@@ -179,13 +207,13 @@ class AdjustStockMenu(Menu):
|
||||
|
||||
|
||||
class CleanupStockMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Stock Cleanup", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Stock Cleanup", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
|
||||
products = self.session.query(Product).filter(Product.stock != 0).all()
|
||||
products = self.sql_session.query(Product).filter(Product.stock != 0).all()
|
||||
|
||||
print("Every product in stock will be printed.")
|
||||
print("Entering no value will keep current stock or set it to 0 if it is negative.")
|
||||
@@ -199,15 +227,16 @@ class CleanupStockMenu(Menu):
|
||||
for product in products:
|
||||
oldstock = product.stock
|
||||
product.stock = self.input_int(product.name, 0, 10000, default=max(0, oldstock))
|
||||
self.session.add(product)
|
||||
self.sql_session.add(product)
|
||||
if oldstock != product.stock:
|
||||
changed_products.append((product, oldstock))
|
||||
|
||||
try:
|
||||
self.session.commit()
|
||||
self.sql_session.commit()
|
||||
print("New stocks are now stored.")
|
||||
self.pause()
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
except SQLAlchemyError as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Could not store stock: {e}")
|
||||
self.pause()
|
||||
return
|
||||
|
||||
@@ -1,129 +1,146 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from textwrap import dedent
|
||||
|
||||
from .helpermenus import MessageMenu, Menu
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .helpermenus import Menu, MessageMenu
|
||||
|
||||
|
||||
class FAQMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Frequently Asked Questions")
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Frequently Asked Questions", sql_session)
|
||||
self.items = [
|
||||
MessageMenu(
|
||||
"What is the meaning with this program?",
|
||||
"""
|
||||
We want to avoid keeping lots of cash in PVVVV's money box and to
|
||||
make it easy to pay for stuff without using money. (Without using
|
||||
money each time, that is. You do of course have to pay for the things
|
||||
you buy eventually).
|
||||
dedent("""
|
||||
We want to avoid keeping lots of cash in PVVVV's money box and to
|
||||
make it easy to pay for stuff without using money. (Without using
|
||||
money each time, that is. You do of course have to pay for the things
|
||||
you buy eventually).
|
||||
|
||||
Dibbler stores a "credit" amount for each user. When you register a
|
||||
purchase in Dibbler, this amount is decreased. To increase your
|
||||
credit, purchase products for dibbler, and register them using "Add
|
||||
stock and adjust credit".
|
||||
Alternatively, add money to the money box and use "Adjust credit" to
|
||||
tell Dibbler about it.
|
||||
""",
|
||||
Dibbler stores a "credit" amount for each user. When you register a
|
||||
purchase in Dibbler, this amount is decreased. To increase your
|
||||
credit, purchase products for dibbler, and register them using "Add
|
||||
stock and adjust credit".
|
||||
Alternatively, add money to the money box and use "Adjust credit" to
|
||||
tell Dibbler about it.
|
||||
"""),
|
||||
sql_session,
|
||||
),
|
||||
MessageMenu(
|
||||
"Can I still pay for stuff using cash?",
|
||||
"""
|
||||
Please put money in the money box and use "Adjust Credit" so that
|
||||
dibbler can keep track of credit and purchases.""",
|
||||
dedent("""
|
||||
Please put money in the money box and use "Adjust Credit" so that
|
||||
dibbler can keep track of credit and purchases.
|
||||
"""),
|
||||
sql_session,
|
||||
),
|
||||
MessageMenu(
|
||||
"How do I exit from a submenu/dialog/thing?",
|
||||
'Type "exit", "q", or ^d.',
|
||||
sql_session,
|
||||
),
|
||||
MessageMenu("How do I exit from a submenu/dialog/thing?", 'Type "exit", "q", or ^d.'),
|
||||
MessageMenu(
|
||||
'What does "." mean?',
|
||||
"""
|
||||
The "." character, known as "full stop" or "period", is most often
|
||||
used to indicate the end of a sentence.
|
||||
dedent("""
|
||||
The "." character, known as "full stop" or "period", is most often
|
||||
used to indicate the end of a sentence.
|
||||
|
||||
It is also used by Dibbler to indicate that the program wants you to
|
||||
read some text before continuing. Whenever some output ends with a
|
||||
line containing only a period, you should read the lines above and
|
||||
then press enter to continue.
|
||||
""",
|
||||
It is also used by Dibbler to indicate that the program wants you to
|
||||
read some text before continuing. Whenever some output ends with a
|
||||
line containing only a period, you should read the lines above and
|
||||
then press enter to continue.
|
||||
"""),
|
||||
sql_session,
|
||||
),
|
||||
MessageMenu(
|
||||
"Why is the user interface so terribly unintuitive?",
|
||||
"""
|
||||
Answer #1: It is not.
|
||||
dedent("""
|
||||
Answer #1: It is not.
|
||||
|
||||
Answer #2: We are trying to compete with PVV's microwave oven in
|
||||
userfriendliness.
|
||||
Answer #2: We are trying to compete with PVV's microwave oven in
|
||||
userfriendliness.
|
||||
|
||||
Answer #3: YOU are unintuitive.
|
||||
""",
|
||||
Answer #3: YOU are unintuitive.
|
||||
"""),
|
||||
sql_session,
|
||||
),
|
||||
MessageMenu(
|
||||
"Why is there no help command?",
|
||||
'There is. Have you tried typing "help"?',
|
||||
sql_session,
|
||||
),
|
||||
MessageMenu(
|
||||
'Where are the easter eggs? I tried saying "moo", but nothing happened.',
|
||||
'Don\'t say "moo".',
|
||||
sql_session,
|
||||
),
|
||||
MessageMenu(
|
||||
"Why does the program speak English when all the users are Norwegians?",
|
||||
"Godt spørsmål. Det virket sikkert som en god idé der og da.",
|
||||
sql_session,
|
||||
),
|
||||
MessageMenu(
|
||||
"Why does the screen have strange colours?",
|
||||
"""
|
||||
Type "c" on the main menu to change the colours of the display, or
|
||||
"cs" if you are a boring person.
|
||||
""",
|
||||
dedent("""
|
||||
Type "c" on the main menu to change the colours of the display, or
|
||||
"cs" if you are a boring person.
|
||||
"""),
|
||||
sql_session,
|
||||
),
|
||||
MessageMenu(
|
||||
"I found a bug; is there a reward?",
|
||||
"""
|
||||
No.
|
||||
dedent("""
|
||||
No.
|
||||
|
||||
But if you are certain that it is a bug, not a feature, then you
|
||||
should fix it (or better: force someone else to do it).
|
||||
But if you are certain that it is a bug, not a feature, then you
|
||||
should fix it (or better: force someone else to do it).
|
||||
|
||||
Follow this procedure:
|
||||
Follow this procedure:
|
||||
|
||||
1. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler
|
||||
1. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler
|
||||
|
||||
2. Fix the bug.
|
||||
2. Fix the bug.
|
||||
|
||||
3. Check that the program still runs (and, preferably, that the bug is
|
||||
in fact fixed).
|
||||
3. Check that the program still runs (and, preferably, that the bug is
|
||||
in fact fixed).
|
||||
|
||||
4. Commit.
|
||||
4. Commit.
|
||||
|
||||
5. Update the running copy from svn:
|
||||
5. Update the running copy from svn:
|
||||
|
||||
$ su -
|
||||
# su -l -s /bin/bash pvvvv
|
||||
$ cd dibbler
|
||||
$ git pull
|
||||
$ su -
|
||||
# su -l -s /bin/bash pvvvv
|
||||
$ cd dibbler
|
||||
$ git pull
|
||||
|
||||
6. Type "restart" in Dibbler to replace the running process by a new
|
||||
one using the updated files.
|
||||
""",
|
||||
6. Type "restart" in Dibbler to replace the running process by a new
|
||||
one using the updated files.
|
||||
"""),
|
||||
sql_session,
|
||||
),
|
||||
MessageMenu(
|
||||
"My question isn't listed here; what do I do?",
|
||||
"""
|
||||
DON'T PANIC.
|
||||
dedent("""
|
||||
DON'T PANIC.
|
||||
|
||||
Follow this procedure:
|
||||
Follow this procedure:
|
||||
|
||||
1. Ask someone (or read the source code) and get an answer.
|
||||
1. Ask someone (or read the source code) and get an answer.
|
||||
|
||||
2. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler
|
||||
2. Check out the Dibbler code: https://github.com/Programvareverkstedet/dibbler
|
||||
|
||||
3. Add your question (with answer) to the FAQ and commit.
|
||||
3. Add your question (with answer) to the FAQ and commit.
|
||||
|
||||
4. Update the running copy from svn:
|
||||
4. Update the running copy from svn:
|
||||
|
||||
$ su -
|
||||
# su -l -s /bin/bash pvvvv
|
||||
$ cd dibbler
|
||||
$ git pull
|
||||
$ su -
|
||||
# su -l -s /bin/bash pvvvv
|
||||
$ cd dibbler
|
||||
$ git pull
|
||||
|
||||
5. Type "restart" in Dibbler to replace the running process by a new
|
||||
one using the updated files.
|
||||
""",
|
||||
5. Type "restart" in Dibbler to replace the running process by a new
|
||||
one using the updated files.
|
||||
"""),
|
||||
sql_session,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -1,44 +1,64 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
from select import select
|
||||
from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar
|
||||
|
||||
from dibbler.db import Session
|
||||
from dibbler.models import User
|
||||
from dibbler.lib.helpers import (
|
||||
search_user,
|
||||
search_product,
|
||||
guess_data_type,
|
||||
argmax,
|
||||
guess_data_type,
|
||||
search_product,
|
||||
search_user,
|
||||
)
|
||||
from dibbler.models import Product, User
|
||||
|
||||
exit_commands = ["exit", "abort", "quit", "bye", "eat flaming death", "q"]
|
||||
help_commands = ["help", "?"]
|
||||
context_commands = ["what", "??"]
|
||||
local_help_commands = ["help!", "???"]
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
exit_commands: list[str] = ["exit", "abort", "quit", "bye", "eat flaming death", "q"]
|
||||
help_commands: list[str] = ["help", "?"]
|
||||
context_commands: list[str] = ["what", "??"]
|
||||
local_help_commands: list[str] = ["help!", "???"]
|
||||
|
||||
|
||||
class ExitMenu(Exception):
|
||||
class ExitMenuException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Menu(object):
|
||||
MenuItemType = TypeVar("MenuItemType", bound="Menu")
|
||||
|
||||
|
||||
class Menu:
|
||||
name: str
|
||||
sql_session: Session
|
||||
items: list[Menu | tuple[MenuItemType, str] | str]
|
||||
prompt: str | None
|
||||
end_prompt: str | None
|
||||
return_index: bool
|
||||
exit_msg: str | None
|
||||
exit_confirm_msg: str | None
|
||||
exit_disallowed_msg: str | None
|
||||
help_text: str | None
|
||||
context: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
items=None,
|
||||
prompt=None,
|
||||
end_prompt="> ",
|
||||
return_index=True,
|
||||
exit_msg=None,
|
||||
exit_confirm_msg=None,
|
||||
exit_disallowed_msg=None,
|
||||
help_text=None,
|
||||
uses_db=False,
|
||||
):
|
||||
self.name = name
|
||||
name: str,
|
||||
sql_session: Session,
|
||||
items: list[Self | tuple[MenuItemType, str] | str] | None = None,
|
||||
prompt: str | None = None,
|
||||
end_prompt: str | None = "> ",
|
||||
return_index: bool = True,
|
||||
exit_msg: str | None = None,
|
||||
exit_confirm_msg: str | None = None,
|
||||
exit_disallowed_msg: str | None = None,
|
||||
help_text: str | None = None,
|
||||
) -> None:
|
||||
self.name: str = name
|
||||
self.sql_session: Session = sql_session
|
||||
self.items = items if items is not None else []
|
||||
self.prompt = prompt
|
||||
self.end_prompt = end_prompt
|
||||
@@ -48,54 +68,61 @@ class Menu(object):
|
||||
self.exit_disallowed_msg = exit_disallowed_msg
|
||||
self.help_text = help_text
|
||||
self.context = None
|
||||
self.uses_db = uses_db
|
||||
self.session = None
|
||||
|
||||
def exit_menu(self):
|
||||
assert name is not None
|
||||
assert self.sql_session is not None
|
||||
|
||||
def exit_menu(self) -> None:
|
||||
if self.exit_disallowed_msg is not None:
|
||||
print(self.exit_disallowed_msg)
|
||||
return
|
||||
if self.exit_confirm_msg is not None:
|
||||
if not self.confirm(self.exit_confirm_msg, default=True):
|
||||
return
|
||||
raise ExitMenu()
|
||||
raise ExitMenuException()
|
||||
|
||||
def at_exit(self):
|
||||
def at_exit(self) -> None:
|
||||
if self.exit_msg:
|
||||
print(self.exit_msg)
|
||||
|
||||
def set_context(self, string, display=True):
|
||||
def set_context(
|
||||
self,
|
||||
string: str | None,
|
||||
display: bool = True,
|
||||
) -> None:
|
||||
self.context = string
|
||||
if self.context is not None and display:
|
||||
print(self.context)
|
||||
|
||||
def add_to_context(self, string):
|
||||
self.context += string
|
||||
def add_to_context(self, string: str) -> None:
|
||||
if self.context is not None:
|
||||
self.context += string
|
||||
else:
|
||||
self.context = string
|
||||
|
||||
def printc(self, string):
|
||||
def printc(self, string: str) -> None:
|
||||
print(string)
|
||||
if self.context is None:
|
||||
self.context = string
|
||||
else:
|
||||
self.context += "\n" + string
|
||||
|
||||
def show_context(self):
|
||||
def show_context(self) -> None:
|
||||
print(self.header())
|
||||
if self.context is not None:
|
||||
print(self.context)
|
||||
|
||||
def item_is_submenu(self, i):
|
||||
def item_is_submenu(self, i: int) -> bool:
|
||||
return isinstance(self.items[i], Menu)
|
||||
|
||||
def item_name(self, i):
|
||||
def item_name(self, i: int) -> str:
|
||||
if self.item_is_submenu(i):
|
||||
return self.items[i].name
|
||||
elif isinstance(self.items[i], tuple):
|
||||
if isinstance(self.items[i], tuple):
|
||||
return self.items[i][1]
|
||||
else:
|
||||
return self.items[i]
|
||||
return self.items[i]
|
||||
|
||||
def item_value(self, i):
|
||||
def item_value(self, i: int) -> MenuItemType | int:
|
||||
if isinstance(self.items[i], tuple):
|
||||
return self.items[i][0]
|
||||
if self.return_index:
|
||||
@@ -104,14 +131,14 @@ class Menu(object):
|
||||
|
||||
def input_str(
|
||||
self,
|
||||
prompt=None,
|
||||
end_prompt=None,
|
||||
regex=None,
|
||||
length_range=(None, None),
|
||||
empty_string_is_none=False,
|
||||
timeout=None,
|
||||
default=None,
|
||||
):
|
||||
prompt: str | None = None,
|
||||
end_prompt: str | None = None,
|
||||
regex: str | None = None,
|
||||
length_range: tuple[int | None, int | None] = (None, None),
|
||||
empty_string_is_none: bool = False,
|
||||
timeout: int | None = None,
|
||||
default: str | None = None,
|
||||
) -> str | None:
|
||||
if prompt is None:
|
||||
prompt = self.prompt if self.prompt is not None else ""
|
||||
if default is not None:
|
||||
@@ -168,7 +195,7 @@ class Menu(object):
|
||||
):
|
||||
if length_range[0] and length_range[1]:
|
||||
print(
|
||||
f"Value must have length in range [{length_range[0]:d}, {length_range[1]:d}]"
|
||||
f"Value must have length in range [{length_range[0]:d}, {length_range[1]:d}]",
|
||||
)
|
||||
elif length_range[0]:
|
||||
print(f"Value must have length at least {length_range[0]:d}")
|
||||
@@ -177,7 +204,7 @@ class Menu(object):
|
||||
continue
|
||||
return result
|
||||
|
||||
def special_input_options(self, result):
|
||||
def special_input_options(self, result) -> bool:
|
||||
"""
|
||||
Handles special, magic input for input_str
|
||||
|
||||
@@ -187,7 +214,7 @@ class Menu(object):
|
||||
"""
|
||||
return False
|
||||
|
||||
def special_input_choice(self, in_str):
|
||||
def special_input_choice(self, in_str: str) -> bool:
|
||||
"""
|
||||
Handle choices which are not simply menu items.
|
||||
|
||||
@@ -197,33 +224,39 @@ class Menu(object):
|
||||
"""
|
||||
return False
|
||||
|
||||
def input_choice(self, number_of_choices, prompt=None, end_prompt=None):
|
||||
def input_choice(
|
||||
self,
|
||||
number_of_choices: int,
|
||||
prompt: str | None = None,
|
||||
end_prompt: str | None = None,
|
||||
) -> int:
|
||||
while True:
|
||||
result = self.input_str(prompt, end_prompt)
|
||||
assert result is not None
|
||||
if result == "":
|
||||
print("Please enter something")
|
||||
else:
|
||||
if result.isdigit():
|
||||
choice = int(result)
|
||||
if choice == 0 and 10 <= number_of_choices:
|
||||
if choice == 0 and number_of_choices >= 10:
|
||||
return 10
|
||||
if 0 < choice <= number_of_choices:
|
||||
return choice
|
||||
if not self.special_input_choice(result):
|
||||
self.invalid_menu_choice(result)
|
||||
|
||||
def invalid_menu_choice(self, in_str):
|
||||
def invalid_menu_choice(self, in_str: str) -> None:
|
||||
print("Please enter a valid choice.")
|
||||
|
||||
def input_int(
|
||||
self,
|
||||
prompt=None,
|
||||
minimum=None,
|
||||
maximum=None,
|
||||
null_allowed=False,
|
||||
zero_allowed=True,
|
||||
default=None,
|
||||
):
|
||||
prompt: str,
|
||||
minimum: int | None = None,
|
||||
maximum: int | None = None,
|
||||
null_allowed: bool = False,
|
||||
zero_allowed: bool = True,
|
||||
default: int | None = None,
|
||||
) -> int | Literal[False]:
|
||||
if minimum is not None and maximum is not None:
|
||||
end_prompt = f"({minimum}-{maximum})>"
|
||||
elif minimum is not None:
|
||||
@@ -234,7 +267,11 @@ class Menu(object):
|
||||
end_prompt = ""
|
||||
|
||||
while True:
|
||||
result = self.input_str(prompt + end_prompt, default=default)
|
||||
result = self.input_str(
|
||||
prompt + end_prompt,
|
||||
default=str(default) if default is not None else None,
|
||||
)
|
||||
assert result is not None
|
||||
if result == "" and null_allowed:
|
||||
return False
|
||||
try:
|
||||
@@ -252,93 +289,115 @@ class Menu(object):
|
||||
except ValueError:
|
||||
print("Please enter an integer")
|
||||
|
||||
def input_user(self, prompt=None, end_prompt=None):
|
||||
def input_user(
|
||||
self,
|
||||
prompt: str | None = None,
|
||||
end_prompt: str | None = None,
|
||||
) -> User:
|
||||
user = None
|
||||
while user is None:
|
||||
user = self.retrieve_user(self.input_str(prompt, end_prompt))
|
||||
search_string = self.input_str(prompt, end_prompt)
|
||||
assert search_string is not None
|
||||
user = self.retrieve_user(search_string)
|
||||
return user
|
||||
|
||||
def retrieve_user(self, search_str):
|
||||
def retrieve_user(self, search_str: str) -> User | None:
|
||||
return self.search_ui(search_user, search_str, "user")
|
||||
|
||||
def input_product(self, prompt=None, end_prompt=None):
|
||||
def input_product(
|
||||
self,
|
||||
prompt: str | None = None,
|
||||
end_prompt: str | None = None,
|
||||
) -> Product:
|
||||
product = None
|
||||
while product is None:
|
||||
product = self.retrieve_product(self.input_str(prompt, end_prompt))
|
||||
search_string = self.input_str(prompt, end_prompt)
|
||||
assert search_string is not None
|
||||
product = self.retrieve_product(search_string)
|
||||
return product
|
||||
|
||||
def retrieve_product(self, search_str):
|
||||
def retrieve_product(self, search_str: str) -> Product | None:
|
||||
return self.search_ui(search_product, search_str, "product")
|
||||
|
||||
def input_thing(
|
||||
self,
|
||||
prompt=None,
|
||||
end_prompt=None,
|
||||
permitted_things=("user", "product"),
|
||||
add_nonexisting=(),
|
||||
empty_input_permitted=False,
|
||||
find_hidden_products=True,
|
||||
):
|
||||
prompt: str | None = None,
|
||||
end_prompt: str | None = None,
|
||||
permitted_things: Iterable[str] = ("user", "product"),
|
||||
add_nonexisting: Iterable[str] = (),
|
||||
empty_input_permitted: bool = False,
|
||||
find_hidden_products: bool = True,
|
||||
) -> User | Product | None:
|
||||
result = None
|
||||
while result is None:
|
||||
search_str = self.input_str(prompt, end_prompt)
|
||||
assert search_str is not None
|
||||
if search_str == "" and empty_input_permitted:
|
||||
return None
|
||||
result = self.search_for_thing(
|
||||
search_str, permitted_things, add_nonexisting, find_hidden_products
|
||||
search_str,
|
||||
permitted_things,
|
||||
add_nonexisting,
|
||||
find_hidden_products,
|
||||
)
|
||||
return result
|
||||
|
||||
def input_multiple(
|
||||
self,
|
||||
prompt=None,
|
||||
end_prompt=None,
|
||||
permitted_things=("user", "product"),
|
||||
add_nonexisting=(),
|
||||
empty_input_permitted=False,
|
||||
find_hidden_products=True,
|
||||
):
|
||||
prompt: str | None = None,
|
||||
end_prompt: str | None = None,
|
||||
permitted_things: Iterable[str] = ("user", "product"),
|
||||
add_nonexisting: Iterable[str] = (),
|
||||
empty_input_permitted: bool = False,
|
||||
find_hidden_products: bool = True,
|
||||
) -> tuple[User | Product, int] | None:
|
||||
result = None
|
||||
num = 0
|
||||
while result is None:
|
||||
search_str = self.input_str(prompt, end_prompt)
|
||||
assert search_str is not None
|
||||
search_lst = search_str.split(" ")
|
||||
if search_str == "" and empty_input_permitted:
|
||||
return None
|
||||
else:
|
||||
result = self.search_for_thing(
|
||||
search_str, permitted_things, add_nonexisting, find_hidden_products
|
||||
)
|
||||
num = 1
|
||||
result = self.search_for_thing(
|
||||
search_str,
|
||||
permitted_things,
|
||||
add_nonexisting,
|
||||
find_hidden_products,
|
||||
)
|
||||
num = 1
|
||||
|
||||
if (result is None) and (len(search_lst) > 1):
|
||||
print('Interpreting input as "<number> <product>"')
|
||||
try:
|
||||
num = int(search_lst[0])
|
||||
result = self.search_for_thing(
|
||||
" ".join(search_lst[1:]),
|
||||
permitted_things,
|
||||
add_nonexisting,
|
||||
find_hidden_products,
|
||||
)
|
||||
# Her kan det legges inn en except ValueError,
|
||||
# men da blir det fort mye plaging av brukeren
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if (result is None) and (len(search_lst) > 1):
|
||||
print('Interpreting input as "<number> <product>"')
|
||||
try:
|
||||
num = int(search_lst[0])
|
||||
result = self.search_for_thing(
|
||||
" ".join(search_lst[1:]),
|
||||
permitted_things,
|
||||
add_nonexisting,
|
||||
find_hidden_products,
|
||||
)
|
||||
# Her kan det legges inn en except ValueError,
|
||||
# men da blir det fort mye plaging av brukeren
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return result, num
|
||||
|
||||
def search_for_thing(
|
||||
self,
|
||||
search_str,
|
||||
permitted_things=("user", "product"),
|
||||
add_non_existing=(),
|
||||
find_hidden_products=True,
|
||||
):
|
||||
search_fun = {"user": search_user, "product": search_product}
|
||||
search_str: str,
|
||||
permitted_things: Iterable[str] = ("user", "product"),
|
||||
add_non_existing: Iterable[str] = (),
|
||||
find_hidden_products: bool = True,
|
||||
) -> User | Product | None:
|
||||
search_fun = {
|
||||
"user": search_user,
|
||||
"product": search_product,
|
||||
}
|
||||
results = {}
|
||||
result_values = {}
|
||||
for thing in permitted_things:
|
||||
results[thing] = search_fun[thing](search_str, self.session, find_hidden_products)
|
||||
results[thing] = search_fun[thing](search_str, self.sql_session, find_hidden_products)
|
||||
result_values[thing] = self.search_result_value(results[thing])
|
||||
selected_thing = argmax(result_values)
|
||||
if not results[selected_thing]:
|
||||
@@ -353,10 +412,14 @@ class Menu(object):
|
||||
return self.search_add(search_str)
|
||||
# print('No match found for "%s".' % search_str)
|
||||
return None
|
||||
return self.search_ui2(search_str, results[selected_thing], selected_thing)
|
||||
return self.search_ui2(
|
||||
search_str,
|
||||
results[selected_thing],
|
||||
selected_thing,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def search_result_value(result):
|
||||
def search_result_value(result) -> Literal[0, 1, 2, 3]:
|
||||
if result is None:
|
||||
return 0
|
||||
if not isinstance(result, list):
|
||||
@@ -367,18 +430,19 @@ class Menu(object):
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def search_add(self, string):
|
||||
def search_add(self, string: str) -> User | None:
|
||||
type_guess = guess_data_type(string)
|
||||
if type_guess == "username":
|
||||
print(f'"{string}" looks like a username, but no such user exists.')
|
||||
if self.confirm(f"Create user {string}?"):
|
||||
user = User(string, None)
|
||||
self.session.add(user)
|
||||
self.sql_session.add(user)
|
||||
return user
|
||||
return None
|
||||
if type_guess == "card":
|
||||
selector = Selector(
|
||||
f'"{string}" looks like a card number, but no user with that card number exists.',
|
||||
self.sql_session,
|
||||
[
|
||||
("create", f"Create user with card number {string}"),
|
||||
("set", f"Set card number of an existing user to {string}"),
|
||||
@@ -387,12 +451,14 @@ class Menu(object):
|
||||
selection = selector.execute()
|
||||
if selection == "create":
|
||||
username = self.input_str(
|
||||
"Username for new user (should be same as PVV username)",
|
||||
User.name_re,
|
||||
(1, 10),
|
||||
prompt="Username for new user (should be same as PVV username)",
|
||||
end_prompt=None,
|
||||
regex=User.name_re,
|
||||
length_range=(1, 10),
|
||||
)
|
||||
assert username is not None
|
||||
user = User(username, string)
|
||||
self.session.add(user)
|
||||
self.sql_session.add(user)
|
||||
return user
|
||||
if selection == "set":
|
||||
user = self.input_user("User to set card number for")
|
||||
@@ -405,11 +471,21 @@ class Menu(object):
|
||||
print(f'"{string}" looks like the bar code for a product, but no such product exists.')
|
||||
return None
|
||||
|
||||
def search_ui(self, search_fun, search_str, thing):
|
||||
result = search_fun(search_str, self.session)
|
||||
def search_ui(
|
||||
self,
|
||||
search_fun: Callable[[str, Session], list[Any] | Any],
|
||||
search_str: str,
|
||||
thing: str,
|
||||
) -> Any:
|
||||
result = search_fun(search_str, self.sql_session)
|
||||
return self.search_ui2(search_str, result, thing)
|
||||
|
||||
def search_ui2(self, search_str, result, thing):
|
||||
def search_ui2(
|
||||
self,
|
||||
search_str: str,
|
||||
result: list[Any] | Any,
|
||||
thing: str,
|
||||
) -> Any:
|
||||
if not isinstance(result, list):
|
||||
return result
|
||||
if len(result) == 0:
|
||||
@@ -429,25 +505,41 @@ class Menu(object):
|
||||
else:
|
||||
select_header = f'{len(result):d} {thing}s matching "{search_str}"'
|
||||
select_items = result
|
||||
selector = Selector(select_header, items=select_items, return_index=False)
|
||||
selector = Selector(
|
||||
select_header,
|
||||
self.sql_session,
|
||||
items=select_items,
|
||||
return_index=False,
|
||||
)
|
||||
return selector.execute()
|
||||
|
||||
@staticmethod
|
||||
def confirm(prompt, end_prompt=None, default=None, timeout=None):
|
||||
return ConfirmMenu(prompt, end_prompt=None, default=default, timeout=timeout).execute()
|
||||
def confirm(
|
||||
self,
|
||||
prompt: str,
|
||||
end_prompt: str | None = None,
|
||||
default: bool | None = None,
|
||||
timeout: int | None = None,
|
||||
) -> bool:
|
||||
return ConfirmMenu(
|
||||
self.sql_session,
|
||||
prompt,
|
||||
end_prompt=None,
|
||||
default=default,
|
||||
timeout=timeout,
|
||||
).execute()
|
||||
|
||||
def header(self):
|
||||
def header(self) -> str:
|
||||
return f"[{self.name}]"
|
||||
|
||||
def print_header(self):
|
||||
def print_header(self) -> None:
|
||||
print("")
|
||||
print(self.header())
|
||||
|
||||
def pause(self):
|
||||
def pause(self) -> None:
|
||||
self.input_str(".", end_prompt="")
|
||||
|
||||
@staticmethod
|
||||
def general_help():
|
||||
def general_help() -> None:
|
||||
print(
|
||||
"""
|
||||
DIBBLER HELP
|
||||
@@ -470,10 +562,10 @@ class Menu(object):
|
||||
of money PVVVV owes the user. This value decreases with the
|
||||
appropriate amount when you register a purchase, and you may increase
|
||||
it by putting money in the box and using the "Adjust credit" menu.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
def local_help(self):
|
||||
def local_help(self) -> None:
|
||||
if self.help_text is None:
|
||||
print("no help here")
|
||||
else:
|
||||
@@ -481,21 +573,15 @@ class Menu(object):
|
||||
print(f"Help for {self.header()}:")
|
||||
print(self.help_text)
|
||||
|
||||
def execute(self, **kwargs):
|
||||
def execute(self, **_kwargs) -> MenuItemType | int | None:
|
||||
self.set_context(None)
|
||||
try:
|
||||
if self.uses_db and not self.session:
|
||||
self.session = Session()
|
||||
return self._execute(**kwargs)
|
||||
except ExitMenu:
|
||||
return self._execute(**_kwargs)
|
||||
except ExitMenuException:
|
||||
self.at_exit()
|
||||
return None
|
||||
finally:
|
||||
if self.session is not None:
|
||||
self.session.close()
|
||||
self.session = None
|
||||
|
||||
def _execute(self, **kwargs):
|
||||
def _execute(self, **_kwargs) -> MenuItemType | int | None:
|
||||
while True:
|
||||
self.print_header()
|
||||
self.set_context(None)
|
||||
@@ -514,12 +600,21 @@ class Menu(object):
|
||||
|
||||
|
||||
class MessageMenu(Menu):
|
||||
def __init__(self, name, message, pause_after_message=True):
|
||||
Menu.__init__(self, name)
|
||||
message: str
|
||||
pause_after_message: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
message: str,
|
||||
sql_session: Session,
|
||||
pause_after_message: bool = True,
|
||||
) -> None:
|
||||
super().__init__(name, sql_session)
|
||||
self.message = message.strip()
|
||||
self.pause_after_message = pause_after_message
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
print("")
|
||||
print(self.message)
|
||||
@@ -528,10 +623,17 @@ class MessageMenu(Menu):
|
||||
|
||||
|
||||
class ConfirmMenu(Menu):
|
||||
def __init__(self, prompt="confirm? ", end_prompt=": ", default=None, timeout=0):
|
||||
Menu.__init__(
|
||||
self,
|
||||
def __init__(
|
||||
self,
|
||||
sql_session: Session,
|
||||
prompt: str = "confirm? ",
|
||||
end_prompt: str | None = ": ",
|
||||
default: bool | None = None,
|
||||
timeout: int | None = 0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
"question",
|
||||
sql_session,
|
||||
prompt=prompt,
|
||||
end_prompt=end_prompt,
|
||||
exit_disallowed_msg="Please answer yes or no",
|
||||
@@ -539,45 +641,55 @@ class ConfirmMenu(Menu):
|
||||
self.default = default
|
||||
self.timeout = timeout
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> bool:
|
||||
options = {True: "[y]/n", False: "y/[n]", None: "y/n"}[self.default]
|
||||
while True:
|
||||
result = self.input_str(
|
||||
f"{self.prompt} ({options})", end_prompt=": ", timeout=self.timeout
|
||||
f"{self.prompt} ({options})",
|
||||
end_prompt=": ",
|
||||
timeout=self.timeout,
|
||||
)
|
||||
result = result.lower().strip()
|
||||
if result in ["y", "yes"]:
|
||||
return True
|
||||
elif result in ["n", "no"]:
|
||||
if result in ["n", "no"]:
|
||||
return False
|
||||
elif self.default is not None and result == "":
|
||||
if self.default is not None and result == "":
|
||||
return self.default
|
||||
else:
|
||||
print("Please answer yes or no")
|
||||
print("Please answer yes or no")
|
||||
|
||||
|
||||
class Selector(Menu):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
items=None,
|
||||
prompt="select",
|
||||
return_index=True,
|
||||
exit_msg=None,
|
||||
exit_confirm_msg=None,
|
||||
help_text=None,
|
||||
):
|
||||
name: str,
|
||||
sql_session: Session,
|
||||
items: list[Self | tuple[MenuItemType, str] | str] | None = None,
|
||||
prompt: str | None = "select",
|
||||
return_index: bool = True,
|
||||
exit_msg: str | None = None,
|
||||
exit_confirm_msg: str | None = None,
|
||||
help_text: str | None = None,
|
||||
) -> None:
|
||||
if items is None:
|
||||
items = []
|
||||
Menu.__init__(self, name, items, prompt, return_index=return_index, exit_msg=exit_msg)
|
||||
super().__init__(
|
||||
name,
|
||||
sql_session,
|
||||
items,
|
||||
prompt,
|
||||
return_index=return_index,
|
||||
exit_msg=exit_msg,
|
||||
help_text=help_text,
|
||||
)
|
||||
|
||||
def header(self):
|
||||
def header(self) -> str:
|
||||
return self.name
|
||||
|
||||
def print_header(self):
|
||||
def print_header(self) -> None:
|
||||
print(self.header())
|
||||
|
||||
def local_help(self):
|
||||
def local_help(self) -> None:
|
||||
if self.help_text is None:
|
||||
print("This is a selection menu. Enter one of the listed numbers, or")
|
||||
print("'exit' to go out and do something else.")
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
from dibbler.db import Session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .buymenu import BuyMenu
|
||||
from .faq import FAQMenu
|
||||
@@ -13,14 +12,17 @@ faq_commands = ["faq"]
|
||||
restart_commands = ["restart"]
|
||||
|
||||
|
||||
def restart():
|
||||
def restart() -> None:
|
||||
# Does not work if the script is not executable, or if it was
|
||||
# started by searching $PATH.
|
||||
os.execv(sys.argv[0], sys.argv)
|
||||
|
||||
|
||||
class MainMenu(Menu):
|
||||
def special_input_choice(self, in_str):
|
||||
def __init__(self, sql_session: Session, **_kwargs) -> None:
|
||||
super().__init__("Dibbler main menu", sql_session, **_kwargs)
|
||||
|
||||
def special_input_choice(self, in_str: str) -> bool:
|
||||
mv = in_str.split()
|
||||
if len(mv) == 2 and mv[0].isdigit():
|
||||
num = int(mv[0])
|
||||
@@ -28,7 +30,7 @@ class MainMenu(Menu):
|
||||
else:
|
||||
num = 1
|
||||
item_name = in_str
|
||||
buy_menu = BuyMenu(Session())
|
||||
buy_menu = BuyMenu(self.sql_session)
|
||||
thing = buy_menu.search_for_thing(item_name, find_hidden_products=False)
|
||||
if thing:
|
||||
buy_menu.execute(initial_contents=[(thing, num)])
|
||||
@@ -36,32 +38,26 @@ class MainMenu(Menu):
|
||||
return True
|
||||
return False
|
||||
|
||||
def special_input_options(self, result):
|
||||
def special_input_options(self, result: str) -> bool:
|
||||
if result in faq_commands:
|
||||
FAQMenu().execute()
|
||||
FAQMenu(self.sql_session).execute()
|
||||
return True
|
||||
if result in restart_commands:
|
||||
if self.confirm("Restart Dibbler?"):
|
||||
restart()
|
||||
pass
|
||||
return True
|
||||
elif result == "c":
|
||||
os.system(
|
||||
'echo -e "\033['
|
||||
+ str(random.randint(40, 49))
|
||||
+ ";"
|
||||
+ str(random.randint(30, 37))
|
||||
+ ';5m"'
|
||||
)
|
||||
os.system("clear")
|
||||
if result == "c":
|
||||
print(f"\033[{random.randint(40, 49)};{random.randint(30, 37)};5m")
|
||||
print("\033[2J")
|
||||
self.show_context()
|
||||
return True
|
||||
elif result == "cs":
|
||||
os.system('echo -e "\033[0m"')
|
||||
os.system("clear")
|
||||
if result == "cs":
|
||||
print("\033[0m")
|
||||
print("\033[2J")
|
||||
self.show_context()
|
||||
return True
|
||||
return False
|
||||
|
||||
def invalid_menu_choice(self, in_str):
|
||||
def invalid_menu_choice(self, in_str: str) -> None:
|
||||
print(self.show_context())
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
import sqlalchemy
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.conf import config
|
||||
from dibbler.models import Transaction, Product, User
|
||||
from dibbler.lib.helpers import less
|
||||
from dibbler.models import Product, Transaction, User
|
||||
|
||||
from .helpermenus import Menu, Selector
|
||||
|
||||
|
||||
class TransferMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Transfer credit between users", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Transfer credit between users", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
amount = self.input_int("Transfer amount", 1, 100000)
|
||||
self.set_context(f"Transferring {amount:d} kr", display=False)
|
||||
@@ -26,24 +28,25 @@ class TransferMenu(Menu):
|
||||
t2 = Transaction(user2, -amount, f'transfer from {user1.name} "{comment}"')
|
||||
t1.perform_transaction()
|
||||
t2.perform_transaction()
|
||||
self.session.add(t1)
|
||||
self.session.add(t2)
|
||||
self.sql_session.add(t1)
|
||||
self.sql_session.add(t2)
|
||||
try:
|
||||
self.session.commit()
|
||||
self.sql_session.commit()
|
||||
print(f"Transferred {amount:d} kr from {user1} to {user2}")
|
||||
print(f"User {user1}'s credit is now {user1.credit:d} kr")
|
||||
print(f"User {user2}'s credit is now {user2.credit:d} kr")
|
||||
print(f"Comment: {comment}")
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
except SQLAlchemyError as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Could not perform transfer: {e}")
|
||||
# self.pause()
|
||||
|
||||
|
||||
class ShowUserMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Show user", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Show user", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
user = self.input_user("User name, card number or RFID")
|
||||
print(f"User name: {user.name}")
|
||||
@@ -52,11 +55,12 @@ class ShowUserMenu(Menu):
|
||||
print(f"Credit: {user.credit} kr")
|
||||
selector = Selector(
|
||||
f"What do you want to know about {user.name}?",
|
||||
self.sql_session,
|
||||
items=[
|
||||
(
|
||||
"transactions",
|
||||
"Recent transactions (List of last "
|
||||
+ str(config.getint("limits", "user_recent_transaction_limit"))
|
||||
+ str(config["limits"]["user_recent_transaction_limit"])
|
||||
+ ")",
|
||||
),
|
||||
("products", f"Which products {user.name} has bought, and how many"),
|
||||
@@ -65,7 +69,7 @@ class ShowUserMenu(Menu):
|
||||
)
|
||||
what = selector.execute()
|
||||
if what == "transactions":
|
||||
self.print_transactions(user, config.getint("limits", "user_recent_transaction_limit"))
|
||||
self.print_transactions(user, config["limits"]["user_recent_transaction_limit"])
|
||||
elif what == "products":
|
||||
self.print_purchased_products(user)
|
||||
elif what == "transactions-all":
|
||||
@@ -74,7 +78,7 @@ class ShowUserMenu(Menu):
|
||||
print("What what?")
|
||||
|
||||
@staticmethod
|
||||
def print_transactions(user, limit=None):
|
||||
def print_transactions(user: User, limit: int | None = None) -> None:
|
||||
num_trans = len(user.transactions)
|
||||
if limit is None:
|
||||
limit = num_trans
|
||||
@@ -87,10 +91,7 @@ class ShowUserMenu(Menu):
|
||||
if t.purchase:
|
||||
products = []
|
||||
for entry in t.purchase.entries:
|
||||
if abs(entry.amount) != 1:
|
||||
amount = f"{abs(entry.amount)}x "
|
||||
else:
|
||||
amount = ""
|
||||
amount = f"{abs(entry.amount)}x " if abs(entry.amount) != 1 else ""
|
||||
product = f"{amount}{entry.product.name}"
|
||||
products.append(product)
|
||||
string += "purchase ("
|
||||
@@ -98,13 +99,13 @@ class ShowUserMenu(Menu):
|
||||
string += ")"
|
||||
if t.penalty > 1:
|
||||
string += f" * {t.penalty:d}x penalty applied"
|
||||
else:
|
||||
elif t.description is not None:
|
||||
string += t.description
|
||||
string += "\n"
|
||||
less(string)
|
||||
|
||||
@staticmethod
|
||||
def print_purchased_products(user):
|
||||
def print_purchased_products(user: User) -> None:
|
||||
products = []
|
||||
for ref in user.products:
|
||||
product = ref.product
|
||||
@@ -123,13 +124,13 @@ class ShowUserMenu(Menu):
|
||||
|
||||
|
||||
class UserListMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "User list", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("User list", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
user_list = self.session.query(User).all()
|
||||
total_credit = self.session.query(sqlalchemy.func.sum(User.credit)).first()[0]
|
||||
user_list = self.sql_session.query(User).all()
|
||||
total_credit = self.sql_session.query(sqlalchemy.func.sum(User.credit)).first()[0]
|
||||
|
||||
line_format = "%-12s | %6s\n"
|
||||
hline = "---------------------\n"
|
||||
@@ -144,10 +145,10 @@ class UserListMenu(Menu):
|
||||
|
||||
|
||||
class AdjustCreditMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Adjust credit", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Adjust credit", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
user = self.input_user("User")
|
||||
print(f"User {user.name}'s credit is {user.credit:d} kr")
|
||||
@@ -164,24 +165,25 @@ class AdjustCreditMenu(Menu):
|
||||
description = "manually adjusted credit"
|
||||
transaction = Transaction(user, -amount, description)
|
||||
transaction.perform_transaction()
|
||||
self.session.add(transaction)
|
||||
self.sql_session.add(transaction)
|
||||
try:
|
||||
self.session.commit()
|
||||
self.sql_session.commit()
|
||||
print(f"User {user.name}'s credit is now {user.credit:d} kr")
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
except SQLAlchemyError as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Could not store transaction: {e}")
|
||||
# self.pause()
|
||||
|
||||
|
||||
class ProductListMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Product list", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Product list", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
text = ""
|
||||
product_list = (
|
||||
self.session.query(Product)
|
||||
self.sql_session.query(Product)
|
||||
.filter(Product.hidden.is_(False))
|
||||
.order_by(Product.stock.desc())
|
||||
)
|
||||
@@ -204,21 +206,22 @@ class ProductListMenu(Menu):
|
||||
|
||||
|
||||
class ProductSearchMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Product search", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Product search", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
self.set_context("Enter (part of) product name or bar code")
|
||||
product = self.input_product()
|
||||
print(
|
||||
"Result: %s, price: %d kr, bar code: %s, stock: %d, hidden: %s"
|
||||
% (
|
||||
product.name,
|
||||
product.price,
|
||||
product.bar_code,
|
||||
product.stock,
|
||||
("Y" if product.hidden else "N"),
|
||||
)
|
||||
", ".join(
|
||||
[
|
||||
f"Result: {product.name}",
|
||||
f"price: {product.price} kr",
|
||||
f"bar code: {product.bar_code}",
|
||||
f"stock: {product.stock}",
|
||||
f"hidden: {'Y' if product.hidden else 'N'}",
|
||||
],
|
||||
),
|
||||
)
|
||||
# self.pause()
|
||||
|
||||
@@ -1,45 +1,46 @@
|
||||
import re
|
||||
|
||||
from dibbler.conf import config
|
||||
from dibbler.models import Product, User
|
||||
from dibbler.lib.printer_helpers import print_bar_code, print_name_label
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# from dibbler.lib.printer_helpers import print_bar_code, print_name_label
|
||||
from .helpermenus import Menu
|
||||
|
||||
|
||||
class PrintLabelMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Print a label", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Print a label", sql_session)
|
||||
self.help_text = """
|
||||
Prints out a product bar code on the printer
|
||||
|
||||
Put it up somewhere in the vicinity.
|
||||
"""
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
|
||||
thing = self.input_thing("Product/User")
|
||||
print("Printer menu is under renovation, please be patient")
|
||||
|
||||
if isinstance(thing, Product):
|
||||
if re.match(r"^[0-9]{13}$", thing.bar_code):
|
||||
bar_type = "ean13"
|
||||
elif re.match(r"^[0-9]{8}$", thing.bar_code):
|
||||
bar_type = "ean8"
|
||||
else:
|
||||
bar_type = "code39"
|
||||
print_bar_code(
|
||||
thing.bar_code,
|
||||
thing.name,
|
||||
barcode_type=bar_type,
|
||||
rotate=config.getboolean("printer", "rotate"),
|
||||
printer_type="QL-700",
|
||||
label_type=config.get("printer", "label_type"),
|
||||
)
|
||||
elif isinstance(thing, User):
|
||||
print_name_label(
|
||||
text=thing.name,
|
||||
label_type=config.get("printer", "label_type"),
|
||||
rotate=config.getboolean("printer", "rotate"),
|
||||
printer_type="QL-700",
|
||||
)
|
||||
return
|
||||
|
||||
# thing = self.input_thing("Product/User")
|
||||
|
||||
# if isinstance(thing, Product):
|
||||
# if re.match(r"^[0-9]{13}$", thing.bar_code):
|
||||
# bar_type = "ean13"
|
||||
# elif re.match(r"^[0-9]{8}$", thing.bar_code):
|
||||
# bar_type = "ean8"
|
||||
# else:
|
||||
# bar_type = "code39"
|
||||
# print_bar_code(
|
||||
# thing.bar_code,
|
||||
# thing.name,
|
||||
# barcode_type=bar_type,
|
||||
# rotate=config["printer"]["rotate"],
|
||||
# printer_type="QL-700",
|
||||
# label_type=config.get("printer", "label_type"),
|
||||
# )
|
||||
# elif isinstance(thing, User):
|
||||
# print_name_label(
|
||||
# text=thing.name,
|
||||
# label_type=config["printer"]["label_type"],
|
||||
# rotate=config["printer"]["rotate"],
|
||||
# printer_type="QL-700",
|
||||
# )
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.lib.helpers import less
|
||||
from dibbler.models import PurchaseEntry, Product, User
|
||||
from dibbler.lib.statistikkHelpers import statisticsTextOnly
|
||||
from dibbler.models import Product, PurchaseEntry, User
|
||||
|
||||
from .helpermenus import Menu
|
||||
|
||||
@@ -15,14 +16,14 @@ __all__ = [
|
||||
|
||||
|
||||
class ProductPopularityMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Products by popularity", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Products by popularity", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
text = ""
|
||||
sub = (
|
||||
self.session.query(
|
||||
self.sql_session.query(
|
||||
PurchaseEntry.product_id,
|
||||
func.sum(PurchaseEntry.amount).label("purchase_count"),
|
||||
)
|
||||
@@ -31,8 +32,8 @@ class ProductPopularityMenu(Menu):
|
||||
.subquery()
|
||||
)
|
||||
product_list = (
|
||||
self.session.query(Product, sub.c.purchase_count)
|
||||
.outerjoin((sub, Product.product_id == sub.c.product_id))
|
||||
self.sql_session.query(Product, sub.c.purchase_count)
|
||||
.outerjoin(sub, Product.product_id == sub.c.product_id)
|
||||
.order_by(desc(sub.c.purchase_count))
|
||||
.filter(sub.c.purchase_count is not None)
|
||||
.all()
|
||||
@@ -48,14 +49,14 @@ class ProductPopularityMenu(Menu):
|
||||
|
||||
|
||||
class ProductRevenueMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Products by revenue", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Products by revenue", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
text = ""
|
||||
sub = (
|
||||
self.session.query(
|
||||
self.sql_session.query(
|
||||
PurchaseEntry.product_id,
|
||||
func.sum(PurchaseEntry.amount).label("purchase_count"),
|
||||
)
|
||||
@@ -64,8 +65,8 @@ class ProductRevenueMenu(Menu):
|
||||
.subquery()
|
||||
)
|
||||
product_list = (
|
||||
self.session.query(Product, sub.c.purchase_count)
|
||||
.outerjoin((sub, Product.product_id == sub.c.product_id))
|
||||
self.sql_session.query(Product, sub.c.purchase_count)
|
||||
.outerjoin(sub, Product.product_id == sub.c.product_id)
|
||||
.order_by(desc(sub.c.purchase_count * Product.price))
|
||||
.filter(sub.c.purchase_count is not None)
|
||||
.all()
|
||||
@@ -86,22 +87,26 @@ class ProductRevenueMenu(Menu):
|
||||
|
||||
|
||||
class BalanceMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Total balance of PVVVV", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Total balance of PVVVV", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
self.print_header()
|
||||
text = ""
|
||||
total_value = 0
|
||||
product_list = self.session.query(Product).filter(Product.stock > 0).all()
|
||||
product_list = self.sql_session.query(Product).filter(Product.stock > 0).all()
|
||||
for p in product_list:
|
||||
total_value += p.stock * p.price
|
||||
|
||||
total_positive_credit = (
|
||||
self.session.query(func.sum(User.credit)).filter(User.credit > 0).first()[0]
|
||||
self.sql_session.query(func.coalesce(func.sum(User.credit), 0))
|
||||
.filter(User.credit > 0)
|
||||
.first()[0]
|
||||
)
|
||||
total_negative_credit = (
|
||||
self.session.query(func.sum(User.credit)).filter(User.credit < 0).first()[0]
|
||||
self.sql_session.query(func.coalesce(func.sum(User.credit), 0))
|
||||
.filter(User.credit < 0)
|
||||
.first()[0]
|
||||
)
|
||||
|
||||
total_credit = total_positive_credit + total_negative_credit
|
||||
@@ -119,8 +124,8 @@ class BalanceMenu(Menu):
|
||||
|
||||
|
||||
class LoggedStatisticsMenu(Menu):
|
||||
def __init__(self):
|
||||
Menu.__init__(self, "Statistics from log", uses_db=True)
|
||||
def __init__(self, sql_session: Session) -> None:
|
||||
super().__init__("Statistics from log", sql_session)
|
||||
|
||||
def _execute(self):
|
||||
statisticsTextOnly()
|
||||
def _execute(self, **_kwargs) -> None:
|
||||
statisticsTextOnly(self.sql_session)
|
||||
|
||||
@@ -17,16 +17,19 @@ def _pascal_case_to_snake_case(name: str) -> str:
|
||||
class Base(DeclarativeBase):
|
||||
metadata = MetaData(
|
||||
naming_convention={
|
||||
"ix": "ix_%(table_name)s_%(column_0_label)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
"ix": "ix__%(table_name)s__%(column_0_label)s",
|
||||
"uq": "uq__%(table_name)s__%(column_0_name)s",
|
||||
"ck": "ck__%(table_name)s__%(constraint_name)s",
|
||||
"fk": "fk__%(table_name)s__%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk__%(table_name)s",
|
||||
},
|
||||
)
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(cls) -> str:
|
||||
if hasattr(cls, "__table_name__"):
|
||||
assert isinstance(cls.__table_name__, str)
|
||||
return cls.__table_name__
|
||||
return _pascal_case_to_snake_case(cls.__name__)
|
||||
|
||||
# NOTE: This is the default implementation of __repr__ for all tables,
|
||||
@@ -46,7 +49,7 @@ class Base(DeclarativeBase):
|
||||
isinstance(v, InstrumentedList),
|
||||
isinstance(v, InstrumentedSet),
|
||||
isinstance(v, InstrumentedDict),
|
||||
]
|
||||
],
|
||||
)
|
||||
)
|
||||
return f"<{self.__class__.__name__}({columns})>"
|
||||
|
||||
27
dibbler/models/LastCacheTransaction.py
Normal file
27
dibbler/models/LastCacheTransaction.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from dibbler.models import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dibbler.models import Transaction
|
||||
|
||||
|
||||
class LastCacheTransaction(Base):
|
||||
"""Tracks the last transaction that affected various caches."""
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
"""Internal database ID"""
|
||||
|
||||
transaction_id: Mapped[int | None] = mapped_column(ForeignKey("trx.id"), index=True)
|
||||
"""The ID of the last transaction that affected the cache(s)."""
|
||||
|
||||
transaction: Mapped[Transaction | None] = relationship(
|
||||
lazy="joined",
|
||||
foreign_keys=[transaction_id],
|
||||
)
|
||||
"""The last transaction that affected the cache(s)."""
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Self
|
||||
from typing import TYPE_CHECKING, Self
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
@@ -19,6 +19,7 @@ class Product(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
"""Internal database ID"""
|
||||
|
||||
# TODO: add more validation for barcode
|
||||
bar_code: Mapped[str] = mapped_column(String(13), unique=True)
|
||||
"""
|
||||
The bar code of the product.
|
||||
|
||||
@@ -1,16 +1,33 @@
|
||||
from datetime import datetime
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Integer, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from dibbler.models import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dibbler.models import LastCacheTransaction, Product
|
||||
|
||||
|
||||
class ProductCache(Base):
|
||||
product_id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
"""Internal database ID"""
|
||||
|
||||
product_id: Mapped[int] = mapped_column(ForeignKey("product.id"))
|
||||
product: Mapped[Product] = relationship(
|
||||
lazy="joined",
|
||||
foreign_keys=[product_id],
|
||||
)
|
||||
|
||||
price: Mapped[int] = mapped_column(Integer)
|
||||
price_timestamp: Mapped[datetime] = mapped_column(DateTime)
|
||||
|
||||
stock: Mapped[int] = mapped_column(Integer)
|
||||
stock_timestamp: Mapped[datetime] = mapped_column(DateTime)
|
||||
|
||||
last_cache_transaction_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("last_cache_transaction.id"), nullable=True,
|
||||
)
|
||||
last_cache_transaction: Mapped[LastCacheTransaction | None] = relationship(
|
||||
lazy="joined",
|
||||
foreign_keys=[last_cache_transaction_id],
|
||||
)
|
||||
|
||||
@@ -33,11 +33,10 @@ if TYPE_CHECKING:
|
||||
from .Product import Product
|
||||
from .User import User
|
||||
|
||||
# TODO: rename to *_PERCENT
|
||||
# NOTE: these only matter when there are no adjustments made in the database.
|
||||
DEFAULT_INTEREST_RATE_PERCENTAGE = 100
|
||||
DEFAULT_INTEREST_RATE_PERCENT = 100
|
||||
DEFAULT_PENALTY_THRESHOLD = -100
|
||||
DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE = 200
|
||||
DEFAULT_PENALTY_MULTIPLIER_PERCENT = 200
|
||||
|
||||
_DYNAMIC_FIELDS: set[str] = {
|
||||
"amount",
|
||||
@@ -88,6 +87,7 @@ def _transaction_type_field_constraints(
|
||||
|
||||
|
||||
class Transaction(Base):
|
||||
__tablename__ = "trx"
|
||||
__table_args__ = (
|
||||
*[
|
||||
_transaction_type_field_constraints(transaction_type, expected_fields)
|
||||
@@ -131,12 +131,12 @@ class Transaction(Base):
|
||||
),
|
||||
name="trx_joint_transaction_id_not_self",
|
||||
),
|
||||
# Speed up product count calculation
|
||||
Index("product_user_time", "product_id", "user_id", "time"),
|
||||
# Speed up product stock calculation
|
||||
Index("ix__transaction__product_id_type_time", "product_id", "type", "time"),
|
||||
# Speed up product owner calculation
|
||||
Index("user_product_time", "user_id", "product_id", "time"),
|
||||
Index("ix__transaction__user_id_product_time", "user_id", "product_id", "time"),
|
||||
# Speed up user transaction list / credit calculation
|
||||
Index("user_time", "user_id", "time"),
|
||||
Index("ix__transaction__user_id_time", "user_id", "time"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
@@ -146,7 +146,7 @@ class Transaction(Base):
|
||||
Not used for anything else than identifying the transaction in the database.
|
||||
"""
|
||||
|
||||
time: Mapped[datetime] = mapped_column(DateTime)
|
||||
time: Mapped[datetime] = mapped_column(DateTime, index=True)
|
||||
"""
|
||||
The time when the transaction took place.
|
||||
|
||||
@@ -162,7 +162,7 @@ class Transaction(Base):
|
||||
This is not used for any calculations, but can be useful for debugging.
|
||||
"""
|
||||
|
||||
type_: Mapped[TransactionType] = mapped_column(TransactionTypeSQL, name="type")
|
||||
type_: Mapped[TransactionType] = mapped_column(TransactionTypeSQL, name="type", index=True)
|
||||
"""
|
||||
Which type of transaction this is.
|
||||
|
||||
@@ -189,7 +189,7 @@ class Transaction(Base):
|
||||
that the user paid in the store would be stored in the `amount` field.
|
||||
"""
|
||||
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), index=True)
|
||||
"""The user who performs the transaction. See `user` for more details."""
|
||||
user: Mapped[User] = relationship(
|
||||
lazy="joined",
|
||||
@@ -207,7 +207,10 @@ class Transaction(Base):
|
||||
In the case of `JOINT` transactions, this is the user who initiated the joint transaction.
|
||||
"""
|
||||
|
||||
joint_transaction_id: Mapped[int | None] = mapped_column(ForeignKey("transaction.id"))
|
||||
joint_transaction_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("trx.id"),
|
||||
index=True,
|
||||
)
|
||||
"""
|
||||
An optional ID to group multiple transactions together as part of a joint transaction.
|
||||
|
||||
@@ -223,7 +226,7 @@ class Transaction(Base):
|
||||
"""
|
||||
|
||||
# Receiving user when moving credit from one user to another
|
||||
transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"))
|
||||
transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"), index=True)
|
||||
"""The user who receives money in a `TRANSFER` transaction."""
|
||||
transfer_user: Mapped[User | None] = relationship(
|
||||
lazy="joined",
|
||||
@@ -232,7 +235,7 @@ class Transaction(Base):
|
||||
"""The user who receives money in a `TRANSFER` transaction."""
|
||||
|
||||
# The product that is either being added or bought
|
||||
product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id"))
|
||||
product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id"), index=True)
|
||||
"""The product being added or bought."""
|
||||
product: Mapped[Product | None] = relationship(lazy="joined")
|
||||
"""The product being added or bought."""
|
||||
@@ -330,7 +333,6 @@ class Transaction(Base):
|
||||
Validates the transaction's fields based on its type.
|
||||
Raises `ValueError` if the transaction is invalid.
|
||||
"""
|
||||
# TODO: do we allow free products?
|
||||
if self.amount == 0:
|
||||
raise ValueError("Amount must not be zero.")
|
||||
|
||||
@@ -352,7 +354,7 @@ class Transaction(Base):
|
||||
and self.amount > self.per_product * self.product_count
|
||||
):
|
||||
raise ValueError(
|
||||
"The real amount of the transaction must be less than the total value of the products."
|
||||
"The real amount of the transaction must be less than the total value of the products.",
|
||||
)
|
||||
|
||||
# TODO: improve printing further
|
||||
@@ -383,7 +385,7 @@ class Transaction(Base):
|
||||
isinstance(v, InstrumentedSet),
|
||||
isinstance(v, InstrumentedDict),
|
||||
*[k in (_DYNAMIC_FIELDS - EXPECTED_FIELDS[self.type_])],
|
||||
]
|
||||
],
|
||||
)
|
||||
)
|
||||
return f"{self.type_.upper()}({columns})"
|
||||
@@ -400,6 +402,11 @@ class Transaction(Base):
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Convenience constructor for creating an `ADJUST_BALANCE` transaction.
|
||||
|
||||
Should NOT be used directly in the application code; use the various queries instead.
|
||||
"""
|
||||
return cls(
|
||||
time=time,
|
||||
type_=TransactionType.ADJUST_BALANCE,
|
||||
@@ -416,6 +423,14 @@ class Transaction(Base):
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Convenience constructor for creating an `ADJUST_INTEREST` transaction.
|
||||
|
||||
Note that the `interest_rate_percent` is absolute, not relative to the previous interest rate.
|
||||
|
||||
Should NOT be used directly in the application code; use the various queries instead.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
time=time,
|
||||
type_=TransactionType.ADJUST_INTEREST,
|
||||
@@ -433,6 +448,14 @@ class Transaction(Base):
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Convenience constructor for creating an `ADJUST_PENALTY` transaction.
|
||||
|
||||
Note that both `penalty_multiplier_percent` and `penalty_threshold` are absolute,
|
||||
not relative to the previous settings.
|
||||
|
||||
Should NOT be used directly in the application code; use the various queries instead.
|
||||
"""
|
||||
return cls(
|
||||
time=time,
|
||||
type_=TransactionType.ADJUST_PENALTY,
|
||||
@@ -451,6 +474,11 @@ class Transaction(Base):
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Convenience constructor for creating an `ADJUST_STOCK` transaction.
|
||||
|
||||
Should NOT be used directly in the application code; use the various queries instead.
|
||||
"""
|
||||
return cls(
|
||||
time=time,
|
||||
type_=TransactionType.ADJUST_STOCK,
|
||||
@@ -471,6 +499,11 @@ class Transaction(Base):
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Convenience constructor for creating an `ADD_PRODUCT` transaction.
|
||||
|
||||
Should NOT be used directly in the application code; use the various queries instead.
|
||||
"""
|
||||
return cls(
|
||||
time=time,
|
||||
type_=TransactionType.ADD_PRODUCT,
|
||||
@@ -491,6 +524,11 @@ class Transaction(Base):
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Convenience constructor for creating a `BUY_PRODUCT` transaction.
|
||||
|
||||
Should NOT be used directly in the application code; use the various queries instead.
|
||||
"""
|
||||
return cls(
|
||||
time=time,
|
||||
type_=TransactionType.BUY_PRODUCT,
|
||||
@@ -509,6 +547,11 @@ class Transaction(Base):
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Convenience constructor for creating a `JOINT` transaction.
|
||||
|
||||
Should NOT be used directly in the application code; use the various queries instead.
|
||||
"""
|
||||
return cls(
|
||||
time=time,
|
||||
type_=TransactionType.JOINT,
|
||||
@@ -526,6 +569,11 @@ class Transaction(Base):
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Convenience constructor for creating a `JOINT_BUY_PRODUCT` transaction.
|
||||
|
||||
Should NOT be used directly in the application code; use the various queries instead.
|
||||
"""
|
||||
return cls(
|
||||
time=time,
|
||||
type_=TransactionType.JOINT_BUY_PRODUCT,
|
||||
@@ -543,6 +591,11 @@ class Transaction(Base):
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Convenience constructor for creating a `TRANSFER` transaction.
|
||||
|
||||
Should NOT be used directly in the application code; use the various queries instead.
|
||||
"""
|
||||
return cls(
|
||||
time=time,
|
||||
type_=TransactionType.TRANSFER,
|
||||
@@ -561,6 +614,11 @@ class Transaction(Base):
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Convenience constructor for creating a `THROW_PRODUCT` transaction.
|
||||
|
||||
Should NOT be used directly in the application code; use the various queries instead.
|
||||
"""
|
||||
return cls(
|
||||
time=time,
|
||||
type_=TransactionType.THROW_PRODUCT,
|
||||
|
||||
@@ -1,14 +1,33 @@
|
||||
from datetime import datetime
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Integer, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from dibbler.models import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dibbler.models import LastCacheTransaction, User
|
||||
|
||||
|
||||
# More like user balance cash money flow, amirite?
|
||||
class UserBalanceCache(Base):
|
||||
user_id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
class UserCache(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
"""internal database id"""
|
||||
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
|
||||
user: Mapped[User] = relationship(
|
||||
lazy="joined",
|
||||
foreign_keys=[user_id],
|
||||
)
|
||||
|
||||
balance: Mapped[int] = mapped_column(Integer)
|
||||
timestamp: Mapped[datetime] = mapped_column(DateTime)
|
||||
|
||||
last_cache_transaction_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("last_cache_transaction.id"), nullable=True,
|
||||
)
|
||||
last_cache_transaction: Mapped[LastCacheTransaction | None] = relationship(
|
||||
lazy="joined",
|
||||
foreign_keys=[last_cache_transaction_id],
|
||||
)
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
__all__ = [
|
||||
"Base",
|
||||
"LastCacheTransaction",
|
||||
"Product",
|
||||
"ProductCache",
|
||||
"Transaction",
|
||||
"TransactionType",
|
||||
"User",
|
||||
"UserCache",
|
||||
]
|
||||
|
||||
from .Base import Base
|
||||
from .LastCacheTransaction import LastCacheTransaction
|
||||
from .Product import Product
|
||||
from .ProductCache import ProductCache
|
||||
from .Transaction import Transaction
|
||||
from .TransactionType import TransactionType
|
||||
from .User import User
|
||||
from .UserCache import UserCache
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
__all__ = [
|
||||
# "add_product",
|
||||
# "add_user",
|
||||
"add_product",
|
||||
"adjust_balance",
|
||||
"adjust_interest",
|
||||
"adjust_penalty",
|
||||
"adjust_stock",
|
||||
"affected_products",
|
||||
"affected_users",
|
||||
"create_product",
|
||||
"create_user",
|
||||
"current_interest",
|
||||
"current_penalty",
|
||||
"joint_buy_product",
|
||||
@@ -11,27 +16,37 @@ __all__ = [
|
||||
"product_price",
|
||||
"product_price_log",
|
||||
"product_stock",
|
||||
# "products_owned_by_user",
|
||||
"search_product",
|
||||
"search_user",
|
||||
"throw_product",
|
||||
"transaction_log",
|
||||
"transfer",
|
||||
"update_cache",
|
||||
"user_balance",
|
||||
"user_balance_log",
|
||||
"user_products",
|
||||
]
|
||||
|
||||
# from .add_product import add_product
|
||||
# from .add_user import add_user
|
||||
from .add_product import add_product
|
||||
from .adjust_balance import adjust_balance
|
||||
from .adjust_interest import adjust_interest
|
||||
from .adjust_penalty import adjust_penalty
|
||||
from .adjust_stock import adjust_stock
|
||||
from .affected_products import affected_products
|
||||
from .affected_users import affected_users
|
||||
from .create_product import create_product
|
||||
from .create_user import create_user
|
||||
from .current_interest import current_interest
|
||||
from .current_penalty import current_penalty
|
||||
from .joint_buy_product import joint_buy_product
|
||||
from .product_owners import product_owners, product_owners_log
|
||||
from .product_price import product_price, product_price_log
|
||||
from .product_stock import product_stock
|
||||
|
||||
# from .products_owned_by_user import products_owned_by_user
|
||||
from .search_product import search_product
|
||||
from .search_user import search_user
|
||||
from .throw_product import throw_product
|
||||
from .transaction_log import transaction_log
|
||||
from .transfer import transfer
|
||||
from .update_cache import update_cache
|
||||
from .user_balance import user_balance, user_balance_log
|
||||
from .user_products import user_products
|
||||
|
||||
@@ -1 +1,51 @@
|
||||
# TODO: implement me
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, Transaction, User
|
||||
|
||||
|
||||
def add_product(
|
||||
sql_session: Session,
|
||||
user: User,
|
||||
product: Product,
|
||||
amount: int,
|
||||
per_product: int,
|
||||
product_count: int,
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Transaction:
|
||||
if user.id is None:
|
||||
raise ValueError("User must be persisted in the database.")
|
||||
|
||||
if product.id is None:
|
||||
raise ValueError("Product must be persisted in the database.")
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount must be positive.")
|
||||
|
||||
if per_product <= 0:
|
||||
raise ValueError("Per product price must be positive.")
|
||||
|
||||
if product_count <= 0:
|
||||
raise ValueError("Product count must be positive.")
|
||||
|
||||
if per_product * product_count < amount:
|
||||
raise ValueError("Total per product price must be at least equal to amount.")
|
||||
|
||||
# TODO: verify time is not behind last transaction's time
|
||||
|
||||
transaction = Transaction.add_product(
|
||||
user_id=user.id,
|
||||
product_id=product.id,
|
||||
amount=amount,
|
||||
per_product=per_product,
|
||||
product_count=product_count,
|
||||
time=time,
|
||||
message=message,
|
||||
)
|
||||
|
||||
sql_session.add(transaction)
|
||||
sql_session.commit()
|
||||
|
||||
return transaction
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# TODO: implement me
|
||||
33
dibbler/queries/adjust_balance.py
Normal file
33
dibbler/queries/adjust_balance.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Transaction, User
|
||||
|
||||
|
||||
def adjust_balance(
|
||||
sql_session: Session,
|
||||
user: User,
|
||||
amount: int,
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Transaction:
|
||||
if user.id is None:
|
||||
raise ValueError("User must be persisted in the database.")
|
||||
|
||||
if amount == 0:
|
||||
raise ValueError("Amount must be non-zero.")
|
||||
|
||||
# TODO: verify time is not behind last transaction's time
|
||||
|
||||
transaction = Transaction.adjust_balance(
|
||||
user_id=user.id,
|
||||
amount=amount,
|
||||
time=time,
|
||||
message=message,
|
||||
)
|
||||
|
||||
sql_session.add(transaction)
|
||||
sql_session.commit()
|
||||
|
||||
return transaction
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Transaction, User
|
||||
@@ -10,19 +12,25 @@ def adjust_interest(
|
||||
sql_session: Session,
|
||||
user: User,
|
||||
new_interest: int,
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
) -> Transaction:
|
||||
if new_interest < 0:
|
||||
raise ValueError("Interest rate cannot be negative")
|
||||
|
||||
if user.id is None:
|
||||
raise ValueError("User must be persisted in the database.")
|
||||
|
||||
# TODO: verify time is not behind last transaction's time
|
||||
|
||||
transaction = Transaction.adjust_interest(
|
||||
user_id=user.id,
|
||||
interest_rate_percent=new_interest,
|
||||
time=time,
|
||||
message=message,
|
||||
)
|
||||
|
||||
sql_session.add(transaction)
|
||||
sql_session.commit()
|
||||
|
||||
return transaction
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Transaction, User
|
||||
@@ -12,8 +14,9 @@ def adjust_penalty(
|
||||
user: User,
|
||||
new_penalty: int | None = None,
|
||||
new_penalty_multiplier: int | None = None,
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
) -> Transaction:
|
||||
if new_penalty is None and new_penalty_multiplier is None:
|
||||
raise ValueError("At least one of new_penalty or new_penalty_multiplier must be provided")
|
||||
|
||||
@@ -30,12 +33,17 @@ def adjust_penalty(
|
||||
if new_penalty_multiplier is None:
|
||||
new_penalty_multiplier = existing_penalty_multiplier
|
||||
|
||||
# TODO: verify time is not behind last transaction's time
|
||||
|
||||
transaction = Transaction.adjust_penalty(
|
||||
user_id=user.id,
|
||||
penalty_threshold=new_penalty,
|
||||
penalty_multiplier_percent=new_penalty_multiplier,
|
||||
time=time,
|
||||
message=message,
|
||||
)
|
||||
|
||||
sql_session.add(transaction)
|
||||
sql_session.commit()
|
||||
|
||||
return transaction
|
||||
|
||||
40
dibbler/queries/adjust_stock.py
Normal file
40
dibbler/queries/adjust_stock.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, Transaction, User
|
||||
|
||||
|
||||
def adjust_stock(
|
||||
sql_session: Session,
|
||||
user: User,
|
||||
product: Product,
|
||||
product_count: int,
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Transaction:
|
||||
if user.id is None:
|
||||
raise ValueError("User must be persisted in the database.")
|
||||
|
||||
if product.id is None:
|
||||
raise ValueError("Product must be persisted in the database.")
|
||||
|
||||
if product_count == 0:
|
||||
raise ValueError("Product count must be non-zero.")
|
||||
|
||||
# TODO: it should not be possible to reduce stock below zero.
|
||||
#
|
||||
# TODO: verify time is not behind last transaction's time
|
||||
|
||||
transaction = Transaction.adjust_stock(
|
||||
user_id=user.id,
|
||||
product_id=product.id,
|
||||
product_count=product_count,
|
||||
time=time,
|
||||
message=message,
|
||||
)
|
||||
|
||||
sql_session.add(transaction)
|
||||
sql_session.commit()
|
||||
|
||||
return transaction
|
||||
88
dibbler/queries/affected_products.py
Normal file
88
dibbler/queries/affected_products.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BindParameter, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, Transaction, TransactionType
|
||||
from dibbler.queries.query_helpers import after_filter, until_filter
|
||||
|
||||
|
||||
def affected_products(
|
||||
sql_session: Session,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: BindParameter[Transaction] | Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
after_time: BindParameter[datetime] | datetime | None = None,
|
||||
after_transaction: Transaction | None = None,
|
||||
after_inclusive: bool = True,
|
||||
) -> set[Product]:
|
||||
"""
|
||||
Get all products where attributes were affected over a given interval.
|
||||
"""
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
if not (after_time is None or after_transaction is None):
|
||||
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
|
||||
|
||||
if isinstance(after_time, datetime):
|
||||
after_time = BindParameter("after_time", value=after_time)
|
||||
|
||||
if isinstance(after_transaction, Transaction):
|
||||
if after_transaction.id is None:
|
||||
raise ValueError("after_transaction must be persisted in the database.")
|
||||
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
|
||||
else:
|
||||
after_transaction_id = None
|
||||
|
||||
if after_time is not None and until_time is not None:
|
||||
assert isinstance(after_time.value, datetime)
|
||||
assert isinstance(until_time.value, datetime)
|
||||
|
||||
if after_time.value > until_time.value:
|
||||
raise ValueError("after_time cannot be after until_time.")
|
||||
|
||||
if after_transaction is not None and until_transaction is not None:
|
||||
assert after_transaction.time is not None
|
||||
assert until_transaction.time is not None
|
||||
|
||||
if after_transaction.time > until_transaction.time:
|
||||
raise ValueError("after_transaction cannot be after until_transaction.")
|
||||
|
||||
result = sql_session.scalars(
|
||||
select(Product)
|
||||
.distinct()
|
||||
.join(Transaction, Product.id == Transaction.product_id)
|
||||
.where(
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
TransactionType.JOINT.as_literal_column(),
|
||||
TransactionType.THROW_PRODUCT.as_literal_column(),
|
||||
],
|
||||
),
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
after_filter(
|
||||
after_time=after_time,
|
||||
after_transaction_id=after_transaction_id,
|
||||
after_inclusive=after_inclusive,
|
||||
),
|
||||
)
|
||||
.order_by(Transaction.time.desc()),
|
||||
).all()
|
||||
|
||||
return set(result)
|
||||
90
dibbler/queries/affected_users.py
Normal file
90
dibbler/queries/affected_users.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BindParameter, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Transaction, TransactionType, User
|
||||
from dibbler.queries.query_helpers import after_filter, until_filter
|
||||
|
||||
|
||||
def affected_users(
|
||||
sql_session: Session,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: BindParameter[Transaction] | Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
after_time: BindParameter[datetime] | datetime | None = None,
|
||||
after_transaction: Transaction | None = None,
|
||||
after_inclusive: bool = True,
|
||||
) -> set[User]:
|
||||
"""
|
||||
Get all users where attributes were affected over a given interval.
|
||||
"""
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
if not (after_time is None or after_transaction is None):
|
||||
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
|
||||
|
||||
if isinstance(after_time, datetime):
|
||||
after_time = BindParameter("after_time", value=after_time)
|
||||
|
||||
if isinstance(after_transaction, Transaction):
|
||||
if after_transaction.id is None:
|
||||
raise ValueError("after_transaction must be persisted in the database.")
|
||||
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
|
||||
else:
|
||||
after_transaction_id = None
|
||||
|
||||
if after_time is not None and until_time is not None:
|
||||
assert isinstance(after_time.value, datetime)
|
||||
assert isinstance(until_time.value, datetime)
|
||||
|
||||
if after_time.value > until_time.value:
|
||||
raise ValueError("after_time cannot be after until_time.")
|
||||
|
||||
if after_transaction is not None and until_transaction is not None:
|
||||
assert after_transaction.time is not None
|
||||
assert until_transaction.time is not None
|
||||
|
||||
if after_transaction.time > until_transaction.time:
|
||||
raise ValueError("after_transaction cannot be after until_transaction.")
|
||||
|
||||
result = sql_session.scalars(
|
||||
select(User)
|
||||
.distinct()
|
||||
.join(
|
||||
Transaction,
|
||||
or_(User.id == Transaction.user_id, User.id == Transaction.transfer_user_id),
|
||||
)
|
||||
.where(
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_BALANCE.as_literal_column(),
|
||||
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
TransactionType.TRANSFER.as_literal_column(),
|
||||
],
|
||||
),
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
after_filter(
|
||||
after_time=after_time,
|
||||
after_transaction_id=after_transaction_id,
|
||||
after_inclusive=after_inclusive,
|
||||
),
|
||||
)
|
||||
.order_by(Transaction.time.desc()),
|
||||
).all()
|
||||
|
||||
return set(result)
|
||||
38
dibbler/queries/buy_product.py
Normal file
38
dibbler/queries/buy_product.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, Transaction, User
|
||||
|
||||
|
||||
def buy_product(
|
||||
sql_session: Session,
|
||||
user: User,
|
||||
product: Product,
|
||||
product_count: int,
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Transaction:
|
||||
if user.id is None:
|
||||
raise ValueError("User must be persisted in the database.")
|
||||
|
||||
if product.id is None:
|
||||
raise ValueError("Product must be persisted in the database.")
|
||||
|
||||
if product_count <= 0:
|
||||
raise ValueError("Product count must be positive.")
|
||||
|
||||
# TODO: verify time is not behind last transaction's time
|
||||
|
||||
transaction = Transaction.buy_product(
|
||||
user_id=user.id,
|
||||
product_id=product.id,
|
||||
product_count=product_count,
|
||||
time=time,
|
||||
message=message,
|
||||
)
|
||||
|
||||
sql_session.add(transaction)
|
||||
sql_session.commit()
|
||||
|
||||
return transaction
|
||||
25
dibbler/queries/create_product.py
Normal file
25
dibbler/queries/create_product.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product
|
||||
|
||||
|
||||
def create_product(
|
||||
sql_session: Session,
|
||||
name: str,
|
||||
barcode: str,
|
||||
) -> Product:
|
||||
if not name:
|
||||
raise ValueError("Name cannot be empty.")
|
||||
|
||||
if not barcode:
|
||||
raise ValueError("Barcode cannot be empty.")
|
||||
|
||||
# TODO: check for duplicate names, barcodes
|
||||
|
||||
# TODO: add more validation for barcode
|
||||
|
||||
product = Product(barcode, name)
|
||||
sql_session.add(product)
|
||||
sql_session.commit()
|
||||
|
||||
return product
|
||||
21
dibbler/queries/create_user.py
Normal file
21
dibbler/queries/create_user.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import User
|
||||
|
||||
|
||||
def create_user(
|
||||
sql_session: Session,
|
||||
name: str,
|
||||
card: str | None,
|
||||
rfid: str | None,
|
||||
) -> User:
|
||||
if not name:
|
||||
raise ValueError("Name cannot be empty.")
|
||||
|
||||
# TODO: check for duplicate names, cards, rfids
|
||||
|
||||
user = User(name=name, card=card, rfid=rfid)
|
||||
sql_session.add(user)
|
||||
sql_session.commit()
|
||||
|
||||
return user
|
||||
@@ -1,24 +1,55 @@
|
||||
from sqlalchemy import select
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BindParameter, bindparam, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Transaction, TransactionType
|
||||
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
|
||||
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT
|
||||
from dibbler.queries.query_helpers import until_filter
|
||||
|
||||
|
||||
# TODO: add until transaction parameter
|
||||
# TODO: add until datetime parameter
|
||||
def current_interest(
|
||||
sql_session: Session,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: BindParameter[Transaction] | Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
Get the current interest rate percentage as of a given time or transaction.
|
||||
|
||||
Returns the interest rate percentage as an integer.
|
||||
"""
|
||||
|
||||
if not (until_time is None or until_transaction is None):
|
||||
raise ValueError("Cannot filter by both until_time and until_transaction.")
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
def current_interest(sql_session: Session) -> int:
|
||||
result = sql_session.scalars(
|
||||
select(Transaction)
|
||||
.where(Transaction.type_ == TransactionType.ADJUST_INTEREST)
|
||||
.where(
|
||||
Transaction.type_ == TransactionType.ADJUST_INTEREST,
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
)
|
||||
.order_by(Transaction.time.desc())
|
||||
.limit(1)
|
||||
.limit(1),
|
||||
).one_or_none()
|
||||
|
||||
if result is None:
|
||||
return DEFAULT_INTEREST_RATE_PERCENTAGE
|
||||
return DEFAULT_INTEREST_RATE_PERCENT
|
||||
elif result.interest_rate_percent is None:
|
||||
return DEFAULT_INTEREST_RATE_PERCENTAGE
|
||||
return DEFAULT_INTEREST_RATE_PERCENT
|
||||
else:
|
||||
return result.interest_rate_percent
|
||||
|
||||
@@ -1,26 +1,57 @@
|
||||
from sqlalchemy import select
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BindParameter, bindparam, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Transaction, TransactionType
|
||||
from dibbler.models.Transaction import (
|
||||
DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE,
|
||||
DEFAULT_PENALTY_MULTIPLIER_PERCENT,
|
||||
DEFAULT_PENALTY_THRESHOLD,
|
||||
)
|
||||
from dibbler.queries.query_helpers import until_filter
|
||||
|
||||
|
||||
# TODO: add until transaction parameter
|
||||
# TODO: add until datetime parameter
|
||||
def current_penalty(
|
||||
sql_session: Session,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: BindParameter[Transaction] | Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get the current penalty settings (threshold and multiplier percentage) as of a given time or transaction.
|
||||
|
||||
Returns a tuple of `(penalty_threshold, penalty_multiplier_percentage)`.
|
||||
"""
|
||||
|
||||
if not (until_time is None or until_transaction is None):
|
||||
raise ValueError("Cannot filter by both until_time and until_transaction.")
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
def current_penalty(sql_session: Session) -> tuple[int, int]:
|
||||
result = sql_session.scalars(
|
||||
select(Transaction)
|
||||
.where(Transaction.type_ == TransactionType.ADJUST_PENALTY)
|
||||
.where(
|
||||
Transaction.type_ == TransactionType.ADJUST_PENALTY,
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
)
|
||||
.order_by(Transaction.time.desc())
|
||||
.limit(1)
|
||||
.limit(1),
|
||||
).one_or_none()
|
||||
|
||||
if result is None:
|
||||
return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE
|
||||
return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENT
|
||||
|
||||
assert result.penalty_threshold is not None, "Penalty threshold must be set"
|
||||
assert result.penalty_multiplier_percent is not None, "Penalty multiplier percent must be set"
|
||||
|
||||
@@ -17,7 +17,7 @@ def joint_buy_product(
|
||||
users: list[User],
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
) -> list[Transaction]:
|
||||
"""
|
||||
Create buy product transactions for multiple users at once.
|
||||
"""
|
||||
@@ -25,15 +25,23 @@ def joint_buy_product(
|
||||
if product.id is None:
|
||||
raise ValueError("Product must be persisted in the database.")
|
||||
|
||||
if instigator not in users:
|
||||
raise ValueError("Instigator must be in the list of users buying the product.")
|
||||
if instigator.id is None:
|
||||
raise ValueError("Instigator must be persisted in the database.")
|
||||
|
||||
if len(users) == 0:
|
||||
raise ValueError("At least bying one user must be specified.")
|
||||
|
||||
if any(user.id is None for user in users):
|
||||
raise ValueError("All users must be persisted in the database.")
|
||||
|
||||
if instigator not in users:
|
||||
raise ValueError("Instigator must be in the list of users buying the product.")
|
||||
|
||||
if product_count <= 0:
|
||||
raise ValueError("Product count must be positive.")
|
||||
|
||||
# TODO: verify time is not behind last transaction's time
|
||||
|
||||
joint_transaction = Transaction.joint(
|
||||
user_id=instigator.id,
|
||||
product_id=product.id,
|
||||
@@ -44,6 +52,8 @@ def joint_buy_product(
|
||||
sql_session.add(joint_transaction)
|
||||
sql_session.flush() # Ensure joint_transaction gets an ID
|
||||
|
||||
transactions = [joint_transaction]
|
||||
|
||||
for user in users:
|
||||
buy_transaction = Transaction.joint_buy_product(
|
||||
user_id=user.id,
|
||||
@@ -52,5 +62,7 @@ def joint_buy_product(
|
||||
message=message,
|
||||
)
|
||||
sql_session.add(buy_transaction)
|
||||
transactions.append(buy_transaction)
|
||||
|
||||
sql_session.commit()
|
||||
return transactions
|
||||
|
||||
@@ -8,11 +8,11 @@ from sqlalchemy import (
|
||||
bindparam,
|
||||
case,
|
||||
func,
|
||||
or_,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO
|
||||
from dibbler.models import (
|
||||
Product,
|
||||
Transaction,
|
||||
@@ -20,13 +20,22 @@ from dibbler.models import (
|
||||
User,
|
||||
)
|
||||
from dibbler.queries.product_stock import _product_stock_query
|
||||
from dibbler.queries.query_helpers import (
|
||||
CONST_NONE,
|
||||
CONST_ONE,
|
||||
CONST_ZERO,
|
||||
until_filter,
|
||||
)
|
||||
|
||||
|
||||
def _product_owners_query(
|
||||
product_id: BindParameter[int] | int,
|
||||
use_cache: bool = True,
|
||||
until: BindParameter[datetime] | datetime | None = None,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
cte_name: str = "rec_cte",
|
||||
trx_subset_name: str = "trx_subset",
|
||||
) -> CTE:
|
||||
"""
|
||||
The inner query for inferring the owners of a given product.
|
||||
@@ -38,13 +47,25 @@ def _product_owners_query(
|
||||
if isinstance(product_id, int):
|
||||
product_id = bindparam("product_id", value=product_id)
|
||||
|
||||
if isinstance(until, datetime):
|
||||
until = BindParameter("until", value=until)
|
||||
if until_time is not None and until_transaction is not None:
|
||||
raise ValueError("Cannot filter by both until_time and until_transaction.")
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = bindparam("until_time", value=until_time)
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
product_stock = _product_stock_query(
|
||||
product_id=product_id,
|
||||
use_cache=use_cache,
|
||||
until=until,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
)
|
||||
|
||||
# Subset of transactions that we'll want to iterate over.
|
||||
@@ -57,22 +78,23 @@ def _product_owners_query(
|
||||
Transaction.user_id,
|
||||
Transaction.product_count,
|
||||
)
|
||||
# TODO: maybe add value constraint on ADJUST_STOCK?
|
||||
.where(
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
# TransactionType.BUY_PRODUCT,
|
||||
TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
# TransactionType.JOINT,
|
||||
# TransactionType.THROW_PRODUCT,
|
||||
]
|
||||
or_(
|
||||
Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
and_(
|
||||
Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
Transaction.product_count > CONST_ZERO,
|
||||
),
|
||||
),
|
||||
Transaction.product_id == product_id,
|
||||
CONST_TRUE if until is None else Transaction.time <= until,
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
)
|
||||
.order_by(Transaction.time.desc())
|
||||
.subquery()
|
||||
.subquery(trx_subset_name)
|
||||
)
|
||||
|
||||
initial_element = select(
|
||||
@@ -117,35 +139,19 @@ def _product_owners_query(
|
||||
).label("product_count"),
|
||||
# How many products left to account for
|
||||
case(
|
||||
# Someone adds the product -> increase the number of products left to account for
|
||||
# Someone adds the product -> known owner, decrease the number of products left to account for
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
|
||||
),
|
||||
# Someone buys/joins/throws the product -> decrease the number of products left to account for
|
||||
# (
|
||||
# trx_subset.c.type_.in_(
|
||||
# [
|
||||
# TransactionType.BUY_PRODUCT,
|
||||
# TransactionType.JOINT,
|
||||
# TransactionType.THROW_PRODUCT,
|
||||
# ]
|
||||
# ),
|
||||
# recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
|
||||
# ),
|
||||
# Someone adjusts the stock ->
|
||||
# If adjusted upwards -> products owned by nobody, decrease products left to account for
|
||||
# If adjusted downwards -> products taken away from owners, decrease products left to account for
|
||||
# Stock got adjusted upwards -> none owner, decrease the number of products left to account for
|
||||
(
|
||||
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
|
||||
and (trx_subset.c.product_count > CONST_ZERO),
|
||||
and_(
|
||||
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
trx_subset.c.product_count > CONST_ZERO,
|
||||
),
|
||||
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
|
||||
),
|
||||
# (
|
||||
# (trx_subset.c.type_ == TransactionType.ADJUST_STOCK)
|
||||
# and (trx_subset.c.product_count < 0),
|
||||
# recursive_cte.c.products_left_to_account_for + trx_subset.c.product_count,
|
||||
# ),
|
||||
else_=recursive_cte.c.products_left_to_account_for,
|
||||
).label("products_left_to_account_for"),
|
||||
)
|
||||
@@ -153,8 +159,9 @@ def _product_owners_query(
|
||||
.where(
|
||||
and_(
|
||||
trx_subset.c.i == recursive_cte.c.i + CONST_ONE,
|
||||
# Base case: stop if we've accounted for all products
|
||||
recursive_cte.c.products_left_to_account_for > CONST_ZERO,
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -167,13 +174,14 @@ class ProductOwnersLogEntry:
|
||||
user: User | None
|
||||
products_left_to_account_for: int
|
||||
|
||||
# TODO: add until datetime parameter
|
||||
|
||||
def product_owners_log(
|
||||
sql_session: Session,
|
||||
product: Product,
|
||||
use_cache: bool = True,
|
||||
until: Transaction | None = None,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> list[ProductOwnersLogEntry]:
|
||||
"""
|
||||
Returns a log of the product ownership calculation for the given product.
|
||||
@@ -184,13 +192,12 @@ def product_owners_log(
|
||||
if product.id is None:
|
||||
raise ValueError("Product must be persisted in the database.")
|
||||
|
||||
if until is not None and until.id is None:
|
||||
raise ValueError("'until' transaction must be persisted in the database.")
|
||||
|
||||
recursive_cte = _product_owners_query(
|
||||
product_id=product.id,
|
||||
use_cache=use_cache,
|
||||
until=until.time if until else None,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
)
|
||||
|
||||
result = sql_session.execute(
|
||||
@@ -209,13 +216,13 @@ def product_owners_log(
|
||||
onclause=User.id == recursive_cte.c.user_id,
|
||||
isouter=True,
|
||||
)
|
||||
.order_by(recursive_cte.c.time.desc())
|
||||
.order_by(recursive_cte.c.time.desc()),
|
||||
).all()
|
||||
|
||||
if result is None:
|
||||
# If there are no transactions for this product, the query should return an empty list, not None.
|
||||
raise RuntimeError(
|
||||
f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id})."
|
||||
f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id}).",
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -228,13 +235,13 @@ def product_owners_log(
|
||||
]
|
||||
|
||||
|
||||
# TODO: add until transaction parameter
|
||||
|
||||
def product_owners(
|
||||
sql_session: Session,
|
||||
product: Product,
|
||||
use_cache: bool = True,
|
||||
until: datetime | None = None,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> list[User | None]:
|
||||
"""
|
||||
Returns an ordered list of users owning the given product.
|
||||
@@ -248,7 +255,9 @@ def product_owners(
|
||||
recursive_cte = _product_owners_query(
|
||||
product_id=product.id,
|
||||
use_cache=use_cache,
|
||||
until=until,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
)
|
||||
|
||||
db_result = sql_session.execute(
|
||||
@@ -258,7 +267,7 @@ def product_owners(
|
||||
User,
|
||||
)
|
||||
.join(User, User.id == recursive_cte.c.user_id, isouter=True)
|
||||
.order_by(recursive_cte.c.time.desc())
|
||||
.order_by(recursive_cte.c.time.desc()),
|
||||
).all()
|
||||
|
||||
print(db_result)
|
||||
@@ -292,7 +301,7 @@ def product_owners(
|
||||
|
||||
result.extend([None] * none_count)
|
||||
|
||||
# # NOTE: if the last line exeeds the product count, we need to truncate it
|
||||
# # NOTE: if the last line exceeds the product count, we need to truncate it
|
||||
# result.extend([user] * min(user_count, products_left_to_account_for))
|
||||
|
||||
# redistribute the user counts to a list of users
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy import (
|
||||
BindParameter,
|
||||
ColumnElement,
|
||||
Integer,
|
||||
asc,
|
||||
bindparam,
|
||||
case,
|
||||
cast,
|
||||
func,
|
||||
@@ -14,52 +14,111 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO
|
||||
from dibbler.models import (
|
||||
LastCacheTransaction,
|
||||
Product,
|
||||
ProductCache,
|
||||
Transaction,
|
||||
TransactionType,
|
||||
)
|
||||
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
|
||||
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT
|
||||
from dibbler.queries.query_helpers import (
|
||||
CONST_NONE,
|
||||
CONST_ONE,
|
||||
CONST_ZERO,
|
||||
after_filter,
|
||||
until_filter,
|
||||
)
|
||||
|
||||
|
||||
def _product_price_query(
|
||||
product_id: int | ColumnElement[int],
|
||||
use_cache: bool = True,
|
||||
until: BindParameter[datetime] | datetime | None = None,
|
||||
until_including: BindParameter[bool] | bool = True,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
cte_name: str = "rec_cte",
|
||||
trx_subset_name: str = "trx_subset",
|
||||
):
|
||||
"""
|
||||
The inner query for calculating the product price.
|
||||
"""
|
||||
|
||||
if use_cache:
|
||||
print("WARNING: Using cache for product price query is not implemented yet.")
|
||||
|
||||
if isinstance(product_id, int):
|
||||
product_id = BindParameter("product_id", value=product_id)
|
||||
|
||||
if isinstance(until, datetime):
|
||||
until = BindParameter("until", value=until)
|
||||
if not (until_time is None or until_transaction is None):
|
||||
raise ValueError("Cannot filter by both until_time and until_transaction.")
|
||||
|
||||
if isinstance(until_including, bool):
|
||||
until_including = BindParameter("until_including", value=until_including)
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
initial_element = select(
|
||||
CONST_ZERO.label("i"),
|
||||
CONST_ZERO.label("time"),
|
||||
CONST_NONE.label("transaction_id"),
|
||||
CONST_ZERO.label("price"),
|
||||
CONST_ZERO.label("product_count"),
|
||||
)
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
if use_cache:
|
||||
initial_element_fields = (
|
||||
select(
|
||||
Transaction.time.label("time"),
|
||||
Transaction.id.label("transaction_id"),
|
||||
ProductCache.price.label("price"),
|
||||
ProductCache.stock.label("product_count"),
|
||||
)
|
||||
.select_from(ProductCache)
|
||||
.join(
|
||||
LastCacheTransaction,
|
||||
ProductCache.last_cache_transaction_id == LastCacheTransaction.id,
|
||||
)
|
||||
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
|
||||
.where(
|
||||
ProductCache.product_id == product_id,
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
)
|
||||
.union(
|
||||
select(
|
||||
CONST_ZERO.label("time"),
|
||||
CONST_NONE.label("transaction_id"),
|
||||
CONST_ZERO.label("price"),
|
||||
CONST_ZERO.label("product_count"),
|
||||
),
|
||||
)
|
||||
.order_by(Transaction.time.desc())
|
||||
.limit(CONST_ONE)
|
||||
.offset(CONST_ZERO)
|
||||
.subquery()
|
||||
.alias("initial_element_fields")
|
||||
)
|
||||
|
||||
initial_element = select(
|
||||
CONST_ZERO.label("i"),
|
||||
initial_element_fields.c.time,
|
||||
initial_element_fields.c.transaction_id,
|
||||
initial_element_fields.c.price,
|
||||
initial_element_fields.c.product_count,
|
||||
).select_from(initial_element_fields)
|
||||
else:
|
||||
initial_element = select(
|
||||
CONST_ZERO.label("i"),
|
||||
CONST_ZERO.label("time"),
|
||||
CONST_NONE.label("transaction_id"),
|
||||
CONST_ZERO.label("price"),
|
||||
CONST_ZERO.label("product_count"),
|
||||
)
|
||||
|
||||
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
|
||||
|
||||
# Subset of transactions that we'll want to iterate over.
|
||||
trx_subset = (
|
||||
select(
|
||||
func.row_number().over(order_by=asc(Transaction.time)).label("i"),
|
||||
func.row_number().over(order_by=Transaction.time.asc()).label("i"),
|
||||
Transaction.id,
|
||||
Transaction.time,
|
||||
Transaction.type_,
|
||||
@@ -73,18 +132,22 @@ def _product_price_query(
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
TransactionType.JOINT.as_literal_column(),
|
||||
]
|
||||
],
|
||||
),
|
||||
Transaction.product_id == product_id,
|
||||
case(
|
||||
(until_including, Transaction.time <= until),
|
||||
else_=Transaction.time < until,
|
||||
)
|
||||
if until is not None
|
||||
else CONST_TRUE,
|
||||
after_filter(
|
||||
after_time=None,
|
||||
after_transaction_id=recursive_cte.c.transaction_id,
|
||||
after_inclusive=False,
|
||||
),
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
)
|
||||
.order_by(Transaction.time.asc())
|
||||
.alias("trx_subset")
|
||||
.subquery(trx_subset_name)
|
||||
)
|
||||
|
||||
recursive_elements = (
|
||||
@@ -115,7 +178,7 @@ def _product_price_query(
|
||||
# and other disastrous phenomena.
|
||||
func.max(recursive_cte.c.product_count, CONST_ZERO)
|
||||
+ trx_subset.c.product_count
|
||||
)
|
||||
),
|
||||
),
|
||||
Integer,
|
||||
),
|
||||
@@ -170,13 +233,13 @@ class ProductPriceLogEntry:
|
||||
product_count: int
|
||||
|
||||
|
||||
# TODO: add until datetime parameter
|
||||
|
||||
def product_price_log(
|
||||
sql_session: Session,
|
||||
product: Product,
|
||||
use_cache: bool = True,
|
||||
until: Transaction | None = None,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> list[ProductPriceLogEntry]:
|
||||
"""
|
||||
Calculates the price of a product and returns a log of the price changes.
|
||||
@@ -185,13 +248,12 @@ def product_price_log(
|
||||
if product.id is None:
|
||||
raise ValueError("Product must be persisted in the database.")
|
||||
|
||||
if until is not None and until.id is None:
|
||||
raise ValueError("'until' transaction must be persisted in the database.")
|
||||
|
||||
recursive_cte = _product_price_query(
|
||||
product.id,
|
||||
use_cache=use_cache,
|
||||
until=until.time if until else None,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
)
|
||||
|
||||
result = sql_session.execute(
|
||||
@@ -205,13 +267,13 @@ def product_price_log(
|
||||
Transaction,
|
||||
onclause=Transaction.id == recursive_cte.c.transaction_id,
|
||||
)
|
||||
.order_by(recursive_cte.c.i.asc())
|
||||
.order_by(recursive_cte.c.i.asc()),
|
||||
).all()
|
||||
|
||||
if result is None:
|
||||
# If there are no transactions for this product, the query should return an empty list, not None.
|
||||
raise RuntimeError(
|
||||
f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})."
|
||||
f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id}).",
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -224,13 +286,13 @@ def product_price_log(
|
||||
]
|
||||
|
||||
|
||||
# TODO: add until datetime parameter
|
||||
|
||||
def product_price(
|
||||
sql_session: Session,
|
||||
product: Product,
|
||||
use_cache: bool = True,
|
||||
until: Transaction | None = None,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
include_interest: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -240,13 +302,22 @@ def product_price(
|
||||
if product.id is None:
|
||||
raise ValueError("Product must be persisted in the database.")
|
||||
|
||||
if until is not None and until.id is None:
|
||||
raise ValueError("'until' transaction must be persisted in the database.")
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
recursive_cte = _product_price_query(
|
||||
product.id,
|
||||
use_cache=use_cache,
|
||||
until=until.time if until else None,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
)
|
||||
|
||||
# TODO: optionally verify subresults:
|
||||
@@ -257,13 +328,13 @@ def product_price(
|
||||
select(recursive_cte.c.price)
|
||||
.order_by(recursive_cte.c.i.desc())
|
||||
.limit(CONST_ONE)
|
||||
.offset(CONST_ZERO)
|
||||
.offset(CONST_ZERO),
|
||||
).one_or_none()
|
||||
|
||||
if result is None:
|
||||
# If there are no transactions for this product, the query should return 0, not None.
|
||||
raise RuntimeError(
|
||||
f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})."
|
||||
f"Something went wrong while calculating the price for product {product.name} (ID: {product.id}).",
|
||||
)
|
||||
|
||||
if include_interest:
|
||||
@@ -272,12 +343,16 @@ def product_price(
|
||||
select(Transaction.interest_rate_percent)
|
||||
.where(
|
||||
Transaction.type_ == TransactionType.ADJUST_INTEREST,
|
||||
CONST_TRUE if until is None else Transaction.time <= until.time,
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
)
|
||||
.order_by(Transaction.time.desc())
|
||||
.limit(CONST_ONE)
|
||||
.limit(CONST_ONE),
|
||||
)
|
||||
or DEFAULT_INTEREST_RATE_PERCENTAGE
|
||||
or DEFAULT_INTEREST_RATE_PERCENT
|
||||
)
|
||||
result = math.ceil(result * interest_rate / 100)
|
||||
|
||||
|
||||
@@ -1,27 +1,31 @@
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
|
||||
from sqlalchemy import (
|
||||
BindParameter,
|
||||
Select,
|
||||
bindparam,
|
||||
case,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.lib.query_helpers import CONST_TRUE
|
||||
from dibbler.models import (
|
||||
Product,
|
||||
Transaction,
|
||||
TransactionType,
|
||||
)
|
||||
from dibbler.queries.query_helpers import until_filter
|
||||
|
||||
|
||||
def _product_stock_query(
|
||||
product_id: BindParameter[int] | int,
|
||||
use_cache: bool = True,
|
||||
until: BindParameter[datetime] | datetime | None = None,
|
||||
) -> Select:
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> Select[tuple[int]]:
|
||||
"""
|
||||
The inner query for calculating the product stock.
|
||||
"""
|
||||
@@ -32,8 +36,18 @@ def _product_stock_query(
|
||||
if isinstance(product_id, int):
|
||||
product_id = BindParameter("product_id", value=product_id)
|
||||
|
||||
if isinstance(until, datetime):
|
||||
until = BindParameter("until", value=until)
|
||||
if not (until_time is None or until_transaction is None):
|
||||
raise ValueError("Cannot filter by both until_time and until_transaction.")
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
query = select(
|
||||
func.sum(
|
||||
@@ -59,8 +73,8 @@ def _product_stock_query(
|
||||
-Transaction.product_count,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
)
|
||||
),
|
||||
).label("stock"),
|
||||
).where(
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
@@ -69,22 +83,26 @@ def _product_stock_query(
|
||||
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
TransactionType.JOINT.as_literal_column(),
|
||||
TransactionType.THROW_PRODUCT.as_literal_column(),
|
||||
]
|
||||
],
|
||||
),
|
||||
Transaction.product_id == product_id,
|
||||
Transaction.time <= until if until is not None else CONST_TRUE,
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
|
||||
# TODO: add until transaction parameter
|
||||
|
||||
def product_stock(
|
||||
sql_session: Session,
|
||||
product: Product,
|
||||
use_cache: bool = True,
|
||||
until: datetime | None = None,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
Returns the number of products in stock.
|
||||
@@ -98,7 +116,9 @@ def product_stock(
|
||||
query = _product_stock_query(
|
||||
product_id=product.id,
|
||||
use_cache=use_cache,
|
||||
until=until,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
)
|
||||
|
||||
result = sql_session.scalars(query).one_or_none()
|
||||
|
||||
119
dibbler/queries/query_helpers.py
Normal file
119
dibbler/queries/query_helpers.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from datetime import datetime
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy import (
|
||||
BindParameter,
|
||||
ColumnExpressionArgument,
|
||||
literal,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import QueryableAttribute
|
||||
|
||||
from dibbler.models import Transaction
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def const(value: T) -> BindParameter[T]:
|
||||
"""
|
||||
Create a constant SQL literal bind parameter.
|
||||
|
||||
This is useful to avoid too many `?` bind parameters in SQL queries,
|
||||
when the input value is known to be safe.
|
||||
"""
|
||||
|
||||
return literal(value, literal_execute=True)
|
||||
|
||||
|
||||
CONST_ZERO: BindParameter[int] = const(0)
|
||||
"""A constant SQL expression `0`. This will render as a literal `0` in SQL queries."""
|
||||
|
||||
CONST_ONE: BindParameter[int] = const(1)
|
||||
"""A constant SQL expression `1`. This will render as a literal `1` in SQL queries."""
|
||||
|
||||
CONST_TRUE: BindParameter[bool] = const(True)
|
||||
"""A constant SQL expression `TRUE`. This will render as a literal `TRUE` in SQL queries."""
|
||||
|
||||
CONST_FALSE: BindParameter[bool] = const(False)
|
||||
"""A constant SQL expression `FALSE`. This will render as a literal `FALSE` in SQL queries."""
|
||||
|
||||
CONST_NONE: BindParameter[None] = const(None)
|
||||
"""A constant SQL expression `NULL`. This will render as a literal `NULL` in SQL queries."""
|
||||
|
||||
|
||||
def until_filter(
|
||||
until_time: BindParameter[datetime] | None = None,
|
||||
until_transaction_id: BindParameter[int] | None = None,
|
||||
until_inclusive: bool = True,
|
||||
transaction_time: QueryableAttribute = Transaction.time,
|
||||
) -> ColumnExpressionArgument[bool]:
|
||||
"""
|
||||
Create a filter condition for transactions up to a given time or transaction.
|
||||
|
||||
Only one of `until_time` or `until_transaction_id` may be specified.
|
||||
"""
|
||||
|
||||
assert not (until_time is not None and until_transaction_id is not None), (
|
||||
"Cannot filter by both until_time and until_transaction_id."
|
||||
)
|
||||
|
||||
match (until_time, until_transaction_id, until_inclusive):
|
||||
case (BindParameter(), None, True):
|
||||
return transaction_time <= until_time
|
||||
case (BindParameter(), None, False):
|
||||
return transaction_time < until_time
|
||||
case (None, BindParameter(), True):
|
||||
return (
|
||||
transaction_time
|
||||
<= select(Transaction.time)
|
||||
.where(Transaction.id == until_transaction_id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
case (None, BindParameter(), False):
|
||||
return (
|
||||
transaction_time
|
||||
< select(Transaction.time)
|
||||
.where(Transaction.id == until_transaction_id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
|
||||
return CONST_TRUE
|
||||
|
||||
|
||||
def after_filter(
|
||||
after_time: BindParameter[datetime] | None = None,
|
||||
after_transaction_id: BindParameter[int] | None = None,
|
||||
after_inclusive: bool = True,
|
||||
transaction_time: QueryableAttribute = Transaction.time,
|
||||
) -> ColumnExpressionArgument[bool]:
|
||||
"""
|
||||
Create a filter condition for transactions after a given time or transaction.
|
||||
|
||||
Only one of `after_time` or `after_transaction_id` may be specified.
|
||||
"""
|
||||
|
||||
assert not (after_time is not None and after_transaction_id is not None), (
|
||||
"Cannot filter by both after_time and after_transaction_id."
|
||||
)
|
||||
|
||||
match (after_time, after_transaction_id, after_inclusive):
|
||||
case (BindParameter(), None, True):
|
||||
return transaction_time >= after_time
|
||||
case (BindParameter(), None, False):
|
||||
return transaction_time > after_time
|
||||
case (None, BindParameter(), True):
|
||||
return (
|
||||
transaction_time
|
||||
>= select(Transaction.time)
|
||||
.where(Transaction.id == after_transaction_id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
case (None, BindParameter(), False):
|
||||
return (
|
||||
transaction_time
|
||||
> select(Transaction.time)
|
||||
.where(Transaction.id == after_transaction_id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
|
||||
return CONST_TRUE
|
||||
@@ -9,6 +9,9 @@ def search_product(
|
||||
sql_session: Session,
|
||||
find_hidden_products=False,
|
||||
) -> Product | list[Product]:
|
||||
if not string:
|
||||
raise ValueError("Search string cannot be empty.")
|
||||
|
||||
exact_match = sql_session.scalars(
|
||||
select(Product).where(
|
||||
or_(
|
||||
@@ -17,8 +20,8 @@ def search_product(
|
||||
Product.name == string,
|
||||
literal(True) if find_hidden_products else not_(Product.hidden),
|
||||
),
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
).first()
|
||||
|
||||
if exact_match:
|
||||
@@ -32,8 +35,8 @@ def search_product(
|
||||
Product.name.ilike(f"%{string}%"),
|
||||
literal(True) if find_hidden_products else not_(Product.hidden),
|
||||
),
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
).all()
|
||||
|
||||
return list(product_list)
|
||||
|
||||
@@ -8,6 +8,9 @@ def search_user(
|
||||
string: str,
|
||||
sql_session: Session,
|
||||
) -> User | list[User]:
|
||||
if not string:
|
||||
raise ValueError("Search string cannot be empty.")
|
||||
|
||||
string = string.lower()
|
||||
|
||||
exact_match = sql_session.scalars(
|
||||
@@ -16,8 +19,8 @@ def search_user(
|
||||
User.name == string,
|
||||
User.card == string,
|
||||
User.rfid == string,
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
).first()
|
||||
|
||||
if exact_match:
|
||||
@@ -29,8 +32,8 @@ def search_user(
|
||||
User.name.ilike(f"%{string}%"),
|
||||
User.card.ilike(f"%{string}%"),
|
||||
User.rfid.ilike(f"%{string}%"),
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
).all()
|
||||
|
||||
return list(user_list)
|
||||
|
||||
42
dibbler/queries/throw_product.py
Normal file
42
dibbler/queries/throw_product.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, Transaction, User
|
||||
|
||||
|
||||
def throw_product(
|
||||
sql_session: Session,
|
||||
user: User,
|
||||
product: Product,
|
||||
product_count: int,
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Transaction:
|
||||
if user.id is None:
|
||||
raise ValueError("User must be persisted in the database.")
|
||||
|
||||
if product.id is None:
|
||||
raise ValueError("Product must be persisted in the database.")
|
||||
|
||||
if product_count <= 0:
|
||||
raise ValueError("Product count must be positive.")
|
||||
|
||||
# TODO: verify time is not behind last transaction's time
|
||||
|
||||
raise NotImplementedError(
|
||||
"Please don't use this function until relevant calculations have been added to user_balance.",
|
||||
)
|
||||
|
||||
transaction = Transaction.throw_product(
|
||||
user_id=user.id,
|
||||
product_id=product.id,
|
||||
product_count=product_count,
|
||||
time=time,
|
||||
message=message,
|
||||
)
|
||||
|
||||
sql_session.add(transaction)
|
||||
sql_session.commit()
|
||||
|
||||
return transaction
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import select
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BindParameter, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import (
|
||||
@@ -15,12 +17,12 @@ def transaction_log(
|
||||
sql_session: Session,
|
||||
user: User | None = None,
|
||||
product: Product | None = None,
|
||||
exclusive_after: bool = False,
|
||||
after_time=None,
|
||||
after_transaction_id: int | None = None,
|
||||
exclusive_before: bool = False,
|
||||
before_time=None,
|
||||
before_transaction_id: int | None = None,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
after_time: BindParameter[datetime] | datetime | None = None,
|
||||
after_transaction: Transaction | None = None,
|
||||
after_inclusive: bool = True,
|
||||
transaction_type: list[TransactionType] | None = None,
|
||||
negate_transaction_type_filter: bool = False,
|
||||
limit: int | None = None,
|
||||
@@ -29,51 +31,101 @@ def transaction_log(
|
||||
Retrieve the transaction log, optionally filtered.
|
||||
|
||||
Only one of `user` or `product` may be specified.
|
||||
Only one of `until_time` or `until_transaction_id` may be specified.
|
||||
Only one of `after_time` or `after_transaction_id` may be specified.
|
||||
Only one of `before_time` or `before_transaction_id` may be specified.
|
||||
|
||||
The before and after filters are inclusive by default.
|
||||
The after and after filters are inclusive by default.
|
||||
"""
|
||||
|
||||
if not (user is None or product is None):
|
||||
raise ValueError("Cannot filter by both user and product.")
|
||||
|
||||
if user is not None and user.id is None:
|
||||
raise ValueError("User must be persisted in the database.")
|
||||
if isinstance(user, User):
|
||||
if user.id is None:
|
||||
raise ValueError("User must be persisted in the database.")
|
||||
user_id = BindParameter("user_id", value=user.id)
|
||||
else:
|
||||
user_id = None
|
||||
|
||||
if product is not None and product.id is None:
|
||||
raise ValueError("Product must be persisted in the database.")
|
||||
if isinstance(product, Product):
|
||||
if product.id is None:
|
||||
raise ValueError("Product must be persisted in the database.")
|
||||
product_id = BindParameter("product_id", value=product.id)
|
||||
else:
|
||||
product_id = None
|
||||
|
||||
if not (after_time is None or after_transaction_id is None):
|
||||
raise ValueError("Cannot filter by both from_time and from_transaction_id.")
|
||||
if not (until_time is None or until_transaction is None):
|
||||
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
if not (after_time is None or after_transaction is None):
|
||||
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
|
||||
|
||||
if isinstance(after_time, datetime):
|
||||
after_time = BindParameter("after_time", value=after_time)
|
||||
|
||||
if isinstance(after_transaction, Transaction):
|
||||
if after_transaction.id is None:
|
||||
raise ValueError("after_transaction must be persisted in the database.")
|
||||
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
|
||||
else:
|
||||
after_transaction_id = None
|
||||
|
||||
if after_time is not None and until_time is not None:
|
||||
assert isinstance(after_time.value, datetime)
|
||||
assert isinstance(until_time.value, datetime)
|
||||
|
||||
if after_time.value > until_time.value:
|
||||
raise ValueError("after_time cannot be after until_time.")
|
||||
|
||||
if after_transaction is not None and until_transaction is not None:
|
||||
assert after_transaction.time is not None
|
||||
assert until_transaction.time is not None
|
||||
|
||||
if after_transaction.time > until_transaction.time:
|
||||
raise ValueError("after_transaction cannot be after until_transaction.")
|
||||
|
||||
if limit is not None and limit <= 0:
|
||||
raise ValueError("Limit must be positive.")
|
||||
|
||||
query = select(Transaction)
|
||||
if user is not None:
|
||||
query = query.where(Transaction.user_id == user.id)
|
||||
query = query.where(Transaction.user_id == user_id)
|
||||
if product is not None:
|
||||
query = query.where(Transaction.product_id == product.id)
|
||||
query = query.where(Transaction.product_id == product_id)
|
||||
|
||||
if after_time is not None:
|
||||
if exclusive_after:
|
||||
query = query.where(Transaction.time > after_time)
|
||||
else:
|
||||
match (until_time, until_transaction_id, until_inclusive):
|
||||
case (BindParameter(), None, True):
|
||||
query = query.where(Transaction.time <= until_time)
|
||||
case (BindParameter(), None, False):
|
||||
query = query.where(Transaction.time < until_time)
|
||||
case (None, BindParameter(), True):
|
||||
query = query.where(Transaction.id <= until_transaction_id)
|
||||
case (None, BindParameter(), False):
|
||||
query = query.where(Transaction.id < until_transaction_id)
|
||||
case _:
|
||||
pass
|
||||
|
||||
match (after_time, after_transaction_id, after_inclusive):
|
||||
case (BindParameter(), None, True):
|
||||
query = query.where(Transaction.time >= after_time)
|
||||
if after_transaction_id is not None:
|
||||
if exclusive_after:
|
||||
query = query.where(Transaction.id > after_transaction_id)
|
||||
else:
|
||||
case (BindParameter(), None, False):
|
||||
query = query.where(Transaction.time > after_time)
|
||||
case (None, BindParameter(), True):
|
||||
query = query.where(Transaction.id >= after_transaction_id)
|
||||
|
||||
if before_time is not None:
|
||||
if exclusive_before:
|
||||
query = query.where(Transaction.time < before_time)
|
||||
else:
|
||||
query = query.where(Transaction.time <= before_time)
|
||||
if before_transaction_id is not None:
|
||||
if exclusive_before:
|
||||
query = query.where(Transaction.id < before_transaction_id)
|
||||
else:
|
||||
query = query.where(Transaction.id <= before_transaction_id)
|
||||
case (None, BindParameter(), False):
|
||||
query = query.where(Transaction.id > after_transaction_id)
|
||||
case _:
|
||||
pass
|
||||
|
||||
if transaction_type is not None:
|
||||
if negate_transaction_type_filter:
|
||||
|
||||
38
dibbler/queries/transfer.py
Normal file
38
dibbler/queries/transfer.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Transaction, User
|
||||
|
||||
|
||||
def transfer(
|
||||
sql_session: Session,
|
||||
from_user: User,
|
||||
to_user: User,
|
||||
amount: int,
|
||||
time: datetime | None = None,
|
||||
message: str | None = None,
|
||||
) -> Transaction:
|
||||
if from_user.id is None:
|
||||
raise ValueError("From user must be persisted in the database.")
|
||||
|
||||
if to_user.id is None:
|
||||
raise ValueError("To user must be persisted in the database.")
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount must be positive.")
|
||||
|
||||
# TODO: verify time is not behind last transaction's time
|
||||
|
||||
transaction = Transaction.transfer(
|
||||
user_id=from_user.id,
|
||||
transfer_user_id=to_user.id,
|
||||
amount=amount,
|
||||
time=time,
|
||||
message=message,
|
||||
)
|
||||
|
||||
sql_session.add(transaction)
|
||||
sql_session.commit()
|
||||
|
||||
return transaction
|
||||
118
dibbler/queries/update_cache.py
Normal file
118
dibbler/queries/update_cache.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from sqlalchemy import insert, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import LastCacheTransaction, ProductCache, Transaction, UserCache
|
||||
from dibbler.queries.affected_products import affected_products
|
||||
from dibbler.queries.affected_users import affected_users
|
||||
from dibbler.queries.product_price import product_price
|
||||
from dibbler.queries.product_stock import product_stock
|
||||
from dibbler.queries.user_balance import user_balance
|
||||
|
||||
|
||||
def update_cache(
|
||||
sql_session: Session,
|
||||
use_cache: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Update the cache used for searching products.
|
||||
|
||||
If `use_cache` is False, the cache will be rebuilt from scratch.
|
||||
"""
|
||||
|
||||
last_transaction = sql_session.scalars(
|
||||
select(Transaction).order_by(Transaction.time.desc()).limit(1),
|
||||
).one_or_none()
|
||||
|
||||
print(last_transaction)
|
||||
|
||||
if last_transaction is None:
|
||||
# No transactions exist, nothing to update
|
||||
return
|
||||
|
||||
if use_cache:
|
||||
last_cache_transaction = sql_session.scalars(
|
||||
select(LastCacheTransaction)
|
||||
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
|
||||
.order_by(Transaction.time.desc())
|
||||
.limit(1),
|
||||
).one_or_none()
|
||||
if last_cache_transaction is not None:
|
||||
last_cache_transaction = last_cache_transaction.transaction
|
||||
else:
|
||||
last_cache_transaction = None
|
||||
|
||||
if last_cache_transaction is not None and last_cache_transaction.id == last_transaction.id:
|
||||
# Cache is already up to date
|
||||
return
|
||||
|
||||
users = affected_users(
|
||||
sql_session,
|
||||
after_transaction=last_cache_transaction,
|
||||
after_inclusive=False,
|
||||
until_transaction=last_transaction,
|
||||
)
|
||||
products = affected_products(
|
||||
sql_session,
|
||||
after_transaction=last_cache_transaction,
|
||||
after_inclusive=False,
|
||||
until_transaction=last_transaction,
|
||||
)
|
||||
|
||||
user_balances = {}
|
||||
for user in users:
|
||||
x = user_balance(
|
||||
sql_session,
|
||||
user,
|
||||
use_cache=use_cache,
|
||||
until_transaction=last_transaction,
|
||||
)
|
||||
user_balances[user.id] = x
|
||||
|
||||
product_stocks = {}
|
||||
product_prices = {}
|
||||
for product in products:
|
||||
product_stocks[product.id] = product_stock(
|
||||
sql_session,
|
||||
product,
|
||||
use_cache=use_cache,
|
||||
until_transaction=last_transaction,
|
||||
)
|
||||
product_prices[product.id] = product_price(
|
||||
sql_session,
|
||||
product,
|
||||
use_cache=use_cache,
|
||||
until_transaction=last_transaction,
|
||||
)
|
||||
|
||||
next_cache_transaction = LastCacheTransaction(transaction_id=last_transaction.id)
|
||||
sql_session.add(next_cache_transaction)
|
||||
sql_session.flush()
|
||||
|
||||
if not len(users) == 0:
|
||||
sql_session.execute(
|
||||
insert(UserCache),
|
||||
[
|
||||
{
|
||||
"user_id": user.id,
|
||||
"balance": user_balances[user.id],
|
||||
"last_cache_transaction_id": next_cache_transaction.id,
|
||||
}
|
||||
for user in users
|
||||
],
|
||||
)
|
||||
|
||||
if not len(products) == 0:
|
||||
sql_session.execute(
|
||||
insert(ProductCache),
|
||||
[
|
||||
{
|
||||
"product_id": product.id,
|
||||
"stock": product_stocks[product.id],
|
||||
"price": product_prices[product.id],
|
||||
"last_cache_transaction_id": next_cache_transaction.id,
|
||||
}
|
||||
for product in products
|
||||
],
|
||||
)
|
||||
|
||||
sql_session.commit()
|
||||
@@ -1,12 +1,15 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
|
||||
from sqlalchemy import (
|
||||
CTE,
|
||||
BindParameter,
|
||||
Float,
|
||||
Integer,
|
||||
Select,
|
||||
and_,
|
||||
bindparam,
|
||||
case,
|
||||
cast,
|
||||
column,
|
||||
@@ -14,28 +17,243 @@ from sqlalchemy import (
|
||||
or_,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, aliased
|
||||
from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
|
||||
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
|
||||
from dibbler.models import (
|
||||
Transaction,
|
||||
TransactionType,
|
||||
User,
|
||||
)
|
||||
from dibbler.models.Transaction import (
|
||||
DEFAULT_INTEREST_RATE_PERCENTAGE,
|
||||
DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE,
|
||||
DEFAULT_INTEREST_RATE_PERCENT,
|
||||
DEFAULT_PENALTY_MULTIPLIER_PERCENT,
|
||||
DEFAULT_PENALTY_THRESHOLD,
|
||||
)
|
||||
from dibbler.queries.product_price import _product_price_query
|
||||
from dibbler.queries.query_helpers import (
|
||||
CONST_NONE,
|
||||
CONST_ONE,
|
||||
CONST_ZERO,
|
||||
const,
|
||||
until_filter,
|
||||
)
|
||||
|
||||
|
||||
def _joint_transaction_query(
|
||||
user_id: BindParameter[int] | int,
|
||||
use_cache: bool = True,
|
||||
until_time: BindParameter[datetime] | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> Select[tuple[int, int, int]]:
|
||||
"""
|
||||
The inner query for getting joint transactions relevant to a user.
|
||||
|
||||
This scans for JOINT_BUY_PRODUCT transactions made by the user,
|
||||
then finds the corresponding JOINT transactions, and counts how many "shares"
|
||||
of the joint transaction the user has, as well as the total number of shares.
|
||||
"""
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
# First, select all joint buy product transactions for the given user
|
||||
# sub_joint_transaction = aliased(Transaction, name="right_trx")
|
||||
sub_joint_transaction = (
|
||||
select(Transaction.joint_transaction_id.distinct().label("joint_transaction_id"))
|
||||
.where(
|
||||
Transaction.type_ == TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
|
||||
Transaction.user_id == user_id,
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
transaction_time=Transaction.time,
|
||||
),
|
||||
)
|
||||
.subquery("sub_joint_transaction")
|
||||
)
|
||||
|
||||
# Join those with their main joint transaction
|
||||
# (just use Transaction)
|
||||
|
||||
# Then, count how many users are involved in each joint transaction
|
||||
joint_transaction_count = aliased(Transaction, name="count_trx")
|
||||
|
||||
joint_transaction = (
|
||||
select(
|
||||
Transaction.id,
|
||||
# Shares the user has in the transaction,
|
||||
func.sum(
|
||||
case(
|
||||
(joint_transaction_count.user_id == user_id, CONST_ONE),
|
||||
else_=CONST_ZERO,
|
||||
),
|
||||
).label("user_shares"),
|
||||
# The total number of shares in the transaction,
|
||||
func.count(joint_transaction_count.id).label("user_count"),
|
||||
)
|
||||
.select_from(sub_joint_transaction)
|
||||
.join(
|
||||
Transaction,
|
||||
onclause=Transaction.id == sub_joint_transaction.c.joint_transaction_id,
|
||||
)
|
||||
.join(
|
||||
joint_transaction_count,
|
||||
onclause=joint_transaction_count.joint_transaction_id == Transaction.id,
|
||||
)
|
||||
.group_by(joint_transaction_count.joint_transaction_id)
|
||||
)
|
||||
|
||||
return joint_transaction
|
||||
|
||||
|
||||
def _non_joint_transaction_query(
|
||||
user_id: BindParameter[int] | int,
|
||||
use_cache: bool = True,
|
||||
until_time: BindParameter[datetime] | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> Select[tuple[int, None, None]]:
|
||||
"""
|
||||
The inner query for getting non-joint transactions relevant to a user.
|
||||
"""
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
query = select(
|
||||
Transaction.id,
|
||||
CONST_NONE.label("user_shares"),
|
||||
CONST_NONE.label("user_count"),
|
||||
).where(
|
||||
or_(
|
||||
and_(
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_BALANCE.as_literal_column(),
|
||||
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
TransactionType.TRANSFER.as_literal_column(),
|
||||
],
|
||||
),
|
||||
),
|
||||
and_(
|
||||
Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
|
||||
Transaction.transfer_user_id == user_id,
|
||||
),
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.THROW_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_INTEREST.as_literal_column(),
|
||||
TransactionType.ADJUST_PENALTY.as_literal_column(),
|
||||
],
|
||||
),
|
||||
),
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
|
||||
def _product_cost_expression(
|
||||
product_count_column: KeyedColumnElement[int],
|
||||
product_id_column: KeyedColumnElement[int],
|
||||
interest_rate_percent_column: KeyedColumnElement[int],
|
||||
user_balance_column: KeyedColumnElement[int],
|
||||
penalty_threshold_column: KeyedColumnElement[int],
|
||||
penalty_multiplier_percent_column: KeyedColumnElement[int],
|
||||
joint_user_shares_column: KeyedColumnElement[int],
|
||||
joint_user_count_column: KeyedColumnElement[int],
|
||||
use_cache: bool = True,
|
||||
until_time: BindParameter[datetime] | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
cte_name: str = "product_price_cte",
|
||||
trx_subset_name: str = "product_price_trx_subset",
|
||||
):
|
||||
# TODO: This can get quite expensive real quick, so we should do some caching of the
|
||||
# product prices somehow.
|
||||
expression = (
|
||||
select(
|
||||
cast(
|
||||
func.ceil(
|
||||
# Base price
|
||||
(
|
||||
cast(
|
||||
column("price") * product_count_column * joint_user_shares_column,
|
||||
Float,
|
||||
)
|
||||
/ joint_user_count_column
|
||||
)
|
||||
# Interest
|
||||
+ (
|
||||
cast(
|
||||
column("price") * product_count_column * joint_user_shares_column,
|
||||
Float,
|
||||
)
|
||||
/ joint_user_count_column
|
||||
* cast(interest_rate_percent_column - const(100), Float)
|
||||
/ const(100.0)
|
||||
)
|
||||
# Penalty
|
||||
+ (
|
||||
(
|
||||
cast(
|
||||
column("price") * product_count_column * joint_user_shares_column,
|
||||
Float,
|
||||
)
|
||||
/ joint_user_count_column
|
||||
)
|
||||
* cast(penalty_multiplier_percent_column - const(100), Float)
|
||||
/ const(100.0)
|
||||
* cast(user_balance_column < penalty_threshold_column, Integer)
|
||||
),
|
||||
),
|
||||
Integer,
|
||||
),
|
||||
)
|
||||
.select_from(
|
||||
_product_price_query(
|
||||
product_id_column,
|
||||
use_cache=use_cache,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
cte_name=cte_name,
|
||||
trx_subset_name=trx_subset_name,
|
||||
),
|
||||
)
|
||||
.order_by(column("i").desc())
|
||||
.limit(CONST_ONE)
|
||||
.scalar_subquery()
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _user_balance_query(
|
||||
user_id: BindParameter[int] | int,
|
||||
use_cache: bool = True,
|
||||
until: BindParameter[datetime] | BindParameter[None] | datetime | None = None,
|
||||
until_including: BindParameter[bool] | bool = True,
|
||||
until_time: BindParameter[datetime] | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
cte_name: str = "rec_cte",
|
||||
trx_subset_name: str = "trx_subset",
|
||||
) -> CTE:
|
||||
"""
|
||||
The inner query for calculating the user's balance.
|
||||
@@ -47,30 +265,44 @@ def _user_balance_query(
|
||||
if isinstance(user_id, int):
|
||||
user_id = BindParameter("user_id", value=user_id)
|
||||
|
||||
if isinstance(until, datetime):
|
||||
until = BindParameter("until", value=until, type_=datetime)
|
||||
|
||||
if isinstance(until_including, bool):
|
||||
until_including = BindParameter("until_including", value=until_including, type_=bool)
|
||||
|
||||
initial_element = select(
|
||||
CONST_ZERO.label("i"),
|
||||
CONST_ZERO.label("time"),
|
||||
CONST_NONE.label("transaction_id"),
|
||||
CONST_ZERO.label("balance"),
|
||||
const(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
|
||||
const(DEFAULT_INTEREST_RATE_PERCENT).label("interest_rate_percent"),
|
||||
const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
|
||||
const(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"),
|
||||
const(DEFAULT_PENALTY_MULTIPLIER_PERCENT).label("penalty_multiplier_percent"),
|
||||
)
|
||||
|
||||
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
|
||||
|
||||
trx_subset_subset = (
|
||||
_non_joint_transaction_query(
|
||||
user_id=user_id,
|
||||
use_cache=use_cache,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
)
|
||||
.union_all(
|
||||
_joint_transaction_query(
|
||||
user_id=user_id,
|
||||
use_cache=use_cache,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
)
|
||||
.subquery(f"{trx_subset_name}_subset")
|
||||
)
|
||||
|
||||
# Subset of transactions that we'll want to iterate over.
|
||||
trx_subset = (
|
||||
select(
|
||||
func.row_number().over(order_by=Transaction.time.asc()).label("i"),
|
||||
Transaction.amount,
|
||||
Transaction.id,
|
||||
Transaction.amount,
|
||||
Transaction.interest_rate_percent,
|
||||
Transaction.penalty_multiplier_percent,
|
||||
Transaction.penalty_threshold,
|
||||
@@ -79,44 +311,16 @@ def _user_balance_query(
|
||||
Transaction.time,
|
||||
Transaction.transfer_user_id,
|
||||
Transaction.type_,
|
||||
trx_subset_subset.c.user_shares,
|
||||
trx_subset_subset.c.user_count,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
and_(
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_BALANCE.as_literal_column(),
|
||||
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
TransactionType.TRANSFER.as_literal_column(),
|
||||
# TODO: join this with the JOINT transactions, and determine
|
||||
# how much the current user paid for the product.
|
||||
TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
|
||||
]
|
||||
),
|
||||
),
|
||||
and_(
|
||||
Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
|
||||
Transaction.transfer_user_id == user_id,
|
||||
),
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.THROW_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_INTEREST.as_literal_column(),
|
||||
TransactionType.ADJUST_PENALTY.as_literal_column(),
|
||||
]
|
||||
),
|
||||
),
|
||||
case(
|
||||
(until_including, Transaction.time <= until),
|
||||
else_=Transaction.time < until,
|
||||
)
|
||||
if until is not None
|
||||
else CONST_TRUE,
|
||||
.select_from(trx_subset_subset)
|
||||
.join(
|
||||
Transaction,
|
||||
onclause=Transaction.id == trx_subset_subset.c.id,
|
||||
)
|
||||
.order_by(Transaction.time.asc())
|
||||
.alias("trx_subset")
|
||||
.subquery(trx_subset_name)
|
||||
)
|
||||
|
||||
recursive_elements = (
|
||||
@@ -139,49 +343,43 @@ def _user_balance_query(
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
recursive_cte.c.balance
|
||||
- (
|
||||
trx_subset.c.product_count
|
||||
# Price of a single product, accounted for penalties and interest.
|
||||
* cast(
|
||||
func.ceil(
|
||||
# TODO: This can get quite expensive real quick, so we should do some caching of the
|
||||
# product prices somehow.
|
||||
# Base price
|
||||
(
|
||||
# FIXME: this always returns 0 for some reason...
|
||||
select(cast(column("price"), Float))
|
||||
.select_from(
|
||||
_product_price_query(
|
||||
trx_subset.c.product_id,
|
||||
use_cache=use_cache,
|
||||
until=trx_subset.c.time,
|
||||
until_including=False,
|
||||
cte_name="product_price_cte",
|
||||
)
|
||||
)
|
||||
.order_by(column("i").desc())
|
||||
.limit(CONST_ONE)
|
||||
).scalar_subquery()
|
||||
# TODO: should interest be applied before or after the penalty multiplier?
|
||||
# at the moment of writing, after sound right, but maybe ask someone?
|
||||
# Interest
|
||||
* (cast(recursive_cte.c.interest_rate_percent, Float) / const(100))
|
||||
# TODO: these should be added together, not multiplied, see specification
|
||||
# Penalty
|
||||
* case(
|
||||
(
|
||||
recursive_cte.c.balance < recursive_cte.c.penalty_threshold,
|
||||
(
|
||||
cast(recursive_cte.c.penalty_multiplier_percent, Float)
|
||||
/ const(100)
|
||||
),
|
||||
),
|
||||
else_=const(1.0),
|
||||
)
|
||||
),
|
||||
Integer,
|
||||
)
|
||||
),
|
||||
- _product_cost_expression(
|
||||
product_count_column=trx_subset.c.product_count,
|
||||
product_id_column=trx_subset.c.product_id,
|
||||
interest_rate_percent_column=recursive_cte.c.interest_rate_percent,
|
||||
user_balance_column=recursive_cte.c.balance,
|
||||
penalty_threshold_column=recursive_cte.c.penalty_threshold,
|
||||
penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent,
|
||||
joint_user_shares_column=CONST_ONE,
|
||||
joint_user_count_column=CONST_ONE,
|
||||
use_cache=use_cache,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
cte_name=f"{cte_name}_price",
|
||||
trx_subset_name=f"{trx_subset_name}_price",
|
||||
).label("product_cost"),
|
||||
),
|
||||
# Joint transaction -> balance decreases proportionally
|
||||
(
|
||||
trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(),
|
||||
recursive_cte.c.balance
|
||||
- _product_cost_expression(
|
||||
product_count_column=trx_subset.c.product_count,
|
||||
product_id_column=trx_subset.c.product_id,
|
||||
interest_rate_percent_column=recursive_cte.c.interest_rate_percent,
|
||||
user_balance_column=recursive_cte.c.balance,
|
||||
penalty_threshold_column=recursive_cte.c.penalty_threshold,
|
||||
penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent,
|
||||
joint_user_shares_column=trx_subset.c.user_shares,
|
||||
joint_user_count_column=trx_subset.c.user_count,
|
||||
use_cache=use_cache,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
cte_name=f"{cte_name}_joint_price",
|
||||
trx_subset_name=f"{trx_subset_name}_joint_price",
|
||||
).label("joint_product_cost"),
|
||||
),
|
||||
# Transfers money to self -> balance increases
|
||||
(
|
||||
@@ -200,8 +398,7 @@ def _user_balance_query(
|
||||
recursive_cte.c.balance - trx_subset.c.amount,
|
||||
),
|
||||
# Throws a product -> if the user is considered to have bought it, balance increases
|
||||
# TODO:
|
||||
# (
|
||||
# TODO: # (
|
||||
# trx_subset.c.type_ == TransactionType.THROW_PRODUCT,
|
||||
# recursive_cte.c.balance + trx_subset.c.amount,
|
||||
# ),
|
||||
@@ -255,17 +452,16 @@ class UserBalanceLogEntry:
|
||||
Returns whether this exact transaction is penalized.
|
||||
"""
|
||||
|
||||
return False
|
||||
raise NotImplementedError("is_penalized is not implemented yet.")
|
||||
|
||||
# return self.transaction.type_ == TransactionType.BUY_PRODUCT and prev?
|
||||
|
||||
# TODO: add until datetime parameter
|
||||
|
||||
def user_balance_log(
|
||||
sql_session: Session,
|
||||
user: User,
|
||||
use_cache: bool = True,
|
||||
until: Transaction | None = None,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> list[UserBalanceLogEntry]:
|
||||
"""
|
||||
Returns a log of the user's balance over time, including interest and penalty adjustments.
|
||||
@@ -276,13 +472,18 @@ def user_balance_log(
|
||||
if user.id is None:
|
||||
raise ValueError("User must be persisted in the database.")
|
||||
|
||||
if until is not None and until.id is None:
|
||||
raise ValueError("'until' transaction must be persisted in the database.")
|
||||
if not (until_time is None or until_transaction is None):
|
||||
raise ValueError("Cannot filter by both until_time and until_transaction.")
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
recursive_cte = _user_balance_query(
|
||||
user.id,
|
||||
use_cache=use_cache,
|
||||
until=until.time if until else None,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
)
|
||||
|
||||
result = sql_session.execute(
|
||||
@@ -298,13 +499,13 @@ def user_balance_log(
|
||||
Transaction,
|
||||
onclause=Transaction.id == recursive_cte.c.transaction_id,
|
||||
)
|
||||
.order_by(recursive_cte.c.i.asc())
|
||||
.order_by(recursive_cte.c.i.asc()),
|
||||
).all()
|
||||
|
||||
if result is None:
|
||||
# If there are no transactions for this user, the query should return 0, not None.
|
||||
raise RuntimeError(
|
||||
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
|
||||
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id}).",
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -319,13 +520,13 @@ def user_balance_log(
|
||||
]
|
||||
|
||||
|
||||
# TODO: add until datetime parameter
|
||||
|
||||
def user_balance(
|
||||
sql_session: Session,
|
||||
user: User,
|
||||
use_cache: bool = True,
|
||||
until: Transaction | None = None,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
Calculates the balance of a user.
|
||||
@@ -336,23 +537,31 @@ def user_balance(
|
||||
if user.id is None:
|
||||
raise ValueError("User must be persisted in the database.")
|
||||
|
||||
if not (until_time is None or until_transaction is None):
|
||||
raise ValueError("Cannot filter by both until_time and until_transaction.")
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
recursive_cte = _user_balance_query(
|
||||
user.id,
|
||||
use_cache=use_cache,
|
||||
until=until.time if until else None,
|
||||
until_time=until_time,
|
||||
until_transaction=until_transaction,
|
||||
until_inclusive=until_inclusive,
|
||||
)
|
||||
|
||||
result = sql_session.scalar(
|
||||
select(recursive_cte.c.balance)
|
||||
.order_by(recursive_cte.c.i.desc())
|
||||
.limit(CONST_ONE)
|
||||
.offset(CONST_ZERO)
|
||||
.offset(CONST_ZERO),
|
||||
)
|
||||
|
||||
if result is None:
|
||||
# If there are no transactions for this user, the query should return 0, not None.
|
||||
raise RuntimeError(
|
||||
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
|
||||
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id}).",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
# This absoulutely needs a cache, else we can't stop recursing until we know all owners for all products...
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BindParameter, bindparam
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, Transaction, User
|
||||
|
||||
# NOTE: This absolutely needs a cache, else we can't stop recursing until we know all owners for all products...
|
||||
#
|
||||
# Since we know that the non-owned products will not get renowned by the user by other means,
|
||||
# we can just check for ownership on the products that have an ADD_PRODUCT transaction for the user.
|
||||
@@ -8,3 +15,34 @@
|
||||
# but we still need to check if the user passes out of ownership for the item, without needing to check past
|
||||
# the cache time. Maybe we also need to store the queue number(s) per user/product combo in the cache? What if
|
||||
# a user has products multiple places in the queue, interleaved with other users?
|
||||
|
||||
|
||||
def user_products(
|
||||
sql_session: Session,
|
||||
user: User,
|
||||
use_cache: bool = True,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
) -> list[tuple[Product, int]]:
|
||||
"""
|
||||
Returns the list of products owned by the user, along with how many of each product they own.
|
||||
"""
|
||||
|
||||
if user.id is None:
|
||||
raise ValueError("User must be persisted in the database.")
|
||||
|
||||
if not (until_time is None or until_transaction is None):
|
||||
raise ValueError("Cannot filter by both until_time and until_transaction.")
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
raise NotImplementedError("Not implemented yet, needs caching system first.")
|
||||
|
||||
@@ -1,79 +1,111 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
from signal import (
|
||||
SIG_IGN,
|
||||
SIGQUIT,
|
||||
SIGTSTP,
|
||||
)
|
||||
from signal import (
|
||||
signal as set_signal_handler,
|
||||
)
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..conf import config
|
||||
from ..lib.helpers import *
|
||||
from ..menus import *
|
||||
from ..menus import (
|
||||
AddProductMenu,
|
||||
AddStockMenu,
|
||||
AddUserMenu,
|
||||
AdjustCreditMenu,
|
||||
AdjustStockMenu,
|
||||
BalanceMenu,
|
||||
BuyMenu,
|
||||
CleanupStockMenu,
|
||||
EditProductMenu,
|
||||
EditUserMenu,
|
||||
FAQMenu,
|
||||
LoggedStatisticsMenu,
|
||||