diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 9bc8bcef8..6acc29ab7 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -147,11 +147,15 @@ def get_device(name: str) -> DeviceModel: return DeviceModel.from_device(device) -def submit_task(task_request: TaskRequest) -> str: +def submit_task( + task_request: TaskRequest, metadata: dict[str, Any] | None = None +) -> str: """Submit a task to be run on begin_task""" - metadata: dict[str, Any] = { - "instrument_session": task_request.instrument_session, - } + # Can't default arg to mutable data structure: + if metadata is None: + metadata = {} + + metadata["instrument_session"] = task_request.instrument_session if context().tiled_conf is not None: md = config().env.metadata # We raise an InvalidConfigError on setting tiled_conf if this isn't set diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 0e6faf9e6..aaaa66503 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -2,7 +2,7 @@ import urllib.parse from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager -from typing import Annotated +from typing import Annotated, Any import jwt from fastapi import ( @@ -114,7 +114,7 @@ def get_app(config: ApplicationConfig): ) dependencies = [] if config.oidc: - dependencies.append(Depends(verify_access_token(config.oidc))) + dependencies.append(Depends(decode_access_token(config.oidc))) app.swagger_ui_init_oauth = { "clientId": "NOT_SUPPORTED", } @@ -136,7 +136,7 @@ def get_app(config: ApplicationConfig): return app -def verify_access_token(config: OIDCConfig): +def decode_access_token(config: OIDCConfig): jwkclient = jwt.PyJWKClient(config.jwks_uri) oauth_scheme = OAuth2AuthorizationCodeBearer( authorizationUrl=config.authorization_endpoint, @@ -144,9 +144,9 @@ def verify_access_token(config: OIDCConfig): refreshUrl=config.token_endpoint, ) - def inner(access_token: str = Depends(oauth_scheme)): + def inner(request: Request, access_token: str = Depends(oauth_scheme)): signing_key = jwkclient.get_signing_key_from_jwt(access_token) - jwt.decode( + decoded: dict[str, Any] = jwt.decode( access_token, signing_key.key, algorithms=config.id_token_signing_alg_values_supported, @@ -154,6 +154,7 @@ def inner(access_token: str = Depends(oauth_scheme)): audience=config.client_audience, issuer=config.issuer, ) + request.state.decoded_access_token = decoded return inner @@ -283,7 +284,16 @@ def submit_task( ) -> TaskResponse: """Submit a task to the worker.""" try: - task_id: str = runner.run(interface.submit_task, task_request) + # Extract user from jwt if using OIDC (if jwt exists) + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + if access_token: + user: str = access_token.get("fedid", "Unknown") + else: + user = "Unknown" + + task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) response.headers["Location"] = f"{request.url}/{task_id}" return TaskResponse(task_id=task_id) except ValidationError as e: diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 3b3915349..6c8304cc1 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -265,6 +265,7 @@ def test_instrument_session_propagated(client: BlueapiClient): response = client.create_task(_SIMPLE_TASK) trackable_task = client.get_task(response.task_id) assert trackable_task.task.metadata == { + "user": "alice", "instrument_session": AUTHORIZED_INSTRUMENT_SESSION, "tiled_access_tags": [ '{"proposal": 12345, "visit": 1, "beamline": "adsim"}', diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py index e86dbc490..281c2be03 100644 --- a/tests/unit_tests/service/test_authentication.py +++ b/tests/unit_tests/service/test_authentication.py @@ -1,7 +1,7 @@ import os from pathlib import Path from typing import Any -from unittest.mock import patch +from unittest.mock import Mock, patch import jwt import pytest @@ -117,9 +117,9 @@ def test_poll_for_token_timeout( def test_server_raises_exception_for_invalid_token( oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock ): - inner = main.verify_access_token(oidc_config) + inner = main.decode_access_token(oidc_config) with pytest.raises(jwt.PyJWTError): - inner(access_token="Invalid Token") + inner(Mock(), access_token="Invalid Token") def test_processes_valid_token( @@ -127,8 +127,8 @@ def test_processes_valid_token( mock_authn_server: responses.RequestsMock, valid_token_with_jwt, ): - inner = main.verify_access_token(oidc_config) - inner(access_token=valid_token_with_jwt["access_token"]) + inner = main.decode_access_token(oidc_config) + inner(Mock(), access_token=valid_token_with_jwt["access_token"]) def test_session_cache_manager_returns_writable_file_path(tmp_path): diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index da6f1a4d9..ecefb26cc 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -337,7 +337,7 @@ def test_get_task_by_id( TaskRequest( name="my_plan", instrument_session=FAKE_INSTRUMENT_SESSION, - ) + ), ) expected_metadata: dict[str, Any] = { @@ -366,6 +366,36 @@ def test_get_task_by_id( ) +@patch("blueapi.service.interface.context") +def test_submit_task_inserts_metadata(context_mock: MagicMock): + context = BlueskyContext() + context.register_plan(my_plan) + context_mock.return_value = context + + metadata = {"foo": "bar"} + + task_id = interface.submit_task( + TaskRequest( + name="my_plan", + instrument_session=FAKE_INSTRUMENT_SESSION, + ), + metadata, + ) + + assert interface.get_task_by_id(task_id) == TrackableTask.model_construct( + task_id=task_id, + request_id=ANY, + task=Task( + name="my_plan", + params={}, + metadata=metadata, + ), + is_complete=False, + is_pending=True, + errors=[], + ) + + @patch("blueapi.service.interface.TiledWriter") @patch("blueapi.service.interface.from_uri") @patch("blueapi.service.interface.context") diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 8e89e7b6e..b209d110e 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -64,7 +64,10 @@ def client(mock_runner: Mock) -> Iterator[TestClient]: @pytest.fixture def client_with_auth( - mock_runner: Mock, oidc_config: OIDCConfig, valid_token_with_jwt: dict[str, Any] + mock_runner: Mock, + oidc_config: OIDCConfig, + valid_token_with_jwt: dict[str, Any], + mock_authn_server, ) -> Iterator[TestClient]: with patch("blueapi.service.interface.worker"): main.setup_runner(runner=mock_runner) @@ -248,10 +251,30 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: response = client.post("/tasks", json=task.model_dump()) - mock_runner.run.assert_called_with(submit_task, task) + mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"}) assert response.json() == {"task_id": task_id} +def test_create_task_inserts_auth_metadata( + mock_runner: Mock, + client_with_auth: TestClient, +) -> None: + task = TaskRequest( + name="count", + params={"detectors": ["x"]}, + instrument_session=FAKE_INSTRUMENT_SESSION, + ) + client_with_auth.follow_redirects = False + task_id = str(uuid.uuid4()) + + # mock_runner.run.side_effect = [task_id] + mock_runner.run.return_value = [task_id] + + client_with_auth.post("/tasks", json=task.model_dump()) + + mock_runner.run.assert_called_with(submit_task, task, {"user": "jd1"}) + + def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> None: mock_runner.run.side_effect = [ ValidationError.from_exception_data(