167 lines
5.8 KiB
Python
167 lines
5.8 KiB
Python
from queue import Queue
|
|
from typing import Any
|
|
|
|
|
|
class CSP:
|
|
def __init__(
|
|
self,
|
|
variables: list[str],
|
|
domains: dict[str, set],
|
|
edges: list[tuple[str, str]],
|
|
):
|
|
"""Constructs a CSP instance with the given variables, domains and edges.
|
|
|
|
Parameters
|
|
----------
|
|
variables : list[str]
|
|
The variables for the CSP
|
|
domains : dict[str, set]
|
|
The domains of the variables
|
|
edges : list[tuple[str, str]]
|
|
Pairs of variables that must not be assigned the same value
|
|
"""
|
|
self.variables = variables
|
|
self.domains = domains
|
|
|
|
# Binary constraints as a dictionary mapping variable pairs to a set of value pairs.
|
|
#
|
|
# To check if variable1=value1, variable2=value2 is in violation of a binary constraint:
|
|
# if (
|
|
# (variable1, variable2) in self.binary_constraints and
|
|
# (value1, value2) not in self.binary_constraints[(variable1, variable2)]
|
|
# ) or (
|
|
# (variable2, variable1) in self.binary_constraints and
|
|
# (value1, value2) not in self.binary_constraints[(variable2, variable1)]
|
|
# ):
|
|
# Violates a binary constraint
|
|
self.binary_constraints: dict[tuple[str, str], set] = {}
|
|
for variable1, variable2 in edges:
|
|
self.binary_constraints[(variable1, variable2)] = set()
|
|
for value1 in self.domains[variable1]:
|
|
for value2 in self.domains[variable2]:
|
|
if value1 != value2:
|
|
self.binary_constraints[(variable1, variable2)].add(
|
|
(value1, value2)
|
|
)
|
|
self.binary_constraints[(variable1, variable2)].add(
|
|
(value2, value1)
|
|
)
|
|
|
|
def ac_3(self) -> bool:
|
|
"""Performs AC-3 on the CSP.
|
|
Meant to be run prior to calling backtracking_search() to reduce the search for some problems.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
False if a domain becomes empty, otherwise True
|
|
"""
|
|
|
|
def revise(csp, xi, xj) -> bool:
|
|
revised = False
|
|
for x in set(csp.domains[xi]):
|
|
if not any(
|
|
[
|
|
(x, y) in csp.binary_constraints[(xi, xj)]
|
|
for y in csp.domains[xj]
|
|
]
|
|
):
|
|
csp.domains[xi].remove(x)
|
|
revised = True
|
|
return revised
|
|
|
|
queue = Queue()
|
|
for edge in self.binary_constraints.keys():
|
|
queue.put(edge)
|
|
|
|
while not queue.empty():
|
|
(xi, xj) = queue.get()
|
|
|
|
if revise(self, xi, xj):
|
|
if len(self.domains[xi]) == 0:
|
|
return False
|
|
for neighboring_edge in [
|
|
(a, b)
|
|
for (a, b) in self.binary_constraints.keys()
|
|
if a != xj and b == xi
|
|
]:
|
|
queue.put(neighboring_edge)
|
|
return True
|
|
|
|
def backtracking_search(self) -> None | dict[str, Any]:
|
|
"""Performs backtracking search on the CSP.
|
|
|
|
Returns
|
|
-------
|
|
None | dict[str, Any]
|
|
A solution if any exists, otherwise None
|
|
"""
|
|
|
|
def backtrack(csp, assignment: dict[str, Any]) -> dict | None:
|
|
if len(assignment) == len(csp.variables):
|
|
return assignment # base-case
|
|
var = select_unassigned_variable(csp, assignment)
|
|
for value in order_domain_values(csp, var, assignment):
|
|
if not consistent(csp, var, value, assignment):
|
|
continue
|
|
assignment[var] = value
|
|
if result := backtrack(csp, assignment):
|
|
return result
|
|
assignment.pop(var)
|
|
return None # failure
|
|
|
|
def consistent(csp, var, value, assignment) -> bool:
|
|
for v in assignment:
|
|
if (var, v) in csp.binary_constraints.keys():
|
|
if (value, assignment[v]) not in csp.binary_constraints[(var, v)]:
|
|
return False
|
|
if (v, var) in csp.binary_constraints.keys():
|
|
if (value, assignment[v]) not in csp.binary_constraints[(v, var)]:
|
|
return False
|
|
return True
|
|
|
|
def select_unassigned_variable(csp, assignment) -> str:
|
|
for v in csp.variables:
|
|
if v not in assignment.keys():
|
|
return v
|
|
return "this shouldn't happen"
|
|
|
|
# choose least constrained value
|
|
def order_domain_values(csp, var, assignment) -> list[Any]:
|
|
def compare(value) -> int:
|
|
s = 0
|
|
for neighbor in [
|
|
b for (a, b) in csp.binary_constraints.keys() if a == var
|
|
]:
|
|
if neighbor in assignment.keys():
|
|
continue
|
|
for neighbor_value in csp.domains[neighbor]:
|
|
s += (value, neighbor_value) not in csp.binary_constraints[
|
|
(var, neighbor)
|
|
]
|
|
return s
|
|
|
|
return sorted(list(csp.domains[var]), key=compare)
|
|
|
|
return backtrack(self, {})
|
|
|
|
|
|
def alldiff(variables: list[str]) -> list[tuple[str, str]]:
|
|
"""Returns a list of edges interconnecting all of the input variables
|
|
|
|
Parameters
|
|
----------
|
|
variables : list[str]
|
|
The variables that all must be different
|
|
|
|
Returns
|
|
-------
|
|
list[tuple[str, str]]
|
|
List of edges in the form (a, b)
|
|
"""
|
|
return [
|
|
(variables[i], variables[j])
|
|
for i in range(len(variables) - 1)
|
|
for j in range(i + 1, len(variables))
|
|
]
|