diff --git a/api/environments/identities/helpers.py b/api/environments/identities/helpers.py deleted file mode 100644 index 3bafa62c6003..000000000000 --- a/api/environments/identities/helpers.py +++ /dev/null @@ -1,31 +0,0 @@ -import hashlib -import typing - - -def get_hashed_percentage_for_object_ids( - object_ids: typing.Iterable[typing.Union[str, int]], iterations: int = 1 -) -> float: - """ - Given a list of object ids, get a floating point number between 0 and 1 based on - the hash of those ids. This should give the same value every time for any - list of ids. - - :param object_ids: list of object ids to calculate the has for - :param iterations: num times to include each id in the generated string to hash - :return: (float) number between 0 (inclusive) and 1 (exclusive) - """ - - to_hash = ",".join(str(id_) for id_ in list(object_ids) * iterations) - hashed_value = hashlib.md5(to_hash.encode("utf-8")) - hashed_value_as_int = int(hashed_value.hexdigest(), base=16) - value = (hashed_value_as_int % 9999) / 9998 - - if value == 1: - # since we want a number between 0 (inclusive) and 1 (exclusive), in the - # unlikely case that we get the exact number 1, we call the method again - # and increase the number of iterations to ensure we get a different result - return get_hashed_percentage_for_object_ids( - object_ids=object_ids, iterations=iterations + 1 - ) - - return value diff --git a/api/environments/identities/models.py b/api/environments/identities/models.py index 4d0d1ae1a63b..fd931e79b232 100644 --- a/api/environments/identities/models.py +++ b/api/environments/identities/models.py @@ -2,6 +2,7 @@ from django.db import models from django.db.models import Prefetch, Q +from flag_engine.engine import get_evaluation_result from environments.identities.managers import IdentityManager from environments.identities.traits.models import Trait @@ -11,15 +12,7 @@ from features.multivariate.models import MultivariateFeatureStateValue from features.versioning.versioning_service import get_environment_flags_list from segments.models import Segment -from util.engine_models.context.mappers import ( - is_context_in_segment, - map_environment_identity_to_context, -) -from util.mappers.engine import ( - map_identity_to_engine, - map_segment_to_engine, - map_traits_to_engine, -) +from util.mappers.engine import map_environment_to_evaluation_context class Identity(models.Model): @@ -154,7 +147,6 @@ def get_segments( :param overrides_only: only retrieve the segments which have a valid override in the environment :return: List of matching segments """ - matching_segments = [] db_traits = ( self.identity_traits.all() if (traits is None and self.id) else traits or [] ) @@ -164,29 +156,18 @@ def get_segments( else: all_segments = self.environment.project.get_segments_from_cache() - engine_identity = map_identity_to_engine( - self, - with_overrides=False, - with_traits=False, + segments_by_pk = {segment.pk: segment for segment in all_segments} + context = map_environment_to_evaluation_context( + identity=self, + environment=self.environment, + traits=db_traits, + segments=all_segments, ) - engine_traits = map_traits_to_engine(db_traits) - - for segment in all_segments: - engine_segment = map_segment_to_engine(segment) - - context = map_environment_identity_to_context( - environment=self.environment, - identity=engine_identity, - override_traits=engine_traits, - ) - - if is_context_in_segment( - context=context, - segment=engine_segment, - ): - matching_segments.append(segment) - - return matching_segments + result = get_evaluation_result(context) + return [ + segments_by_pk[segment_result["metadata"]["pk"]] + for segment_result in result["segments"] + ] def get_all_user_traits(self): # type: ignore[no-untyped-def] # this is pointless, we should probably replace all uses with the below code diff --git a/api/features/models.py b/api/features/models.py index 5e32ce7ef7df..616d4141545c 100644 --- a/api/features/models.py +++ b/api/features/models.py @@ -25,6 +25,7 @@ LifecycleModelMixin, hook, ) +from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids from ordered_model.models import OrderedModelBase # type: ignore[import-untyped] from simple_history.models import HistoricalRecords # type: ignore[import-untyped] @@ -49,9 +50,6 @@ SoftDeleteExportableModel, abstract_base_auditable_model_factory, ) -from environments.identities.helpers import ( - get_hashed_percentage_for_object_ids, -) from features.constants import ENVIRONMENT, FEATURE_SEGMENT, IDENTITY from features.custom_lifecycle import CustomLifecycleModelMixin from features.feature_states.models import AbstractBaseFeatureValueModel @@ -750,8 +748,8 @@ def get_multivariate_feature_state_value( # avoid further queries to the DB mv_options = list(self.multivariate_feature_state_values.all()) - percentage_value = ( - get_hashed_percentage_for_object_ids([self.id, identity_hash_key]) * 100 + percentage_value = get_hashed_percentage_for_object_ids( + [self.id, identity_hash_key] ) # Iterate over the mv options in order of id (so we get the same value each diff --git a/api/integrations/webhook/serializers.py b/api/integrations/webhook/serializers.py index 9710541d3a2e..b844ff6b7308 100644 --- a/api/integrations/webhook/serializers.py +++ b/api/integrations/webhook/serializers.py @@ -1,6 +1,7 @@ import typing from django.db.models import Q +from flag_engine.engine import get_evaluation_result from rest_framework import serializers from features.serializers import FeatureStateSerializerFull @@ -8,12 +9,7 @@ BaseEnvironmentIntegrationModelSerializer, ) from segments.models import Segment -from util.engine_models.context.mappers import is_context_in_segment -from util.mappers.engine import ( - map_engine_identity_to_context, - map_identity_to_engine, - map_segment_to_engine, -) +from util.mappers.engine import map_environment_to_evaluation_context from .models import WebhookConfiguration @@ -32,16 +28,14 @@ class Meta: fields = ("id", "name", "member") def get_member(self, obj: Segment) -> bool: - engine_identity = map_identity_to_engine( - self.context.get("identity"), # type: ignore[arg-type] - with_overrides=False, - ) - engine_segment = map_segment_to_engine(obj) - context = map_engine_identity_to_context(engine_identity) - return is_context_in_segment( - context=context, - segment=engine_segment, + identity = self.context["identity"] + context = map_environment_to_evaluation_context( + identity=identity, + environment=identity.environment, + segments=[obj], ) + result = get_evaluation_result(context) + return bool(result["segments"]) class IntegrationFeatureStateSerializer(FeatureStateSerializerFull): diff --git a/api/segments/types.py b/api/segments/types.py new file mode 100644 index 000000000000..ee463461b5ec --- /dev/null +++ b/api/segments/types.py @@ -0,0 +1,5 @@ +from typing_extensions import TypedDict + + +class SegmentEngineMetadata(TypedDict): + pk: int diff --git a/api/tests/integration/environments/identities/test_integration_identities.py b/api/tests/integration/environments/identities/test_integration_identities.py index ae16baa46445..fb5d3884b8ad 100644 --- a/api/tests/integration/environments/identities/test_integration_identities.py +++ b/api/tests/integration/environments/identities/test_integration_identities.py @@ -31,9 +31,9 @@ @pytest.mark.parametrize( "hashed_percentage, expected_mv_value", ( - (variant_1_percentage_allocation / 100 - 0.01, variant_1_value), - (total_variance_percentage / 100 - 0.01, variant_2_value), - (total_variance_percentage / 100 + 0.01, control_value), + (variant_1_percentage_allocation - 1, variant_1_value), + (total_variance_percentage - 1, variant_2_value), + (total_variance_percentage + 1, control_value), ), ) @mock.patch("features.models.get_hashed_percentage_for_object_ids") diff --git a/api/tests/unit/environments/identities/test_unit_identities_helpers.py b/api/tests/unit/environments/identities/test_unit_identities_helpers.py deleted file mode 100644 index 20fd0d9714d8..000000000000 --- a/api/tests/unit/environments/identities/test_unit_identities_helpers.py +++ /dev/null @@ -1,124 +0,0 @@ -import itertools -from unittest import mock - -from environments.identities.helpers import ( - get_hashed_percentage_for_object_ids, -) - - -def test_get_hashed_percentage_for_object_ids_is_number_between_0_inc_and_1_exc(): # type: ignore[no-untyped-def] - assert 1 > get_hashed_percentage_for_object_ids([12, 93]) >= 0 - - -def test_get_hashed_percentage_for_object_ids_is_the_same_each_time(): # type: ignore[no-untyped-def] - # Given - object_ids = [30, 73] - - # When - result_1 = get_hashed_percentage_for_object_ids(object_ids) - result_2 = get_hashed_percentage_for_object_ids(object_ids) - - # Then - assert result_1 == result_2 - - -def test_percentage_value_is_unique_for_different_identities(): # type: ignore[no-untyped-def] - # Given - first_object_ids = [14, 106] - second_object_ids = [53, 200] - - # When - result_1 = get_hashed_percentage_for_object_ids(first_object_ids) - result_2 = get_hashed_percentage_for_object_ids(second_object_ids) - - # Then - assert result_1 != result_2 - - -def test_get_hashed_percentage_for_object_ids_should_be_evenly_distributed(): # type: ignore[no-untyped-def] - """ - This test checks if the percentage value returned by the helper function returns - evenly distributed values. - - Note that since it's technically random, it's not guaranteed to pass every time, - however, it should pass 99/100 times. It will likely be more accurate by increasing - the test_sample value and / or decreasing the num_test_buckets value. - """ - test_sample = 500 # number of ids to sample in each list - num_test_buckets = 50 # split the sample into 'buckets' to check that the values are evenly distributed - test_bucket_size = int(test_sample / num_test_buckets) - error_factor = 0.1 - - # Given - object_id_pairs = itertools.product(range(test_sample), range(test_sample)) - - # When - values = sorted( - get_hashed_percentage_for_object_ids(pair) for pair in object_id_pairs - ) - - # Then - for i in range(num_test_buckets): - bucket_start = i * test_bucket_size - bucket_end = (i + 1) * test_bucket_size - bucket_value_limit = min( - (i + 1) / num_test_buckets + error_factor * ((i + 1) / num_test_buckets), - 1, - ) - - assert all( - [value <= bucket_value_limit for value in values[bucket_start:bucket_end]] - ) - - -@mock.patch("environments.identities.helpers.hashlib") -def test_get_hashed_percentage_does_not_return_1(mock_hashlib): # type: ignore[no-untyped-def] - """ - Quite complex test to ensure that the function will never return 1. - - To achieve this, we mock the hashlib module to return a magic mock so that we can - subsequently mock the hexdigest method to return known strings. These strings are - chosen such that they can be converted (via `int(s, base=16)`) to known integers. - """ - - # Given - object_ids = [12, 93] - - # -- SETTING UP THE MOCKS -- - # hash strings specifically created to return specific values when converted to - # integers via int(s, base=16). Note that the reverse function was created - # courtesy of https://code.i-harness.com/en/q/1f7c41 - hash_string_to_return_1 = "270e" - hash_string_to_return_0 = "270f" - hashed_values = [hash_string_to_return_0, hash_string_to_return_1] - - def hexdigest_side_effect(): # type: ignore[no-untyped-def] - return hashed_values.pop() - - mock_hash = mock.MagicMock() - mock_hashlib.md5.return_value = mock_hash - - mock_hash.hexdigest.side_effect = hexdigest_side_effect - - # -- FINISH SETTING UP THE MOCKS -- - - # When - # we get the hashed percentage value for the given object ids - value = get_hashed_percentage_for_object_ids(object_ids) - - # Then - # The value is 0 as defined by the mock data - assert value == 0 - - # and the md5 function was called twice - # (i.e. the get_hashed_percentage_for_object_ids function was also called twice) - call_list = mock_hashlib.md5.call_args_list - assert len(call_list) == 2 - - # the first call, with a string (in bytes) that contains each object id once - expected_bytes_1 = ",".join(str(id_) for id_ in object_ids).encode("utf-8") - assert call_list[0][0][0] == expected_bytes_1 - - # the second call, with a string (in bytes) that contains each object id twice - expected_bytes_2 = ",".join(str(id_) for id_ in object_ids * 2).encode("utf-8") - assert call_list[1][0][0] == expected_bytes_2 diff --git a/api/tests/unit/environments/identities/test_unit_identities_views.py b/api/tests/unit/environments/identities/test_unit_identities_views.py index dff9ed67bead..ca319459c7c2 100644 --- a/api/tests/unit/environments/identities/test_unit_identities_views.py +++ b/api/tests/unit/environments/identities/test_unit_identities_views.py @@ -23,9 +23,6 @@ STRING, ) from environments.identities import views -from environments.identities.helpers import ( - get_hashed_percentage_for_object_ids, -) from environments.identities.models import Identity from environments.identities.traits.models import Trait from environments.models import Environment, EnvironmentAPIKey @@ -598,14 +595,9 @@ def test_identities_endpoint_returns_value_for_segment_if_rule_type_percentage_s segment=segment, type=SegmentRule.ALL_RULE ) - identity_percentage_value = get_hashed_percentage_for_object_ids( - [segment.id, identity.id] - ) Condition.objects.create( operator=PERCENTAGE_SPLIT, - value=int( - (identity_percentage_value + (1 - identity_percentage_value) / 2) * 100.0 - ), + value=100, rule=segment_rule, ) feature_segment = FeatureSegment.objects.create( @@ -651,12 +643,9 @@ def test_identities_endpoint_returns_default_value_if_rule_type_percentage_split segment=segment, type=SegmentRule.ALL_RULE ) - identity_percentage_value = get_hashed_percentage_for_object_ids( - [segment.id, identity.id] - ) Condition.objects.create( operator=PERCENTAGE_SPLIT, - value=int(identity_percentage_value / 2), + value=0, rule=segment_rule, ) feature_segment = FeatureSegment.objects.create( diff --git a/api/tests/unit/features/test_unit_features_models.py b/api/tests/unit/features/test_unit_features_models.py index b10ba936dda8..b5d2b24990f0 100644 --- a/api/tests/unit/features/test_unit_features_models.py +++ b/api/tests/unit/features/test_unit_features_models.py @@ -594,7 +594,7 @@ def test_feature_state_type_feature_segment( assert feature_state.type == FEATURE_SEGMENT -@pytest.mark.parametrize("hashed_percentage", (0.0, 0.3, 0.5, 0.8, 0.999999)) +@pytest.mark.parametrize("hashed_percentage", (0.0, 30.0, 50.0, 80.0, 99.9999)) @mock.patch("features.models.get_hashed_percentage_for_object_ids") def test_get_multivariate_value_returns_correct_value_when_we_pass_identity( # type: ignore[no-untyped-def] mock_get_hashed_percentage, diff --git a/api/tests/unit/util/mappers/test_unit_mappers_engine.py b/api/tests/unit/util/mappers/test_unit_mappers_engine.py index 52cd0f04895b..cd2020a4012c 100644 --- a/api/tests/unit/util/mappers/test_unit_mappers_engine.py +++ b/api/tests/unit/util/mappers/test_unit_mappers_engine.py @@ -6,6 +6,8 @@ from django.utils import timezone from pytest_mock import MockerFixture +from environments.identities.models import Identity +from environments.identities.traits.models import Trait from environments.models import Environment from features.models import FeatureSegment, FeatureState from features.versioning.models import EnvironmentFeatureVersion @@ -15,7 +17,7 @@ from integrations.mixpanel.models import MixpanelConfiguration from integrations.segment.models import SegmentConfiguration from integrations.webhook.models import WebhookConfiguration -from segments.models import Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule from users.models import FFAdminUser from util.engine_models.environments.integrations.models import IntegrationModel from util.engine_models.environments.models import ( @@ -45,7 +47,6 @@ from util.mappers import engine if TYPE_CHECKING: - from environments.identities import Identity, Trait # type: ignore[attr-defined] from environments.models import EnvironmentAPIKey from features.models import Feature from projects.models import Project @@ -657,3 +658,207 @@ def test_map_environment_to_engine_v2_versioning_segment_overrides( environment_model.project.segments[0].feature_states[0].django_id == v3_segment_override.id ) + + +def test_map_environment_to_evaluation_context__no_identity__returns_environment_only( + environment: Environment, +) -> None: + # When + result = engine.map_environment_to_evaluation_context(environment=environment) + + # Then + assert result == { + "environment": { + "key": environment.api_key, + "name": environment.name, + }, + } + + +def test_map_environment_to_evaluation_context__with_identity__returns_identity_context( + environment: Environment, + identity: Identity, +) -> None: + # When + result = engine.map_environment_to_evaluation_context( + environment=environment, + identity=identity, + ) + + # Then + assert result == { + "environment": { + "key": environment.api_key, + "name": environment.name, + }, + "identity": { + "identifier": identity.identifier, + "key": identity.get_hash_key( + environment.use_identity_composite_key_for_hashing + ), + "traits": {}, + }, + } + + +def test_map_environment_to_evaluation_context__with_explicit_traits__returns_given_traits( + environment: Environment, + identity: Identity, + trait: Trait, +) -> None: + # When + result = engine.map_environment_to_evaluation_context( + environment=environment, + identity=identity, + traits=[trait], + ) + + # Then + assert result == { + "environment": { + "key": environment.api_key, + "name": environment.name, + }, + "identity": { + "identifier": identity.identifier, + "key": identity.get_hash_key( + environment.use_identity_composite_key_for_hashing + ), + "traits": {trait.trait_key: trait.trait_value}, + }, + } + + +def test_map_environment_to_evaluation_context__no_explicit_traits__returns_identity_traits( + environment: Environment, + identity: Identity, + trait: Trait, +) -> None: + # When + result = engine.map_environment_to_evaluation_context( + environment=environment, + identity=identity, + ) + + # Then + assert result == { + "environment": { + "key": environment.api_key, + "name": environment.name, + }, + "identity": { + "identifier": identity.identifier, + "key": identity.get_hash_key( + environment.use_identity_composite_key_for_hashing + ), + "traits": {trait.trait_key: trait.trait_value}, + }, + } + + +def test_map_environment_to_evaluation_context__with_segments__returns_segment_contexts( + environment: Environment, + identity_matching_segment: Segment, +) -> None: + # When + result = engine.map_environment_to_evaluation_context( + environment=environment, + segments=[identity_matching_segment], + ) + + # Then + segment_key = str(identity_matching_segment.pk) + assert result == { + "environment": { + "key": environment.api_key, + "name": environment.name, + }, + "segments": { + segment_key: engine.map_segment_to_segment_context( + identity_matching_segment + ), + }, + } + + +def test_map_segment_to_segment_context__segment_with_rule__returns_expected( + identity_matching_segment: Segment, +) -> None: + # Given + condition = Condition.objects.get( + rule__segment=identity_matching_segment, + ) + + # When + result = engine.map_segment_to_segment_context(identity_matching_segment) + + # Then + assert result == { + "key": str(identity_matching_segment.pk), + "name": identity_matching_segment.name, + "rules": [ + { + "type": "ALL", + "conditions": [ + { + "property": condition.property, + "operator": condition.operator, + "value": condition.value, + }, + ], + "rules": [], + }, + ], + "metadata": {"pk": identity_matching_segment.pk}, + } + + +def test_map_rule_to_segment_rule__with_nested_rule__returns_expected( + segment_rule: SegmentRule, + identity_matching_segment: Segment, +) -> None: + # Given + matching_rule = SegmentRule.objects.get(segment=identity_matching_segment) + matching_rule.rules.add(segment_rule) + condition = Condition.objects.get(rule=matching_rule) + + # When + result = engine.map_rule_to_segment_rule(matching_rule) + + # Then + assert result == { + "type": "ALL", + "conditions": [ + { + "property": condition.property, + "operator": condition.operator, + "value": condition.value, + }, + ], + "rules": [ + { + "type": "ALL", + "conditions": [], + "rules": [], + }, + ], + } + + +def test_map_condition_to_segment_condition__valid_condition__returns_expected( + identity_matching_segment: Segment, +) -> None: + # Given + condition = Condition.objects.get( + rule__segment=identity_matching_segment, + ) + + # When + result = engine.map_condition_to_segment_condition(condition) + + # Then + assert result == { + "property": condition.property, + "operator": condition.operator, + "value": condition.value, + } diff --git a/api/util/mappers/engine.py b/api/util/mappers/engine.py index 3d6a06c5c694..eac85292cc9b 100644 --- a/api/util/mappers/engine.py +++ b/api/util/mappers/engine.py @@ -3,10 +3,13 @@ from typing import TYPE_CHECKING, Dict, List, Optional from uuid import UUID -from flag_engine.context.types import EvaluationContext +from flag_engine.context import types as engine_types +from flag_engine.segments.types import ConditionOperator, RuleType +from pydantic import TypeAdapter from environments.constants import IDENTITY_INTEGRATIONS_RELATION_NAMES from features.versioning.models import EnvironmentFeatureVersion +from segments.types import SegmentEngineMetadata from util.engine_models.environments.integrations.models import IntegrationModel from util.engine_models.environments.models import ( EnvironmentAPIKeyModel, @@ -45,16 +48,20 @@ from integrations.webhook.models import WebhookConfiguration from organisations.models import Organisation from projects.models import Project - from segments.models import Segment, SegmentRule + from segments.models import Condition, Segment, SegmentRule __all__ = ( + "map_condition_to_segment_condition", "map_environment_api_key_to_engine", "map_environment_to_engine", "map_feature_to_engine", "map_identity_to_engine", + "map_environment_to_evaluation_context", "map_mv_option_to_engine", + "map_rule_to_segment_rule", "map_segment_to_engine", + "map_segment_to_segment_context", "map_traits_to_engine", ) @@ -417,24 +424,76 @@ def map_identity_to_engine( ) -def map_engine_identity_to_context( - identity: IdentityModel, -) -> "EvaluationContext": - """ - A special mapper to produce a minimal EvaluationContext - in an environment-less form. - Used when an environment object is not available, like when evaluating segments for webhooks. - """ - return { - "environment": {"key": identity.environment_api_key, "name": ""}, - "identity": { - "identifier": identity.identifier, - "key": str(identity.django_id or identity.composite_key), - "traits": { - trait.trait_key: trait.trait_value for trait in identity.identity_traits - }, +_rule_type_adapter: TypeAdapter[RuleType] = TypeAdapter(RuleType) +_condition_operator_adapter: TypeAdapter[ConditionOperator] = TypeAdapter( + ConditionOperator +) + + +def map_environment_to_evaluation_context( + *, + environment: "Environment", + identity: "Identity | None" = None, + traits: "Iterable[Trait] | None" = None, + segments: "Iterable[Segment] | None" = None, +) -> "engine_types.EvaluationContext[SegmentEngineMetadata, object]": + """Map Django ORM Environment (and optionally Identity) to a flag-engine EvaluationContext.""" + context: engine_types.EvaluationContext[SegmentEngineMetadata, object] = { + "environment": { + "key": environment.api_key, + "name": environment.name or "", }, } + if identity is not None: + trait_items: "Iterable[Trait]" = ( + traits if traits is not None else identity.identity_traits.all() + ) + context["identity"] = { + "identifier": identity.identifier, + "key": identity.get_hash_key( + environment.use_identity_composite_key_for_hashing + ), + "traits": {trait.trait_key: trait.trait_value for trait in trait_items}, + } + if segments is not None: + context["segments"] = { + str(segment.pk): map_segment_to_segment_context(segment) + for segment in segments + } + return context + + +def map_segment_to_segment_context( + segment: "Segment", +) -> "engine_types.SegmentContext[SegmentEngineMetadata, object]": + """Map a Django ORM Segment to a flag-engine SegmentContext TypedDict.""" + return { + "key": str(segment.pk), + "name": segment.name, + "rules": [map_rule_to_segment_rule(rule) for rule in segment.rules.all()], + "metadata": SegmentEngineMetadata(pk=segment.pk), + } + + +def map_rule_to_segment_rule(rule: "SegmentRule") -> engine_types.SegmentRule: + return { + "type": _rule_type_adapter.validate_python(rule.type), + "conditions": [ + map_condition_to_segment_condition(condition) + for condition in rule.conditions.all() + ], + "rules": [map_rule_to_segment_rule(sub_rule) for sub_rule in rule.rules.all()], + } + + +def map_condition_to_segment_condition( + condition: "Condition", +) -> engine_types.StrValueSegmentCondition: + return { + "property": condition.property or "", + "operator": _condition_operator_adapter.validate_python(condition.operator), + "value": condition.value or "", + } def _get_prioritised_feature_states(