Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
### Bugs Fixed

- Fixed `PipelineClient.format_url` to preserve trailing slash in the base URL when the URL template is query-string-only (e.g., `?key=value`). #45365
- Fixed `SensitiveHeaderCleanupPolicy` to persist the `insecure_domain_change` flag across retries after a cross-domain redirect. #45518

### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,6 @@ def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HT
raise ex from HttpResponseError(response=response.http_response)

if request_authorized:
# if we receive a challenge response, we retrieve a new token
# which matches the new target. In this case, we don't want to remove
# token from the request so clear the 'insecure_domain_change' tag
request.context.options.pop("insecure_domain_change", False)
try:
response = self.next.send(request)
self.on_response(request, response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ async def send(
raise ex from HttpResponseError(response=response.http_response)

if request_authorized:
# if we receive a challenge response, we retrieve a new token
# which matches the new target. In this case, we don't want to remove
# token from the request so clear the 'insecure_domain_change' tag
request.context.options.pop("insecure_domain_change", False)
try:
response = await self.next.send(request)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,8 @@ def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HT
if domain_changed(original_domain, request.http_request.url):
# "insecure_domain_change" is used to indicate that a redirect
# has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy
# to clean up sensitive headers. We need to remove it before sending the request
# to the transport layer.
request.context.options["insecure_domain_change"] = True
# to clean up sensitive headers.
request.context["insecure_domain_change"] = True
continue
return response

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ async def send(
if domain_changed(original_domain, request.http_request.url):
# "insecure_domain_change" is used to indicate that a redirect
# has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy
# to clean up sensitive headers. We need to remove it before sending the request
# to the transport layer.
request.context.options["insecure_domain_change"] = True
# to clean up sensitive headers.
request.context["insecure_domain_change"] = True
continue
return response

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
"""
# "insecure_domain_change" is used to indicate that a redirect
# has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy
# to clean up sensitive headers. We need to remove it before sending the request
# to the transport layer.
insecure_domain_change = request.context.options.pop("insecure_domain_change", False)
# to clean up sensitive headers.
insecure_domain_change = request.context.get("insecure_domain_change", False)
if not self._disable_redirect_cleanup and insecure_domain_change:
for header in self._blocked_redirect_headers:
request.http_request.headers.pop(header, None)
139 changes: 139 additions & 0 deletions sdk/core/azure-core/tests/async_tests/test_authentication_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AsyncBearerTokenCredentialPolicy,
SansIOHTTPPolicy,
AsyncRedirectPolicy,
AsyncRetryPolicy,
SensitiveHeaderCleanupPolicy,
)
from azure.core.pipeline.policies._authentication import MAX_REFRESH_JITTER_SECONDS
Expand Down Expand Up @@ -768,3 +769,141 @@ async def test_jitter_set_on_token_request_async():

assert policy._refresh_jitter == 25
mock_randint.assert_called_once_with(0, MAX_REFRESH_JITTER_SECONDS)


@pytest.mark.asyncio
async def test_challenge_auth_header_stripped_after_redirect():
"""Assuming the SensitiveHeaderCleanupPolicy is in the pipeline, the authorization header should be stripped after
a redirect to a different domain by default, and preserved if the policy is configured to disable cleanup."""

class MockTransport(AsyncHttpTransport):
def __init__(self, cleanup_disabled=False):
self._first = True
self._cleanup_disabled = cleanup_disabled

async def __aexit__(self, exc_type, exc_val, exc_tb):
pass

async def close(self):
pass

async def open(self):
pass

async def send(self, request, **kwargs):
if self._first:
self._first = False
assert request.headers["Authorization"] == "Bearer {}".format(auth_header)
response = Response()
response.status_code = 307
response.headers["location"] = "https://redirect-target.example.invalid"
return response

# Second request: after redirect
if self._cleanup_disabled:
assert request.headers.get("Authorization")
else:
assert not request.headers.get("Authorization")
response = Response()
response.status_code = 401
response.headers["WWW-Authenticate"] = (
'Bearer error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsiZm9vIjoiYmFyIn19"'
)
return response

auth_header = "token"
get_token_call_count = 0

async def mock_get_token(*_, **__):
nonlocal get_token_call_count
get_token_call_count += 1
return AccessToken(auth_header, 0)

credential = Mock(spec_set=["get_token"], get_token=mock_get_token)
auth_policy = AsyncBearerTokenCredentialPolicy(credential, "scope")
redirect_policy = AsyncRedirectPolicy()
header_clean_up_policy = SensitiveHeaderCleanupPolicy()
pipeline = AsyncPipeline(transport=MockTransport(), policies=[redirect_policy, auth_policy, header_clean_up_policy])

response = await pipeline.run(HttpRequest("GET", "https://legitimate.azure.com"))
assert response.http_response.status_code == 401

header_clean_up_policy = SensitiveHeaderCleanupPolicy(disable_redirect_cleanup=True)
pipeline = AsyncPipeline(
transport=MockTransport(cleanup_disabled=True),
policies=[redirect_policy, auth_policy, header_clean_up_policy],
)
response = await pipeline.run(HttpRequest("GET", "https://legitimate.azure.com"))
assert response.http_response.status_code == 401


@pytest.mark.asyncio
async def test_auth_header_stripped_after_cross_domain_redirect_with_retry():
"""After a cross-domain redirect, if the redirected-to endpoint returns a retryable status code,
the Authorization header should still be stripped on the retry attempt. This verifies that the
insecure_domain_change flag persists across retries so SensitiveHeaderCleanupPolicy continues to
remove the Authorization header."""

class MockTransport(AsyncHttpTransport):
def __init__(self):
self._request_count = 0

async def __aexit__(self, exc_type, exc_val, exc_tb):
pass

async def close(self):
pass

async def open(self):
pass

async def send(self, request, **kwargs):
self._request_count += 1

if self._request_count == 1:
# First request: to the original domain — should have auth header
assert request.headers.get("Authorization") == "Bearer {}".format(auth_header)
response = Response()
response.status_code = 307
response.headers["location"] = "https://redirect-target.example.invalid"
return response

if self._request_count == 2:
# Second request: after redirect to attacker domain — auth header should be stripped
assert not request.headers.get(
"Authorization"
), "Authorization header should be stripped on first request to redirected domain"
response = Response()
response.status_code = 500
return response

if self._request_count == 3:
# Third request: retry to attacker domain — auth header should STILL be stripped
assert not request.headers.get(
"Authorization"
), "Authorization header should be stripped on retry to redirected domain"
response = Response()
response.status_code = 200
return response

raise RuntimeError("Unexpected request count: {}".format(self._request_count))

auth_header = "token"

async def mock_get_token(*_, **__):
return AccessToken(auth_header, 0)

credential = Mock(spec_set=["get_token"], get_token=mock_get_token)
auth_policy = AsyncBearerTokenCredentialPolicy(credential, "scope")
redirect_policy = AsyncRedirectPolicy()
retry_policy = AsyncRetryPolicy(retry_total=1, retry_backoff_factor=0)
header_clean_up_policy = SensitiveHeaderCleanupPolicy()
transport = MockTransport()
# Pipeline order matches the real default: redirect -> retry -> auth -> ... -> sensitive header cleanup
pipeline = AsyncPipeline(
transport=transport,
policies=[redirect_policy, retry_policy, auth_policy, header_clean_up_policy],
)
response = await pipeline.run(HttpRequest("GET", "https://legitimate.azure.com"))
assert response.http_response.status_code == 200
assert transport._request_count == 3
132 changes: 132 additions & 0 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from azure.core.pipeline.policies import (
BearerTokenCredentialPolicy,
RedirectPolicy,
RetryPolicy,
SansIOHTTPPolicy,
AzureKeyCredentialPolicy,
AzureSasCredentialPolicy,
Expand Down Expand Up @@ -1088,3 +1089,134 @@ def failing_get_token(*scopes, **kwargs):
# Verify the exception chaining
assert exc_info.value.__cause__ is not None
assert isinstance(exc_info.value.__cause__, HttpResponseError)


def test_challenge_auth_header_stripped_after_redirect():
"""Assuming the SensitiveHeaderCleanupPolicy is in the pipeline, the authorization header should be stripped after
a redirect to a different domain by default, and preserved if the policy is configured to disable cleanup."""

class MockTransport(HttpTransport):
def __init__(self, cleanup_disabled=False):
self._first = True
self._cleanup_disabled = cleanup_disabled

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def close(self):
pass

def open(self):
pass

def send(self, request, **kwargs):
if self._first:
self._first = False
assert request.headers["Authorization"] == "Bearer {}".format(auth_header)
response = Response()
response.status_code = 307
response.headers["location"] = "https://redirect-target.example.invalid"
return response

# Second request: after redirect
if self._cleanup_disabled:
assert request.headers.get("Authorization")
else:
assert not request.headers.get("Authorization")
response = Response()
response.status_code = 401
response.headers["WWW-Authenticate"] = (
'Bearer error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsiZm9vIjoiYmFyIn19"'
)
return response

auth_header = "token"
get_token_call_count = 0

def mock_get_token(*_, **__):
nonlocal get_token_call_count
get_token_call_count += 1
return AccessToken(auth_header, 0)

credential = Mock(spec_set=["get_token"], get_token=mock_get_token)
auth_policy = BearerTokenCredentialPolicy(credential, "scope")
redirect_policy = RedirectPolicy()
header_clean_up_policy = SensitiveHeaderCleanupPolicy()
pipeline = Pipeline(transport=MockTransport(), policies=[redirect_policy, auth_policy, header_clean_up_policy])
response = pipeline.run(HttpRequest("GET", "https://legitimate.azure.com"))
assert response.http_response.status_code == 401

header_clean_up_policy = SensitiveHeaderCleanupPolicy(disable_redirect_cleanup=True)
pipeline = Pipeline(
transport=MockTransport(cleanup_disabled=True), policies=[redirect_policy, auth_policy, header_clean_up_policy]
)
response = pipeline.run(HttpRequest("GET", "https://legitimate.azure.com"))
assert response.http_response.status_code == 401


def test_auth_header_stripped_after_cross_domain_redirect_with_retry():
"""After a cross-domain redirect, if the redirected-to endpoint returns a retryable status code,
the Authorization header should still be stripped on the retry attempt. This verifies that the
insecure_domain_change flag persists across retries so SensitiveHeaderCleanupPolicy continues to
remove the Authorization header."""

class MockTransport(HttpTransport):
def __init__(self):
self._request_count = 0

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def close(self):
pass

def open(self):
pass

def send(self, request, **kwargs):
self._request_count += 1

if self._request_count == 1:
# First request: to the original domain — should have auth header
assert request.headers.get("Authorization") == "Bearer {}".format(auth_header)
response = Response()
response.status_code = 307
response.headers["location"] = "https://redirect-target.example.invalid"
return response

if self._request_count == 2:
# Second request: after redirect to attacker domain — auth header should be stripped
assert not request.headers.get(
"Authorization"
), "Authorization header should be stripped on first request to redirected domain"
response = Response()
response.status_code = 500
return response

if self._request_count == 3:
# Third request: retry to attacker domain — auth header should STILL be stripped
assert not request.headers.get(
"Authorization"
), "Authorization header should be stripped on retry to redirected domain"
response = Response()
response.status_code = 200
return response

raise RuntimeError("Unexpected request count: {}".format(self._request_count))

auth_header = "token"
token = AccessToken(auth_header, 0)
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=token))
auth_policy = BearerTokenCredentialPolicy(credential, "scope")
redirect_policy = RedirectPolicy()
retry_policy = RetryPolicy(retry_total=1, retry_backoff_factor=0)
header_clean_up_policy = SensitiveHeaderCleanupPolicy()
transport = MockTransport()
# Pipeline order matches the real default: redirect -> retry -> auth -> ... -> sensitive header cleanup
pipeline = Pipeline(
transport=transport,
policies=[redirect_policy, retry_policy, auth_policy, header_clean_up_policy],
)
response = pipeline.run(HttpRequest("GET", "https://legitimate.azure.com"))
assert response.http_response.status_code == 200
assert transport._request_count == 3
Loading