This commit is contained in:
2025-09-24 13:50:08 +02:00
parent d0fe31383a
commit 78b04c4445

View File

@@ -56,6 +56,20 @@ class CSP:
bool
False if a domain becomes empty, otherwise True
"""
def revise(csp, xi, xj) -> bool:
revised = False
for x in set(csp.domains[xi]):
if not any(
[
(x, y) in csp.binary_constraints[(xi, xj)]
for y in csp.domains[xj]
]
):
csp.domains[xi].remove(x)
revised = True
return revised
queue = Queue()
for edge in self.binary_constraints.keys():
queue.put(edge)
@@ -63,7 +77,7 @@ class CSP:
while not queue.empty():
(xi, xj) = queue.get()
if self._revise(xi, xj):
if revise(self, xi, xj):
if len(self.domains[xi]) == 0:
return False
for neighboring_edge in [
@@ -74,53 +88,6 @@ class CSP:
queue.put(neighboring_edge)
return True
def _revise(self, xi, xj) -> bool:
"""Internal of ac_3
Makes an arc consistent by shrinking xi's domain.
Parameters
----------
xi, xj
Nodes of the arc
Returns
-------
bool
True if a change was made to the domain of xi, False otherwise
"""
revised = False
for x in set(self.domains[xi]):
if not any(
[(x, y) in self.binary_constraints[(xi, xj)] for y in self.domains[xj]]
):
self.domains[xi].remove(x)
revised = True
return revised
def _consistent(self, var, value, assignment) -> bool:
for v in assignment:
if (var, v) in self.binary_constraints.keys():
if (value, assignment[v]) not in self.binary_constraints[(var, v)]:
return False
if (v, var) in self.binary_constraints.keys():
if (value, assignment[v]) not in self.binary_constraints[(v, var)]:
return False
return True
def _backtrack(self, assignment: dict[str, Any]) -> dict | None:
if len(assignment) == len(self.variables):
return assignment # base-case
var = self._select_unassigned_variable(assignment)
for value in self._order_domain_values(var, assignment):
if not self._consistent(var, value, assignment):
continue
assignment[var] = value
if result := self._backtrack(assignment):
return result
assignment.pop(var)
return None # failure
def backtracking_search(self) -> None | dict[str, Any]:
"""Performs backtracking search on the CSP.
@@ -130,32 +97,53 @@ class CSP:
A solution if any exists, otherwise None
"""
return self._backtrack({})
def backtrack(csp, assignment: dict[str, Any]) -> dict | None:
if len(assignment) == len(csp.variables):
return assignment # base-case
var = select_unassigned_variable(csp, assignment)
for value in order_domain_values(csp, var, assignment):
if not consistent(csp, var, value, assignment):
continue
assignment[var] = value
if result := backtrack(csp, assignment):
return result
assignment.pop(var)
return None # failure
def _select_unassigned_variable(self, assignment) -> str:
for v in self.variables:
def consistent(csp, var, value, assignment) -> bool:
for v in assignment:
if (var, v) in csp.binary_constraints.keys():
if (value, assignment[v]) not in csp.binary_constraints[(var, v)]:
return False
if (v, var) in csp.binary_constraints.keys():
if (value, assignment[v]) not in csp.binary_constraints[(v, var)]:
return False
return True
def select_unassigned_variable(csp, assignment) -> str:
for v in csp.variables:
if v not in assignment.keys():
return v
return "this shouldn't happen"
# TODO: may be wrong
def _order_domain_values(self, var, assignment) -> list[Any]:
"""least constrained value"""
# choose least constrained value
def order_domain_values(csp, var, assignment) -> list[Any]:
def compare(value) -> int:
s = 0
for neighbor in [
b for (a, b) in self.binary_constraints.keys() if a == var
b for (a, b) in csp.binary_constraints.keys() if a == var
]:
if neighbor in assignment.keys():
continue
for neighbor_value in self.domains[neighbor]:
s += (value, neighbor_value) not in self.binary_constraints[
for neighbor_value in csp.domains[neighbor]:
s += (value, neighbor_value) not in csp.binary_constraints[
(var, neighbor)
]
return s
return sorted(list(self.domains[var]), key=compare)
return sorted(list(csp.domains[var]), key=compare)
return backtrack(self, {})
def alldiff(variables: list[str]) -> list[tuple[str, str]]: