From d6b002f676c29c58523872b06df559781950e59f Mon Sep 17 00:00:00 2001 From: h7x4 Date: Sun, 25 Jan 2026 19:53:03 +0900 Subject: [PATCH] repl: better typing --- src/libdib/repl.py | 142 +++++++++++++++++++++++++-------------------- 1 file changed, 79 insertions(+), 63 deletions(-) diff --git a/src/libdib/repl.py b/src/libdib/repl.py index 53c7c09..20c21ed 100644 --- a/src/libdib/repl.py +++ b/src/libdib/repl.py @@ -1,49 +1,63 @@ from cmd import Cmd from datetime import datetime -from typing import Any, Callable +from typing import Any, Callable, TypeVar, Generic, Self from sqlalchemy import select from sqlalchemy.orm import Session + def prompt_yes_no(question: str, default: bool | None = None) -> bool: prompt = { - None: '[y/n]', - True: '[Y/n]', - False: '[y/N]', + None: "[y/n]", + True: "[Y/n]", + False: "[y/N]", }[default] - while not any([ - (answer := input(f'{question} {prompt} ').lower()) in ('y','n'), - (default != None and answer.strip() == '') - ]): + while not any( + [ + (answer := input(f"{question} {prompt} ").lower()) in ("y", "n"), + (default is not None and answer.strip() == ""), + ] + ): pass return { - 'y': True, - 'n': False, - '': default, + "y": True, + "n": False, + "": default, }[answer] -def format_date(date: datetime): +def format_date(date: datetime) -> str: return date.strftime("%a %b %d, %Y") -class InteractiveItemSelector(Cmd): +InteractiveItemSelectorType = TypeVar("InteractiveItemSelectorType") + + +class InteractiveItemSelector(Generic[InteractiveItemSelectorType], Cmd): + sql_session: Session + execute_selection: Callable[[Session, type, str], list[Any]] + complete_selection: Callable[[Session, type, str], list[str]] + default_item: InteractiveItemSelectorType | None + result: InteractiveItemSelectorType | None + def __init__( self, cls: type, sql_session: Session, - execute_selection: Callable[[Session, type, str], list[Any]] = lambda session, cls, arg: session.scalars( - select(cls) - .where(cls.name == arg), + execute_selection: Callable[ + [Session, type, str], list[InteractiveItemSelectorType] + ] = lambda session, cls, arg: session.scalars( + select(cls).where(cls.name == arg), ).all(), - complete_selection: Callable[[Session, type, str], list[str]] = lambda session, cls, text: session.scalars( - select(cls.name) - .where(cls.name.istartswith(text)), + complete_selection: Callable[[Session, type, str], list[str]] = lambda session, + cls, + text: session.scalars( + select(cls.name).where(cls.name.istartswith(text)), ).all(), - default: Any | None = None, - ): + default: InteractiveItemSelectorType | None = None, + ) -> None: """ This is a utility class for prompting the user to select an item from the database. The default functions assumes that @@ -61,10 +75,9 @@ class InteractiveItemSelector(Cmd): self.result = None if default is not None: - self.prompt = f'Select {cls.__name__} [{default.name}]> ' + self.prompt = f"Select {cls.__name__} [{default.name}]> " else: - self.prompt = f'Select {cls.__name__}> ' - + self.prompt = f"Select {cls.__name__}> " def emptyline(self) -> bool: if self.default_item is not None: @@ -72,15 +85,14 @@ class InteractiveItemSelector(Cmd): return True return False - - def default(self, arg: str): + def default(self, arg: str) -> bool | None: try: - result = self.execute_selection(self.sql_session, self.cls, arg) + result = self.execute_selection(self.sql_session, self.cls, arg) except Exception as e: - print(f'Error executing selection: {e}') + print(f"Error executing selection: {e}") if len(result) != 1: - print(f'No such {self.cls.__name__} found: {arg}') + print(f"No such {self.cls.__name__} found: {arg}") return self.result = result[0] @@ -132,72 +144,77 @@ class NumberedCmd(Cmd): ``` """ - prompt_header: str | None = None - funcs: dict[int, dict[str, str | Callable[[Any, str], bool | None]]] + funcs: dict[ + int, + dict[ + str, + str | Callable[[Self, str], bool | None], + ], + ] - - def __init__(self): + def __init__(self) -> None: super().__init__() - def _generate_usage_list(self) -> str: - result = '' + result = "" for i, func in self.funcs.items(): if i == 0: - i = '*' - result += f'{i}) {func["doc"]}\n' + i = "*" + result += f"{i}) {func['doc']}\n" return result - - def _default(self, arg: str): + def _default(self, arg: str) -> bool | None: try: i = int(arg) self.funcs[i] except (ValueError, KeyError): return - return self.funcs[i]['f'](self, arg) + return self.funcs[i]["f"](self, arg) - - def default(self, arg: str): + def default(self, arg: str) -> bool | None: return self._default(arg) - def _postcmd(self, stop: bool, _: str) -> bool: if not stop: - print() - print('-----------------') - print() + print() + print("-----------------") + print() return stop - def postcmd(self, stop: bool, line: str) -> bool: return self._postcmd(stop, line) - @property - def prompt(self): - result = '' + def prompt(self) -> str: + result = "" - if self.prompt_header != None: - result += self.prompt_header + '\n' + if self.prompt_header is not None: + result += self.prompt_header + "\n" result += self._generate_usage_list() - if self.lastcmd == '': - result += f'> ' + if self.lastcmd == "": + result += "> " else: - result += f'[{self.lastcmd}]> ' + result += f"[{self.lastcmd}]> " return result -class NumberedItemSelector(NumberedCmd): +NumberedItemSelectorType = TypeVar("NumberedItemSelectorType") + + +class NumberedItemSelector(Generic[NumberedItemSelectorType], NumberedCmd): + items: list[NumberedItemSelectorType] + stringify: Callable[[NumberedItemSelectorType], str] + result: NumberedItemSelectorType | None + def __init__( self, - items: list[Any], - stringify: Callable[[Any], str] = lambda x: str(x), + items: list[NumberedItemSelectorType], + stringify: Callable[[NumberedItemSelectorType], str] = lambda x: str(x), ): super().__init__() self.items = items @@ -205,13 +222,12 @@ class NumberedItemSelector(NumberedCmd): self.result = None self.funcs = { i: { - 'f': self._select_item, - 'doc': self.stringify(item), + "f": self._select_item, + "doc": self.stringify(item), } for i, item in enumerate(items, start=1) } - - def _select_item(self, *a): - self.result = self.items[int(self.lastcmd)-1] + def _select_item(self, *a) -> bool: + self.result = self.items[int(self.lastcmd) - 1] return True