Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion posthog/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,7 @@ def _compute_flag_locally(
person_properties=None,
group_properties=None,
warn_on_unknown_groups=True,
device_id=None,
) -> FlagValue:
groups = groups or {}
person_properties = person_properties or {}
Expand Down Expand Up @@ -1465,6 +1466,7 @@ def _compute_flag_locally(
focused_group_properties,
self.feature_flags_by_key,
evaluation_cache,
hashing_identifier=groups[group_name],
)
else:
return match_feature_flag_properties(
Expand All @@ -1474,6 +1476,7 @@ def _compute_flag_locally(
self.cohorts,
self.feature_flags_by_key,
evaluation_cache,
device_id=device_id,
)

def feature_enabled(
Expand Down Expand Up @@ -1580,8 +1583,12 @@ def _get_feature_flag_result(
evaluated_at = None
feature_flag_error: Optional[str] = None

# Resolve device_id from context if not provided
if device_id is None:
device_id = get_context_device_id()

flag_value = self._locally_evaluate_flag(
key, distinct_id, groups, person_properties, group_properties
key, distinct_id, groups, person_properties, group_properties, device_id
)
flag_was_locally_evaluated = flag_value is not None

Expand Down Expand Up @@ -1785,6 +1792,7 @@ def _locally_evaluate_flag(
groups: dict[str, str],
person_properties: dict[str, str],
group_properties: dict[str, str],
device_id: Optional[str] = None,
) -> Optional[FlagValue]:
if self.feature_flags is None and self.personal_api_key:
self.load_feature_flags()
Expand All @@ -1804,6 +1812,7 @@ def _locally_evaluate_flag(
groups=groups,
person_properties=person_properties,
group_properties=group_properties,
device_id=device_id,
)
self.log.debug(
f"Successfully computed flag locally: {key} -> {response}"
Expand Down Expand Up @@ -2106,12 +2115,17 @@ def get_all_flags_and_payloads(
)
)

# Resolve device_id from context if not provided
if device_id is None:
device_id = get_context_device_id()

response, fallback_to_flags = self._get_all_flags_and_payloads_locally(
distinct_id,
groups=groups,
person_properties=person_properties,
group_properties=group_properties,
flag_keys_to_evaluate=flag_keys_to_evaluate,
device_id=device_id,
)

if fallback_to_flags and not only_evaluate_locally:
Expand Down Expand Up @@ -2142,6 +2156,7 @@ def _get_all_flags_and_payloads_locally(
group_properties=None,
warn_on_unknown_groups=False,
flag_keys_to_evaluate: Optional[list[str]] = None,
device_id: Optional[str] = None,
) -> tuple[FlagsAndPayloads, bool]:
person_properties = person_properties or {}
group_properties = group_properties or {}
Expand Down Expand Up @@ -2171,6 +2186,7 @@ def _get_all_flags_and_payloads_locally(
person_properties=person_properties,
group_properties=group_properties,
warn_on_unknown_groups=warn_on_unknown_groups,
device_id=device_id,
)
matched_payload = self._compute_payload_locally(
flag["key"], flags[flag["key"]]
Expand Down
68 changes: 52 additions & 16 deletions posthog/feature_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,18 @@ class RequiresServerEvaluation(Exception):
pass


# This function takes a distinct_id and a feature flag key and returns a float between 0 and 1.
# Given the same distinct_id and key, it'll always return the same float. These floats are
# This function takes an identifier and a feature flag key and returns a float between 0 and 1.
# Given the same identifier and key, it'll always return the same float. These floats are
# uniformly distributed between 0 and 1, so if we want to show this feature to 20% of traffic
# we can do _hash(key, distinct_id) < 0.2
def _hash(key: str, distinct_id: str, salt: str = "") -> float:
hash_key = f"{key}.{distinct_id}{salt}"
# we can do _hash(key, identifier) < 0.2
def _hash(key: str, identifier: str, salt: str = "") -> float:
hash_key = f"{key}.{identifier}{salt}"
hash_val = int(hashlib.sha1(hash_key.encode("utf-8")).hexdigest()[:15], 16)
return hash_val / __LONG_SCALE__


def get_matching_variant(flag, distinct_id):
hash_value = _hash(flag["key"], distinct_id, salt="variant")
def get_matching_variant(flag, hashing_identifier):
hash_value = _hash(flag["key"], hashing_identifier, salt="variant")
for variant in variant_lookup_table(flag):
if hash_value >= variant["value_min"] and hash_value < variant["value_max"]:
return variant["key"]
Expand All @@ -68,7 +68,13 @@ def variant_lookup_table(feature_flag):


def evaluate_flag_dependency(
property, flags_by_key, evaluation_cache, distinct_id, properties, cohort_properties
property,
flags_by_key,
evaluation_cache,
distinct_id,
properties,
cohort_properties,
device_id=None,
):
"""
Evaluate a flag dependency property according to the dependency chain algorithm.
Expand All @@ -80,6 +86,7 @@ def evaluate_flag_dependency(
distinct_id: The distinct ID being evaluated
properties: Person properties for evaluation
cohort_properties: Cohort properties for evaluation
device_id: The device ID for bucketing (optional)

Returns:
bool: True if all dependencies in the chain evaluate to True, False otherwise
Expand Down Expand Up @@ -131,6 +138,7 @@ def evaluate_flag_dependency(
cohort_properties,
flags_by_key,
evaluation_cache,
device_id=device_id,
)
evaluation_cache[dep_flag_key] = dep_result
except InconclusiveMatchError as e:
Expand Down Expand Up @@ -222,16 +230,31 @@ def match_feature_flag_properties(
cohort_properties=None,
flags_by_key=None,
evaluation_cache=None,
device_id=None,
hashing_identifier=None,
) -> FlagValue:
flag_conditions = (flag.get("filters") or {}).get("groups") or []
flag_filters = flag.get("filters") or {}
flag_conditions = flag_filters.get("groups") or []
is_inconclusive = False
cohort_properties = cohort_properties or {}
# Some filters can be explicitly set to null, which require accessing variants like so
flag_variants = ((flag.get("filters") or {}).get("multivariate") or {}).get(
"variants"
) or []
flag_variants = (flag_filters.get("multivariate") or {}).get("variants") or []
valid_variant_keys = [variant["key"] for variant in flag_variants]

# Determine the hashing identifier:
# - If caller provided one explicitly (e.g. group key for group flags), use it directly
# - Otherwise resolve from the flag's bucketing_identifier setting
if hashing_identifier is None:
bucketing_identifier = flag_filters.get("bucketing_identifier")
if bucketing_identifier == "device_id":
if not device_id:
raise InconclusiveMatchError(
"Flag requires device_id for bucketing but none was provided"
)
hashing_identifier = device_id
else:
hashing_identifier = distinct_id

for condition in flag_conditions:
try:
# if any one condition resolves to True, we can shortcircuit and return
Expand All @@ -244,12 +267,14 @@ def match_feature_flag_properties(
cohort_properties,
flags_by_key,
evaluation_cache,
hashing_identifier=hashing_identifier,
device_id=device_id,
):
variant_override = condition.get("variant")
if variant_override and variant_override in valid_variant_keys:
variant = variant_override
else:
variant = get_matching_variant(flag, distinct_id)
variant = get_matching_variant(flag, hashing_identifier)
return variant or True
except RequiresServerEvaluation:
# Static cohort or other missing server-side data - must fallback to API
Expand Down Expand Up @@ -277,6 +302,9 @@ def is_condition_match(
cohort_properties,
flags_by_key=None,
evaluation_cache=None,
*,
hashing_identifier,
device_id=None,
) -> bool:
rollout_percentage = condition.get("rollout_percentage")
if len(condition.get("properties") or []) > 0:
Expand All @@ -290,6 +318,7 @@ def is_condition_match(
flags_by_key,
evaluation_cache,
distinct_id,
device_id=device_id,
)
elif property_type == "flag":
matches = evaluate_flag_dependency(
Expand All @@ -299,6 +328,7 @@ def is_condition_match(
distinct_id,
properties,
cohort_properties,
device_id=device_id,
)
else:
matches = match_property(prop, properties)
Expand All @@ -308,9 +338,9 @@ def is_condition_match(
if rollout_percentage is None:
return True

if rollout_percentage is not None and _hash(feature_flag["key"], distinct_id) > (
rollout_percentage / 100
):
if rollout_percentage is not None and _hash(
feature_flag["key"], hashing_identifier
) > (rollout_percentage / 100):
return False

return True
Expand Down Expand Up @@ -454,6 +484,7 @@ def match_cohort(
flags_by_key=None,
evaluation_cache=None,
distinct_id=None,
device_id=None,
) -> bool:
# Cohort properties are in the form of property groups like this:
# {
Expand All @@ -478,6 +509,7 @@ def match_cohort(
flags_by_key,
evaluation_cache,
distinct_id,
device_id=device_id,
)


Expand All @@ -488,6 +520,7 @@ def match_property_group(
flags_by_key=None,
evaluation_cache=None,
distinct_id=None,
device_id=None,
) -> bool:
if not property_group:
return True
Expand All @@ -512,6 +545,7 @@ def match_property_group(
flags_by_key,
evaluation_cache,
distinct_id,
device_id=device_id,
)
if property_group_type == "AND":
if not matches:
Expand Down Expand Up @@ -545,6 +579,7 @@ def match_property_group(
flags_by_key,
evaluation_cache,
distinct_id,
device_id=device_id,
)
elif prop.get("type") == "flag":
matches = evaluate_flag_dependency(
Expand All @@ -554,6 +589,7 @@ def match_property_group(
distinct_id,
property_values,
cohort_properties,
device_id=device_id,
)
else:
matches = match_property(prop, property_values)
Expand Down
Loading
Loading