diff --git a/posthog/client.py b/posthog/client.py index 64db985d..27c63504 100644 --- a/posthog/client.py +++ b/posthog/client.py @@ -1418,6 +1418,7 @@ def _compute_flag_locally( person_properties=None, group_properties=None, warn_on_unknown_groups=True, + device_id=None, ) -> FlagValue: groups = groups or {} person_properties = person_properties or {} @@ -1465,6 +1466,7 @@ def _compute_flag_locally( focused_group_properties, self.feature_flags_by_key, evaluation_cache, + hashing_identifier=groups[group_name], ) else: return match_feature_flag_properties( @@ -1474,6 +1476,7 @@ def _compute_flag_locally( self.cohorts, self.feature_flags_by_key, evaluation_cache, + device_id=device_id, ) def feature_enabled( @@ -1580,8 +1583,12 @@ def _get_feature_flag_result( evaluated_at = None feature_flag_error: Optional[str] = None + # Resolve device_id from context if not provided + if device_id is None: + device_id = get_context_device_id() + flag_value = self._locally_evaluate_flag( - key, distinct_id, groups, person_properties, group_properties + key, distinct_id, groups, person_properties, group_properties, device_id ) flag_was_locally_evaluated = flag_value is not None @@ -1785,6 +1792,7 @@ def _locally_evaluate_flag( groups: dict[str, str], person_properties: dict[str, str], group_properties: dict[str, str], + device_id: Optional[str] = None, ) -> Optional[FlagValue]: if self.feature_flags is None and self.personal_api_key: self.load_feature_flags() @@ -1804,6 +1812,7 @@ def _locally_evaluate_flag( groups=groups, person_properties=person_properties, group_properties=group_properties, + device_id=device_id, ) self.log.debug( f"Successfully computed flag locally: {key} -> {response}" @@ -2106,12 +2115,17 @@ def get_all_flags_and_payloads( ) ) + # Resolve device_id from context if not provided + if device_id is None: + device_id = get_context_device_id() + response, fallback_to_flags = self._get_all_flags_and_payloads_locally( distinct_id, groups=groups, person_properties=person_properties, group_properties=group_properties, flag_keys_to_evaluate=flag_keys_to_evaluate, + device_id=device_id, ) if fallback_to_flags and not only_evaluate_locally: @@ -2142,6 +2156,7 @@ def _get_all_flags_and_payloads_locally( group_properties=None, warn_on_unknown_groups=False, flag_keys_to_evaluate: Optional[list[str]] = None, + device_id: Optional[str] = None, ) -> tuple[FlagsAndPayloads, bool]: person_properties = person_properties or {} group_properties = group_properties or {} @@ -2171,6 +2186,7 @@ def _get_all_flags_and_payloads_locally( person_properties=person_properties, group_properties=group_properties, warn_on_unknown_groups=warn_on_unknown_groups, + device_id=device_id, ) matched_payload = self._compute_payload_locally( flag["key"], flags[flag["key"]] diff --git a/posthog/feature_flags.py b/posthog/feature_flags.py index ef850e60..101d73d7 100644 --- a/posthog/feature_flags.py +++ b/posthog/feature_flags.py @@ -34,18 +34,18 @@ class RequiresServerEvaluation(Exception): pass -# This function takes a distinct_id and a feature flag key and returns a float between 0 and 1. -# Given the same distinct_id and key, it'll always return the same float. These floats are +# This function takes an identifier and a feature flag key and returns a float between 0 and 1. +# Given the same identifier and key, it'll always return the same float. These floats are # uniformly distributed between 0 and 1, so if we want to show this feature to 20% of traffic -# we can do _hash(key, distinct_id) < 0.2 -def _hash(key: str, distinct_id: str, salt: str = "") -> float: - hash_key = f"{key}.{distinct_id}{salt}" +# we can do _hash(key, identifier) < 0.2 +def _hash(key: str, identifier: str, salt: str = "") -> float: + hash_key = f"{key}.{identifier}{salt}" hash_val = int(hashlib.sha1(hash_key.encode("utf-8")).hexdigest()[:15], 16) return hash_val / __LONG_SCALE__ -def get_matching_variant(flag, distinct_id): - hash_value = _hash(flag["key"], distinct_id, salt="variant") +def get_matching_variant(flag, hashing_identifier): + hash_value = _hash(flag["key"], hashing_identifier, salt="variant") for variant in variant_lookup_table(flag): if hash_value >= variant["value_min"] and hash_value < variant["value_max"]: return variant["key"] @@ -68,7 +68,13 @@ def variant_lookup_table(feature_flag): def evaluate_flag_dependency( - property, flags_by_key, evaluation_cache, distinct_id, properties, cohort_properties + property, + flags_by_key, + evaluation_cache, + distinct_id, + properties, + cohort_properties, + device_id=None, ): """ Evaluate a flag dependency property according to the dependency chain algorithm. @@ -80,6 +86,7 @@ def evaluate_flag_dependency( distinct_id: The distinct ID being evaluated properties: Person properties for evaluation cohort_properties: Cohort properties for evaluation + device_id: The device ID for bucketing (optional) Returns: bool: True if all dependencies in the chain evaluate to True, False otherwise @@ -131,6 +138,7 @@ def evaluate_flag_dependency( cohort_properties, flags_by_key, evaluation_cache, + device_id=device_id, ) evaluation_cache[dep_flag_key] = dep_result except InconclusiveMatchError as e: @@ -222,16 +230,31 @@ def match_feature_flag_properties( cohort_properties=None, flags_by_key=None, evaluation_cache=None, + device_id=None, + hashing_identifier=None, ) -> FlagValue: - flag_conditions = (flag.get("filters") or {}).get("groups") or [] + flag_filters = flag.get("filters") or {} + flag_conditions = flag_filters.get("groups") or [] is_inconclusive = False cohort_properties = cohort_properties or {} # Some filters can be explicitly set to null, which require accessing variants like so - flag_variants = ((flag.get("filters") or {}).get("multivariate") or {}).get( - "variants" - ) or [] + flag_variants = (flag_filters.get("multivariate") or {}).get("variants") or [] valid_variant_keys = [variant["key"] for variant in flag_variants] + # Determine the hashing identifier: + # - If caller provided one explicitly (e.g. group key for group flags), use it directly + # - Otherwise resolve from the flag's bucketing_identifier setting + if hashing_identifier is None: + bucketing_identifier = flag_filters.get("bucketing_identifier") + if bucketing_identifier == "device_id": + if not device_id: + raise InconclusiveMatchError( + "Flag requires device_id for bucketing but none was provided" + ) + hashing_identifier = device_id + else: + hashing_identifier = distinct_id + for condition in flag_conditions: try: # if any one condition resolves to True, we can shortcircuit and return @@ -244,12 +267,14 @@ def match_feature_flag_properties( cohort_properties, flags_by_key, evaluation_cache, + hashing_identifier=hashing_identifier, + device_id=device_id, ): variant_override = condition.get("variant") if variant_override and variant_override in valid_variant_keys: variant = variant_override else: - variant = get_matching_variant(flag, distinct_id) + variant = get_matching_variant(flag, hashing_identifier) return variant or True except RequiresServerEvaluation: # Static cohort or other missing server-side data - must fallback to API @@ -277,6 +302,9 @@ def is_condition_match( cohort_properties, flags_by_key=None, evaluation_cache=None, + *, + hashing_identifier, + device_id=None, ) -> bool: rollout_percentage = condition.get("rollout_percentage") if len(condition.get("properties") or []) > 0: @@ -290,6 +318,7 @@ def is_condition_match( flags_by_key, evaluation_cache, distinct_id, + device_id=device_id, ) elif property_type == "flag": matches = evaluate_flag_dependency( @@ -299,6 +328,7 @@ def is_condition_match( distinct_id, properties, cohort_properties, + device_id=device_id, ) else: matches = match_property(prop, properties) @@ -308,9 +338,9 @@ def is_condition_match( if rollout_percentage is None: return True - if rollout_percentage is not None and _hash(feature_flag["key"], distinct_id) > ( - rollout_percentage / 100 - ): + if rollout_percentage is not None and _hash( + feature_flag["key"], hashing_identifier + ) > (rollout_percentage / 100): return False return True @@ -454,6 +484,7 @@ def match_cohort( flags_by_key=None, evaluation_cache=None, distinct_id=None, + device_id=None, ) -> bool: # Cohort properties are in the form of property groups like this: # { @@ -478,6 +509,7 @@ def match_cohort( flags_by_key, evaluation_cache, distinct_id, + device_id=device_id, ) @@ -488,6 +520,7 @@ def match_property_group( flags_by_key=None, evaluation_cache=None, distinct_id=None, + device_id=None, ) -> bool: if not property_group: return True @@ -512,6 +545,7 @@ def match_property_group( flags_by_key, evaluation_cache, distinct_id, + device_id=device_id, ) if property_group_type == "AND": if not matches: @@ -545,6 +579,7 @@ def match_property_group( flags_by_key, evaluation_cache, distinct_id, + device_id=device_id, ) elif prop.get("type") == "flag": matches = evaluate_flag_dependency( @@ -554,6 +589,7 @@ def match_property_group( distinct_id, property_values, cohort_properties, + device_id=device_id, ) else: matches = match_property(prop, property_values) diff --git a/posthog/test/test_feature_flags.py b/posthog/test/test_feature_flags.py index 783793f8..5eb7968a 100644 --- a/posthog/test/test_feature_flags.py +++ b/posthog/test/test_feature_flags.py @@ -3220,6 +3220,385 @@ def test_fallback_to_api_when_flag_has_static_cohort_in_multi_condition( # Verify API was called (fallback occurred) self.assertEqual(patch_flags.call_count, 1) + @mock.patch("posthog.client.flags") + def test_device_id_bucketing_uses_device_id_for_hash(self, patch_flags): + """ + When a flag has bucketing_identifier: "device_id", the device_id should be + used for hashing instead of distinct_id. + """ + client = Client(FAKE_TEST_API_KEY, personal_api_key=FAKE_TEST_API_KEY) + + # This flag uses device_id for bucketing + client.feature_flags = [ + { + "id": 1, + "key": "device-bucketed-flag", + "active": True, + "filters": { + "bucketing_identifier": "device_id", + "groups": [ + { + "properties": [], + "rollout_percentage": 100, + } + ], + }, + } + ] + + # Same distinct_id with different device_ids should produce different results + # (based on rollout percentage, we check consistency) + result1 = client.get_feature_flag( + "device-bucketed-flag", "user-123", device_id="device-A" + ) + result2 = client.get_feature_flag( + "device-bucketed-flag", "user-123", device_id="device-A" + ) + + # Same device_id should give consistent results + self.assertEqual(result1, result2) + + # No API fallback should occur + self.assertEqual(patch_flags.call_count, 0) + + @mock.patch("posthog.client.flags") + def test_device_id_bucketing_same_device_different_users_same_result( + self, patch_flags + ): + """ + When a flag uses device_id bucketing, different distinct_ids with the same + device_id should get the same result. + """ + client = Client(FAKE_TEST_API_KEY, personal_api_key=FAKE_TEST_API_KEY) + + client.feature_flags = [ + { + "id": 1, + "key": "device-bucketed-flag", + "active": True, + "filters": { + "bucketing_identifier": "device_id", + "groups": [ + { + "properties": [], + "rollout_percentage": 50, + } + ], + }, + } + ] + + # Different distinct_ids with the same device_id should get the same result + result1 = client.get_feature_flag( + "device-bucketed-flag", "user-A", device_id="shared-device" + ) + result2 = client.get_feature_flag( + "device-bucketed-flag", "user-B", device_id="shared-device" + ) + result3 = client.get_feature_flag( + "device-bucketed-flag", "user-C", device_id="shared-device" + ) + + # All should be the same since device_id is the same + self.assertEqual(result1, result2) + self.assertEqual(result2, result3) + + self.assertEqual(patch_flags.call_count, 0) + + @mock.patch("posthog.client.flags") + def test_device_id_bucketing_fallback_when_device_id_missing(self, patch_flags): + """ + When a flag requires device_id for bucketing but none is provided, + it should fallback to server evaluation. + """ + patch_flags.return_value = {"featureFlags": {"device-bucketed-flag": True}} + client = Client(FAKE_TEST_API_KEY, personal_api_key=FAKE_TEST_API_KEY) + + client.feature_flags = [ + { + "id": 1, + "key": "device-bucketed-flag", + "active": True, + "filters": { + "bucketing_identifier": "device_id", + "groups": [ + { + "properties": [], + "rollout_percentage": 100, + } + ], + }, + } + ] + + # No device_id provided - should fallback to API + result = client.get_feature_flag("device-bucketed-flag", "user-123") + + self.assertTrue(result) + # API should have been called + self.assertEqual(patch_flags.call_count, 1) + + @mock.patch("posthog.client.flags") + def test_device_id_bucketing_returns_none_when_only_evaluate_locally_and_no_device_id( + self, patch_flags + ): + """ + When only_evaluate_locally=True and device_id is required but missing, + should return None instead of falling back to API. + """ + client = Client(FAKE_TEST_API_KEY, personal_api_key=FAKE_TEST_API_KEY) + + client.feature_flags = [ + { + "id": 1, + "key": "device-bucketed-flag", + "active": True, + "filters": { + "bucketing_identifier": "device_id", + "groups": [ + { + "properties": [], + "rollout_percentage": 100, + } + ], + }, + } + ] + + # No device_id + only_evaluate_locally should return None + result = client.get_feature_flag( + "device-bucketed-flag", "user-123", only_evaluate_locally=True + ) + + self.assertIsNone(result) + # API should NOT have been called + self.assertEqual(patch_flags.call_count, 0) + + @mock.patch("posthog.client.flags") + def test_default_bucketing_identifier_uses_distinct_id(self, patch_flags): + """ + When bucketing_identifier is not set or is 'distinct_id', should use + distinct_id for hashing (default behavior). + """ + client = Client(FAKE_TEST_API_KEY, personal_api_key=FAKE_TEST_API_KEY) + + # Flag without bucketing_identifier (defaults to distinct_id) + client.feature_flags = [ + { + "id": 1, + "key": "normal-flag", + "active": True, + "filters": { + "groups": [ + { + "properties": [], + "rollout_percentage": 50, + } + ], + }, + } + ] + + # Different distinct_ids should potentially produce different results + # but same distinct_id should produce same result + result1 = client.get_feature_flag("normal-flag", "user-A") + result2 = client.get_feature_flag("normal-flag", "user-A") + + self.assertEqual(result1, result2) + self.assertEqual(patch_flags.call_count, 0) + + @mock.patch("posthog.client.flags") + def test_device_id_bucketing_with_multivariate_flag(self, patch_flags): + """ + Multivariate flag variant selection should use device_id when + bucketing_identifier is set to device_id. + """ + client = Client(FAKE_TEST_API_KEY, personal_api_key=FAKE_TEST_API_KEY) + + client.feature_flags = [ + { + "id": 1, + "key": "multivariate-device-flag", + "active": True, + "filters": { + "bucketing_identifier": "device_id", + "groups": [ + { + "properties": [], + "rollout_percentage": 100, + } + ], + "multivariate": { + "variants": [ + {"key": "control", "rollout_percentage": 50}, + {"key": "test", "rollout_percentage": 50}, + ] + }, + }, + } + ] + + # Same device_id should give same variant + result1 = client.get_feature_flag( + "multivariate-device-flag", "user-A", device_id="device-1" + ) + result2 = client.get_feature_flag( + "multivariate-device-flag", "user-B", device_id="device-1" + ) + + # Both should get the same variant because device_id is the same + self.assertEqual(result1, result2) + self.assertIn(result1, ["control", "test"]) + + self.assertEqual(patch_flags.call_count, 0) + + @mock.patch("posthog.client.flags") + def test_device_id_bucketing_from_context(self, patch_flags): + """ + When device_id is not passed as a parameter but is set in the context, + it should be resolved from context. + """ + from posthog.contexts import new_context, set_context_device_id + + client = Client(FAKE_TEST_API_KEY, personal_api_key=FAKE_TEST_API_KEY) + + client.feature_flags = [ + { + "id": 1, + "key": "device-bucketed-flag", + "active": True, + "filters": { + "bucketing_identifier": "device_id", + "groups": [ + { + "properties": [], + "rollout_percentage": 100, + } + ], + }, + } + ] + + # Set device_id in context + with new_context(): + set_context_device_id("context-device-id") + result = client.get_feature_flag("device-bucketed-flag", "user-123") + + # Should evaluate locally using the context device_id + self.assertTrue(result) + self.assertEqual(patch_flags.call_count, 0) + + @mock.patch("posthog.client.flags") + def test_group_flags_ignore_bucketing_identifier(self, patch_flags): + """ + Group flags should continue to use the group identifier for hashing, + regardless of the bucketing_identifier setting. + """ + client = Client(FAKE_TEST_API_KEY, personal_api_key=FAKE_TEST_API_KEY) + + client.feature_flags = [ + { + "id": 1, + "key": "group-flag", + "active": True, + "filters": { + "aggregation_group_type_index": 0, + "bucketing_identifier": "device_id", # Should be ignored for group flags + "groups": [ + { + "properties": [], + "rollout_percentage": 100, + } + ], + }, + } + ] + client.group_type_mapping = {"0": "company"} + + # Even with bucketing_identifier set to device_id, group flag should use group identifier + result = client.get_feature_flag( + "group-flag", + "user-123", + groups={"company": "acme-inc"}, + device_id="some-device", + ) + + self.assertTrue(result) + self.assertEqual(patch_flags.call_count, 0) + + @mock.patch("posthog.client.flags") + def test_get_all_flags_with_device_id_bucketing(self, patch_flags): + """ + get_all_flags_and_payloads should properly handle flags with device_id bucketing. + """ + client = Client(FAKE_TEST_API_KEY, personal_api_key=FAKE_TEST_API_KEY) + + client.feature_flags = [ + { + "id": 1, + "key": "normal-flag", + "active": True, + "filters": { + "groups": [{"properties": [], "rollout_percentage": 100}], + }, + }, + { + "id": 2, + "key": "device-flag", + "active": True, + "filters": { + "bucketing_identifier": "device_id", + "groups": [{"properties": [], "rollout_percentage": 100}], + }, + }, + ] + + # With device_id provided, both flags should be evaluated locally + result = client.get_all_flags("user-123", device_id="my-device") + + self.assertEqual(result["normal-flag"], True) + self.assertEqual(result["device-flag"], True) + self.assertEqual(patch_flags.call_count, 0) + + @mock.patch("posthog.client.flags") + def test_get_all_flags_fallback_when_device_id_missing_for_some_flags( + self, patch_flags + ): + """ + When some flags require device_id but it's not provided, those flags + should trigger fallback while others can be evaluated locally. + """ + patch_flags.return_value = { + "featureFlags": {"normal-flag": True, "device-flag": "from-api"} + } + client = Client(FAKE_TEST_API_KEY, personal_api_key=FAKE_TEST_API_KEY) + + client.feature_flags = [ + { + "id": 1, + "key": "normal-flag", + "active": True, + "filters": { + "groups": [{"properties": [], "rollout_percentage": 100}], + }, + }, + { + "id": 2, + "key": "device-flag", + "active": True, + "filters": { + "bucketing_identifier": "device_id", + "groups": [{"properties": [], "rollout_percentage": 100}], + }, + }, + ] + + # Without device_id, device-flag can't be evaluated locally + client.get_all_flags("user-123") + + # Should fallback to API for all flags when any can't be evaluated locally + self.assertEqual(patch_flags.call_count, 1) + class TestMatchProperties(unittest.TestCase): def property(self, key, value, operator=None):