Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,15 @@ def get_device(name: str) -> DeviceModel:
return DeviceModel.from_device(device)


def submit_task(task_request: TaskRequest) -> str:
def submit_task(
task_request: TaskRequest, metadata: dict[str, Any] | None = None
) -> str:
"""Submit a task to be run on begin_task"""
metadata: dict[str, Any] = {
"instrument_session": task_request.instrument_session,
}
# Can't default arg to mutable data structure:
if metadata is None:
metadata = {}

metadata["instrument_session"] = task_request.instrument_session
if context().tiled_conf is not None:
md = config().env.metadata
# We raise an InvalidConfigError on setting tiled_conf if this isn't set
Expand Down
22 changes: 16 additions & 6 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import urllib.parse
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from typing import Annotated
from typing import Annotated, Any

import jwt
from fastapi import (
Expand Down Expand Up @@ -114,7 +114,7 @@ def get_app(config: ApplicationConfig):
)
dependencies = []
if config.oidc:
dependencies.append(Depends(verify_access_token(config.oidc)))
dependencies.append(Depends(decode_access_token(config.oidc)))
app.swagger_ui_init_oauth = {
"clientId": "NOT_SUPPORTED",
}
Expand All @@ -136,24 +136,25 @@ def get_app(config: ApplicationConfig):
return app


def verify_access_token(config: OIDCConfig):
def decode_access_token(config: OIDCConfig):
jwkclient = jwt.PyJWKClient(config.jwks_uri)
oauth_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=config.authorization_endpoint,
tokenUrl=config.token_endpoint,
refreshUrl=config.token_endpoint,
)

def inner(access_token: str = Depends(oauth_scheme)):
def inner(request: Request, access_token: str = Depends(oauth_scheme)):
signing_key = jwkclient.get_signing_key_from_jwt(access_token)
jwt.decode(
decoded: dict[str, Any] = jwt.decode(
access_token,
signing_key.key,
algorithms=config.id_token_signing_alg_values_supported,
verify=True,
audience=config.client_audience,
issuer=config.issuer,
)
request.state.decoded_access_token = decoded

return inner

Expand Down Expand Up @@ -283,7 +284,16 @@ def submit_task(
) -> TaskResponse:
"""Submit a task to the worker."""
try:
task_id: str = runner.run(interface.submit_task, task_request)
# Extract user from jwt if using OIDC (if jwt exists)
access_token: dict[str, Any] | None = getattr(
request.state, "decoded_access_token", None
)
if access_token:
user: str = access_token.get("fedid", "Unknown")
else:
user = "Unknown"

task_id: str = runner.run(interface.submit_task, task_request, {"user": user})
response.headers["Location"] = f"{request.url}/{task_id}"
return TaskResponse(task_id=task_id)
except ValidationError as e:
Expand Down
1 change: 1 addition & 0 deletions tests/system_tests/test_blueapi_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def test_instrument_session_propagated(client: BlueapiClient):
response = client.create_task(_SIMPLE_TASK)
trackable_task = client.get_task(response.task_id)
assert trackable_task.task.metadata == {
"user": "alice",
"instrument_session": AUTHORIZED_INSTRUMENT_SESSION,
"tiled_access_tags": [
'{"proposal": 12345, "visit": 1, "beamline": "adsim"}',
Expand Down
10 changes: 5 additions & 5 deletions tests/unit_tests/service/test_authentication.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from pathlib import Path
from typing import Any
from unittest.mock import patch
from unittest.mock import Mock, patch

import jwt
import pytest
Expand Down Expand Up @@ -117,18 +117,18 @@ def test_poll_for_token_timeout(
def test_server_raises_exception_for_invalid_token(
oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock
):
inner = main.verify_access_token(oidc_config)
inner = main.decode_access_token(oidc_config)
with pytest.raises(jwt.PyJWTError):
inner(access_token="Invalid Token")
inner(Mock(), access_token="Invalid Token")


def test_processes_valid_token(
oidc_config: OIDCConfig,
mock_authn_server: responses.RequestsMock,
valid_token_with_jwt,
):
inner = main.verify_access_token(oidc_config)
inner(access_token=valid_token_with_jwt["access_token"])
inner = main.decode_access_token(oidc_config)
inner(Mock(), access_token=valid_token_with_jwt["access_token"])


def test_session_cache_manager_returns_writable_file_path(tmp_path):
Expand Down
32 changes: 31 additions & 1 deletion tests/unit_tests/service/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def test_get_task_by_id(
TaskRequest(
name="my_plan",
instrument_session=FAKE_INSTRUMENT_SESSION,
)
),
)

expected_metadata: dict[str, Any] = {
Expand Down Expand Up @@ -366,6 +366,36 @@ def test_get_task_by_id(
)


@patch("blueapi.service.interface.context")
def test_submit_task_inserts_metadata(context_mock: MagicMock):
context = BlueskyContext()
context.register_plan(my_plan)
context_mock.return_value = context

metadata = {"foo": "bar"}

task_id = interface.submit_task(
TaskRequest(
name="my_plan",
instrument_session=FAKE_INSTRUMENT_SESSION,
),
metadata,
)

assert interface.get_task_by_id(task_id) == TrackableTask.model_construct(
task_id=task_id,
request_id=ANY,
task=Task(
name="my_plan",
params={},
metadata=metadata,
),
is_complete=False,
is_pending=True,
errors=[],
)


@patch("blueapi.service.interface.TiledWriter")
@patch("blueapi.service.interface.from_uri")
@patch("blueapi.service.interface.context")
Expand Down
27 changes: 25 additions & 2 deletions tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def client(mock_runner: Mock) -> Iterator[TestClient]:

@pytest.fixture
def client_with_auth(
mock_runner: Mock, oidc_config: OIDCConfig, valid_token_with_jwt: dict[str, Any]
mock_runner: Mock,
oidc_config: OIDCConfig,
valid_token_with_jwt: dict[str, Any],
mock_authn_server,
) -> Iterator[TestClient]:
with patch("blueapi.service.interface.worker"):
main.setup_runner(runner=mock_runner)
Expand Down Expand Up @@ -248,10 +251,30 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None:

response = client.post("/tasks", json=task.model_dump())

mock_runner.run.assert_called_with(submit_task, task)
mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"})
assert response.json() == {"task_id": task_id}


def test_create_task_inserts_auth_metadata(
mock_runner: Mock,
client_with_auth: TestClient,
) -> None:
task = TaskRequest(
name="count",
params={"detectors": ["x"]},
instrument_session=FAKE_INSTRUMENT_SESSION,
)
client_with_auth.follow_redirects = False
task_id = str(uuid.uuid4())

# mock_runner.run.side_effect = [task_id]
mock_runner.run.return_value = [task_id]

client_with_auth.post("/tasks", json=task.model_dump())

mock_runner.run.assert_called_with(submit_task, task, {"user": "jd1"})


def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> None:
mock_runner.run.side_effect = [
ValidationError.from_exception_data(
Expand Down