From 315b8e562d2ea17c2fbbc06636414b6b6744d968 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Wed, 4 Mar 2026 23:36:20 +0100 Subject: [PATCH 1/7] chore: removed deadcode, extracted copypasted code into separate funcs, updated docs --- backend/app/db/docs/saga.py | 5 +- backend/app/domain/events/typed.py | 28 ++-- backend/app/domain/saga/__init__.py | 2 - backend/app/domain/saga/models.py | 19 --- .../services/event_replay/replay_service.py | 62 ++++---- backend/app/services/notification_service.py | 71 ++++----- backend/app/services/saga/__init__.py | 3 +- backend/workers/bootstrap.py | 83 +++++++++++ backend/workers/run_dlq_processor.py | 85 ++++------- backend/workers/run_event_replay.py | 81 ++++------- backend/workers/run_k8s_worker.py | 64 +++------ backend/workers/run_pod_monitor.py | 136 +++++++----------- backend/workers/run_result_processor.py | 50 ++----- backend/workers/run_saga_orchestrator.py | 102 +++++-------- docs/architecture/event-system-design.md | 22 +++ docs/architecture/lifecycle.md | 51 ++++--- docs/architecture/services-overview.md | 2 +- docs/components/workers/index.md | 4 + 18 files changed, 409 insertions(+), 461 deletions(-) create mode 100644 backend/workers/bootstrap.py diff --git a/backend/app/db/docs/saga.py b/backend/app/db/docs/saga.py index 10669b67..1c64bb05 100644 --- a/backend/app/db/docs/saga.py +++ b/backend/app/db/docs/saga.py @@ -11,10 +11,7 @@ class SagaDocument(Document): - """Domain model for saga stored in database. - - Copied from Saga/SagaInstance dataclass. - """ + """Domain model for saga stored in database.""" saga_id: Indexed(str, unique=True) = Field(default_factory=lambda: str(uuid4())) # type: ignore[valid-type] saga_name: Indexed(str) # type: ignore[valid-type] diff --git a/backend/app/domain/events/typed.py b/backend/app/domain/events/typed.py index 531cf125..5851baf3 100644 --- a/backend/app/domain/events/typed.py +++ b/backend/app/domain/events/typed.py @@ -58,11 +58,12 @@ class BaseEvent(BaseModel): metadata: EventMetadata -# --- Execution Events --- +# --- Execution Spec (shared fields between ExecutionRequestedEvent and CreatePodCommandEvent) --- -class ExecutionRequestedEvent(BaseEvent): - event_type: Literal[EventType.EXECUTION_REQUESTED] = EventType.EXECUTION_REQUESTED +class ExecutionSpec(BaseModel): + """Shared execution specification fields (mixin for ExecutionRequestedEvent and CreatePodCommandEvent).""" + execution_id: str script: str language: str @@ -78,6 +79,13 @@ class ExecutionRequestedEvent(BaseEvent): priority: QueuePriority = QueuePriority.NORMAL +# --- Execution Events --- + + +class ExecutionRequestedEvent(BaseEvent, ExecutionSpec): + event_type: Literal[EventType.EXECUTION_REQUESTED] = EventType.EXECUTION_REQUESTED + + class ExecutionAcceptedEvent(BaseEvent): event_type: Literal[EventType.EXECUTION_ACCEPTED] = EventType.EXECUTION_ACCEPTED execution_id: str @@ -413,22 +421,10 @@ class SagaCompensatedEvent(BaseEvent): # --- Saga Command Events --- -class CreatePodCommandEvent(BaseEvent): +class CreatePodCommandEvent(BaseEvent, ExecutionSpec): event_type: Literal[EventType.CREATE_POD_COMMAND] = EventType.CREATE_POD_COMMAND saga_id: str - execution_id: str - script: str - language: str - language_version: str - runtime_image: str runtime_command: list[str] = Field(default_factory=list) - runtime_filename: str - timeout_seconds: int - cpu_limit: str - memory_limit: str - cpu_request: str - memory_request: str - priority: QueuePriority = QueuePriority.NORMAL class DeletePodCommandEvent(BaseEvent): diff --git a/backend/app/domain/saga/__init__.py b/backend/app/domain/saga/__init__.py index d132dccf..c764a706 100644 --- a/backend/app/domain/saga/__init__.py +++ b/backend/app/domain/saga/__init__.py @@ -13,7 +13,6 @@ SagaConfig, SagaContextData, SagaFilter, - SagaInstance, SagaListResult, SagaQuery, ) @@ -25,7 +24,6 @@ "SagaCancellationResult", "SagaConfig", "SagaContextData", - "SagaInstance", "SagaFilter", "SagaListResult", "SagaQuery", diff --git a/backend/app/domain/saga/models.py b/backend/app/domain/saga/models.py index c419b04d..4a117137 100644 --- a/backend/app/domain/saga/models.py +++ b/backend/app/domain/saga/models.py @@ -92,25 +92,6 @@ class SagaConfig: publish_commands: bool = True -@dataclass -class SagaInstance: - """Runtime instance of a saga execution (domain).""" - - saga_name: str - execution_id: str - state: SagaState = SagaState.CREATED - saga_id: str = field(default_factory=lambda: str(uuid4())) - current_step: str | None = None - completed_steps: list[str] = field(default_factory=list) - compensated_steps: list[str] = field(default_factory=list) - context_data: SagaContextData = field(default_factory=SagaContextData) - error_message: str | None = None - created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - completed_at: datetime | None = None - retry_count: int = 0 - - @dataclass class SagaCancellationResult: """Domain result for saga cancellation operations.""" diff --git a/backend/app/services/event_replay/replay_service.py b/backend/app/services/event_replay/replay_service.py index b92b5e73..3e6316ef 100644 --- a/backend/app/services/event_replay/replay_service.py +++ b/backend/app/services/event_replay/replay_service.py @@ -183,19 +183,45 @@ async def cleanup_old_sessions(self, older_than_hours: int = 24) -> CleanupResul self.logger.info("Cleaned up old replay sessions", removed_count=total_removed) return CleanupResult(removed_sessions=total_removed, message=f"Removed {total_removed} old sessions") + async def _load_next_event(self, session: ReplaySessionState) -> DomainEvent | None: + """Pop the next event from the buffer, loading a new batch if needed.""" + event = self._pop_next_event(session.session_id) + if event is not None: + return event + if not await self._load_next_batch(session.session_id): + return None + return self._pop_next_event(session.session_id) + + def _calculate_replay_delay(self, session: ReplaySessionState) -> float: + """Calculate the delay before dispatching the next event based on speed multiplier.""" + next_event = self._peek_next_event(session.session_id) + if next_event and session.last_event_at and session.config.speed_multiplier < 100: + time_diff = (next_event.timestamp - session.last_event_at).total_seconds() + return max(time_diff / session.config.speed_multiplier, 0) + return 0.0 + + def _reschedule_dispatch(self, session: ReplaySessionState, delay: float) -> None: + """Schedule the next _dispatch_next call if the session is still running.""" + scheduler = self._schedulers.get(session.session_id) + if scheduler and scheduler.running and session.status == ReplayStatus.RUNNING: + scheduler.add_job( + self._dispatch_next, + trigger="date", + run_date=datetime.now(timezone.utc) + timedelta(seconds=delay), + args=[session], + id=f"dispatch_{session.session_id}", + replace_existing=True, + misfire_grace_time=None, + ) + async def _dispatch_next(self, session: ReplaySessionState) -> None: if session.status != ReplayStatus.RUNNING: return - event = self._pop_next_event(session.session_id) + event = await self._load_next_event(session) if event is None: - if not await self._load_next_batch(session.session_id): - await self._finalize_session(session, ReplayStatus.COMPLETED) - return - event = self._pop_next_event(session.session_id) - if event is None: - await self._finalize_session(session, ReplayStatus.COMPLETED) - return + await self._finalize_session(session, ReplayStatus.COMPLETED) + return buf = self._event_buffers.get(session.session_id, []) idx = self._buffer_indices.get(session.session_id, 0) @@ -229,26 +255,10 @@ async def _dispatch_next(self, session: ReplaySessionState) -> None: session.last_event_at = event.timestamp await self._update_session_in_db(session) - next_event = self._peek_next_event(session.session_id) - delay = 0.0 - if next_event and session.last_event_at and session.config.speed_multiplier < 100: - time_diff = (next_event.timestamp - session.last_event_at).total_seconds() - delay = max(time_diff / session.config.speed_multiplier, 0) - + delay = self._calculate_replay_delay(session) if delay > 0: self._metrics.record_delay_applied(delay) - - scheduler = self._schedulers.get(session.session_id) - if scheduler and scheduler.running and session.status == ReplayStatus.RUNNING: - scheduler.add_job( - self._dispatch_next, - trigger="date", - run_date=datetime.now(timezone.utc) + timedelta(seconds=delay), - args=[session], - id=f"dispatch_{session.session_id}", - replace_existing=True, - misfire_grace_time=None, - ) + self._reschedule_dispatch(session, delay) def _pop_next_event(self, session_id: str) -> DomainEvent | None: idx = self._buffer_indices.get(session_id, 0) diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index b9a63b83..52df3a92 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -122,6 +122,38 @@ def __init__( } # --8<-- [end:channel_handlers] + def _validate_scheduled_time(self, scheduled_for: datetime) -> None: + """Validate that scheduled_for is in the future and within the max schedule window.""" + if scheduled_for < datetime.now(UTC): + raise NotificationValidationError("scheduled_for must be in the future") + max_days = self.settings.NOTIF_MAX_SCHEDULE_DAYS + max_schedule = datetime.now(UTC) + timedelta(days=max_days) + if scheduled_for > max_schedule: + raise NotificationValidationError(f"scheduled_for cannot exceed {max_days} days from now") + + async def _check_throttle(self, user_id: str, severity: NotificationSeverity, source: str) -> None: + """Check throttle and raise NotificationThrottledError if rate limit exceeded.""" + if self.settings.ENVIRONMENT == "test": + return + throttled = await self._throttle_cache.check_throttle( + user_id, + severity, + window_hours=self.settings.NOTIF_THROTTLE_WINDOW_HOURS, + max_per_hour=self.settings.NOTIF_THROTTLE_MAX_PER_HOUR, + ) + if throttled: + self.logger.warning( + f"Notification rate limit exceeded for user {user_id}. " + f"Max {self.settings.NOTIF_THROTTLE_MAX_PER_HOUR} " + f"per {self.settings.NOTIF_THROTTLE_WINDOW_HOURS} hour(s)" + ) + self.metrics.record_notification_throttled(source) + raise NotificationThrottledError( + user_id, + self.settings.NOTIF_THROTTLE_MAX_PER_HOUR, + self.settings.NOTIF_THROTTLE_WINDOW_HOURS, + ) + async def create_notification( self, user_id: str, @@ -137,14 +169,7 @@ async def create_notification( if not tags: raise NotificationValidationError("tags must be a non-empty list") if scheduled_for is not None: - if scheduled_for < datetime.now(UTC): - raise NotificationValidationError("scheduled_for must be in the future") - max_days = self.settings.NOTIF_MAX_SCHEDULE_DAYS - max_schedule = datetime.now(UTC) + timedelta(days=max_days) - if scheduled_for > max_schedule: - raise NotificationValidationError( - f"scheduled_for cannot exceed {max_days} days from now" - ) + self._validate_scheduled_time(scheduled_for) self.logger.info( f"Creating notification for user {user_id}", user_id=user_id, @@ -154,25 +179,7 @@ async def create_notification( scheduled=scheduled_for is not None, ) - # Check throttling - if self.settings.ENVIRONMENT != "test" and await self._throttle_cache.check_throttle( - user_id, - severity, - window_hours=self.settings.NOTIF_THROTTLE_WINDOW_HOURS, - max_per_hour=self.settings.NOTIF_THROTTLE_MAX_PER_HOUR, - ): - error_msg = ( - f"Notification rate limit exceeded for user {user_id}. " - f"Max {self.settings.NOTIF_THROTTLE_MAX_PER_HOUR} " - f"per {self.settings.NOTIF_THROTTLE_WINDOW_HOURS} hour(s)" - ) - self.logger.warning(error_msg) - self.metrics.record_notification_throttled("general") - raise NotificationThrottledError( - user_id, - self.settings.NOTIF_THROTTLE_MAX_PER_HOUR, - self.settings.NOTIF_THROTTLE_WINDOW_HOURS, - ) + await self._check_throttle(user_id, severity, "general") # Create notification create_data = DomainNotificationCreate( @@ -290,13 +297,9 @@ async def _create_system_for_user( ) -> str: try: if not cfg.throttle_exempt: - throttled = await self._throttle_cache.check_throttle( - user_id, - cfg.severity, - window_hours=self.settings.NOTIF_THROTTLE_WINDOW_HOURS, - max_per_hour=self.settings.NOTIF_THROTTLE_MAX_PER_HOUR, - ) - if throttled: + try: + await self._check_throttle(user_id, cfg.severity, "system") + except NotificationThrottledError: return "throttled" await self.create_notification( diff --git a/backend/app/services/saga/__init__.py b/backend/app/services/saga/__init__.py index 045d66aa..7c085bc8 100644 --- a/backend/app/services/saga/__init__.py +++ b/backend/app/services/saga/__init__.py @@ -1,5 +1,5 @@ from app.domain.enums import SagaState -from app.domain.saga import SagaConfig, SagaInstance +from app.domain.saga import SagaConfig from app.services.saga.execution_saga import ( AllocateResourcesStep, CreatePodStep, @@ -17,7 +17,6 @@ "SagaService", "SagaConfig", "SagaState", - "SagaInstance", "SagaContext", "SagaStep", "CompensationStep", diff --git a/backend/workers/bootstrap.py b/backend/workers/bootstrap.py new file mode 100644 index 00000000..a2e3000c --- /dev/null +++ b/backend/workers/bootstrap.py @@ -0,0 +1,83 @@ +import asyncio +from collections.abc import Awaitable, Callable +from typing import Any + +import structlog +from app.core.logging import setup_log_exporter, setup_logger +from app.db.docs import ALL_DOCUMENTS +from app.settings import Settings +from beanie import init_beanie +from dishka import AsyncContainer +from dishka.integrations.faststream import setup_dishka +from faststream import FastStream +from faststream.kafka import KafkaBroker +from pymongo import AsyncMongoClient + + +def run_worker( + worker_name: str, + config_override: str, + container_factory: Callable[[Settings], AsyncContainer], + register_handlers: Callable[[KafkaBroker], None] | None = None, + on_startup: Callable[[AsyncContainer, KafkaBroker, structlog.stdlib.BoundLogger], Awaitable[None]] | None = None, + on_shutdown: Callable[[], Awaitable[None]] | None = None, +) -> None: + """Boot a worker with standardised init sequence. + + Parameters + ---------- + worker_name: + Human-readable name used in log messages. + config_override: + TOML filename passed to ``Settings(override_path=...)``. + container_factory: + Dishka container factory — receives ``Settings``, returns ``AsyncContainer``. + register_handlers: + Optional callback to register ``@broker.subscriber`` handlers **before** + ``setup_dishka`` is called (required for subscriber auto-injection). + on_startup: + Optional async callback executed **after** the broker is ready. + Receives ``(container, broker, logger)`` so it can resolve services + and wire up APScheduler jobs, K8s setup, etc. + on_shutdown: + Optional async callback executed on FastStream shutdown **before** + the container is closed. Use for scheduler teardown etc. + """ + settings = Settings(override_path=config_override) + + logger = setup_logger(settings.LOG_LEVEL) + setup_log_exporter(settings, logger) + + logger.info(f"Starting {worker_name}...") + + async def _run() -> None: + client: AsyncMongoClient[dict[str, Any]] = AsyncMongoClient(settings.MONGODB_URL, tz_aware=True) + await init_beanie( + database=client.get_default_database(default=settings.DATABASE_NAME), + document_models=ALL_DOCUMENTS, + ) + logger.info("MongoDB initialized via Beanie") + + container = container_factory(settings) + + broker: KafkaBroker = await container.get(KafkaBroker) + + if register_handlers is not None: + register_handlers(broker) + setup_dishka(container, broker=broker, auto_inject=True) + + startup_hooks: list[Callable[[], Awaitable[None]]] = [] + shutdown_hooks: list[Callable[[], Awaitable[None]]] = [] + + if on_startup is not None: + startup_hooks.append(lambda: on_startup(container, broker, logger)) + + if on_shutdown is not None: + shutdown_hooks.append(on_shutdown) + shutdown_hooks.append(container.close) + + app = FastStream(broker, on_startup=startup_hooks, on_shutdown=shutdown_hooks) + await app.run() + logger.info(f"{worker_name} shutdown complete") + + asyncio.run(_run()) diff --git a/backend/workers/run_dlq_processor.py b/backend/workers/run_dlq_processor.py index 19a09dea..d22b641f 100644 --- a/backend/workers/run_dlq_processor.py +++ b/backend/workers/run_dlq_processor.py @@ -1,71 +1,44 @@ -import asyncio -from typing import Any - +import structlog from app.core.container import create_dlq_processor_container -from app.core.logging import setup_log_exporter, setup_logger -from app.db.docs import ALL_DOCUMENTS from app.dlq.manager import DLQManager -from app.settings import Settings from apscheduler.schedulers.asyncio import AsyncIOScheduler -from beanie import init_beanie -from dishka.integrations.faststream import setup_dishka -from faststream import FastStream +from dishka import AsyncContainer from faststream.kafka import KafkaBroker -from pymongo import AsyncMongoClient - - -def main() -> None: - """Main entry point for DLQ processor worker.""" - settings = Settings(override_path="config.dlq-processor.toml") - logger = setup_logger(settings.LOG_LEVEL) - setup_log_exporter(settings, logger) +from workers.bootstrap import run_worker - logger.info("Starting DLQ Processor worker...") +_scheduler = AsyncIOScheduler() - async def run() -> None: - # Initialize Beanie with tz_aware client (so MongoDB returns aware datetimes) - client: AsyncMongoClient[dict[str, Any]] = AsyncMongoClient(settings.MONGODB_URL, tz_aware=True) - await init_beanie( - database=client.get_default_database(default=settings.DATABASE_NAME), - document_models=ALL_DOCUMENTS, - ) - logger.info("MongoDB initialized via Beanie") - # Create DI container - container = create_dlq_processor_container(settings) +async def _on_startup( + container: AsyncContainer, broker: KafkaBroker, logger: structlog.stdlib.BoundLogger +) -> None: + manager = await container.get(DLQManager) + _scheduler.add_job( + manager.process_monitoring_cycle, + trigger="interval", + seconds=10, + id="dlq_monitor_retries", + max_instances=1, + misfire_grace_time=60, + ) + _scheduler.start() + logger.info("DLQ Processor initialized (APScheduler interval=10s)") - # Get broker from DI - broker: KafkaBroker = await container.get(KafkaBroker) - # Set up DI integration (no subscribers — DLQ manager uses APScheduler, - # broker is only needed for publishing retry/status events) - setup_dishka(container, broker=broker, auto_inject=True) +async def _on_shutdown() -> None: + _scheduler.shutdown(wait=False) - scheduler = AsyncIOScheduler() - async def init_dlq() -> None: - manager = await container.get(DLQManager) - scheduler.add_job( - manager.process_monitoring_cycle, - trigger="interval", - seconds=10, - id="dlq_monitor_retries", - max_instances=1, - misfire_grace_time=60, - ) - scheduler.start() - logger.info("DLQ Processor initialized (APScheduler interval=10s)") - - async def shutdown() -> None: - scheduler.shutdown(wait=False) - await container.close() - - app = FastStream(broker, on_startup=[init_dlq], on_shutdown=[shutdown]) - await app.run() - logger.info("DLQ Processor shutdown complete") - - asyncio.run(run()) +def main() -> None: + """Main entry point for DLQ processor worker.""" + run_worker( + worker_name="DLQ Processor", + config_override="config.dlq-processor.toml", + container_factory=create_dlq_processor_container, + on_startup=_on_startup, + on_shutdown=_on_shutdown, + ) if __name__ == "__main__": diff --git a/backend/workers/run_event_replay.py b/backend/workers/run_event_replay.py index 0ce4279d..032cc1a4 100644 --- a/backend/workers/run_event_replay.py +++ b/backend/workers/run_event_replay.py @@ -1,66 +1,45 @@ -import asyncio -from typing import Any - +import structlog from app.core.container import create_event_replay_container -from app.core.logging import setup_log_exporter, setup_logger -from app.db.docs import ALL_DOCUMENTS from app.services.event_replay import EventReplayService -from app.settings import Settings from apscheduler.schedulers.asyncio import AsyncIOScheduler -from beanie import init_beanie -from dishka.integrations.faststream import setup_dishka -from faststream import FastStream +from dishka import AsyncContainer from faststream.kafka import KafkaBroker -from pymongo import AsyncMongoClient - - -def main() -> None: - """Main entry point for event replay service""" - settings = Settings(override_path="config.event-replay.toml") - - logger = setup_logger(settings.LOG_LEVEL) - setup_log_exporter(settings, logger) - logger.info("Starting Event Replay Service...") +from workers.bootstrap import run_worker - async def run() -> None: - # Initialize Beanie with tz_aware client (so MongoDB returns aware datetimes) - client: AsyncMongoClient[dict[str, Any]] = AsyncMongoClient(settings.MONGODB_URL, tz_aware=True) - await init_beanie( - database=client.get_default_database(default=settings.DATABASE_NAME), - document_models=ALL_DOCUMENTS, - ) - logger.info("MongoDB initialized via Beanie") +_scheduler = AsyncIOScheduler() - container = create_event_replay_container(settings) - broker: KafkaBroker = await container.get(KafkaBroker) - setup_dishka(container, broker=broker, auto_inject=True) - scheduler = AsyncIOScheduler() +async def _on_startup( + container: AsyncContainer, broker: KafkaBroker, logger: structlog.stdlib.BoundLogger +) -> None: + service = await container.get(EventReplayService) + _scheduler.add_job( + service.cleanup_old_sessions, + trigger="interval", + hours=6, + kwargs={"older_than_hours": 48}, + id="replay_cleanup_old_sessions", + max_instances=1, + misfire_grace_time=300, + ) + _scheduler.start() + logger.info("Event replay service initialized (APScheduler interval=6h)") - async def init_replay() -> None: - service = await container.get(EventReplayService) - scheduler.add_job( - service.cleanup_old_sessions, - trigger="interval", - hours=6, - kwargs={"older_than_hours": 48}, - id="replay_cleanup_old_sessions", - max_instances=1, - misfire_grace_time=300, - ) - scheduler.start() - logger.info("Event replay service initialized (APScheduler interval=6h)") - async def shutdown() -> None: - scheduler.shutdown(wait=False) - await container.close() +async def _on_shutdown() -> None: + _scheduler.shutdown(wait=False) - app = FastStream(broker, on_startup=[init_replay], on_shutdown=[shutdown]) - await app.run() - logger.info("EventReplayService shutdown complete") - asyncio.run(run()) +def main() -> None: + """Main entry point for event replay service""" + run_worker( + worker_name="EventReplayService", + config_override="config.event-replay.toml", + container_factory=create_event_replay_container, + on_startup=_on_startup, + on_shutdown=_on_shutdown, + ) if __name__ == "__main__": diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index 4e47d5b6..a25be7b6 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -1,58 +1,32 @@ -import asyncio -from typing import Any +import structlog from app.core.container import create_k8s_worker_container -from app.core.logging import setup_log_exporter, setup_logger -from app.db.docs import ALL_DOCUMENTS from app.events.handlers import register_k8s_worker_subscriber from app.services.k8s_worker import KubernetesWorker -from app.settings import Settings -from beanie import init_beanie -from dishka.integrations.faststream import setup_dishka -from faststream import FastStream +from dishka import AsyncContainer from faststream.kafka import KafkaBroker -from pymongo import AsyncMongoClient +from workers.bootstrap import run_worker -def main() -> None: - """Main entry point for Kubernetes worker""" - settings = Settings(override_path="config.k8s-worker.toml") - - logger = setup_logger(settings.LOG_LEVEL) - setup_log_exporter(settings, logger) - - logger.info("Starting KubernetesWorker...") - - async def run() -> None: - # Initialize Beanie with tz_aware client (so MongoDB returns aware datetimes) - client: AsyncMongoClient[dict[str, Any]] = AsyncMongoClient(settings.MONGODB_URL, tz_aware=True) - await init_beanie( - database=client.get_default_database(default=settings.DATABASE_NAME), - document_models=ALL_DOCUMENTS, - ) - logger.info("MongoDB initialized via Beanie") - # Create DI container - container = create_k8s_worker_container(settings) +async def _on_startup( + container: AsyncContainer, broker: KafkaBroker, logger: structlog.stdlib.BoundLogger +) -> None: + worker = await container.get(KubernetesWorker) + await worker.ensure_namespace_security() + await worker.ensure_image_pre_puller_daemonset() + logger.info("KubernetesWorker initialized with namespace security and pre-puller daemonset") - # Get broker from DI - broker: KafkaBroker = await container.get(KafkaBroker) - # Register subscriber and set up DI integration - register_k8s_worker_subscriber(broker) - setup_dishka(container, broker=broker, auto_inject=True) - - async def init_k8s_worker() -> None: - worker = await container.get(KubernetesWorker) - await worker.ensure_namespace_security() - await worker.ensure_image_pre_puller_daemonset() - logger.info("KubernetesWorker initialized with namespace security and pre-puller daemonset") - - app = FastStream(broker, on_startup=[init_k8s_worker], on_shutdown=[container.close]) - await app.run() - logger.info("KubernetesWorker shutdown complete") - - asyncio.run(run()) +def main() -> None: + """Main entry point for Kubernetes worker""" + run_worker( + worker_name="KubernetesWorker", + config_override="config.k8s-worker.toml", + container_factory=create_k8s_worker_container, + register_handlers=register_k8s_worker_subscriber, + on_startup=_on_startup, + ) if __name__ == "__main__": diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index c89d3f54..139d6b99 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -1,94 +1,68 @@ -import asyncio -from typing import Any - +import structlog from app.core.container import create_pod_monitor_container -from app.core.logging import setup_log_exporter, setup_logger from app.core.metrics import KubernetesMetrics -from app.db.docs import ALL_DOCUMENTS from app.services.pod_monitor import ErrorType, PodMonitor -from app.settings import Settings from apscheduler.schedulers.asyncio import AsyncIOScheduler -from beanie import init_beanie -from dishka.integrations.faststream import setup_dishka -from faststream import FastStream +from dishka import AsyncContainer from faststream.kafka import KafkaBroker from kubernetes_asyncio.client.rest import ApiException -from pymongo import AsyncMongoClient + +from workers.bootstrap import run_worker + +_scheduler = AsyncIOScheduler() + + +async def _on_startup( + container: AsyncContainer, broker: KafkaBroker, logger: structlog.stdlib.BoundLogger +) -> None: + monitor = await container.get(PodMonitor) + kubernetes_metrics = await container.get(KubernetesMetrics) + + async def _watch_cycle() -> None: + error_type: ErrorType | None = None + try: + await monitor.watch_pod_events() + except ApiException as e: + if e.status == 410: + logger.warning("Resource version expired, resetting watch cursor") + monitor._last_resource_version = None + error_type = ErrorType.RESOURCE_VERSION_EXPIRED + else: + logger.error("API error in watch", status=e.status, reason=e.reason) + error_type = ErrorType.API_ERROR + except Exception: + logger.error("Unexpected error in watch", exc_info=True) + error_type = ErrorType.UNEXPECTED + + if error_type is not None: + kubernetes_metrics.record_pod_monitor_watch_error(error_type) + kubernetes_metrics.increment_pod_monitor_watch_reconnects() + + _scheduler.add_job( + _watch_cycle, + trigger="interval", + seconds=5, + id="pod_monitor_watch", + max_instances=1, + misfire_grace_time=60, + ) + _scheduler.start() + logger.info("PodMonitor initialized (APScheduler interval=5s)") + + +async def _on_shutdown() -> None: + _scheduler.shutdown(wait=False) def main() -> None: """Main entry point for pod monitor worker""" - settings = Settings(override_path="config.pod-monitor.toml") - - logger = setup_logger(settings.LOG_LEVEL) - setup_log_exporter(settings, logger) - - logger.info("Starting PodMonitor worker...") - - async def run() -> None: - # Initialize Beanie with tz_aware client (so MongoDB returns aware datetimes) - client: AsyncMongoClient[dict[str, Any]] = AsyncMongoClient(settings.MONGODB_URL, tz_aware=True) - await init_beanie( - database=client.get_default_database(default=settings.DATABASE_NAME), - document_models=ALL_DOCUMENTS, - ) - logger.info("MongoDB initialized via Beanie") - - # Create DI container - container = create_pod_monitor_container(settings) - - # Get broker from DI (PodMonitor publishes events via KafkaEventService) - broker: KafkaBroker = await container.get(KafkaBroker) - - # Set up DI integration (no subscribers for pod monitor - it only publishes) - setup_dishka(container, broker=broker, auto_inject=True) - - scheduler = AsyncIOScheduler() - - async def init_monitor() -> None: - monitor = await container.get(PodMonitor) - kubernetes_metrics = await container.get(KubernetesMetrics) - - async def _watch_cycle() -> None: - error_type: ErrorType | None = None - try: - await monitor.watch_pod_events() - except ApiException as e: - if e.status == 410: - logger.warning("Resource version expired, resetting watch cursor") - monitor._last_resource_version = None - error_type = ErrorType.RESOURCE_VERSION_EXPIRED - else: - logger.error("API error in watch", status=e.status, reason=e.reason) - error_type = ErrorType.API_ERROR - except Exception: - logger.error("Unexpected error in watch", exc_info=True) - error_type = ErrorType.UNEXPECTED - - if error_type is not None: - kubernetes_metrics.record_pod_monitor_watch_error(error_type) - kubernetes_metrics.increment_pod_monitor_watch_reconnects() - - scheduler.add_job( - _watch_cycle, - trigger="interval", - seconds=5, - id="pod_monitor_watch", - max_instances=1, - misfire_grace_time=60, - ) - scheduler.start() - logger.info("PodMonitor initialized (APScheduler interval=5s)") - - async def shutdown() -> None: - scheduler.shutdown(wait=False) - await container.close() - - app = FastStream(broker, on_startup=[init_monitor], on_shutdown=[shutdown]) - await app.run() - logger.info("PodMonitor shutdown complete") - - asyncio.run(run()) + run_worker( + worker_name="PodMonitor", + config_override="config.pod-monitor.toml", + container_factory=create_pod_monitor_container, + on_startup=_on_startup, + on_shutdown=_on_shutdown, + ) if __name__ == "__main__": diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index f0952b75..e5cf626e 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -1,51 +1,17 @@ -import asyncio -from typing import Any - from app.core.container import create_result_processor_container -from app.core.logging import setup_log_exporter, setup_logger -from app.db.docs import ALL_DOCUMENTS from app.events.handlers import register_result_processor_subscriber -from app.settings import Settings -from beanie import init_beanie -from dishka.integrations.faststream import setup_dishka -from faststream import FastStream -from faststream.kafka import KafkaBroker -from pymongo import AsyncMongoClient + +from workers.bootstrap import run_worker def main() -> None: """Main entry point for result processor worker""" - settings = Settings(override_path="config.result-processor.toml") - - logger = setup_logger(settings.LOG_LEVEL) - setup_log_exporter(settings, logger) - - logger.info("Starting ResultProcessor worker...") - - async def run() -> None: - # Initialize Beanie with tz_aware client (so MongoDB returns aware datetimes) - client: AsyncMongoClient[dict[str, Any]] = AsyncMongoClient(settings.MONGODB_URL, tz_aware=True) - await init_beanie( - database=client.get_default_database(default=settings.DATABASE_NAME), - document_models=ALL_DOCUMENTS, - ) - logger.info("MongoDB initialized via Beanie") - - # Create DI container - container = create_result_processor_container(settings) - - # Get broker from DI - broker: KafkaBroker = await container.get(KafkaBroker) - - # Register subscriber and set up DI integration - register_result_processor_subscriber(broker) - setup_dishka(container, broker=broker, auto_inject=True) - - app = FastStream(broker, on_shutdown=[container.close]) - await app.run() - logger.info("ResultProcessor shutdown complete") - - asyncio.run(run()) + run_worker( + worker_name="ResultProcessor", + config_override="config.result-processor.toml", + container_factory=create_result_processor_container, + register_handlers=register_result_processor_subscriber, + ) if __name__ == "__main__": diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 3a38fbb3..c73bdf91 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -1,80 +1,54 @@ -import asyncio -from typing import Any - +import structlog from app.core.container import create_saga_orchestrator_container -from app.core.logging import setup_log_exporter, setup_logger -from app.db.docs import ALL_DOCUMENTS from app.events.handlers import register_saga_subscriber from app.services.saga import SagaOrchestrator -from app.settings import Settings from apscheduler.schedulers.asyncio import AsyncIOScheduler -from beanie import init_beanie -from dishka.integrations.faststream import setup_dishka -from faststream import FastStream +from dishka import AsyncContainer from faststream.kafka import KafkaBroker -from pymongo import AsyncMongoClient - - -def main() -> None: - """Main entry point for saga orchestrator worker""" - settings = Settings(override_path="config.saga-orchestrator.toml") - logger = setup_logger(settings.LOG_LEVEL) - setup_log_exporter(settings, logger) +from workers.bootstrap import run_worker - logger.info("Starting Saga Orchestrator worker...") +_scheduler = AsyncIOScheduler() - async def run() -> None: - # Initialize Beanie with tz_aware client (so MongoDB returns aware datetimes) - client: AsyncMongoClient[dict[str, Any]] = AsyncMongoClient(settings.MONGODB_URL, tz_aware=True) - await init_beanie( - database=client.get_default_database(default=settings.DATABASE_NAME), - document_models=ALL_DOCUMENTS, - ) - logger.info("MongoDB initialized via Beanie") - # Create DI container - container = create_saga_orchestrator_container(settings) +async def _on_startup( + container: AsyncContainer, broker: KafkaBroker, logger: structlog.stdlib.BoundLogger +) -> None: + orchestrator = await container.get(SagaOrchestrator) + _scheduler.add_job( + orchestrator.check_timeouts, + trigger="interval", + seconds=30, + id="saga_check_timeouts", + max_instances=1, + misfire_grace_time=60, + ) + _scheduler.add_job( + orchestrator.try_schedule_from_queue, + trigger="interval", + seconds=10, + id="saga_try_schedule", + max_instances=1, + misfire_grace_time=30, + ) + _scheduler.start() + logger.info("SagaOrchestrator initialized (APScheduler: timeouts=30s, scheduling=10s)") - # Get broker from DI - broker: KafkaBroker = await container.get(KafkaBroker) - # Register subscriber and set up DI integration - register_saga_subscriber(broker) - setup_dishka(container, broker=broker, auto_inject=True) +async def _on_shutdown() -> None: + _scheduler.shutdown(wait=False) - scheduler = AsyncIOScheduler() - async def init_saga() -> None: - orchestrator = await container.get(SagaOrchestrator) - scheduler.add_job( - orchestrator.check_timeouts, - trigger="interval", - seconds=30, - id="saga_check_timeouts", - max_instances=1, - misfire_grace_time=60, - ) - scheduler.add_job( - orchestrator.try_schedule_from_queue, - trigger="interval", - seconds=10, - id="saga_try_schedule", - max_instances=1, - misfire_grace_time=30, - ) - scheduler.start() - logger.info("SagaOrchestrator initialized (APScheduler: timeouts=30s, scheduling=10s)") - - async def shutdown() -> None: - scheduler.shutdown(wait=False) - await container.close() - - app = FastStream(broker, on_startup=[init_saga], on_shutdown=[shutdown]) - await app.run() - logger.info("SagaOrchestrator shutdown complete") - - asyncio.run(run()) +def main() -> None: + """Main entry point for saga orchestrator worker""" + run_worker( + worker_name="SagaOrchestrator", + config_override="config.saga-orchestrator.toml", + container_factory=create_saga_orchestrator_container, + register_handlers=register_saga_subscriber, + on_startup=_on_startup, + on_shutdown=_on_shutdown, + ) if __name__ == "__main__": diff --git a/docs/architecture/event-system-design.md b/docs/architecture/event-system-design.md index 025abc47..9834ae89 100644 --- a/docs/architecture/event-system-design.md +++ b/docs/architecture/event-system-design.md @@ -41,6 +41,28 @@ The unified approach addresses these issues: - **Storage-ready**: Events include storage fields (`stored_at`, `ttl_expires_at`) that MongoDB uses - **1:1 topic mapping**: Topic name = `EventType` value — no mapping layer needed +## Shared field mixins + +Some event classes share identical field sets. Rather than duplicating fields, shared groups are extracted into Pydantic `BaseModel` mixins that concrete events inherit alongside `BaseEvent` via multiple inheritance: + +```python +class ExecutionSpec(BaseModel): + """Shared fields for execution request and pod command events.""" + execution_id: str + script: str + language: str + # ... resource limits, runtime config, priority + +class ExecutionRequestedEvent(BaseEvent, ExecutionSpec): + event_type: Literal[EventType.EXECUTION_REQUESTED] = EventType.EXECUTION_REQUESTED + +class CreatePodCommandEvent(BaseEvent, ExecutionSpec): + event_type: Literal[EventType.CREATE_POD_COMMAND] = EventType.CREATE_POD_COMMAND + saga_id: str +``` + +Mixins inherit from `BaseModel` (not `BaseEvent`) so they don't appear in `BaseEvent.__subclasses__()` and aren't flagged by the orphan-event-class test. Subclasses can override mixin fields with different defaults (e.g., `CreatePodCommandEvent` makes `runtime_command` optional with `Field(default_factory=list)`). + ## How discriminated unions work When events come back from MongoDB, we need to deserialize them into the correct Python class. A document with `event_type: "execution_completed"` should become an `ExecutionCompletedEvent` instance, not a generic dict. diff --git a/docs/architecture/lifecycle.md b/docs/architecture/lifecycle.md index 4d7f138d..273959f5 100644 --- a/docs/architecture/lifecycle.md +++ b/docs/architecture/lifecycle.md @@ -36,35 +36,48 @@ Explicit try/finally keeps teardown visible. The broker stops before the contain ## Worker entrypoints -Workers use FastStream's `on_startup` / `on_shutdown` callbacks instead of a context manager: +All six workers share a common bootstrap sequence: load settings, init logger, connect to MongoDB via Beanie, create a Dishka DI container, retrieve the Kafka broker, wire up Dishka, and run FastStream. The shared `run_worker()` function in `workers/bootstrap.py` handles this boilerplate. Each worker only provides its unique configuration and optional hooks: ```python -async def init_service() -> None: - worker = await container.get(SomeService) - # optional: set up APScheduler, pre-warm caches, etc. +from workers.bootstrap import run_worker + +def main() -> None: + run_worker( + worker_name="ResultProcessor", + config_override="config.result-processor.toml", + container_factory=create_result_processor_container, + register_handlers=register_result_processor_subscriber, + ) +``` + +Workers that need custom startup logic (APScheduler jobs, K8s pre-warming) pass `on_startup` and `on_shutdown` callbacks: -async def shutdown() -> None: - scheduler.shutdown(wait=False) # if APScheduler was used - await container.close() +```python +async def _on_startup(container, broker, logger) -> None: + service = await container.get(SomeService) + scheduler.add_job(service.periodic_task, trigger="interval", seconds=10, ...) + scheduler.start() -app = FastStream(broker, on_startup=[init_service], on_shutdown=[shutdown]) -await app.run() +async def _on_shutdown() -> None: + scheduler.shutdown(wait=False) + +run_worker(..., on_startup=_on_startup, on_shutdown=_on_shutdown) ``` -FastStream calls `on_startup` after the broker connects and `on_shutdown` when the process receives a signal. The worker blocks on `app.run()` until termination. +FastStream calls `on_startup` after the broker connects and `on_shutdown` when the process receives a signal. The worker blocks on `app.run()` until termination. The container is always closed last (appended automatically by `run_worker`). All six workers follow this pattern: | Worker | Startup hook | Has APScheduler | |--------|-------------|-----------------| -| k8s-worker | Pre-pull daemonset | No | +| k8s-worker | Pre-pull daemonset, namespace security | No | | pod-monitor | Start watch loop | Yes | | result-processor | — | No | | saga-orchestrator | Start saga scheduler | Yes | | dlq-processor | Start monitoring cycle | Yes | | event-replay | Start replay scheduler | Yes | -Workers without startup work pass only `on_shutdown=[container.close]`. +Workers without startup work pass only `register_handlers` and rely on `run_worker`'s default shutdown (container close). ## Why stateless services @@ -82,10 +95,12 @@ Write a plain class that takes its dependencies in `__init__` via Dishka. Expose ## Building new workers -Follow the existing pattern in `backend/workers/`: +Use the shared `run_worker()` function from `workers/bootstrap.py`: + +1. Create a container factory in `app/core/container.py` for the new worker +2. Create a `config..toml` override file +3. Optionally register Kafka subscriber handlers +4. Optionally define `on_startup` / `on_shutdown` callbacks for APScheduler or other setup +5. Call `run_worker()` with the worker name, config override, container factory, and optional hooks -1. Create a `Settings` with the worker's override TOML -2. Init Beanie, create the worker's DI container, get the broker -3. Register subscribers and set up Dishka on the broker -4. Define `init_*` and `shutdown` callbacks -5. Create `FastStream(broker, on_startup=[...], on_shutdown=[...])` and call `app.run()` +See any existing `workers/run_*.py` for a working example. The simplest is `run_result_processor.py` (no startup hooks), the most complex is `run_pod_monitor.py` (APScheduler + K8s watch loop). diff --git a/docs/architecture/services-overview.md b/docs/architecture/services-overview.md index b7858712..646495d0 100644 --- a/docs/architecture/services-overview.md +++ b/docs/architecture/services-overview.md @@ -48,7 +48,7 @@ The saga_service.py provides read-model access for saga state and guardrails lik ## Deployed workers -These services run outside the API container for isolation and horizontal scaling. Each has a small run_*.py entry and a dedicated Dockerfile in backend/workers/. +These services run outside the API container for isolation and horizontal scaling. Each has a small `run_*.py` entry in `backend/workers/` that calls the shared `run_worker()` bootstrap function from `workers/bootstrap.py`. The bootstrap handles Settings, Beanie, DI container, broker, and FastStream setup; each worker only provides its config override, container factory, and optional startup/shutdown hooks. The Saga Orchestrator is a stateful choreographer for execution lifecycle. It subscribes to EXECUTION_EVENTS and internal saga topics, manages the Redis-backed execution queue via `ExecutionQueueService` for priority scheduling and per-user limits, publishes SAGA_COMMANDS (CreatePodCommandEvent, DeletePodCommandEvent), rebuilds saga state from events, and issues commands only when transitions are valid and not yet executed. On failures, timeouts, or cancellations it publishes compensating commands and finalizes the saga. diff --git a/docs/components/workers/index.md b/docs/components/workers/index.md index 737ff096..b142f3e4 100644 --- a/docs/components/workers/index.md +++ b/docs/components/workers/index.md @@ -33,6 +33,10 @@ graph LR | [DLQ Processor](dlq_processor.md) | Retries failed messages from the dead letter queue | `run_dlq_processor.py` | All entry points live in [`backend/workers/`](https://github.com/HardMax71/Integr8sCode/tree/main/backend/workers). +The shared bootstrap logic (Settings, Beanie, DI container, broker, FastStream) is in +[`bootstrap.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/workers/bootstrap.py) — +each `run_*.py` calls `run_worker()` with its specific config and optional hooks. +See [Service Lifecycle](../../architecture/lifecycle.md#worker-entrypoints) for details. ## Running locally From 26299a33deea4ff003aef21bd0c466725e651a67 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 5 Mar 2026 00:29:32 +0100 Subject: [PATCH 2/7] chore: second group of duplications removed --- backend/app/db/docs/__init__.py | 4 -- backend/app/db/docs/user_settings.py | 31 ++------- backend/app/schemas_pydantic/user_settings.py | 16 +---- backend/app/services/user_settings_service.py | 11 +++ backend/tests/load/plot_report.py | 69 ++++++++++--------- 5 files changed, 53 insertions(+), 78 deletions(-) diff --git a/backend/app/db/docs/__init__.py b/backend/app/db/docs/__init__.py index 1b0c82c2..9cb52346 100644 --- a/backend/app/db/docs/__init__.py +++ b/backend/app/db/docs/__init__.py @@ -21,8 +21,6 @@ from app.db.docs.saved_script import SavedScriptDocument from app.db.docs.user import UserDocument from app.db.docs.user_settings import ( - EditorSettings, - NotificationSettings, UserSettingsDocument, UserSettingsSnapshotDocument, ) @@ -60,8 +58,6 @@ # User Settings "UserSettingsDocument", "UserSettingsSnapshotDocument", - "NotificationSettings", - "EditorSettings", # Saga "SagaDocument", # DLQ diff --git a/backend/app/db/docs/user_settings.py b/backend/app/db/docs/user_settings.py index 014094de..a53f9c09 100644 --- a/backend/app/db/docs/user_settings.py +++ b/backend/app/db/docs/user_settings.py @@ -2,16 +2,13 @@ from typing import Any from beanie import Document, Indexed -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field from app.domain.enums import NotificationChannel, Theme class NotificationSettings(BaseModel): - """User notification preferences (embedded document). - - Copied from user_settings.py NotificationSettings. - """ + """User notification preferences (embedded document).""" model_config = ConfigDict(from_attributes=True) @@ -23,10 +20,7 @@ class NotificationSettings(BaseModel): class EditorSettings(BaseModel): - """Code editor preferences (embedded document). - - Copied from user_settings.py EditorSettings. - """ + """Code editor preferences (embedded document).""" model_config = ConfigDict(from_attributes=True) @@ -36,26 +30,9 @@ class EditorSettings(BaseModel): word_wrap: bool = True show_line_numbers: bool = True - @field_validator("font_size") - @classmethod - def validate_font_size(cls, v: int) -> int: - if v < 8 or v > 32: - raise ValueError("Font size must be between 8 and 32") - return v - - @field_validator("tab_size") - @classmethod - def validate_tab_size(cls, v: int) -> int: - if v not in (2, 4, 8): - raise ValueError("Tab size must be 2, 4, or 8") - return v - class UserSettingsDocument(Document): - """Complete user settings model. - - Copied from UserSettings schema. - """ + """Complete user settings model.""" user_id: Indexed(str, unique=True) # type: ignore[valid-type] theme: Theme = Theme.AUTO diff --git a/backend/app/schemas_pydantic/user_settings.py b/backend/app/schemas_pydantic/user_settings.py index a2ee7e1c..29210b21 100644 --- a/backend/app/schemas_pydantic/user_settings.py +++ b/backend/app/schemas_pydantic/user_settings.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone from typing import Any -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field from app.domain.enums import EventType, NotificationChannel, Theme @@ -29,20 +29,6 @@ class EditorSettings(BaseModel): word_wrap: bool = True show_line_numbers: bool = True - @field_validator("font_size") - @classmethod - def validate_font_size(cls, v: int) -> int: - if v < 8 or v > 32: - raise ValueError("Font size must be between 8 and 32") - return v - - @field_validator("tab_size") - @classmethod - def validate_tab_size(cls, v: int) -> int: - if v not in (2, 4, 8): - raise ValueError("Tab size must be 2, 4, or 8") - return v - class UserSettings(BaseModel): """Complete user settings model""" diff --git a/backend/app/services/user_settings_service.py b/backend/app/services/user_settings_service.py index 82abf731..cb409e6f 100644 --- a/backend/app/services/user_settings_service.py +++ b/backend/app/services/user_settings_service.py @@ -8,6 +8,7 @@ from app.db import UserSettingsRepository from app.domain.enums import EventType, Theme from app.domain.events import EventMetadata, UserSettingsUpdatedEvent +from app.domain.exceptions import ValidationError from app.domain.user import ( DomainEditorSettings, DomainNotificationSettings, @@ -86,6 +87,9 @@ async def update_user_settings( self, user_id: str, updates: DomainUserSettingsUpdate, reason: str | None = None ) -> DomainUserSettings: """Upsert provided fields into current settings, publish minimal event, and cache.""" + if updates.editor is not None: + self._validate_editor_settings(updates.editor) + current = await self.get_user_settings(user_id) changes = {k: v for k, v in dataclasses.asdict(updates).items() if v is not None} @@ -218,6 +222,13 @@ def _build_settings(data: dict[str, Any]) -> DomainUserSettings: filtered["editor"] = DomainEditorSettings(**filtered["editor"]) return DomainUserSettings(**filtered) + @staticmethod + def _validate_editor_settings(editor: DomainEditorSettings) -> None: + if not 8 <= editor.font_size <= 32: + raise ValidationError("Font size must be between 8 and 32") + if editor.tab_size not in (2, 4, 8): + raise ValidationError("Tab size must be 2, 4, or 8") + async def invalidate_cache(self, user_id: str) -> None: """Invalidate cached settings for a user.""" if self._cache.pop(user_id, None) is not None: diff --git a/backend/tests/load/plot_report.py b/backend/tests/load/plot_report.py index b415e15e..95411e24 100644 --- a/backend/tests/load/plot_report.py +++ b/backend/tests/load/plot_report.py @@ -3,10 +3,13 @@ import argparse import json from pathlib import Path -from typing import Dict, List, Tuple, TypedDict +from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, TypedDict import matplotlib.pyplot as plt +if TYPE_CHECKING: + from matplotlib.axes import Axes + class LatencyStats(TypedDict, total=False): p50: int @@ -75,6 +78,28 @@ def _top_endpoints(report: ReportDict, top_n: int = 10) -> List[Tuple[str, Endpo return items[:top_n] +def _save_bar_chart( + labels: List[str], + title: str, + ylabel: str, + out_path: Path, + plot_bars: Callable[[Axes, range], None], +) -> Path: + x = range(len(labels)) + fig, ax = plt.subplots(figsize=(max(10, len(labels) * 0.6), 5)) + plot_bars(ax, x) + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.set_xticks(list(x)) + ax.set_xticklabels(labels, rotation=45, ha="right") + ax.grid(True, axis="y", alpha=0.3) + ax.legend() + fig.tight_layout() + fig.savefig(out_path) + plt.close(fig) + return out_path + + def plot_endpoint_latency(report: ReportDict, out_dir: Path, top_n: int = 10) -> Path: data = _top_endpoints(report, top_n) if not data: @@ -86,24 +111,14 @@ def plot_endpoint_latency(report: ReportDict, out_dir: Path, top_n: int = 10) -> p90 = [v.get("latency_ms_success", empty_latency).get("p90", 0) for _, v in data] p99 = [v.get("latency_ms_success", empty_latency).get("p99", 0) for _, v in data] - x = range(len(labels)) width = 0.25 - fig, ax = plt.subplots(figsize=(max(10, len(labels) * 0.6), 5)) - ax.bar([i - width for i in x], p50, width=width, label="p50", color="#22c55e") - ax.bar(x, p90, width=width, label="p90", color="#eab308") - ax.bar([i + width for i in x], p99, width=width, label="p99", color="#ef4444") - ax.set_ylabel("Latency (ms)") - ax.set_title("Success Latency by Endpoint (Top N)") - ax.set_xticks(list(x)) - ax.set_xticklabels(labels, rotation=45, ha="right") - ax.grid(True, axis="y", alpha=0.3) - ax.legend() - out_path = out_dir / "endpoint_latency.png" - fig.tight_layout() - fig.savefig(out_path) - plt.close(fig) - return out_path + def bars(ax: Axes, x: range) -> None: + ax.bar([i - width for i in x], p50, width=width, label="p50", color="#22c55e") + ax.bar(x, p90, width=width, label="p90", color="#eab308") + ax.bar([i + width for i in x], p99, width=width, label="p99", color="#ef4444") + + return _save_bar_chart(labels, "Success Latency by Endpoint (Top N)", "Latency (ms)", out_dir / "endpoint_latency.png", bars) def plot_endpoint_throughput(report: ReportDict, out_dir: Path, top_n: int = 10) -> Path: @@ -116,23 +131,13 @@ def plot_endpoint_throughput(report: ReportDict, out_dir: Path, top_n: int = 10) errors = [v.get("errors", 0) for _, v in data] successes = [t - e for t, e in zip(total, errors)] - x = range(len(labels)) width = 0.45 - fig, ax = plt.subplots(figsize=(max(10, len(labels) * 0.6), 5)) - ax.bar(x, successes, width=width, label="Success", color="#22c55e") - ax.bar(x, errors, width=width, bottom=successes, label="Errors", color="#ef4444") - ax.set_ylabel("Requests") - ax.set_title("Endpoint Throughput (Top N)") - ax.set_xticks(list(x)) - ax.set_xticklabels(labels, rotation=45, ha="right") - ax.grid(True, axis="y", alpha=0.3) - ax.legend() - out_path = out_dir / "endpoint_throughput.png" - fig.tight_layout() - fig.savefig(out_path) - plt.close(fig) - return out_path + def bars(ax: Axes, x: range) -> None: + ax.bar(x, successes, width=width, label="Success", color="#22c55e") + ax.bar(x, errors, width=width, bottom=successes, label="Errors", color="#ef4444") + + return _save_bar_chart(labels, "Endpoint Throughput (Top N)", "Requests", out_dir / "endpoint_throughput.png", bars) def generate_plots(report_path: str | Path, output_dir: str | Path | None = None) -> List[Path]: From bdc4e18a8369a621bd15e54dcd6a95f38f08d751 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 5 Mar 2026 00:56:53 +0100 Subject: [PATCH 3/7] chore: isolated frontend/unit tests --- frontend/vitest.config.ts | 2 +- frontend/vitest.setup.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/vitest.config.ts b/frontend/vitest.config.ts index e52fb2dc..54bbf668 100644 --- a/frontend/vitest.config.ts +++ b/frontend/vitest.config.ts @@ -13,7 +13,7 @@ export default defineConfig({ pool: 'threads', maxWorkers: 8, minWorkers: 2, - isolate: false, + isolate: true, css: false, testTimeout: 10_000, environment: 'jsdom', diff --git a/frontend/vitest.setup.ts b/frontend/vitest.setup.ts index c63ef93b..cb40d677 100644 --- a/frontend/vitest.setup.ts +++ b/frontend/vitest.setup.ts @@ -98,7 +98,7 @@ Element.prototype.animate = vi.fn().mockImplementation(() => { return mock as unknown as Animation; }); -// Reset storage and DOM between every test (required for isolate: false) +// Reset storage and DOM between every test beforeEach(() => { Object.keys(localStorageStore).forEach(key => delete localStorageStore[key]); Object.keys(sessionStorageStore).forEach(key => delete sessionStorageStore[key]); From 22b09bab9006755c318493b743d4fe1693986c06 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 5 Mar 2026 01:08:16 +0100 Subject: [PATCH 4/7] chore: in frontend tests, replaced relative imports with $ --- backend/workers/bootstrap.py | 1 + frontend/src/components/__tests__/Header.test.ts | 6 +++--- .../components/__tests__/NotificationCenter.test.ts | 6 +++--- frontend/src/lib/__tests__/user-settings.test.ts | 10 +++++----- .../src/routes/admin/__tests__/AdminEvents.test.ts | 6 +++--- .../src/routes/admin/__tests__/AdminExecutions.test.ts | 6 +++--- frontend/src/routes/admin/__tests__/AdminSagas.test.ts | 6 +++--- .../src/routes/admin/__tests__/AdminSettings.test.ts | 2 +- frontend/src/routes/admin/__tests__/AdminUsers.test.ts | 2 +- frontend/src/stores/__tests__/auth.test.ts | 6 +++--- frontend/src/stores/__tests__/theme.test.ts | 4 ++-- 11 files changed, 28 insertions(+), 27 deletions(-) diff --git a/backend/workers/bootstrap.py b/backend/workers/bootstrap.py index a2e3000c..43e39ffa 100644 --- a/backend/workers/bootstrap.py +++ b/backend/workers/bootstrap.py @@ -75,6 +75,7 @@ async def _run() -> None: if on_shutdown is not None: shutdown_hooks.append(on_shutdown) shutdown_hooks.append(container.close) + shutdown_hooks.append(client.aclose) app = FastStream(broker, on_startup=startup_hooks, on_shutdown=shutdown_hooks) await app.run() diff --git a/frontend/src/components/__tests__/Header.test.ts b/frontend/src/components/__tests__/Header.test.ts index e74cd8b2..52957ad9 100644 --- a/frontend/src/components/__tests__/Header.test.ts +++ b/frontend/src/components/__tests__/Header.test.ts @@ -22,12 +22,12 @@ const mocks = vi.hoisted(() => ({ mockToggleTheme: vi.fn(), })); -vi.mock('../../stores/auth.svelte', () => ({ +vi.mock('$stores/auth.svelte', () => ({ get authStore() { return mocks.mockAuthStore; }, })); -vi.mock('../../stores/theme.svelte', () => ({ +vi.mock('$stores/theme.svelte', () => ({ get themeStore() { return mocks.mockThemeStore; }, @@ -35,7 +35,7 @@ vi.mock('../../stores/theme.svelte', () => ({ return mocks.mockToggleTheme; }, })); -vi.mock('../NotificationCenter.svelte', () => { +vi.mock('$components/NotificationCenter.svelte', () => { const M = function () { return {}; } as any; diff --git a/frontend/src/components/__tests__/NotificationCenter.test.ts b/frontend/src/components/__tests__/NotificationCenter.test.ts index 52931f3e..08874fef 100644 --- a/frontend/src/components/__tests__/NotificationCenter.test.ts +++ b/frontend/src/components/__tests__/NotificationCenter.test.ts @@ -41,17 +41,17 @@ const mocks = vi.hoisted(() => ({ })); vi.mock('@mateothegreat/svelte5-router', () => ({ goto: mocks.mockGoto })); -vi.mock('../../stores/auth.svelte', () => ({ +vi.mock('$stores/auth.svelte', () => ({ get authStore() { return mocks.mockAuthStore; }, })); -vi.mock('../../stores/notificationStore.svelte', () => ({ +vi.mock('$stores/notificationStore.svelte', () => ({ get notificationStore() { return mocks.mockNotificationStore; }, })); -vi.mock('../../lib/notifications/stream.svelte', () => ({ +vi.mock('$lib/notifications/stream.svelte', () => ({ notificationStream: { connect: mocks.mockStreamConnect, disconnect: mocks.mockStreamDisconnect, diff --git a/frontend/src/lib/__tests__/user-settings.test.ts b/frontend/src/lib/__tests__/user-settings.test.ts index 98756932..58a1bba8 100644 --- a/frontend/src/lib/__tests__/user-settings.test.ts +++ b/frontend/src/lib/__tests__/user-settings.test.ts @@ -3,7 +3,7 @@ import { describe, it, expect, beforeEach, vi, afterEach } from 'vitest'; const mockGetUserSettings = vi.fn(); const mockUpdateUserSettings = vi.fn(); -vi.mock('../api', () => ({ +vi.mock('$lib/api', () => ({ getUserSettingsApiV1UserSettingsGet: (...args: unknown[]) => (mockGetUserSettings as (...a: unknown[]) => unknown)(...args), updateUserSettingsApiV1UserSettingsPut: (...args: unknown[]) => @@ -12,13 +12,13 @@ vi.mock('../api', () => ({ const mockSetUserSettings = vi.fn(); -vi.mock('../../stores/userSettings.svelte', () => ({ +vi.mock('$stores/userSettings.svelte', () => ({ setUserSettings: (settings: unknown) => mockSetUserSettings(settings), })); const mockSetTheme = vi.fn(); -vi.mock('../../stores/theme.svelte', () => ({ +vi.mock('$stores/theme.svelte', () => ({ setTheme: (theme: string) => mockSetTheme(theme), })); @@ -26,11 +26,11 @@ const mockAuthStore = { isAuthenticated: true as boolean | null, }; -vi.mock('../../stores/auth.svelte', () => ({ +vi.mock('$stores/auth.svelte', () => ({ authStore: mockAuthStore, })); -vi.mock('../api-interceptors', () => ({})); +vi.mock('$lib/api-interceptors', () => ({})); describe('user-settings', () => { beforeEach(async () => { diff --git a/frontend/src/routes/admin/__tests__/AdminEvents.test.ts b/frontend/src/routes/admin/__tests__/AdminEvents.test.ts index 5a642a4b..7c59022d 100644 --- a/frontend/src/routes/admin/__tests__/AdminEvents.test.ts +++ b/frontend/src/routes/admin/__tests__/AdminEvents.test.ts @@ -26,7 +26,7 @@ const mocks = vi.hoisted(() => ({ })); // Mock API module -vi.mock('../../../lib/api', () => ({ +vi.mock('$lib/api', () => ({ browseEventsApiV1AdminEventsBrowsePost: (...args: unknown[]) => mocks.browseEventsApiV1AdminEventsBrowsePost(...args), getEventStatsApiV1AdminEventsStatsGet: (...args: unknown[]) => mocks.getEventStatsApiV1AdminEventsStatsGet(...args), @@ -51,10 +51,10 @@ vi.mock('svelte-sonner', () => ({ }, })); -vi.mock('../../../lib/api-interceptors'); +vi.mock('$lib/api-interceptors'); // Simple mock for AdminLayout -vi.mock('../AdminLayout.svelte', async () => { +vi.mock('$routes/admin/AdminLayout.svelte', async () => { const { default: MockLayout } = await import('$routes/admin/__tests__/mocks/MockAdminLayout.svelte'); return { default: MockLayout }; }); diff --git a/frontend/src/routes/admin/__tests__/AdminExecutions.test.ts b/frontend/src/routes/admin/__tests__/AdminExecutions.test.ts index 75dc1844..8c5f9d52 100644 --- a/frontend/src/routes/admin/__tests__/AdminExecutions.test.ts +++ b/frontend/src/routes/admin/__tests__/AdminExecutions.test.ts @@ -9,7 +9,7 @@ const mocks = vi.hoisted(() => ({ addToast: vi.fn(), })); -vi.mock('../../../lib/api', () => ({ +vi.mock('$lib/api', () => ({ listExecutionsApiV1AdminExecutionsGet: (...args: unknown[]) => mocks.listExecutionsApiV1AdminExecutionsGet(...args), updatePriorityApiV1AdminExecutionsExecutionIdPriorityPut: (...args: unknown[]) => mocks.updatePriorityApiV1AdminExecutionsExecutionIdPriorityPut(...args), @@ -26,12 +26,12 @@ vi.mock('svelte-sonner', () => ({ }, })); -vi.mock('../../../lib/api-interceptors', async (importOriginal) => { +vi.mock('$lib/api-interceptors', async (importOriginal) => { const actual = (await importOriginal()) as Record; return { ...actual }; }); -vi.mock('../AdminLayout.svelte', async () => { +vi.mock('$routes/admin/AdminLayout.svelte', async () => { const { default: MockLayout } = await import('$routes/admin/__tests__/mocks/MockAdminLayout.svelte'); return { default: MockLayout }; }); diff --git a/frontend/src/routes/admin/__tests__/AdminSagas.test.ts b/frontend/src/routes/admin/__tests__/AdminSagas.test.ts index ca9294d2..d1d84295 100644 --- a/frontend/src/routes/admin/__tests__/AdminSagas.test.ts +++ b/frontend/src/routes/admin/__tests__/AdminSagas.test.ts @@ -8,13 +8,13 @@ const mocks = vi.hoisted(() => ({ getSagaStatusApiV1SagasSagaIdGet: vi.fn(), })); -vi.mock('../../../lib/api', () => ({ +vi.mock('$lib/api', () => ({ listSagasApiV1SagasGet: (...args: unknown[]) => mocks.listSagasApiV1SagasGet(...args), getSagaStatusApiV1SagasSagaIdGet: (...args: unknown[]) => mocks.getSagaStatusApiV1SagasSagaIdGet(...args), })); -vi.mock('../../../lib/api-interceptors'); -vi.mock('../AdminLayout.svelte', async () => { +vi.mock('$lib/api-interceptors'); +vi.mock('$routes/admin/AdminLayout.svelte', async () => { const { default: MockLayout } = await import('$routes/admin/__tests__/mocks/MockAdminLayout.svelte'); return { default: MockLayout }; }); diff --git a/frontend/src/routes/admin/__tests__/AdminSettings.test.ts b/frontend/src/routes/admin/__tests__/AdminSettings.test.ts index 5f076994..5d3ca5e9 100644 --- a/frontend/src/routes/admin/__tests__/AdminSettings.test.ts +++ b/frontend/src/routes/admin/__tests__/AdminSettings.test.ts @@ -34,7 +34,7 @@ const mocks = vi.hoisted(() => ({ }, })); -vi.mock('../../../lib/api', () => ({ +vi.mock('$lib/api', () => ({ getSystemSettingsApiV1AdminSettingsGet: (...args: unknown[]) => mocks.getSystemSettingsApiV1AdminSettingsGet(...args), updateSystemSettingsApiV1AdminSettingsPut: (...args: unknown[]) => diff --git a/frontend/src/routes/admin/__tests__/AdminUsers.test.ts b/frontend/src/routes/admin/__tests__/AdminUsers.test.ts index 16f889e5..4e248b23 100644 --- a/frontend/src/routes/admin/__tests__/AdminUsers.test.ts +++ b/frontend/src/routes/admin/__tests__/AdminUsers.test.ts @@ -43,7 +43,7 @@ vi.mock('$lib/formatters', () => ({ })); // Simple mock for AdminLayout that just renders children -vi.mock('../AdminLayout.svelte', async () => { +vi.mock('$routes/admin/AdminLayout.svelte', async () => { const { default: MockLayout } = await import('$routes/admin/__tests__/mocks/MockAdminLayout.svelte'); return { default: MockLayout }; }); diff --git a/frontend/src/stores/__tests__/auth.test.ts b/frontend/src/stores/__tests__/auth.test.ts index 14ed36c4..b7247e51 100644 --- a/frontend/src/stores/__tests__/auth.test.ts +++ b/frontend/src/stores/__tests__/auth.test.ts @@ -5,7 +5,7 @@ const mockLoginApi = vi.fn(); const mockLogoutApi = vi.fn(); const mockGetProfileApi = vi.fn(); -vi.mock('../../lib/api', () => ({ +vi.mock('$lib/api', () => ({ loginApiV1AuthLoginPost: (...args: unknown[]) => mockLoginApi(...args), logoutApiV1AuthLogoutPost: (...args: unknown[]) => mockLogoutApi(...args), getCurrentUserProfileApiV1AuthMeGet: (...args: unknown[]) => mockGetProfileApi(...args), @@ -13,7 +13,7 @@ vi.mock('../../lib/api', () => ({ // Mock clearUserSettings (static import in auth.svelte.ts) const mockClearUserSettings = vi.fn(); -vi.mock('../userSettings.svelte', () => ({ +vi.mock('$stores/userSettings.svelte', () => ({ clearUserSettings: () => mockClearUserSettings(), setUserSettings: vi.fn(), userSettingsStore: { settings: null, editorSettings: {} }, @@ -21,7 +21,7 @@ vi.mock('../userSettings.svelte', () => ({ // Mock loadUserSettings (dynamic import in auth.svelte.ts) const mockLoadUserSettings = vi.fn(); -vi.mock('../../lib/user-settings', () => ({ +vi.mock('$lib/user-settings', () => ({ loadUserSettings: () => mockLoadUserSettings(), })); diff --git a/frontend/src/stores/__tests__/theme.test.ts b/frontend/src/stores/__tests__/theme.test.ts index da7d65b0..643db4d6 100644 --- a/frontend/src/stores/__tests__/theme.test.ts +++ b/frontend/src/stores/__tests__/theme.test.ts @@ -1,11 +1,11 @@ import { describe, it, expect, beforeEach, vi, afterEach } from 'vitest'; // Mock the dynamic imports before importing the theme module -vi.mock('../../lib/user-settings', () => ({ +vi.mock('$lib/user-settings', () => ({ saveUserSettings: vi.fn().mockResolvedValue(true), })); -vi.mock('../auth.svelte', () => ({ +vi.mock('$stores/auth.svelte', () => ({ authStore: { isAuthenticated: false, }, From 659b46f994605852ac009d102a3a50cef04a3d11 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 5 Mar 2026 01:27:19 +0100 Subject: [PATCH 5/7] chore: in frontend, xxStore files: moving autoRefresh as an optional functionality instead of starting timers all the time --- .../stores/__tests__/eventsStore.test.ts | 11 ++++++++-- .../stores/__tests__/executionsStore.test.ts | 11 +--------- .../admin/stores/__tests__/sagasStore.test.ts | 13 ++++++++--- .../lib/admin/stores/eventsStore.svelte.ts | 20 +++++++++-------- .../admin/stores/executionsStore.svelte.ts | 20 +++++++++-------- .../src/lib/admin/stores/sagasStore.svelte.ts | 22 ++++++++++--------- 6 files changed, 54 insertions(+), 43 deletions(-) diff --git a/frontend/src/lib/admin/stores/__tests__/eventsStore.test.ts b/frontend/src/lib/admin/stores/__tests__/eventsStore.test.ts index b3f73353..19fd2645 100644 --- a/frontend/src/lib/admin/stores/__tests__/eventsStore.test.ts +++ b/frontend/src/lib/admin/stores/__tests__/eventsStore.test.ts @@ -99,6 +99,12 @@ describe('EventsStore', () => { }); function createStore() { + teardown = effect_root(() => { + store = createEventsStore({ autoRefresh: false }); + }); + } + + function createStoreWithAutoRefresh() { teardown = effect_root(() => { store = createEventsStore(); }); @@ -108,6 +114,7 @@ describe('EventsStore', () => { vi.unstubAllGlobals(); store?.cleanup(); teardown?.(); + vi.clearAllTimers(); }); describe('initial state', () => { @@ -386,7 +393,7 @@ describe('EventsStore', () => { describe('auto-refresh', () => { it('fires loadAll on 30s interval', async () => { - createStore(); + createStoreWithAutoRefresh(); vi.clearAllMocks(); await vi.advanceTimersByTimeAsync(30000); @@ -397,7 +404,7 @@ describe('EventsStore', () => { }); it('stops on teardown', async () => { - createStore(); + createStoreWithAutoRefresh(); await vi.advanceTimersByTimeAsync(30000); expect(mocks.browseEventsApiV1AdminEventsBrowsePost).toHaveBeenCalled(); diff --git a/frontend/src/lib/admin/stores/__tests__/executionsStore.test.ts b/frontend/src/lib/admin/stores/__tests__/executionsStore.test.ts index df292230..bf4d2c91 100644 --- a/frontend/src/lib/admin/stores/__tests__/executionsStore.test.ts +++ b/frontend/src/lib/admin/stores/__tests__/executionsStore.test.ts @@ -56,19 +56,10 @@ describe('ExecutionsStore', () => { setupDefaultMocks(); }); - /** - * Creates a store with auto-refresh timers immediately cleared. - * The $effect fires synchronously inside effect_root, starting a - * setInterval. We clear all timers, reset mocks, and re-apply defaults - * so individual tests control timing explicitly. - */ function createStore() { teardown = effect_root(() => { - store = createExecutionsStore(); + store = createExecutionsStore({ autoRefresh: false }); }); - vi.clearAllTimers(); - vi.clearAllMocks(); - setupDefaultMocks(); } function createStoreWithAutoRefresh() { diff --git a/frontend/src/lib/admin/stores/__tests__/sagasStore.test.ts b/frontend/src/lib/admin/stores/__tests__/sagasStore.test.ts index 13ff151d..030b423a 100644 --- a/frontend/src/lib/admin/stores/__tests__/sagasStore.test.ts +++ b/frontend/src/lib/admin/stores/__tests__/sagasStore.test.ts @@ -29,6 +29,12 @@ describe('SagasStore', () => { }); function createStore() { + teardown = effect_root(() => { + store = createSagasStore({ autoRefresh: false }); + }); + } + + function createStoreWithAutoRefresh() { teardown = effect_root(() => { store = createSagasStore(); }); @@ -36,6 +42,7 @@ describe('SagasStore', () => { afterEach(() => { teardown?.(); + vi.clearAllTimers(); }); describe('initial state', () => { @@ -190,7 +197,7 @@ describe('SagasStore', () => { describe('auto-refresh', () => { it('fires loadSagas on interval', async () => { - createStore(); + createStoreWithAutoRefresh(); vi.clearAllMocks(); await vi.advanceTimersByTimeAsync(5000); @@ -206,7 +213,7 @@ describe('SagasStore', () => { data: { sagas, total: 1 }, }); - createStore(); + createStoreWithAutoRefresh(); await store.loadExecutionSagas('exec-target'); vi.clearAllMocks(); @@ -224,7 +231,7 @@ describe('SagasStore', () => { }); it('stops when refreshEnabled set to false', async () => { - createStore(); + createStoreWithAutoRefresh(); await vi.advanceTimersByTimeAsync(5000); expect(mocks.listSagasApiV1SagasGet).toHaveBeenCalled(); diff --git a/frontend/src/lib/admin/stores/eventsStore.svelte.ts b/frontend/src/lib/admin/stores/eventsStore.svelte.ts index ee2715c7..9fe9c964 100644 --- a/frontend/src/lib/admin/stores/eventsStore.svelte.ts +++ b/frontend/src/lib/admin/stores/eventsStore.svelte.ts @@ -39,13 +39,15 @@ class EventsStore { pagination = createPaginationState({ initialPageSize: 10 }); - constructor() { - $effect(() => { - const id = setInterval(() => this.loadAll(), 30_000); - return () => { - clearInterval(id); - }; - }); + constructor({ autoRefresh = true }: { autoRefresh?: boolean } = {}) { + if (autoRefresh) { + $effect(() => { + const id = setInterval(() => this.loadAll(), 30_000); + return () => { + clearInterval(id); + }; + }); + } } async loadAll(): Promise { @@ -235,6 +237,6 @@ class EventsStore { } } -export function createEventsStore(): EventsStore { - return new EventsStore(); +export function createEventsStore(options?: { autoRefresh?: boolean }): EventsStore { + return new EventsStore(options); } diff --git a/frontend/src/lib/admin/stores/executionsStore.svelte.ts b/frontend/src/lib/admin/stores/executionsStore.svelte.ts index aed104df..d7e9cf07 100644 --- a/frontend/src/lib/admin/stores/executionsStore.svelte.ts +++ b/frontend/src/lib/admin/stores/executionsStore.svelte.ts @@ -23,13 +23,15 @@ class ExecutionsStore { pagination = createPaginationState({ initialPageSize: 20 }); - constructor() { - $effect(() => { - const id = setInterval(() => this.loadData(), 5000); - return () => { - clearInterval(id); - }; - }); + constructor({ autoRefresh = true }: { autoRefresh?: boolean } = {}) { + if (autoRefresh) { + $effect(() => { + const id = setInterval(() => this.loadData(), 5000); + return () => { + clearInterval(id); + }; + }); + } } async loadData(): Promise { @@ -80,6 +82,6 @@ class ExecutionsStore { } } -export function createExecutionsStore(): ExecutionsStore { - return new ExecutionsStore(); +export function createExecutionsStore(options?: { autoRefresh?: boolean }): ExecutionsStore { + return new ExecutionsStore(options); } diff --git a/frontend/src/lib/admin/stores/sagasStore.svelte.ts b/frontend/src/lib/admin/stores/sagasStore.svelte.ts index 376f928d..557c6488 100644 --- a/frontend/src/lib/admin/stores/sagasStore.svelte.ts +++ b/frontend/src/lib/admin/stores/sagasStore.svelte.ts @@ -20,14 +20,16 @@ class SagasStore { pagination = createPaginationState({ initialPageSize: 10 }); - constructor() { - $effect(() => { - if (!this.refreshEnabled) return; - const id = setInterval(() => this.loadSagas(), this.refreshRate * 1000); - return () => { - clearInterval(id); - }; - }); + constructor({ autoRefresh = true }: { autoRefresh?: boolean } = {}) { + if (autoRefresh) { + $effect(() => { + if (!this.refreshEnabled) return; + const id = setInterval(() => this.loadSagas(), this.refreshRate * 1000); + return () => { + clearInterval(id); + }; + }); + } } async loadSagas(): Promise { @@ -77,6 +79,6 @@ class SagasStore { } } -export function createSagasStore(): SagasStore { - return new SagasStore(); +export function createSagasStore(options?: { autoRefresh?: boolean }): SagasStore { + return new SagasStore(options); } From 99f2c0fc58d10d75835353e54d7da8df61e9707b Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 5 Mar 2026 12:48:01 +0100 Subject: [PATCH 6/7] chore: streamlined sse service, no sse bus --- backend/app/core/providers.py | 30 ++-- backend/app/events/handlers.py | 6 +- .../services/event_replay/replay_service.py | 8 +- backend/app/services/notification_service.py | 8 +- backend/app/services/sse/__init__.py | 3 +- backend/app/services/sse/redis_bus.py | 98 ----------- backend/app/services/sse/sse_service.py | 123 +++++++++---- backend/tests/e2e/conftest.py | 40 +++-- .../tests/e2e/core/test_dishka_lifespan.py | 10 +- .../notifications/test_notification_sse.py | 19 +- .../sse/test_partitioned_event_router.py | 30 ++-- .../tests/unit/services/sse/test_redis_bus.py | 164 ----------------- .../unit/services/sse/test_sse_publish.py | 149 ++++++++++++++++ .../unit/services/sse/test_sse_service.py | 166 ++++++++---------- docs/architecture/lifecycle.md | 5 +- 15 files changed, 397 insertions(+), 462 deletions(-) delete mode 100644 backend/app/services/sse/redis_bus.py delete mode 100644 backend/tests/unit/services/sse/test_redis_bus.py create mode 100644 backend/tests/unit/services/sse/test_sse_publish.py diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index b4219689..5eb21237 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -63,7 +63,7 @@ from app.services.runtime_settings import RuntimeSettingsLoader from app.services.saga import SagaOrchestrator, SagaService from app.services.saved_script_service import SavedScriptService -from app.services.sse import SSERedisBus, SSEService +from app.services.sse import SSEService from app.services.user_settings_service import UserSettingsService from app.settings import Settings @@ -131,6 +131,7 @@ async def get_redis_client( decode_responses=settings.REDIS_DECODE_RESPONSES, socket_connect_timeout=5, socket_timeout=5, + socket_keepalive=True, ) # Test connection await client.ping() # type: ignore[misc] # redis-py returns Awaitable[bool] | bool @@ -351,22 +352,19 @@ class SSEProvider(Provider): scope = Scope.APP @provide - def get_sse_redis_bus( + def get_sse_service( self, redis_client: redis.Redis, + execution_repository: ExecutionRepository, logger: structlog.stdlib.BoundLogger, connection_metrics: ConnectionMetrics, - ) -> SSERedisBus: - return SSERedisBus(redis_client, logger, connection_metrics) - - @provide - def get_sse_service( - self, - bus: SSERedisBus, - execution_repository: ExecutionRepository, - logger: structlog.stdlib.BoundLogger, ) -> SSEService: - return SSEService(bus=bus, execution_repository=execution_repository, logger=logger) + return SSEService( + redis_client=redis_client, + execution_repository=execution_repository, + logger=logger, + connection_metrics=connection_metrics, + ) class AuthProvider(Provider): @@ -483,7 +481,7 @@ def get_notification_service( self, notification_repository: NotificationRepository, kafka_event_service: KafkaEventService, - sse_redis_bus: SSERedisBus, + sse_service: SSEService, settings: Settings, logger: structlog.stdlib.BoundLogger, notification_metrics: NotificationMetrics, @@ -491,7 +489,7 @@ def get_notification_service( return NotificationService( notification_repository=notification_repository, event_service=kafka_event_service, - sse_bus=sse_redis_bus, + sse_service=sse_service, settings=settings, logger=logger, notification_metrics=notification_metrics, @@ -731,12 +729,12 @@ def get_event_replay_service( kafka_producer: UnifiedProducer, replay_metrics: ReplayMetrics, logger: structlog.stdlib.BoundLogger, - sse_bus: SSERedisBus, + sse_service: SSEService, ) -> EventReplayService: return EventReplayService( repository=replay_repository, producer=kafka_producer, replay_metrics=replay_metrics, logger=logger, - sse_bus=sse_bus, + sse_service=sse_service, ) diff --git a/backend/app/events/handlers.py b/backend/app/events/handlers.py index b2af375b..94406aff 100644 --- a/backend/app/events/handlers.py +++ b/backend/app/events/handlers.py @@ -25,7 +25,7 @@ from app.services.notification_service import NotificationService from app.services.result_processor import ResultProcessor from app.services.saga import SagaOrchestrator -from app.services.sse import SSERedisBus +from app.services.sse import SSEService from app.settings import Settings _sse_field_names: frozenset[str] = frozenset(f.name for f in dataclasses.fields(SSEExecutionEventData)) @@ -261,14 +261,14 @@ def register_sse_subscriber(broker: KafkaBroker, settings: Settings) -> None: ) async def on_sse_event( body: DomainEvent, - sse_bus: FromDishka[SSERedisBus], + sse_service: FromDishka[SSEService], ) -> None: execution_id = getattr(body, "execution_id", None) if execution_id: sse_data = SSEExecutionEventData(**{ k: v for k, v in body.model_dump().items() if k in _sse_field_names }) - await sse_bus.publish_event(execution_id, sse_data) + await sse_service.publish_event(execution_id, sse_data) def register_notification_subscriber(broker: KafkaBroker) -> None: diff --git a/backend/app/services/event_replay/replay_service.py b/backend/app/services/event_replay/replay_service.py index 3e6316ef..09602edd 100644 --- a/backend/app/services/event_replay/replay_service.py +++ b/backend/app/services/event_replay/replay_service.py @@ -27,7 +27,7 @@ ) from app.domain.sse import DomainReplaySSEPayload from app.events import UnifiedProducer -from app.services.sse.redis_bus import SSERedisBus +from app.services.sse import SSEService class EventReplayService: @@ -37,7 +37,7 @@ def __init__( producer: UnifiedProducer, replay_metrics: ReplayMetrics, logger: structlog.stdlib.BoundLogger, - sse_bus: SSERedisBus, + sse_service: SSEService, ) -> None: self._sessions: dict[str, ReplaySessionState] = {} self._schedulers: dict[str, AsyncIOScheduler] = {} @@ -49,7 +49,7 @@ def __init__( self.logger = logger self._file_locks: dict[str, asyncio.Lock] = {} self._metrics = replay_metrics - self._sse_bus = sse_bus + self._sse_service = sse_service async def create_session_from_config(self, config: ReplayConfig) -> ReplayOperationResult: try: @@ -429,6 +429,6 @@ async def _publish_replay_status(self, session: ReplaySessionState) -> None: completed_at=session.completed_at, errors=session.errors, ) - await self._sse_bus.publish_replay_status(session.session_id, payload) + await self._sse_service.publish_replay_status(session.session_id, payload) except Exception as e: self.logger.error("Failed to publish replay status to SSE", error=str(e)) diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 52df3a92..fbc4efd5 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -33,7 +33,7 @@ ) from app.domain.sse import DomainNotificationSSEPayload from app.services.kafka_event_service import KafkaEventService -from app.services.sse import SSERedisBus +from app.services.sse import SSEService from app.settings import Settings # Constants @@ -100,7 +100,7 @@ def __init__( self, notification_repository: NotificationRepository, event_service: KafkaEventService, - sse_bus: SSERedisBus, + sse_service: SSEService, settings: Settings, logger: structlog.stdlib.BoundLogger, notification_metrics: NotificationMetrics, @@ -109,7 +109,7 @@ def __init__( self.event_service = event_service self.metrics = notification_metrics self.settings = settings - self.sse_bus = sse_bus + self.sse_service = sse_service self.logger = logger self._throttle_cache = ThrottleCache() @@ -578,7 +578,7 @@ async def _publish_notification_sse(self, notification: DomainNotification) -> N action_url=notification.action_url, created_at=notification.created_at, ) - await self.sse_bus.publish_notification(notification.user_id, payload) + await self.sse_service.publish_notification(notification.user_id, payload) # --8<-- [start:should_skip_notification] async def _should_skip_notification( diff --git a/backend/app/services/sse/__init__.py b/backend/app/services/sse/__init__.py index cd9f7bac..afdb5a71 100644 --- a/backend/app/services/sse/__init__.py +++ b/backend/app/services/sse/__init__.py @@ -1,4 +1,3 @@ -from app.services.sse.redis_bus import SSERedisBus from app.services.sse.sse_service import SSEService -__all__ = ["SSERedisBus", "SSEService"] +__all__ = ["SSEService"] diff --git a/backend/app/services/sse/redis_bus.py b/backend/app/services/sse/redis_bus.py deleted file mode 100644 index 7141f516..00000000 --- a/backend/app/services/sse/redis_bus.py +++ /dev/null @@ -1,98 +0,0 @@ -from __future__ import annotations - -from collections.abc import AsyncGenerator -from datetime import datetime, timezone - -import redis.asyncio as redis -import structlog -from pydantic import TypeAdapter - -from app.core.metrics import ConnectionMetrics -from app.domain.sse import DomainNotificationSSEPayload, DomainReplaySSEPayload, SSEExecutionEventData - -_sse_event_adapter = TypeAdapter(SSEExecutionEventData) -_notif_payload_adapter = TypeAdapter(DomainNotificationSSEPayload) -_replay_adapter = TypeAdapter(DomainReplaySSEPayload) - - -class SSERedisBus: - """Redis-backed pub/sub bus for SSE event fan-out across workers.""" - - def __init__( - self, - redis_client: redis.Redis, - logger: structlog.stdlib.BoundLogger, - connection_metrics: ConnectionMetrics, - exec_prefix: str = "sse:exec:", - notif_prefix: str = "sse:notif:", - replay_prefix: str = "sse:replay:", - ) -> None: - self._redis = redis_client - self.logger = logger - self._metrics = connection_metrics - self._exec_prefix = exec_prefix - self._notif_prefix = notif_prefix - self._replay_prefix = replay_prefix - - def _exec_channel(self, execution_id: str) -> str: - return f"{self._exec_prefix}{execution_id}" - - def _notif_channel(self, user_id: str) -> str: - return f"{self._notif_prefix}{user_id}" - - async def publish_event(self, execution_id: str, event: SSEExecutionEventData) -> None: - await self._redis.publish(self._exec_channel(execution_id), _sse_event_adapter.dump_json(event)) - - async def publish_notification(self, user_id: str, notification: DomainNotificationSSEPayload) -> None: - await self._redis.publish(self._notif_channel(user_id), _notif_payload_adapter.dump_json(notification)) - - async def listen_execution(self, execution_id: str) -> AsyncGenerator[SSEExecutionEventData, None]: - start = datetime.now(timezone.utc) - self._metrics.increment_sse_connections("executions") - self.logger.info("SSE execution stream opened", execution_id=execution_id) - try: - async with self._redis.pubsub(ignore_subscribe_messages=True) as pubsub: - await pubsub.subscribe(self._exec_channel(execution_id)) - async for message in pubsub.listen(): - yield _sse_event_adapter.validate_json(message["data"]) - finally: - duration = (datetime.now(timezone.utc) - start).total_seconds() - self._metrics.record_sse_connection_duration(duration, "executions") - self._metrics.decrement_sse_connections("executions") - self.logger.info("SSE execution stream closed", execution_id=execution_id) - - async def listen_notifications(self, user_id: str) -> AsyncGenerator[DomainNotificationSSEPayload, None]: - start = datetime.now(timezone.utc) - self._metrics.increment_sse_connections("notifications") - self.logger.info("SSE notification stream opened", user_id=user_id) - try: - async with self._redis.pubsub(ignore_subscribe_messages=True) as pubsub: - await pubsub.subscribe(self._notif_channel(user_id)) - async for message in pubsub.listen(): - yield _notif_payload_adapter.validate_json(message["data"]) - finally: - duration = (datetime.now(timezone.utc) - start).total_seconds() - self._metrics.record_sse_connection_duration(duration, "notifications") - self._metrics.decrement_sse_connections("notifications") - self.logger.info("SSE notification stream closed", user_id=user_id) - - def _replay_channel(self, session_id: str) -> str: - return f"{self._replay_prefix}{session_id}" - - async def publish_replay_status(self, session_id: str, status: DomainReplaySSEPayload) -> None: - await self._redis.publish(self._replay_channel(session_id), _replay_adapter.dump_json(status)) - - async def listen_replay(self, session_id: str) -> AsyncGenerator[DomainReplaySSEPayload, None]: - start = datetime.now(timezone.utc) - self._metrics.increment_sse_connections("replay") - self.logger.info("SSE replay stream opened", session_id=session_id) - try: - async with self._redis.pubsub(ignore_subscribe_messages=True) as pubsub: - await pubsub.subscribe(self._replay_channel(session_id)) - async for message in pubsub.listen(): - yield _replay_adapter.validate_json(message["data"]) - finally: - duration = (datetime.now(timezone.utc) - start).total_seconds() - self._metrics.record_sse_connection_duration(duration, "replay") - self._metrics.decrement_sse_connections("replay") - self.logger.info("SSE replay stream closed", session_id=session_id) diff --git a/backend/app/services/sse/sse_service.py b/backend/app/services/sse/sse_service.py index 14a98c93..97164b48 100644 --- a/backend/app/services/sse/sse_service.py +++ b/backend/app/services/sse/sse_service.py @@ -1,10 +1,14 @@ +import asyncio import dataclasses from collections.abc import AsyncGenerator -from typing import Any +from datetime import datetime, timezone +from typing import Any, TypeVar +import redis.asyncio as redis import structlog from pydantic import TypeAdapter +from app.core.metrics import ConnectionMetrics from app.db import ExecutionRepository from app.domain.enums import EventType, SSEControlEvent, UserRole from app.domain.exceptions import ForbiddenError @@ -12,35 +16,72 @@ from app.domain.execution.models import DomainExecution from app.domain.sse import DomainNotificationSSEPayload, DomainReplaySSEPayload, SSEExecutionEventData from app.infrastructure.kafka.topics import EXECUTION_PIPELINE_TERMINAL_EVENT_TYPES -from app.services.sse.redis_bus import SSERedisBus _exec_adapter = TypeAdapter(SSEExecutionEventData) _notif_adapter = TypeAdapter(DomainNotificationSSEPayload) _replay_adapter = TypeAdapter(DomainReplaySSEPayload) +T = TypeVar("T") + class SSEService: - """SSE service — transforms bus events and DB state into SSE wire format.""" + """SSE service — publishes events via Redis Streams and transforms them into SSE wire format.""" + + _MAXLEN = 100 + _STREAM_TTL = 600 def __init__( self, - bus: SSERedisBus, + redis_client: redis.Redis, execution_repository: ExecutionRepository, logger: structlog.stdlib.BoundLogger, + connection_metrics: ConnectionMetrics, + exec_prefix: str = "sse:exec:", + notif_prefix: str = "sse:notif:", + replay_prefix: str = "sse:replay:", + poll_interval: float = 0.5, ) -> None: - self._bus = bus + self._redis = redis_client self._execution_repository = execution_repository self._logger = logger + self._metrics = connection_metrics + self._exec_prefix = exec_prefix + self._notif_prefix = notif_prefix + self._replay_prefix = replay_prefix + self._poll_interval = poll_interval + + async def _xpublish(self, key: str, data: bytes) -> None: + await self._redis.xadd(key, {"d": data}, maxlen=self._MAXLEN, approximate=True) + await self._redis.expire(key, self._STREAM_TTL) + + async def publish_event(self, execution_id: str, event: SSEExecutionEventData) -> None: + await self._xpublish(f"{self._exec_prefix}{execution_id}", _exec_adapter.dump_json(event)) + + async def publish_notification(self, user_id: str, notification: DomainNotificationSSEPayload) -> None: + await self._xpublish(f"{self._notif_prefix}{user_id}", _notif_adapter.dump_json(notification)) + + async def publish_replay_status(self, session_id: str, status: DomainReplaySSEPayload) -> None: + await self._xpublish(f"{self._replay_prefix}{session_id}", _replay_adapter.dump_json(status)) + + async def _read_after(self, key: str, last_id: str) -> list[tuple[str, bytes]]: + result = await self._redis.xread({key: last_id}, count=100) + if not result: + return [] + return [(mid, fields[b"d"]) for _, msgs in result for mid, fields in msgs] + + async def _poll_stream(self, key: str, adapter: TypeAdapter[T]) -> AsyncGenerator[T, None]: + last_id = "0-0" + while True: + batch = await self._read_after(key, last_id) + for msg_id, raw in batch: + last_id = msg_id + yield adapter.validate_json(raw) + if not batch: + await asyncio.sleep(self._poll_interval) async def create_execution_stream( - self, execution_id: str, user_id: str, user_role: UserRole + self, execution_id: str, user_id: str, user_role: UserRole, ) -> AsyncGenerator[dict[str, Any], None]: - """Eagerly validate access then return the event stream generator. - - Raises ExecutionNotFoundError or ForbiddenError before any streaming - begins, so FastAPI's exception handlers can return a proper HTTP error - response instead of crashing inside the SSE task group. - """ execution = await self._execution_repository.get_execution(execution_id) if not execution: raise ExecutionNotFoundError(execution_id) @@ -49,45 +90,55 @@ async def create_execution_stream( return self._execution_pipeline(execution) async def _execution_pipeline( - self, execution: DomainExecution + self, execution: DomainExecution, ) -> AsyncGenerator[dict[str, Any], None]: execution_id = execution.execution_id - yield {"data": _exec_adapter.dump_json(SSEExecutionEventData( - event_type=SSEControlEvent.STATUS, - execution_id=execution_id, - timestamp=execution.updated_at, - status=execution.status, - )).decode()} - async for event in self._bus.listen_execution(execution_id): - if event.event_type == EventType.RESULT_STORED: - result = await self._execution_repository.get_execution_result(execution_id) - event = dataclasses.replace(event, result=result) - self._logger.info("SSE event", execution_id=execution_id, event_type=event.event_type) - yield {"data": _exec_adapter.dump_json(event).decode()} - if event.event_type in EXECUTION_PIPELINE_TERMINAL_EVENT_TYPES: - return + start = datetime.now(timezone.utc) + self._metrics.increment_sse_connections("executions") + try: + yield {"data": _exec_adapter.dump_json(SSEExecutionEventData( + event_type=SSEControlEvent.STATUS, + execution_id=execution_id, + timestamp=execution.updated_at, + status=execution.status, + )).decode()} + async for event in self._poll_stream(f"{self._exec_prefix}{execution_id}", _exec_adapter): + if event.event_type == EventType.RESULT_STORED: + result = await self._execution_repository.get_execution_result(execution_id) + event = dataclasses.replace(event, result=result) + self._logger.info("SSE event", execution_id=execution_id, event_type=event.event_type) + yield {"data": _exec_adapter.dump_json(event).decode()} + if event.event_type in EXECUTION_PIPELINE_TERMINAL_EVENT_TYPES: + return + finally: + duration = (datetime.now(timezone.utc) - start).total_seconds() + self._metrics.record_sse_connection_duration(duration, "executions") + self._metrics.decrement_sse_connections("executions") async def create_notification_stream(self, user_id: str) -> AsyncGenerator[dict[str, Any], None]: - async for payload in self._bus.listen_notifications(user_id): - yield {"event": "notification", "data": _notif_adapter.dump_json(payload).decode()} + start = datetime.now(timezone.utc) + self._metrics.increment_sse_connections("notifications") + try: + async for payload in self._poll_stream(f"{self._notif_prefix}{user_id}", _notif_adapter): + yield {"event": "notification", "data": _notif_adapter.dump_json(payload).decode()} + finally: + duration = (datetime.now(timezone.utc) - start).total_seconds() + self._metrics.record_sse_connection_duration(duration, "notifications") + self._metrics.decrement_sse_connections("notifications") async def create_replay_stream( - self, initial_status: DomainReplaySSEPayload + self, initial_status: DomainReplaySSEPayload, ) -> AsyncGenerator[dict[str, Any], None]: - """Return the replay event stream generator. - - Caller (route) handles validation and initial DB fetch. - """ return self._replay_pipeline(initial_status) async def _replay_pipeline( - self, initial_status: DomainReplaySSEPayload + self, initial_status: DomainReplaySSEPayload, ) -> AsyncGenerator[dict[str, Any], None]: session_id = initial_status.session_id yield {"data": _replay_adapter.dump_json(initial_status).decode()} if initial_status.status.is_terminal: return - async for status in self._bus.listen_replay(session_id): + async for status in self._poll_stream(f"{self._replay_prefix}{session_id}", _replay_adapter): self._logger.info("SSE replay event", session_id=session_id, status=status.status) yield {"data": _replay_adapter.dump_json(status).decode()} if status.status.is_terminal: diff --git a/backend/tests/e2e/conftest.py b/backend/tests/e2e/conftest.py index f9cd621f..0fecec12 100644 --- a/backend/tests/e2e/conftest.py +++ b/backend/tests/e2e/conftest.py @@ -19,6 +19,17 @@ _sse_adapter = TypeAdapter(SSEExecutionEventData) +async def _read_stream( + redis_client: redis.Redis, + key: str, + after: str = "0-0", +) -> list[tuple[str, dict[bytes, bytes]]]: + result = await redis_client.xread({key: after}, count=100) + if not result: + return [] + return [(mid, fields) for _, msgs in result for mid, fields in msgs] + + async def wait_for_sse_event( redis_client: redis.Redis, execution_id: str, @@ -26,19 +37,20 @@ async def wait_for_sse_event( *, timeout: float = 15.0, ) -> SSEExecutionEventData: - """Subscribe to execution's Redis SSE channel and await matching event. + """Poll execution's Redis Stream and await matching event. - The SSE bridge publishes all execution lifecycle events to - sse:exec:{execution_id}. Pure event-driven — no polling. + Reads from "0-0" on each iteration so late subscribers never miss events. """ - channel = f"sse:exec:{execution_id}" + key = f"sse:exec:{execution_id}" + last_id = "0-0" async with asyncio.timeout(timeout): - async with redis_client.pubsub(ignore_subscribe_messages=True) as pubsub: - await pubsub.subscribe(channel) - async for message in pubsub.listen(): - event = _sse_adapter.validate_json(message["data"]) + while True: + for msg_id, fields in await _read_stream(redis_client, key, last_id): + last_id = msg_id + event = _sse_adapter.validate_json(fields[b"d"]) if predicate(event): return event + await asyncio.sleep(0.1) raise AssertionError("unreachable") @@ -76,19 +88,19 @@ async def wait_for_notification( *, timeout: float = 30.0, ) -> None: - """Wait for a notification on the user's SSE channel. + """Wait for a notification on the user's Redis Stream. The notification service publishes to sse:notif:{user_id} only after persisting to MongoDB, so receiving a message is a correct readiness signal — unlike RESULT_STORED which comes from an independent consumer group with no ordering guarantee. """ - channel = f"sse:notif:{user_id}" + key = f"sse:notif:{user_id}" async with asyncio.timeout(timeout): - async with redis_client.pubsub(ignore_subscribe_messages=True) as pubsub: - await pubsub.subscribe(channel) - async for _message in pubsub.listen(): - return # first message = notification persisted + while True: + if await _read_stream(redis_client, key): + return + await asyncio.sleep(0.1) @pytest.fixture diff --git a/backend/tests/e2e/core/test_dishka_lifespan.py b/backend/tests/e2e/core/test_dishka_lifespan.py index 6059017c..c89f18be 100644 --- a/backend/tests/e2e/core/test_dishka_lifespan.py +++ b/backend/tests/e2e/core/test_dishka_lifespan.py @@ -3,7 +3,7 @@ import pytest import redis.asyncio as aioredis from app.db.docs import UserDocument -from app.services.sse import SSERedisBus +from app.services.sse import SSEService from app.settings import Settings from dishka import AsyncContainer from fastapi import FastAPI @@ -70,8 +70,8 @@ async def test_redis_connected(self, scope: AsyncContainer) -> None: assert pong is True @pytest.mark.asyncio - async def test_sse_redis_bus_available(self, scope: AsyncContainer) -> None: - """SSE Redis bus is available after lifespan.""" - bus = await scope.get(SSERedisBus) - assert bus is not None + async def test_sse_service_available(self, scope: AsyncContainer) -> None: + """SSE service is available after lifespan.""" + svc = await scope.get(SSEService) + assert svc is not None diff --git a/backend/tests/e2e/notifications/test_notification_sse.py b/backend/tests/e2e/notifications/test_notification_sse.py index 977376b9..b267ec3e 100644 --- a/backend/tests/e2e/notifications/test_notification_sse.py +++ b/backend/tests/e2e/notifications/test_notification_sse.py @@ -4,8 +4,10 @@ import pytest from app.domain.enums import NotificationChannel, NotificationSeverity from app.domain.notification import DomainSubscriptionUpdate +from app.domain.sse import DomainNotificationSSEPayload from app.services.notification_service import NotificationService -from app.services.sse import SSERedisBus +from app.services.sse import SSEService +from app.services.sse.sse_service import _notif_adapter from dishka import AsyncContainer pytestmark = [pytest.mark.e2e, pytest.mark.redis] @@ -14,17 +16,14 @@ @pytest.mark.asyncio async def test_in_app_notification_published_to_sse(scope: AsyncContainer) -> None: svc: NotificationService = await scope.get(NotificationService) - bus: SSERedisBus = await scope.get(SSERedisBus) + sse: SSEService = await scope.get(SSEService) user_id = f"notif-user-{uuid4().hex[:8]}" # Enable IN_APP subscription for the user to allow delivery await svc.update_subscription(user_id, NotificationChannel.IN_APP, DomainSubscriptionUpdate(enabled=True)) - # Start generator (subscription happens on first __anext__) and publish concurrently. - # By the time create_notification fires, the subscribe is already established. - messages = bus.listen_notifications(user_id) - pub_task = asyncio.create_task(svc.create_notification( + await svc.create_notification( user_id=user_id, subject="Hello", body="World", @@ -32,9 +31,11 @@ async def test_in_app_notification_published_to_sse(scope: AsyncContainer) -> No action_url="/api/v1/notifications", severity=NotificationSeverity.MEDIUM, channel=NotificationChannel.IN_APP, - )) - msg = await asyncio.wait_for(messages.__anext__(), timeout=5.0) - await pub_task + ) + + # Read back from stream + gen = sse._poll_stream(f"sse:notif:{user_id}", _notif_adapter) + msg: DomainNotificationSSEPayload = await asyncio.wait_for(gen.__anext__(), timeout=5.0) assert msg is not None assert msg.subject == "Hello" diff --git a/backend/tests/e2e/services/sse/test_partitioned_event_router.py b/backend/tests/e2e/services/sse/test_partitioned_event_router.py index 5e44d3d3..3f5b4739 100644 --- a/backend/tests/e2e/services/sse/test_partitioned_event_router.py +++ b/backend/tests/e2e/services/sse/test_partitioned_event_router.py @@ -8,7 +8,8 @@ from app.core.metrics import ConnectionMetrics from app.domain.enums import EventType from app.domain.sse import SSEExecutionEventData -from app.services.sse import SSERedisBus +from app.services.sse import SSEService +from app.services.sse.sse_service import _exec_adapter from app.settings import Settings pytestmark = [pytest.mark.e2e, pytest.mark.redis] @@ -16,15 +17,25 @@ _test_logger = structlog.get_logger("test.services.sse.partitioned_event_router_integration") +class _FakeExecRepo: + async def get_execution(self, execution_id: str) -> None: # noqa: ARG002 + return None + + async def get_execution_result(self, execution_id: str) -> None: # noqa: ARG002 + return None + + @pytest.mark.asyncio -async def test_bus_routes_event_to_redis(redis_client: redis.Redis, test_settings: Settings) -> None: +async def test_service_routes_event_to_redis_stream(redis_client: redis.Redis, test_settings: Settings) -> None: suffix = uuid4().hex[:6] - bus = SSERedisBus( - redis_client, + svc = SSEService( + redis_client=redis_client, + execution_repository=_FakeExecRepo(), # type: ignore[arg-type] logger=_test_logger, connection_metrics=MagicMock(spec=ConnectionMetrics), exec_prefix=f"sse:exec:{suffix}:", notif_prefix=f"sse:notif:{suffix}:", + poll_interval=0.05, ) execution_id = f"e-{uuid4().hex[:8]}" @@ -33,12 +44,11 @@ async def test_bus_routes_event_to_redis(redis_client: redis.Redis, test_setting execution_id=execution_id, ) - # Start generator (subscription happens on first __anext__) and publish concurrently. - # By the time publish fires (~1 Redis RTT), the subscribe is already established. - messages = bus.listen_execution(execution_id) - pub_task = asyncio.create_task(bus.publish_event(execution_id, ev)) - msg = await asyncio.wait_for(messages.__anext__(), timeout=2.0) - await pub_task + await svc.publish_event(execution_id, ev) + + # Read back from stream + gen = svc._poll_stream(f"sse:exec:{suffix}:{execution_id}", _exec_adapter) + msg = await asyncio.wait_for(gen.__anext__(), timeout=2.0) assert msg is not None assert msg.event_type == ev.event_type diff --git a/backend/tests/unit/services/sse/test_redis_bus.py b/backend/tests/unit/services/sse/test_redis_bus.py deleted file mode 100644 index b839d248..00000000 --- a/backend/tests/unit/services/sse/test_redis_bus.py +++ /dev/null @@ -1,164 +0,0 @@ -from __future__ import annotations - -import asyncio -import structlog -from collections.abc import AsyncGenerator -from datetime import datetime, timezone -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -import redis.asyncio as redis_async -from app.core.metrics import ConnectionMetrics -from app.domain.enums import EventType, NotificationSeverity, NotificationStatus, ReplayStatus -from app.domain.sse import DomainNotificationSSEPayload, DomainReplaySSEPayload, SSEExecutionEventData -from app.services.sse import SSERedisBus -from app.services.sse.redis_bus import _sse_event_adapter - -pytestmark = pytest.mark.unit - -_test_logger = structlog.get_logger("test.services.sse.redis_bus") - - -class _FakePubSub: - def __init__(self) -> None: - self.subscribed: set[str] = set() - self._queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue() - self.closed = False - - async def __aenter__(self) -> _FakePubSub: - return self - - async def __aexit__(self, *_: object) -> None: - await self.aclose() - - async def subscribe(self, channel: str) -> None: - self.subscribed.add(channel) - - async def push(self, channel: str, payload: str | bytes) -> None: - self._queue.put_nowait({"data": payload, "channel": channel}) - - async def listen(self) -> AsyncGenerator[dict[str, Any], None]: - while True: - msg = await self._queue.get() - if msg is None: - return - yield msg - - async def aclose(self) -> None: - self.closed = True - - -class _FakeRedis: - """Fake Redis for testing - used in place of real Redis. - - Note: SSERedisBus uses duck-typing so this works without inheritance. - """ - - def __init__(self) -> None: - self.published: list[tuple[str, str]] = [] - self._pubsub = _FakePubSub() - - async def publish(self, channel: str, payload: str) -> int: - self.published.append((channel, payload)) - return 1 - - def pubsub(self, ignore_subscribe_messages: bool = False) -> _FakePubSub: # noqa: ARG002 - return self._pubsub - - -@pytest.mark.asyncio -async def test_publish_and_subscribe_round_trip() -> None: - r = _FakeRedis() - bus = SSERedisBus(cast(redis_async.Redis, r), logger=_test_logger, connection_metrics=MagicMock(spec=ConnectionMetrics)) - - # Publish event as SSEExecutionEventData (field projection happens in handlers.py) - evt = SSEExecutionEventData( - event_type=EventType.EXECUTION_COMPLETED, - execution_id="exec-1", - ) - await bus.publish_event("exec-1", evt) - assert r.published, "nothing published" - ch, payload = r.published[-1] - assert ch.endswith("exec-1") - - # Push message into fake pubsub queue before iterating (subscription is lazy) - await r._pubsub.push(ch, payload) - - # listen_execution is an async generator — no await needed - messages = bus.listen_execution("exec-1") - msg = await asyncio.wait_for(messages.__anext__(), timeout=2.0) - # Subscription happened inside __anext__ - assert "sse:exec:exec-1" in r._pubsub.subscribed - assert msg.event_type == EventType.EXECUTION_COMPLETED - assert msg.execution_id == "exec-1" - - # A second valid message passes through cleanly - good_payload = _sse_event_adapter.dump_json(SSEExecutionEventData( - event_type=EventType.EXECUTION_COMPLETED, - execution_id="exec-1", - )) - await r._pubsub.push(ch, good_payload) - msg2 = await asyncio.wait_for(messages.__anext__(), timeout=2.0) - assert msg2.event_type == EventType.EXECUTION_COMPLETED - - - -@pytest.mark.asyncio -async def test_notifications_channels() -> None: - r = _FakeRedis() - bus = SSERedisBus(cast(redis_async.Redis, r), logger=_test_logger, connection_metrics=MagicMock(spec=ConnectionMetrics)) - - notif = DomainNotificationSSEPayload( - notification_id="n1", - severity=NotificationSeverity.LOW, - status=NotificationStatus.PENDING, - tags=[], - subject="test", - body="body", - action_url="", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - await bus.publish_notification("user-1", notif) - ch, payload = r.published[-1] - assert ch.endswith("user-1") - - # Push message before iterating (subscription is lazy) - await r._pubsub.push(ch, payload) - - messages = bus.listen_notifications("user-1") - got = await asyncio.wait_for(messages.__anext__(), timeout=2.0) - # Subscription happened inside __anext__ - assert "sse:notif:user-1" in r._pubsub.subscribed - assert got.notification_id == "n1" - - - -@pytest.mark.asyncio -async def test_replay_publish_and_subscribe_round_trip() -> None: - r = _FakeRedis() - bus = SSERedisBus(cast(redis_async.Redis, r), logger=_test_logger, connection_metrics=MagicMock(spec=ConnectionMetrics)) - - status = DomainReplaySSEPayload( - session_id="sess-1", - status=ReplayStatus.RUNNING, - total_events=10, - replayed_events=3, - failed_events=0, - skipped_events=0, - replay_id="replay-1", - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ) - await bus.publish_replay_status("sess-1", status) - assert r.published, "nothing published" - ch, payload = r.published[-1] - assert ch.endswith("sess-1") - - await r._pubsub.push(ch, payload) - - messages = bus.listen_replay("sess-1") - got = await asyncio.wait_for(messages.__anext__(), timeout=2.0) - assert "sse:replay:sess-1" in r._pubsub.subscribed - assert got.session_id == "sess-1" - assert got.status == ReplayStatus.RUNNING - assert got.replayed_events == 3 diff --git a/backend/tests/unit/services/sse/test_sse_publish.py b/backend/tests/unit/services/sse/test_sse_publish.py new file mode 100644 index 00000000..67dcaa35 --- /dev/null +++ b/backend/tests/unit/services/sse/test_sse_publish.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import asyncio +import structlog +from datetime import datetime, timezone +from typing import Any +from unittest.mock import MagicMock + +import pytest +from app.core.metrics import ConnectionMetrics +from app.domain.enums import EventType, NotificationSeverity, NotificationStatus, ReplayStatus +from app.domain.sse import DomainNotificationSSEPayload, DomainReplaySSEPayload, SSEExecutionEventData +from app.services.sse import SSEService +from app.services.sse.sse_service import _exec_adapter, _notif_adapter, _replay_adapter + +pytestmark = pytest.mark.unit + +_test_logger = structlog.get_logger("test.services.sse.publish") + + +class _FakeRedis: + """Fake Redis with Streams support (XADD / XREAD / EXPIRE).""" + + def __init__(self) -> None: + self._streams: dict[str, list[tuple[str, dict[bytes, bytes]]]] = {} + self._counter = 0 + + async def xadd(self, key: str, fields: dict[str, bytes], **_kw: Any) -> str: + self._counter += 1 + msg_id = f"{self._counter}-0" + encoded = {k.encode() if isinstance(k, str) else k: v for k, v in fields.items()} + self._streams.setdefault(key, []).append((msg_id, encoded)) + return msg_id + + async def xread(self, streams: dict[str, str], **_kw: Any) -> list[tuple[str, list[tuple[str, dict[bytes, bytes]]]]] | None: + result: list[tuple[str, list[tuple[str, dict[bytes, bytes]]]]] = [] + for key, after in streams.items(): + after_seq = int(after.split("-")[0]) + msgs = [(mid, f) for mid, f in self._streams.get(key, []) if int(mid.split("-")[0]) > after_seq] + if msgs: + result.append((key, msgs)) + return result or None + + async def expire(self, key: str, seconds: int) -> bool: # noqa: ARG002 + return True + + +class _FakeExecRepo: + async def get_execution(self, execution_id: str) -> None: # noqa: ARG002 + return None + + async def get_execution_result(self, execution_id: str) -> None: # noqa: ARG002 + return None + + +def _make_service(fake_redis: _FakeRedis) -> SSEService: + return SSEService( + redis_client=fake_redis, # type: ignore[arg-type] + execution_repository=_FakeExecRepo(), # type: ignore[arg-type] + logger=_test_logger, + connection_metrics=MagicMock(spec=ConnectionMetrics), + poll_interval=0.01, + ) + + +@pytest.mark.asyncio +async def test_publish_event_writes_to_stream() -> None: + r = _FakeRedis() + svc = _make_service(r) + + evt = SSEExecutionEventData( + event_type=EventType.EXECUTION_COMPLETED, + execution_id="exec-1", + ) + await svc.publish_event("exec-1", evt) + + stream = r._streams.get("sse:exec:exec-1") + assert stream, "nothing written to stream" + _, fields = stream[0] + parsed = _exec_adapter.validate_json(fields[b"d"]) + assert parsed.event_type == EventType.EXECUTION_COMPLETED + assert parsed.execution_id == "exec-1" + + +@pytest.mark.asyncio +async def test_publish_and_poll_round_trip() -> None: + r = _FakeRedis() + svc = _make_service(r) + + evt = SSEExecutionEventData( + event_type=EventType.EXECUTION_COMPLETED, + execution_id="exec-1", + ) + await svc.publish_event("exec-1", evt) + + gen = svc._poll_stream("sse:exec:exec-1", _exec_adapter) + msg = await asyncio.wait_for(gen.__anext__(), timeout=2.0) + assert msg.event_type == EventType.EXECUTION_COMPLETED + assert msg.execution_id == "exec-1" + + +@pytest.mark.asyncio +async def test_notification_publish_round_trip() -> None: + r = _FakeRedis() + svc = _make_service(r) + + notif = DomainNotificationSSEPayload( + notification_id="n1", + severity=NotificationSeverity.LOW, + status=NotificationStatus.PENDING, + tags=[], + subject="test", + body="body", + action_url="", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + await svc.publish_notification("user-1", notif) + + stream = r._streams.get("sse:notif:user-1") + assert stream, "nothing written to stream" + _, fields = stream[0] + parsed = _notif_adapter.validate_json(fields[b"d"]) + assert parsed.notification_id == "n1" + + +@pytest.mark.asyncio +async def test_replay_publish_round_trip() -> None: + r = _FakeRedis() + svc = _make_service(r) + + status = DomainReplaySSEPayload( + session_id="sess-1", + status=ReplayStatus.RUNNING, + total_events=10, + replayed_events=3, + failed_events=0, + skipped_events=0, + replay_id="replay-1", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + await svc.publish_replay_status("sess-1", status) + + stream = r._streams.get("sse:replay:sess-1") + assert stream, "nothing written to stream" + _, fields = stream[0] + parsed = _replay_adapter.validate_json(fields[b"d"]) + assert parsed.session_id == "sess-1" + assert parsed.status == ReplayStatus.RUNNING + assert parsed.replayed_events == 3 diff --git a/backend/tests/unit/services/sse/test_sse_service.py b/backend/tests/unit/services/sse/test_sse_service.py index b3a9177f..7547f125 100644 --- a/backend/tests/unit/services/sse/test_sse_service.py +++ b/backend/tests/unit/services/sse/test_sse_service.py @@ -1,16 +1,18 @@ import asyncio import json import structlog -from collections.abc import AsyncGenerator from datetime import datetime, timezone from typing import Any +from unittest.mock import MagicMock import pytest +from app.core.metrics import ConnectionMetrics from app.domain.enums import EventType, ExecutionStatus, NotificationSeverity, NotificationStatus, ReplayStatus, SSEControlEvent, UserRole from app.domain.execution.models import DomainExecution, ExecutionResultDomain from app.domain.sse import DomainNotificationSSEPayload, DomainReplaySSEPayload, SSEExecutionEventData from app.services.sse import SSEService +from app.services.sse.sse_service import _exec_adapter, _notif_adapter pytestmark = pytest.mark.unit @@ -19,55 +21,31 @@ _NOW = datetime(2025, 1, 1, tzinfo=timezone.utc) -class _FakeBus: - """Fake SSERedisBus backed by asyncio queues.""" +class _FakeRedis: + """Fake Redis with Streams support (XADD / XREAD / EXPIRE).""" def __init__(self) -> None: - self._exec_q: asyncio.Queue[SSEExecutionEventData | None] = asyncio.Queue() - self._notif_q: asyncio.Queue[DomainNotificationSSEPayload | None] = asyncio.Queue() - self._replay_q: asyncio.Queue[DomainReplaySSEPayload | None] = asyncio.Queue() - self.exec_closed = False - self.notif_closed = False - self.replay_closed = False - - async def push_exec(self, event: SSEExecutionEventData | None) -> None: - await self._exec_q.put(event) - - async def push_notif(self, payload: DomainNotificationSSEPayload | None) -> None: - await self._notif_q.put(payload) - - async def push_replay(self, status: DomainReplaySSEPayload | None) -> None: - await self._replay_q.put(status) - - async def listen_execution(self, execution_id: str) -> AsyncGenerator[SSEExecutionEventData, None]: # noqa: ARG002 - try: - while True: - item = await self._exec_q.get() - if item is None: - return - yield item - finally: - self.exec_closed = True - - async def listen_notifications(self, user_id: str) -> AsyncGenerator[DomainNotificationSSEPayload, None]: # noqa: ARG002 - try: - while True: - item = await self._notif_q.get() - if item is None: - return - yield item - finally: - self.notif_closed = True - - async def listen_replay(self, session_id: str) -> AsyncGenerator[DomainReplaySSEPayload, None]: # noqa: ARG002 - try: - while True: - item = await self._replay_q.get() - if item is None: - return - yield item - finally: - self.replay_closed = True + self._streams: dict[str, list[tuple[str, dict[bytes, bytes]]]] = {} + self._counter = 0 + + async def xadd(self, key: str, fields: dict[str, bytes], **_kw: Any) -> str: + self._counter += 1 + msg_id = f"{self._counter}-0" + encoded = {k.encode() if isinstance(k, str) else k: v for k, v in fields.items()} + self._streams.setdefault(key, []).append((msg_id, encoded)) + return msg_id + + async def xread(self, streams: dict[str, str], **_kw: Any) -> list[tuple[str, list[tuple[str, dict[bytes, bytes]]]]] | None: + result: list[tuple[str, list[tuple[str, dict[bytes, bytes]]]]] = [] + for key, after in streams.items(): + after_seq = int(after.split("-")[0]) + msgs = [(mid, f) for mid, f in self._streams.get(key, []) if int(mid.split("-")[0]) > after_seq] + if msgs: + result.append((key, msgs)) + return result or None + + async def expire(self, key: str, seconds: int) -> bool: # noqa: ARG002 + return True class _FakeExecRepo: @@ -93,55 +71,53 @@ def _decode(evt: dict[str, Any]) -> dict[str, Any]: return result -def _make_service(bus: _FakeBus, exec_repo: _FakeExecRepo = _FakeExecRepo()) -> SSEService: +def _make_service(fake_redis: _FakeRedis, exec_repo: _FakeExecRepo = _FakeExecRepo()) -> SSEService: return SSEService( - bus=bus, # type: ignore[arg-type] + redis_client=fake_redis, # type: ignore[arg-type] execution_repository=exec_repo, # type: ignore[arg-type] logger=_test_logger, + connection_metrics=MagicMock(spec=ConnectionMetrics), + poll_interval=0.01, ) @pytest.mark.asyncio async def test_execution_stream_prepends_status_from_db() -> None: execution = DomainExecution(execution_id="exec-1", status=ExecutionStatus.RUNNING) - bus = _FakeBus() - svc = _make_service(bus, _FakeExecRepo(execution=execution)) + fake_redis = _FakeRedis() + svc = _make_service(fake_redis, _FakeExecRepo(execution=execution)) agen = await svc.create_execution_stream("exec-1", user_id="u1", user_role=UserRole.USER) - # Signal end of live stream so the generator can finish - await bus.push_exec(None) - # First item must be the STATUS prepended from DB - stat = await agen.__anext__() + stat = await asyncio.wait_for(agen.__anext__(), timeout=2.0) data = _decode(stat) assert data["event_type"] == "status" assert data["execution_id"] == "exec-1" - with pytest.raises(StopAsyncIteration): - await agen.__anext__() - @pytest.mark.asyncio async def test_execution_stream_closes_on_terminal_event() -> None: - bus = _FakeBus() - svc = _make_service(bus, _FakeExecRepo(execution=DomainExecution(execution_id="exec-1", status=ExecutionStatus.RUNNING))) + fake_redis = _FakeRedis() + svc = _make_service(fake_redis, _FakeExecRepo(execution=DomainExecution(execution_id="exec-1", status=ExecutionStatus.RUNNING))) agen = await svc.create_execution_stream("exec-1", user_id="u1", user_role=UserRole.USER) # DB status prepend is always yielded first - stat = await agen.__anext__() + stat = await asyncio.wait_for(agen.__anext__(), timeout=2.0) assert _decode(stat)["event_type"] == "status" - await bus.push_exec(SSEExecutionEventData( + # Push a terminal event into the stream + await fake_redis.xadd("sse:exec:exec-1", {"d": _exec_adapter.dump_json(SSEExecutionEventData( event_type=EventType.EXECUTION_FAILED, execution_id="exec-1", - )) - failed = await agen.__anext__() + ))}) + + failed = await asyncio.wait_for(agen.__anext__(), timeout=2.0) assert _decode(failed)["event_type"] == EventType.EXECUTION_FAILED with pytest.raises(StopAsyncIteration): - await agen.__anext__() + await asyncio.wait_for(agen.__anext__(), timeout=2.0) @pytest.mark.asyncio @@ -153,37 +129,37 @@ async def test_execution_stream_enriches_result_stored() -> None: stdout="out", stderr="", ) - bus = _FakeBus() - svc = _make_service(bus, _FakeExecRepo(execution=DomainExecution(execution_id="exec-2", status=ExecutionStatus.RUNNING), result=result)) + fake_redis = _FakeRedis() + svc = _make_service(fake_redis, _FakeExecRepo(execution=DomainExecution(execution_id="exec-2", status=ExecutionStatus.RUNNING), result=result)) agen = await svc.create_execution_stream("exec-2", user_id="u1", user_role=UserRole.USER) # Consume DB status prepend - await agen.__anext__() + await asyncio.wait_for(agen.__anext__(), timeout=2.0) - await bus.push_exec(SSEExecutionEventData( + await fake_redis.xadd("sse:exec:exec-2", {"d": _exec_adapter.dump_json(SSEExecutionEventData( event_type=EventType.RESULT_STORED, execution_id="exec-2", - )) - evt = await agen.__anext__() + ))}) + + evt = await asyncio.wait_for(agen.__anext__(), timeout=2.0) data = _decode(evt) assert data["event_type"] == EventType.RESULT_STORED assert data["result"] is not None assert data["result"]["execution_id"] == "exec-2" with pytest.raises(StopAsyncIteration): - await agen.__anext__() + await asyncio.wait_for(agen.__anext__(), timeout=2.0) @pytest.mark.asyncio -async def test_notification_stream_yields_notification_and_cleans_up() -> None: +async def test_notification_stream_yields_notification() -> None: """Notification stream yields {"event": "notification", "data": ...} for each message.""" - bus = _FakeBus() - svc = _make_service(bus) - - agen = svc.create_notification_stream(user_id="u1") + fake_redis = _FakeRedis() + svc = _make_service(fake_redis) - await bus.push_notif(DomainNotificationSSEPayload( + # Push notification into stream before starting the generator + await fake_redis.xadd("sse:notif:u1", {"d": _notif_adapter.dump_json(DomainNotificationSSEPayload( notification_id="n1", severity=NotificationSeverity.LOW, status=NotificationStatus.PENDING, @@ -192,7 +168,9 @@ async def test_notification_stream_yields_notification_and_cleans_up() -> None: body="b", action_url="", created_at=_NOW, - )) + ))}) + + agen = svc.create_notification_stream(user_id="u1") notif = await asyncio.wait_for(agen.__anext__(), timeout=2.0) assert notif["event"] == "notification" @@ -202,12 +180,11 @@ async def test_notification_stream_yields_notification_and_cleans_up() -> None: assert data["channel"] == "in_app" - @pytest.mark.asyncio async def test_replay_stream_yields_initial_then_live() -> None: """Replay pipeline yields initial status from DB then streams live updates.""" - bus = _FakeBus() - svc = _make_service(bus) + fake_redis = _FakeRedis() + svc = _make_service(fake_redis) initial = DomainReplaySSEPayload( session_id="sess-1", @@ -223,14 +200,15 @@ async def test_replay_stream_yields_initial_then_live() -> None: agen = await svc.create_replay_stream(initial) # First item is the initial status - first = await agen.__anext__() + first = await asyncio.wait_for(agen.__anext__(), timeout=2.0) data = _decode(first) assert data["session_id"] == "sess-1" assert data["status"] == "running" assert data["replayed_events"] == 0 - # Push a live update - await bus.push_replay(DomainReplaySSEPayload( + # Push a live update into the stream + from app.services.sse.sse_service import _replay_adapter + await fake_redis.xadd("sse:replay:sess-1", {"d": _replay_adapter.dump_json(DomainReplaySSEPayload( session_id="sess-1", status=ReplayStatus.RUNNING, total_events=5, @@ -239,14 +217,14 @@ async def test_replay_stream_yields_initial_then_live() -> None: skipped_events=0, replay_id="replay-1", created_at=_NOW, - )) + ))}) second = await asyncio.wait_for(agen.__anext__(), timeout=2.0) data2 = _decode(second) assert data2["replayed_events"] == 3 # Push terminal status - await bus.push_replay(DomainReplaySSEPayload( + await fake_redis.xadd("sse:replay:sess-1", {"d": _replay_adapter.dump_json(DomainReplaySSEPayload( session_id="sess-1", status=ReplayStatus.COMPLETED, total_events=5, @@ -255,21 +233,21 @@ async def test_replay_stream_yields_initial_then_live() -> None: skipped_events=0, replay_id="replay-1", created_at=_NOW, - )) + ))}) third = await asyncio.wait_for(agen.__anext__(), timeout=2.0) data3 = _decode(third) assert data3["status"] == "completed" with pytest.raises(StopAsyncIteration): - await agen.__anext__() + await asyncio.wait_for(agen.__anext__(), timeout=2.0) @pytest.mark.asyncio async def test_replay_stream_terminal_initial_closes_immediately() -> None: """If the initial replay status is terminal, the stream closes after yielding it.""" - bus = _FakeBus() - svc = _make_service(bus) + fake_redis = _FakeRedis() + svc = _make_service(fake_redis) initial = DomainReplaySSEPayload( session_id="sess-2", @@ -284,9 +262,9 @@ async def test_replay_stream_terminal_initial_closes_immediately() -> None: agen = await svc.create_replay_stream(initial) - first = await agen.__anext__() + first = await asyncio.wait_for(agen.__anext__(), timeout=2.0) data = _decode(first) assert data["status"] == "completed" with pytest.raises(StopAsyncIteration): - await agen.__anext__() + await asyncio.wait_for(agen.__anext__(), timeout=2.0) diff --git a/docs/architecture/lifecycle.md b/docs/architecture/lifecycle.md index 273959f5..d04a20cc 100644 --- a/docs/architecture/lifecycle.md +++ b/docs/architecture/lifecycle.md @@ -99,8 +99,7 @@ Use the shared `run_worker()` function from `workers/bootstrap.py`: 1. Create a container factory in `app/core/container.py` for the new worker 2. Create a `config..toml` override file -3. Optionally register Kafka subscriber handlers -4. Optionally define `on_startup` / `on_shutdown` callbacks for APScheduler or other setup -5. Call `run_worker()` with the worker name, config override, container factory, and optional hooks +3. Register Kafka subscriber handlers and/or define `on_startup` / `on_shutdown` callbacks for APScheduler or other setup as needed +4. Call `run_worker()` with the worker name, config override, container factory, and optional hooks See any existing `workers/run_*.py` for a working example. The simplest is `run_result_processor.py` (no startup hooks), the most complex is `run_pod_monitor.py` (APScheduler + K8s watch loop). From d024b8a40de8f229a144188cd6b2b61729bd4e20 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 5 Mar 2026 12:58:33 +0100 Subject: [PATCH 7/7] chore: no sync for backend/unit tests --- .github/workflows/stack-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stack-tests.yml b/.github/workflows/stack-tests.yml index eee041c3..ccc42c7e 100644 --- a/.github/workflows/stack-tests.yml +++ b/.github/workflows/stack-tests.yml @@ -58,7 +58,7 @@ jobs: timeout-minutes: 5 run: | cd backend - uv run pytest tests/unit -v -rs \ + uv run --no-sync pytest tests/unit -v -rs \ --durations=0 \ --cov=app \ --cov-report=xml --cov-report=term