repl: better typing

This commit is contained in:
2026-01-25 19:53:03 +09:00
parent 2b42130bee
commit d6b002f676
+79 -63
View File
@@ -1,49 +1,63 @@
from cmd import Cmd from cmd import Cmd
from datetime import datetime from datetime import datetime
from typing import Any, Callable from typing import Any, Callable, TypeVar, Generic, Self
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
def prompt_yes_no(question: str, default: bool | None = None) -> bool: def prompt_yes_no(question: str, default: bool | None = None) -> bool:
prompt = { prompt = {
None: '[y/n]', None: "[y/n]",
True: '[Y/n]', True: "[Y/n]",
False: '[y/N]', False: "[y/N]",
}[default] }[default]
while not any([ while not any(
(answer := input(f'{question} {prompt} ').lower()) in ('y','n'), [
(default != None and answer.strip() == '') (answer := input(f"{question} {prompt} ").lower()) in ("y", "n"),
]): (default is not None and answer.strip() == ""),
]
):
pass pass
return { return {
'y': True, "y": True,
'n': False, "n": False,
'': default, "": default,
}[answer] }[answer]
def format_date(date: datetime): def format_date(date: datetime) -> str:
return date.strftime("%a %b %d, %Y") return date.strftime("%a %b %d, %Y")
class InteractiveItemSelector(Cmd): InteractiveItemSelectorType = TypeVar("InteractiveItemSelectorType")
class InteractiveItemSelector(Generic[InteractiveItemSelectorType], Cmd):
sql_session: Session
execute_selection: Callable[[Session, type, str], list[Any]]
complete_selection: Callable[[Session, type, str], list[str]]
default_item: InteractiveItemSelectorType | None
result: InteractiveItemSelectorType | None
def __init__( def __init__(
self, self,
cls: type, cls: type,
sql_session: Session, sql_session: Session,
execute_selection: Callable[[Session, type, str], list[Any]] = lambda session, cls, arg: session.scalars( execute_selection: Callable[
select(cls) [Session, type, str], list[InteractiveItemSelectorType]
.where(cls.name == arg), ] = lambda session, cls, arg: session.scalars(
select(cls).where(cls.name == arg),
).all(), ).all(),
complete_selection: Callable[[Session, type, str], list[str]] = lambda session, cls, text: session.scalars( complete_selection: Callable[[Session, type, str], list[str]] = lambda session,
select(cls.name) cls,
.where(cls.name.istartswith(text)), text: session.scalars(
select(cls.name).where(cls.name.istartswith(text)),
).all(), ).all(),
default: Any | None = None, default: InteractiveItemSelectorType | None = None,
): ) -> None:
""" """
This is a utility class for prompting the user to select an This is a utility class for prompting the user to select an
item from the database. The default functions assumes that item from the database. The default functions assumes that
@@ -61,10 +75,9 @@ class InteractiveItemSelector(Cmd):
self.result = None self.result = None
if default is not None: if default is not None:
self.prompt = f'Select {cls.__name__} [{default.name}]> ' self.prompt = f"Select {cls.__name__} [{default.name}]> "
else: else:
self.prompt = f'Select {cls.__name__}> ' self.prompt = f"Select {cls.__name__}> "
def emptyline(self) -> bool: def emptyline(self) -> bool:
if self.default_item is not None: if self.default_item is not None:
@@ -72,15 +85,14 @@ class InteractiveItemSelector(Cmd):
return True return True
return False return False
def default(self, arg: str) -> bool | None:
def default(self, arg: str):
try: try:
result = self.execute_selection(self.sql_session, self.cls, arg) result = self.execute_selection(self.sql_session, self.cls, arg)
except Exception as e: except Exception as e:
print(f'Error executing selection: {e}') print(f"Error executing selection: {e}")
if len(result) != 1: if len(result) != 1:
print(f'No such {self.cls.__name__} found: {arg}') print(f"No such {self.cls.__name__} found: {arg}")
return return
self.result = result[0] self.result = result[0]
@@ -132,72 +144,77 @@ class NumberedCmd(Cmd):
``` ```
""" """
prompt_header: str | None = None prompt_header: str | None = None
funcs: dict[int, dict[str, str | Callable[[Any, str], bool | None]]] funcs: dict[
int,
dict[
str,
str | Callable[[Self, str], bool | None],
],
]
def __init__(self) -> None:
def __init__(self):
super().__init__() super().__init__()
def _generate_usage_list(self) -> str: def _generate_usage_list(self) -> str:
result = '' result = ""
for i, func in self.funcs.items(): for i, func in self.funcs.items():
if i == 0: if i == 0:
i = '*' i = "*"
result += f'{i}) {func["doc"]}\n' result += f"{i}) {func['doc']}\n"
return result return result
def _default(self, arg: str) -> bool | None:
def _default(self, arg: str):
try: try:
i = int(arg) i = int(arg)
self.funcs[i] self.funcs[i]
except (ValueError, KeyError): except (ValueError, KeyError):
return return
return self.funcs[i]['f'](self, arg) return self.funcs[i]["f"](self, arg)
def default(self, arg: str) -> bool | None:
def default(self, arg: str):
return self._default(arg) return self._default(arg)
def _postcmd(self, stop: bool, _: str) -> bool: def _postcmd(self, stop: bool, _: str) -> bool:
if not stop: if not stop:
print() print()
print('-----------------') print("-----------------")
print() print()
return stop return stop
def postcmd(self, stop: bool, line: str) -> bool: def postcmd(self, stop: bool, line: str) -> bool:
return self._postcmd(stop, line) return self._postcmd(stop, line)
@property @property
def prompt(self): def prompt(self) -> str:
result = '' result = ""
if self.prompt_header != None: if self.prompt_header is not None:
result += self.prompt_header + '\n' result += self.prompt_header + "\n"
result += self._generate_usage_list() result += self._generate_usage_list()
if self.lastcmd == '': if self.lastcmd == "":
result += f'> ' result += "> "
else: else:
result += f'[{self.lastcmd}]> ' result += f"[{self.lastcmd}]> "
return result return result
class NumberedItemSelector(NumberedCmd): NumberedItemSelectorType = TypeVar("NumberedItemSelectorType")
class NumberedItemSelector(Generic[NumberedItemSelectorType], NumberedCmd):
items: list[NumberedItemSelectorType]
stringify: Callable[[NumberedItemSelectorType], str]
result: NumberedItemSelectorType | None
def __init__( def __init__(
self, self,
items: list[Any], items: list[NumberedItemSelectorType],
stringify: Callable[[Any], str] = lambda x: str(x), stringify: Callable[[NumberedItemSelectorType], str] = lambda x: str(x),
): ):
super().__init__() super().__init__()
self.items = items self.items = items
@@ -205,13 +222,12 @@ class NumberedItemSelector(NumberedCmd):
self.result = None self.result = None
self.funcs = { self.funcs = {
i: { i: {
'f': self._select_item, "f": self._select_item,
'doc': self.stringify(item), "doc": self.stringify(item),
} }
for i, item in enumerate(items, start=1) for i, item in enumerate(items, start=1)
} }
def _select_item(self, *a) -> bool:
def _select_item(self, *a): self.result = self.items[int(self.lastcmd) - 1]
self.result = self.items[int(self.lastcmd)-1]
return True return True