Skip to content

Commit dab880d

Browse files
committed
v0.3.2 Initial Greynet integration into GreyJack Solver with launchable NQueens example (TODO: debug incorrect state management, that causes wrong results while solving)
1 parent 6ecf95d commit dab880d

10 files changed

Lines changed: 485 additions & 148 deletions

File tree

examples/object_oriented/nqueens/persistence/CotwinBuilderNQueens.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from examples.object_oriented.nqueens.cotwin.NQueensCotwin import NQueensCotwin
66
from examples.object_oriented.nqueens.score.PlainScoreCalculatorNQueens import PlainScoreCalculatorNQueens
77
from examples.object_oriented.nqueens.score.IncrementalScoreCalculatorNQueens import IncrementalScoreCalculatorNQueens
8+
from examples.object_oriented.nqueens.score.GreynetScoreCalculatorNQueens import greynet_score_calculator_nqueens
89
from examples.object_oriented.nqueens.cotwin.CotQueen import CotQueen
910
from greyjack.variables.GJInteger import GJInteger
1011

1112

1213
class CotwinBuilderNQueens(CotwinBuilderBase):
13-
def __init__(self, use_incremental_score_calculator):
14-
self.use_incremental_score_calculator = use_incremental_score_calculator
14+
def __init__(self, scorer_name):
15+
self.scorer_name = scorer_name
1516
pass
1617

1718
def build_cotwin(self, domain_model, is_already_initialized):
@@ -29,10 +30,14 @@ def build_cotwin(self, domain_model, is_already_initialized):
2930

3031
nqueens_cotwin = NQueensCotwin()
3132
nqueens_cotwin.add_planning_entities_list( cot_queens, "queens" )
32-
if self.use_incremental_score_calculator:
33+
if self.scorer_name == "plain":
34+
nqueens_cotwin.set_score_calculator( PlainScoreCalculatorNQueens() )
35+
elif self.scorer_name == "pseudo":
3336
nqueens_cotwin.set_score_calculator( IncrementalScoreCalculatorNQueens() )
37+
elif self.scorer_name == "greynet":
38+
nqueens_cotwin.set_score_calculator( greynet_score_calculator_nqueens )
3439
else:
35-
nqueens_cotwin.set_score_calculator( PlainScoreCalculatorNQueens() )
40+
raise ValueError("Available score calculators: plain, pseudo, greynet")
3641

3742
return nqueens_cotwin
3843

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# file: nqueens_constraint_builder.py
2+
3+
from greyjack.score_calculation.score_calculators.GreynetScoreCalculator import GreynetScoreCalculator
4+
from greyjack.score_calculation.greynet.builder import ConstraintBuilder
5+
from greyjack.score_calculation.scores.SimpleScore import SimpleScore
6+
from greyjack.score_calculation.scores.ScoreVariants import ScoreVariants
7+
from ..cotwin.CotQueen import CotQueen
8+
9+
cb = ConstraintBuilder(name="NQueens", score_class=SimpleScore)
10+
11+
@cb.constraint("Row Conflict", default_weight=1.0)
12+
def row_conflict():
13+
return (cb.for_each_unique_pair(CotQueen)
14+
.filter(lambda q1, q2: q1.row_id == q2.row_id)
15+
.penalize_simple(1))
16+
17+
@cb.constraint("Ascending Diagonal Conflict", default_weight=1.0)
18+
def ascending_diagonal_conflict():
19+
return (cb.for_each_unique_pair(CotQueen)
20+
.filter(lambda q1, q2: (q1.row_id - q1.column_id) == (q2.row_id - q2.column_id))
21+
.penalize_simple(1))
22+
23+
@cb.constraint("Descending Diagonal Conflict", default_weight=1.0)
24+
def descending_diagonal_conflict():
25+
return (
26+
cb.for_each_unique_pair(CotQueen)
27+
.filter(lambda q1, q2: (q1.row_id + q1.column_id) == (q2.row_id + q2.column_id))
28+
.penalize_simple(1)
29+
)
30+
31+
greynet_score_calculator_nqueens = GreynetScoreCalculator(constraint_builder=cb, score_variant=ScoreVariants.SimpleScore)

examples/object_oriented/nqueens/scripts/solve_nqueens.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,42 @@
1616
from greyjack.agents.base.LoggingLevel import LoggingLevel
1717
from greyjack.agents.base.ParallelizationBackend import ParallelizationBackend
1818
from greyjack.agents import *
19+
import traceback
1920

2021
if __name__ == "__main__":
2122

22-
# build domain model
23-
domain_builder = DomainBuilderNQueens(10000, random_seed=45)
24-
cotwin_builder = CotwinBuilderNQueens(use_incremental_score_calculator=True)
25-
26-
#termination_strategy = StepsLimit(step_count_limit=1000)
27-
#termination_strategy = TimeSpentLimit(time_seconds_limit=60)
28-
#termination_strategy = ScoreNoImprovement(time_seconds_limit=15)
29-
termination_strategy = ScoreLimit(score_to_compare=[0])
30-
agent = TabuSearch(neighbours_count=20, tabu_entity_rate=0.0,
31-
mutation_rate_multiplier=None, move_probas=[0, 1, 0, 0, 0, 0],
32-
migration_frequency=10, termination_strategy=termination_strategy)
33-
"""agent = GeneticAlgorithm(population_size=128, crossover_probability=0.5, p_best_rate=0.05,
34-
tabu_entity_rate=0.0, mutation_rate_multiplier=1.0, move_probas=[0, 1, 0, 0, 0, 0],
35-
migration_rate=0.00001, migration_frequency=1, termination_strategy=termination_strategy)"""
36-
"""agent = LateAcceptance(late_acceptance_size=10, tabu_entity_rate=0.0,
37-
mutation_rate_multiplier=None, move_probas=[0, 1, 0, 0, 0, 0],
38-
compare_to_global_frequency=1, termination_strategy=termination_strategy)"""
39-
"""agent = SimulatedAnnealing(initial_temperature=[1.0], cooling_rate=0.9999, tabu_entity_rate=0.0,
40-
mutation_rate_multiplier=None, move_probas=[0, 1, 0, 0, 0, 0],
41-
migration_frequency=10, compare_to_global_frequency=10, termination_strategy=termination_strategy)"""
42-
43-
solver = SolverOOP(domain_builder, cotwin_builder, agent,
44-
ParallelizationBackend.Multiprocessing, LoggingLevel.FreshOnly,
45-
n_jobs=10, score_precision=[0])
46-
solution = solver.solve()
47-
#print( "Cotwin solution looks that: " )
48-
#print( solution )
49-
50-
domain = domain_builder.build_from_solution(solution)
51-
print(domain)
23+
try:
24+
# build domain model
25+
domain_builder = DomainBuilderNQueens(10000, random_seed=45)
26+
cotwin_builder = CotwinBuilderNQueens(scorer_name="pseudo")
27+
28+
#termination_strategy = StepsLimit(step_count_limit=1000)
29+
#termination_strategy = TimeSpentLimit(time_seconds_limit=60)
30+
#termination_strategy = ScoreNoImprovement(time_seconds_limit=15)
31+
termination_strategy = ScoreLimit(score_to_compare=[0])
32+
agent = TabuSearch(neighbours_count=20, tabu_entity_rate=0.0,
33+
mutation_rate_multiplier=None, move_probas=[0, 1, 0, 0, 0, 0],
34+
migration_frequency=10, termination_strategy=termination_strategy)
35+
"""agent = GeneticAlgorithm(population_size=128, crossover_probability=0.5, p_best_rate=0.05,
36+
tabu_entity_rate=0.0, mutation_rate_multiplier=1.0, move_probas=[0, 1, 0, 0, 0, 0],
37+
migration_rate=0.00001, migration_frequency=1, termination_strategy=termination_strategy)"""
38+
"""agent = LateAcceptance(late_acceptance_size=10, tabu_entity_rate=0.0,
39+
mutation_rate_multiplier=None, move_probas=[0, 1, 0, 0, 0, 0],
40+
compare_to_global_frequency=1, termination_strategy=termination_strategy)"""
41+
"""agent = SimulatedAnnealing(initial_temperature=[1.0], cooling_rate=0.9999, tabu_entity_rate=0.0,
42+
mutation_rate_multiplier=None, move_probas=[0, 1, 0, 0, 0, 0],
43+
migration_frequency=10, compare_to_global_frequency=10, termination_strategy=termination_strategy)"""
44+
45+
solver = SolverOOP(domain_builder, cotwin_builder, agent,
46+
ParallelizationBackend.Multiprocessing, LoggingLevel.FreshOnly,
47+
n_jobs=10, score_precision=[0])
48+
solution = solver.solve()
49+
#print( "Cotwin solution looks that: " )
50+
#print( solution )
51+
52+
domain = domain_builder.build_from_solution(solution)
53+
print(domain)
54+
except Exception as e:
55+
print(traceback.format_exc())
5256

5357
print( "done" )

greyjack/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

greyjack/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "greyjack"
3-
version = "0.3.0"
3+
version = "0.3.2"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

greyjack/greyjack/agents/TabuSearch.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,15 @@ def __init__(
3131
self.is_win_from_comparing_with_global = True
3232

3333
def _build_metaheuristic_base(self):
34-
35-
# when I use issubclass() solver dies silently, so check specific attributes
3634
if hasattr(self.cotwin, "planning_entities"):
3735
self.score_requester = OOPScoreRequester(self.cotwin)
3836
score_variant = self.cotwin.score_calculator.score_variant
3937
elif isinstance(self.cotwin, MathModel):
4038
self.score_requester = PureMathScoreRequester(self.cotwin)
4139
score_variant = self.cotwin.score_variant
42-
self.cotwin.score_calculator.is_incremental = False # if True, currently works badder. Will try improve later
40+
self.cotwin.score_calculator.is_incremental = False
4341
else:
44-
raise Exception("Cotwin must be either subclass of CotwinBase, either be instance of MathModel")
42+
raise Exception("Cotwin must be either subclass of CotwinBase, or an instance of MathModel")
4543

4644
semantic_groups_dict = self.score_requester.variables_manager.semantic_groups_map.copy()
4745
discrete_ids = self.score_requester.variables_manager.discrete_ids
@@ -57,7 +55,6 @@ def _build_metaheuristic_base(self):
5755
discrete_ids,
5856
)
5957

60-
# to remove redundant clonning
6158
self.metaheuristic_name = self.metaheuristic_base.metaheuristic_name
6259
self.metaheuristic_kind = self.metaheuristic_base.metaheuristic_kind
6360

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# greyjack/score_calculation/score_calculators/GreynetScoreCalculator.py
2+
3+
from __future__ import annotations
4+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
5+
6+
from greyjack.score_calculation.scores.ScoreVariants import ScoreVariants
7+
from greyjack.variables.GJFloat import GJFloat
8+
from greyjack.variables.GJInteger import GJInteger
9+
from greyjack.variables.GJBinary import GJBinary
10+
from copy import deepcopy
11+
import numpy as np
12+
13+
if TYPE_CHECKING:
14+
from greyjack.score_calculation.greynet.builder import ConstraintBuilder
15+
from greyjack.score_calculation.greynet.session import Session
16+
17+
class GreynetScoreCalculator:
18+
"""
19+
An incremental score calculator that uses the Greynet rule engine.
20+
This calculator holds the Greynet session and provides methods for
21+
the ScoreRequester to interact with it efficiently.
22+
"""
23+
def __init__(self, constraint_builder: 'ConstraintBuilder', score_variant: ScoreVariants):
24+
"""
25+
Initializes the calculator by building the Greynet session from the
26+
provided constraint definitions.
27+
28+
Args:
29+
constraint_builder (ConstraintBuilder): The Greynet constraint builder
30+
containing all the rules for the problem.
31+
score_variant (ScoreVariants): The score variant enumeration that
32+
corresponds to the score class used in the constraint builder.
33+
"""
34+
from greyjack.score_calculation.greynet.builder import ConstraintBuilder as GreynetConstraintBuilder
35+
if not isinstance(constraint_builder, GreynetConstraintBuilder):
36+
raise TypeError("constraint_builder must be an instance of greynet.ConstraintBuilder")
37+
38+
self.session: 'Session' = constraint_builder.build()
39+
self.score_variant = score_variant
40+
self.is_incremental = True
41+
self.score_type = self.session.score_class
42+
43+
# This mapping is populated by the ScoreRequester during initialization.
44+
# It is essential for translating the solver's variable indices to domain objects.
45+
# Key: var_idx (int) -> Value: (fact_object, attribute_name_str)
46+
self.var_idx_to_entity_map: Dict[int, Tuple[Any, str]] = {}
47+
48+
self.first_call_apply_deltas_internal = True
49+
50+
def initial_load(self, planning_entities: Dict[str, List[Any]], problem_facts: Dict[str, List[Any]]):
51+
"""
52+
Performs the initial population of the Greynet session with all facts
53+
from the problem domain. This should only be called once.
54+
55+
Args:
56+
planning_entities (dict): A dictionary of planning entity lists.
57+
problem_facts (dict): A dictionary of problem fact lists.
58+
"""
59+
self.session.clear()
60+
61+
for group_name in problem_facts:
62+
self.session.insert_batch(problem_facts[group_name])
63+
64+
for group_name in planning_entities:
65+
initialized_entities = self.build_initialized_entities(planning_entities, group_name)
66+
self.session.insert_batch(initialized_entities)
67+
68+
self.session.flush()
69+
70+
def build_initialized_entities(self, planning_entities, group_name):
71+
72+
current_planning_entities_group = planning_entities[group_name]
73+
initialized_entities = []
74+
75+
for entity in current_planning_entities_group:
76+
new_entity = self.build_initialized_entity(entity)
77+
initialized_entities.append(new_entity)
78+
79+
return initialized_entities
80+
81+
def build_initialized_entity(self, entity):
82+
entity_attributes_dict = entity.__dict__
83+
new_entity_kwargs = {}
84+
for attribute_name in entity_attributes_dict:
85+
attribute_value = entity_attributes_dict[attribute_name]
86+
87+
if type(attribute_value) in {GJFloat, GJInteger, GJBinary}:
88+
value = attribute_value.planning_variable.initial_value
89+
90+
if value is None:
91+
raise ValueError("All planning variables must have initial value for scoring by greynet")
92+
else:
93+
value = attribute_value
94+
95+
new_entity_kwargs[attribute_name] = value
96+
97+
new_entity = type(entity)(**new_entity_kwargs)
98+
return new_entity
99+
100+
def get_score(self) -> Any:
101+
"""
102+
Retrieves the current total score from the Greynet session.
103+
Assumes all pending changes have been flushed.
104+
105+
Returns:
106+
A score object (e.g., HardSoftScore) representing the current state.
107+
"""
108+
score = self.session.get_score()
109+
return score
110+
111+
def _full_sync_and_get_score(self, sample: List[float]) -> Any:
112+
"""
113+
A non-incremental way to get a score for a full solution vector.
114+
This modifies the session state and is primarily for debugging or fallback.
115+
"""
116+
changed_facts, original_vals = self._apply_deltas_internal(list(enumerate(sample)))
117+
score = self.get_score()
118+
self._revert_deltas_internal(changed_facts, original_vals)
119+
return score
120+
121+
def _apply_and_get_score_for_batch(self, deltas: List[List[Tuple[int, float]]]) -> List[Any]:
122+
"""
123+
Applies a batch of deltas, gets the score for each, and reverts the state
124+
between each delta application. This is the primary method for incremental scoring.
125+
"""
126+
scores = []
127+
for delta_set in deltas:
128+
if not delta_set:
129+
scores.append(self.get_score())
130+
continue
131+
132+
changed_facts, original_values = self._apply_deltas_internal(delta_set)
133+
scores.append(self.get_score())
134+
self._revert_deltas_internal(changed_facts, original_values)
135+
136+
137+
return scores
138+
139+
def _apply_deltas_internal(self, deltas: List[Tuple[int, float]]) -> Tuple[Dict[int, Any], Dict[int, Any]]:
140+
"""
141+
Internal helper to apply changes to the session state.
142+
143+
Returns:
144+
A tuple containing (dict of changed_facts, dict of original_values) for reverting.
145+
"""
146+
changed_facts: Dict[int, Any] = {}
147+
original_facts: Dict[int, Any] = {}
148+
149+
if self.first_call_apply_deltas_internal:
150+
self.session.clear()
151+
for var_idx, new_value in deltas:
152+
old_entity, attr_name = self.var_idx_to_entity_map[var_idx]
153+
new_entity = self.build_initialized_entity(old_entity)
154+
self.var_idx_to_entity_map[var_idx] = (new_entity, attr_name)
155+
self.first_call_apply_deltas_internal = False
156+
157+
# removing previous values
158+
for var_idx, new_value in deltas:
159+
entity, attr_name = self.var_idx_to_entity_map[var_idx]
160+
entity_id = id(entity)
161+
162+
if entity_id not in original_facts:
163+
original_facts[entity_id] = entity
164+
165+
self.session.retract_batch(list(original_facts.values()))
166+
167+
# inserting new values
168+
for var_idx, new_value in deltas:
169+
entity, attr_name = self.var_idx_to_entity_map[var_idx]
170+
entity_id = id(entity)
171+
172+
if entity_id not in changed_facts:
173+
changed_facts[entity_id] = deepcopy(entity)
174+
175+
setattr(changed_facts[entity_id], attr_name, new_value)
176+
177+
self.session.insert_batch(list(changed_facts.values()))
178+
self.session.flush()
179+
180+
return changed_facts, original_facts
181+
182+
def _revert_deltas_internal(self, changed_facts: Dict[int, Any], original_facts: Dict[int, Any]):
183+
"""Internal helper to revert changes to the session state."""
184+
185+
self.session.retract_batch(list(changed_facts.values()))
186+
self.session.insert_batch(list(original_facts.values()))
187+
self.session.flush()
188+
189+
pass

0 commit comments

Comments
 (0)