Skip to content

Commit 083da7e

Browse files
NiraliPopatvipul-mittalpsriramsnc
authored
[Enhancement] Metrics changes and Eval demo tasks addition (#119)
* Metrics changes and Eval demo tasks addition * Unit tests modification * fix format * fix lint * fix tests * fix tests * add documentation * fix documentation link --------- Co-authored-by: Vipul Mittal <118464422+vipul-mittal@users.noreply.github.com> Co-authored-by: Sriram Puttagunta <sriram.puttagunta@servicenow.com>
1 parent 45fd394 commit 083da7e

21 files changed

Lines changed: 1339 additions & 983 deletions

File tree

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,19 @@ LLM-based nodes require a model configured in `models.yaml` and runtime paramete
172172

173173
As of now, LLM inference is supported for TGI, vLLM, OpenAI, Azure, Azure OpenAI, Ollama and Triton compatible servers. Model deployment is external and configured in `models.yaml`.
174174

175+
## SyGra as a Platform
176+
177+
SyGra can be used as a reusable platform to build different categories of tasks on top of the same graph execution engine, node types, processors, and metric infrastructure.
178+
179+
### Eval
180+
181+
Evaluation tasks live under `tasks/eval` and provide a standard pattern for:
182+
183+
- Computing **unit metrics** per record during graph execution
184+
- Computing **aggregator metrics** after the run via graph post-processing
185+
186+
See: [`tasks/eval/README.md`](https://github.com/ServiceNow/SyGra/blob/main/tasks/eval/README.md)
187+
175188
<!-- ![SygraComponents](https://raw.githubusercontent.com/ServiceNow/SyGra/refs/heads/main/docs/resources/images/sygra_usecase2framework.png) -->
176189

177190

sygra/core/base_task_executor.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,9 @@ def _repeat_to_merge_sequentially(
544544

545545
# merge the primary and secondary dataframe horizontally by randomlly picking one and adding into primary
546546
# primary : M rows(a columns), secondary: N rows(b columns), merged: M rows(a+b columns)
547-
def _shuffle_and_extend(self, primary_df, secondary_df) -> pd.DataFrame:
547+
def _shuffle_and_extend(
548+
self, primary_df: pd.DataFrame, secondary_df: pd.DataFrame
549+
) -> pd.DataFrame:
548550
max_len = len(primary_df)
549551
# Shuffle the secondary dataframe
550552
shuffled_secondary = secondary_df.sample(frac=1).reset_index(drop=True)
@@ -560,7 +562,7 @@ def _shuffle_and_extend(self, primary_df, secondary_df) -> pd.DataFrame:
560562
final_secondary = pd.concat([shuffled_secondary, extra_rows], ignore_index=True)
561563

562564
# now both dataset are same length, merge and return
563-
return pd.concat([primary_df, final_secondary], axis=1)
565+
return cast(pd.DataFrame, pd.concat([primary_df, final_secondary], axis=1))
564566

565567
def _load_source_data(
566568
self, data_config: dict
@@ -587,8 +589,8 @@ def _load_source_data(
587589
full_data = self.apply_transforms(source_config_obj, full_data)
588590
elif isinstance(source_config, list):
589591
# if multiple dataset configured as list
590-
dataset_list = []
591-
primary_df = None
592+
dataset_list: list[dict[str, Any]] = []
593+
primary_df: Optional[pd.DataFrame] = None
592594
primary_config = None
593595
# if multiple dataset, verify if join_type and alias is defined in each config(@source and @sink)
594596
if isinstance(source_config, list):
@@ -650,6 +652,9 @@ def _load_source_data(
650652
ds_conf: dict[str, Any] = ds.get("conf", {})
651653
join_type = ds_conf.get(constants.DATASET_JOIN_TYPE)
652654
current_df = ds.get("dataset")
655+
if current_df is None or not isinstance(current_df, pd.DataFrame):
656+
logger.error("Dataset is missing or not a dataframe")
657+
continue
653658
if join_type == constants.JOIN_TYPE_COLUMN:
654659
sec_alias_name = ds_conf.get(constants.DATASET_ALIAS)
655660
pri_alias_name = (
@@ -665,22 +670,26 @@ def _load_source_data(
665670
# where_clause = ds.get("conf").get("where_clause")
666671
primary_df = pd.merge(
667672
primary_df,
668-
current_df,
673+
cast(pd.DataFrame, current_df),
669674
left_on=primary_column,
670675
right_on=join_column,
671676
how="left",
672677
)
673678
elif join_type == constants.JOIN_TYPE_SEQUENTIAL:
674-
primary_df = self._repeat_to_merge_sequentially(primary_df, current_df)
679+
primary_df = self._repeat_to_merge_sequentially(
680+
primary_df, cast(pd.DataFrame, current_df)
681+
)
675682
elif join_type == constants.JOIN_TYPE_CROSS:
676-
primary_df = primary_df.merge(current_df, how="cross")
683+
primary_df = primary_df.merge(cast(pd.DataFrame, current_df), how="cross")
677684
elif join_type == constants.JOIN_TYPE_RANDOM:
678-
primary_df = self._shuffle_and_extend(primary_df, current_df)
685+
primary_df = self._shuffle_and_extend(
686+
primary_df, cast(pd.DataFrame, current_df)
687+
)
679688
else:
680689
logger.error("Not implemented join_type")
681690

682691
# now convert dataframe to list of dict (full_data)
683-
full_data = primary_df.to_dict(orient="records")
692+
full_data = cast(list[dict[str, Any]], primary_df.to_dict(orient="records"))
684693
else:
685694
logger.error("Unsupported source config type.")
686695

sygra/core/eval/metrics/aggregator_metrics/aggregator_metric_registry.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
# Avoid circular imports
88
from __future__ import annotations
99

10+
import importlib
11+
import pkgutil
1012
from typing import TYPE_CHECKING, Dict, List, Type
1113

1214
from sygra.logger.logger_config import logger
@@ -42,6 +44,33 @@ class AggregatorMetricRegistry:
4244

4345
# Class-level storage (create singleton to have central control)
4446
_metrics: Dict[str, Type[BaseAggregatorMetric]] = {}
47+
_discovered: bool = False
48+
49+
@classmethod
50+
def _ensure_discovered(cls) -> None:
51+
if cls._discovered:
52+
return
53+
54+
try:
55+
import sygra.core.eval.metrics.aggregator_metrics as aggregator_metrics_pkg
56+
57+
for module_info in pkgutil.iter_modules(
58+
aggregator_metrics_pkg.__path__, aggregator_metrics_pkg.__name__ + "."
59+
):
60+
module_name = module_info.name
61+
if module_name.endswith(
62+
(
63+
".base_aggregator_metric",
64+
".aggregator_metric_registry",
65+
)
66+
):
67+
continue
68+
importlib.import_module(module_name)
69+
70+
cls._discovered = True
71+
except Exception as e:
72+
logger.error(f"Failed to auto-discover aggregator metrics: {e}")
73+
cls._discovered = True
4574

4675
@classmethod
4776
def register(cls, name: str, metric_class: Type[BaseAggregatorMetric]) -> None:
@@ -105,6 +134,8 @@ def get_metric(cls, name: str, **kwargs) -> BaseAggregatorMetric:
105134
# Get metric with custom parameters
106135
topk = AggregatorMetricRegistry.get_metric("top_k_accuracy", k=5)
107136
"""
137+
cls._ensure_discovered()
138+
108139
if name not in cls._metrics:
109140
available = cls.list_metrics()
110141
raise KeyError(
@@ -135,6 +166,7 @@ def list_metrics(cls) -> List[str]:
135166
AggregatorMetricRegistry.list_metrics()
136167
['accuracy', 'confusion_matrix', 'f1', 'precision', 'recall']
137168
"""
169+
cls._ensure_discovered()
138170
return sorted(cls._metrics.keys())
139171

140172
@classmethod
@@ -149,6 +181,7 @@ def has_metric(cls, name: str) -> bool:
149181
if AggregatorMetricRegistry.has_metric("f1"):
150182
metric = AggregatorMetricRegistry.get_metric("f1")
151183
"""
184+
cls._ensure_discovered()
152185
return name in cls._metrics
153186

154187
@classmethod
@@ -163,6 +196,8 @@ def get_metric_class(cls, name: str) -> Type[BaseAggregatorMetric]:
163196
Raises:
164197
KeyError: If metric name is not registered
165198
"""
199+
cls._ensure_discovered()
200+
166201
if name not in cls._metrics:
167202
available = cls.list_metrics()
168203
raise KeyError(

sygra/core/eval/metrics/aggregator_metrics/f1_score.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from typing import Any, Dict, List
99

10-
from pydantic import BaseModel, Field, field_validator
10+
from pydantic import BaseModel, Field
1111

1212
from sygra.core.eval.metrics.aggregator_metrics.aggregator_metric_registry import aggregator_metric
1313
from sygra.core.eval.metrics.aggregator_metrics.base_aggregator_metric import BaseAggregatorMetric
@@ -23,14 +23,6 @@ class F1ScoreMetricConfig(BaseModel):
2323

2424
predicted_key: str = Field(..., min_length=1, description="Key in predicted dict to check")
2525
golden_key: str = Field(..., min_length=1, description="Key in golden dict to check")
26-
positive_class: Any = Field(..., description="Value representing positive class")
27-
28-
@field_validator("positive_class")
29-
@classmethod
30-
def validate_positive_class(cls, v):
31-
if v is None:
32-
raise ValueError("positive_class is required (cannot be None)")
33-
return v
3426

3527

3628
@aggregator_metric("f1_score")
@@ -43,7 +35,6 @@ class F1ScoreMetric(BaseAggregatorMetric):
4335
Required configuration:
4436
predicted_key: Key in predicted dict to check (e.g., "tool")
4537
golden_key: Key in golden dict to check (e.g., "event")
46-
positive_class: Value representing the positive class (e.g., "click")
4738
"""
4839

4940
def __init__(self, **config):
@@ -60,15 +51,10 @@ def validate_config(self):
6051
# Store validated fields as instance attributes
6152
self.predicted_key = config_obj.predicted_key
6253
self.golden_key = config_obj.golden_key
63-
self.positive_class = config_obj.positive_class
6454

6555
# Create precision and recall metrics (reuse implementations)
66-
self.precision_metric = PrecisionMetric(
67-
predicted_key=self.predicted_key, positive_class=self.positive_class
68-
)
69-
self.recall_metric = RecallMetric(
70-
golden_key=self.golden_key, positive_class=self.positive_class
71-
)
56+
self.precision_metric = PrecisionMetric(predicted_key=self.predicted_key)
57+
self.recall_metric = RecallMetric(golden_key=self.golden_key)
7258

7359
def get_metadata(self) -> BaseMetricMetadata:
7460
"""Return metadata for F1 score metric"""
@@ -93,16 +79,27 @@ def calculate(self, results: List[UnitMetricResult]) -> Dict[str, Any]:
9379
"""
9480
if not results:
9581
logger.warning(f"{self.__class__.__name__}: No results provided")
96-
return {"f1_score": 0.0}
82+
return {"average_f1_score": 0.0, "f1_score_per_class": {}}
9783

84+
f1_score = dict()
9885
# Reuse existing metric implementations
9986
precision_result = self.precision_metric.calculate(results)
10087
recall_result = self.recall_metric.calculate(results)
10188

102-
precision = precision_result.get("precision", 0.0)
103-
recall = recall_result.get("recall", 0.0)
104-
10589
# Calculate F1 as harmonic mean of precision and recall
106-
f1_score = self._safe_divide(2 * precision * recall, precision + recall)
90+
average_precision = precision_result.get("average_precision", 0.0)
91+
average_recall = recall_result.get("average_recall", 0.0)
92+
average_f1_score = self._safe_divide(
93+
2 * average_precision * average_recall, average_precision + average_recall
94+
)
95+
96+
precision_classes = set(precision_result.get("precision_per_class", {}).keys())
97+
recall_classes = set(recall_result.get("recall_per_class", {}).keys())
98+
all_classes = precision_classes.union(recall_classes)
99+
100+
for class_ in all_classes:
101+
precision = precision_result.get("precision_per_class", {}).get(class_, 0.0)
102+
recall = recall_result.get("recall_per_class", {}).get(class_, 0.0)
103+
f1_score[class_] = self._safe_divide(2 * precision * recall, precision + recall)
107104

108-
return {"f1_score": f1_score}
105+
return {"average_f1_score": average_f1_score, "f1_score_per_class": f1_score}

sygra/core/eval/metrics/aggregator_metrics/precision.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
Measures: Of all predicted positives, how many were actually positive?
66
"""
77

8-
from typing import Any, Dict, List
8+
from collections import defaultdict
9+
from typing import Any, DefaultDict, Dict, List
910

10-
from pydantic import BaseModel, Field, field_validator
11+
from pydantic import BaseModel, Field
1112

1213
from sygra.core.eval.metrics.aggregator_metrics.aggregator_metric_registry import aggregator_metric
1314
from sygra.core.eval.metrics.aggregator_metrics.base_aggregator_metric import BaseAggregatorMetric
@@ -20,14 +21,6 @@ class PrecisionMetricConfig(BaseModel):
2021
"""Configuration for Precision Metric"""
2122

2223
predicted_key: str = Field(..., min_length=1, description="Key in predicted dict to check")
23-
positive_class: Any = Field(..., description="Value representing positive class")
24-
25-
@field_validator("positive_class")
26-
@classmethod
27-
def validate_positive_class(cls, v):
28-
if v is None:
29-
raise ValueError("positive_class is required (cannot be None)")
30-
return v
3124

3225

3326
@aggregator_metric("precision")
@@ -39,12 +32,12 @@ class PrecisionMetric(BaseAggregatorMetric):
3932
4033
Required configuration:
4134
predicted_key: Key in predicted dict to check (e.g., "tool")
42-
positive_class: Value representing the positive class (e.g., "click")
4335
"""
4436

4537
def __init__(self, **config):
4638
"""Initialize precision metric with two-phase initialization."""
4739
super().__init__(**config)
40+
self.predicted_key = None
4841
self.validate_config()
4942
self.metadata = self.get_metadata()
5043

@@ -55,7 +48,6 @@ def validate_config(self):
5548

5649
# Store validated fields as instance attributes
5750
self.predicted_key = config_obj.predicted_key
58-
self.positive_class = config_obj.positive_class
5951

6052
def get_metadata(self) -> BaseMetricMetadata:
6153
"""Return metadata for precision metric"""
@@ -76,23 +68,54 @@ def calculate(self, results: List[UnitMetricResult]) -> Dict[str, Any]:
7668
results: List of UnitMetricResult
7769
7870
Returns:
79-
dict: {"precision": float (0.0 to 1.0)}
71+
dict: {
72+
"average_precision": float (0.0 to 1.0)
73+
"precision_per_class": {
74+
"class_1": float (0.0 to 1.0),
75+
"class_2": float (0.0 to 1.0),
76+
...
77+
"class_n": float (0.0 to 1.0)
78+
}
79+
}
8080
"""
8181
if not results:
8282
logger.warning(f"{self.__class__.__name__}: No results provided")
83-
return {"precision": 0.0}
84-
85-
# Calculate TP and FP
86-
tp = sum(
87-
1
88-
for r in results
89-
if r.predicted.get(self.predicted_key) == self.positive_class and r.correct
90-
)
91-
fp = sum(
92-
1
93-
for r in results
94-
if r.predicted.get(self.predicted_key) == self.positive_class and not r.correct
83+
return {"average_precision": 0.0, "precision_per_class": {}}
84+
85+
predicted_count: DefaultDict[str, int] = defaultdict(int)
86+
true_positive: DefaultDict[str, int] = defaultdict(int)
87+
88+
for r in results:
89+
try:
90+
predicted_key = self.predicted_key
91+
if predicted_key is None:
92+
logger.warning(f"{self.__class__.__name__}: predicted_key is not configured")
93+
continue
94+
label = r.predicted[predicted_key]
95+
except KeyError:
96+
logger.warning(
97+
f"{self.__class__.__name__}: Missing predicted_key '{self.predicted_key}' in result"
98+
)
99+
continue
100+
101+
if not isinstance(label, str):
102+
label = str(label)
103+
104+
predicted_count[label] += 1
105+
if r.correct:
106+
true_positive[label] += 1
107+
108+
precision_per_class = {
109+
label: self._safe_divide(true_positive[label], count)
110+
for label, count in predicted_count.items()
111+
}
112+
113+
average_precision = self._safe_divide(
114+
sum(precision_per_class.values()),
115+
len(precision_per_class),
95116
)
96117

97-
precision = self._safe_divide(tp, tp + fp)
98-
return {"precision": precision}
118+
return {
119+
"average_precision": average_precision,
120+
"precision_per_class": precision_per_class,
121+
}

0 commit comments

Comments
 (0)