Skip to content

Commit f3ea713

Browse files
committed
v0.3.9 Reformulated greynet constraints for NQueens + reverse index for beta node (improves retraction performance) TODO: add optimizations from Rust version with greynet integration (SmallRng instead StdRng; corrected delta commiting in Agent, Greynet score calculator, requester; mover enhancements (sampling, etc))
1 parent e42b384 commit f3ea713

File tree

6 files changed

+89
-46
lines changed

6 files changed

+89
-46
lines changed

examples/object_oriented/nqueens/score/GreynetScoreCalculatorNQueens.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,53 @@
1-
# file: nqueens_constraint_builder.py
1+
# file: GreynetScoreCalculatorNQueens.py
22

33
from greyjack.score_calculation.score_calculators.GreynetScoreCalculator import GreynetScoreCalculator
44
from greyjack.score_calculation.greynet.builder import ConstraintBuilder
55
from greyjack.score_calculation.scores.SimpleScore import SimpleScore
66
from greyjack.score_calculation.scores.ScoreVariants import ScoreVariants
7+
from greyjack.score_calculation.greynet.common.joiner_type import JoinerType
78
from ..cotwin.CotQueen import CotQueen
89

910
cb = ConstraintBuilder(name="NQueens", score_class=SimpleScore)
1011

11-
@cb.constraint("Row Conflict", default_weight=1.0)
12-
def row_conflict():
12+
@cb.constraint("Same Row", default_weight=1.0)
13+
def same_row():
1314
return (
14-
cb.for_each_unique_pair(CotQueen)
15-
.filter(lambda q1, q2: q1.row_id == q2.row_id)
15+
cb.for_each(CotQueen)
16+
.join(
17+
cb.for_each(CotQueen),
18+
JoinerType.EQUAL,
19+
lambda q: q.row_id,
20+
lambda q: q.row_id
21+
)
22+
.filter(lambda q1, q2: q1.queen_id < q2.queen_id)
1623
.penalize_simple(1)
1724
)
1825

19-
@cb.constraint("Ascending Diagonal Conflict", default_weight=1.0)
20-
def ascending_diagonal_conflict():
26+
@cb.constraint("Ascending Diagonal", default_weight=1.0)
27+
def ascending_diagonal():
2128
return (
22-
cb.for_each_unique_pair(CotQueen)
23-
.filter(lambda q1, q2: (q1.row_id - q1.column_id) == (q2.row_id - q2.column_id))
29+
cb.for_each(CotQueen)
30+
.join(
31+
cb.for_each(CotQueen),
32+
JoinerType.EQUAL,
33+
lambda q: q.column_id - q.row_id,
34+
lambda q: q.column_id - q.row_id
35+
)
36+
.filter(lambda q1, q2: q1.queen_id < q2.queen_id)
2437
.penalize_simple(1)
2538
)
2639

27-
@cb.constraint("Descending Diagonal Conflict", default_weight=1.0)
28-
def descending_diagonal_conflict():
40+
@cb.constraint("Descending Diagonal", default_weight=1.0)
41+
def descending_diagonal():
2942
return (
30-
cb.for_each_unique_pair(CotQueen)
31-
.filter(lambda q1, q2: (q1.row_id + q1.column_id) == (q2.row_id + q2.column_id))
43+
cb.for_each(CotQueen)
44+
.join(
45+
cb.for_each(CotQueen),
46+
JoinerType.EQUAL,
47+
lambda q: q.column_id + q.row_id,
48+
lambda q: q.column_id + q.row_id
49+
)
50+
.filter(lambda q1, q2: q1.queen_id < q2.queen_id)
3251
.penalize_simple(1)
3352
)
3453

examples/object_oriented/nqueens/scripts/solve_nqueens_greynet_experimental.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
if __name__ == "__main__":
2121

2222
# build domain model
23-
domain_builder = DomainBuilderNQueens(16, random_seed=45)
23+
domain_builder = DomainBuilderNQueens(1000, random_seed=45)
2424
cotwin_builder = CotwinBuilderNQueens(scorer_name="greynet_incremental")
2525

2626
#domain = domain_builder.build_domain_from_scratch()
@@ -44,13 +44,13 @@
4444
migration_frequency=10, compare_to_global_frequency=10, termination_strategy=termination_strategy)"""
4545

4646
solver = SolverOOP(domain_builder, cotwin_builder, agent,
47-
ParallelizationBackend.Threading, LoggingLevel.Info,
47+
ParallelizationBackend.Multiprocessing, LoggingLevel.FreshOnly,
4848
n_jobs=1, score_precision=[0])
4949
solution = solver.solve()
5050
#print( "Cotwin solution looks that: " )
5151
#print( solution )
5252

5353
domain = domain_builder.build_from_solution(solution)
54-
print(domain)
54+
#print(domain)
5555

5656
print( "done" )

greyjack/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

greyjack/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "greyjack"
3-
version = "0.3.8"
3+
version = "0.3.9"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

greyjack/greyjack/score_calculation/greynet/nodes/beta_node.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22
from abc import ABC, abstractmethod
3+
from collections import defaultdict
34

45
from ..nodes.abstract_node import AbstractNode
56
from ..common.index.uni_index import UniIndex
67
from ..common.index.advanced_index import AdvancedIndex
78
from ..core.tuple import TupleState, AbstractTuple
89
from ..common.joiner_type import JoinerType
910

10-
# --- Start of Bug Fix ---
1111
# Defines the logical inverse for each joiner type, which is essential for
1212
# creating symmetrical join logic within the BetaNode.
1313
JOINER_INVERSES = {
@@ -21,43 +21,34 @@
2121
JoinerType.RANGE_CONTAINS: JoinerType.RANGE_WITHIN,
2222
JoinerType.RANGE_WITHIN: JoinerType.RANGE_CONTAINS,
2323
}
24-
# --- End of Bug Fix ---
2524

2625

2726
class BetaNode(AbstractNode, ABC):
2827
"""
2928
An abstract base class for all join nodes (Bi, Tri, Quad, etc.).
3029
It contains the common logic for indexing, matching, and propagating tuples.
30+
This version includes reverse indices for O(1) retraction performance.
3131
"""
3232
def __init__(self, node_id, joiner_type, left_index_properties, right_index_properties, scheduler, tuple_pool):
3333
super().__init__(node_id)
3434
self.scheduler = scheduler
3535
self.tuple_pool = tuple_pool
3636
self.joiner_type = joiner_type
37-
38-
# --- Start of Bug Fix ---
39-
# The join is defined from left to right (e.g., left.key < right.key).
40-
# When a new right_tuple arrives, we probe the left_index to find left_tuples
41-
# that satisfy the original joiner (e.g., left.key < new_right.key).
42-
self.left_index = self._create_index(left_index_properties, joiner_type)
4337

44-
# When a new left_tuple arrives, we must probe the right_index using the
45-
# inverse joiner to find right_tuples that satisfy the condition.
46-
# e.g., to find 'right' where 'new_left.key < right.key', we must query for
47-
# 'right.key > new_left.key'.
38+
self.left_index = self._create_index(left_index_properties, joiner_type)
4839
inverse_joiner = JOINER_INVERSES.get(joiner_type)
4940
if inverse_joiner is None:
5041
raise ValueError(f"Joiner type {joiner_type} has no defined inverse.")
5142
self.right_index = self._create_index(right_index_properties, inverse_joiner)
52-
# --- End of Bug Fix ---
53-
43+
5444
self.beta_memory = {}
45+
self.left_to_pairs = defaultdict(list)
46+
self.right_to_pairs = defaultdict(list)
5547

5648
def __repr__(self) -> str:
5749
"""Overrides base representation to include the joiner type."""
5850
return f"<{self.__class__.__name__} id={self._node_id} joiner={self.joiner_type.name}>"
5951

60-
6152
def _create_index(self, props, joiner):
6253
"""Factory method to create the appropriate index based on joiner type."""
6354
if joiner == JoinerType.EQUAL:
@@ -70,40 +61,73 @@ def _create_child_tuple(self, left_tuple: AbstractTuple, right_tuple: AbstractTu
7061
pass
7162

7263
def insert_left(self, left_tuple: AbstractTuple):
64+
"""Handles insertion of a tuple from the left parent."""
7365
self.left_index.put(left_tuple)
7466
key = self.left_index._index_properties.get_property(left_tuple)
67+
# Probe the opposing index for matches
7568
right_matches = self.right_index.get_matches(key) if hasattr(self.right_index, 'get_matches') else self.right_index.get(key)
7669
for right_tuple in right_matches:
77-
self.create_and_schedule_child(left_tuple, right_tuple)
70+
self.create_and_propagate_child(left_tuple, right_tuple)
7871

7972
def insert_right(self, right_tuple: AbstractTuple):
73+
"""Handles insertion of a tuple from the right parent."""
8074
self.right_index.put(right_tuple)
8175
key = self.right_index._index_properties.get_property(right_tuple)
76+
# Probe the opposing index for matches
8277
left_matches = self.left_index.get_matches(key) if hasattr(self.left_index, 'get_matches') else self.left_index.get(key)
8378
for left_tuple in left_matches:
84-
self.create_and_schedule_child(left_tuple, right_tuple)
79+
self.create_and_propagate_child(left_tuple, right_tuple)
8580

8681
def retract_left(self, left_tuple: AbstractTuple):
82+
"""Handles retraction of a tuple from the left parent using the reverse index."""
8783
self.left_index.remove(left_tuple)
88-
pairs_to_remove = [p for p in self.beta_memory if p[0] == left_tuple]
89-
for pair in pairs_to_remove:
90-
self.retract_and_schedule_child(pair[0], pair[1])
84+
85+
if left_tuple in self.left_to_pairs:
86+
pairs_to_remove = list(self.left_to_pairs[left_tuple])
87+
88+
for pair in pairs_to_remove:
89+
self.retract_and_propagate_child(pair)
9190

9291
def retract_right(self, right_tuple: AbstractTuple):
92+
"""Handles retraction of a tuple from the right parent using the reverse index."""
9393
self.right_index.remove(right_tuple)
94-
pairs_to_remove = [p for p in self.beta_memory if p[1] == right_tuple]
95-
for pair in pairs_to_remove:
96-
self.retract_and_schedule_child(pair[0], pair[1])
94+
95+
if right_tuple in self.right_to_pairs:
96+
pairs_to_remove = list(self.right_to_pairs[right_tuple])
97+
98+
for pair in pairs_to_remove:
99+
self.retract_and_propagate_child(pair)
100+
101+
def create_and_propagate_child(self, left_tuple: AbstractTuple, right_tuple: AbstractTuple):
102+
"""Creates a new child tuple, stores it, updates indices, and schedules propagation."""
103+
pair = (left_tuple, right_tuple)
104+
if pair in self.beta_memory:
105+
return # Avoid creating duplicate children
97106

98-
def create_and_schedule_child(self, left_tuple: AbstractTuple, right_tuple: AbstractTuple):
99107
child = self._create_child_tuple(left_tuple, right_tuple)
100108
child.node, child.state = self, TupleState.CREATING
101-
self.beta_memory[(left_tuple, right_tuple)] = child
109+
self.beta_memory[pair] = child
110+
111+
self.left_to_pairs[left_tuple].append(pair)
112+
self.right_to_pairs[right_tuple].append(pair)
113+
102114
self.scheduler.schedule(child)
103115

104-
def retract_and_schedule_child(self, left: AbstractTuple, right: AbstractTuple):
105-
child = self.beta_memory.pop((left, right), None)
116+
def retract_and_propagate_child(self, pair: tuple[AbstractTuple, AbstractTuple]):
117+
"""Removes a child tuple, cleans up all indices, and schedules retraction."""
118+
child = self.beta_memory.pop(pair, None)
106119
if child:
120+
left, right = pair
121+
if left in self.left_to_pairs:
122+
self.left_to_pairs[left].remove(pair)
123+
if not self.left_to_pairs[left]:
124+
del self.left_to_pairs[left]
125+
126+
if right in self.right_to_pairs:
127+
self.right_to_pairs[right].remove(pair)
128+
if not self.right_to_pairs[right]:
129+
del self.right_to_pairs[right]
130+
107131
if child.state == TupleState.CREATING:
108132
child.state = TupleState.ABORTING
109133
elif not child.state.is_dirty():

greyjack/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ features = ["pyo3/extension-module"]
77

88
[project]
99
name = "greyjack"
10-
version = "0.3.8"
10+
version = "0.3.9"
1111
requires-python = ">=3.9"
1212
dependencies = [
1313
"bitarray>=3.5.0",

0 commit comments

Comments
 (0)