diff --git a/centml/cli/main.py b/centml/cli/main.py index b1ecc73..fb5b932 100644 --- a/centml/cli/main.py +++ b/centml/cli/main.py @@ -2,6 +2,7 @@ from centml.cli.login import login, logout from centml.cli.cluster import ls, get, delete, pause, resume +from centml.cli.shell import shell, exec_cmd @click.group() @@ -47,6 +48,8 @@ def ccluster(): ccluster.add_command(delete) ccluster.add_command(pause) ccluster.add_command(resume) +ccluster.add_command(shell) +ccluster.add_command(exec_cmd, name="exec") cli.add_command(ccluster, name="cluster") diff --git a/centml/cli/shell.py b/centml/cli/shell.py new file mode 100644 index 0000000..72944c5 --- /dev/null +++ b/centml/cli/shell.py @@ -0,0 +1,58 @@ +"""CLI commands for interactive shell and command execution in deployment pods.""" + +import asyncio +import sys + +import click + +from centml.cli.cluster import handle_exception +from centml.sdk import auth +from centml.sdk.api import get_centml_client +from centml.sdk.config import settings +from centml.sdk.shell import ShellError +from centml.sdk.shell.session import build_ws_url, exec_session, interactive_session, resolve_pod + + +@click.command(help="Open an interactive shell to a deployment pod") +@click.argument("deployment_id", type=int) +@click.option("--pod", default=None, help="Specific pod name (auto-selects first running pod)") +@click.option("--shell", "shell_type", default=None, type=click.Choice(["bash", "sh", "zsh"]), help="Shell type") +@handle_exception +def shell(deployment_id, pod, shell_type): + if not sys.stdin.isatty(): + raise click.ClickException("Interactive shell requires a terminal (TTY)") + + with get_centml_client() as cclient: + try: + pod_name, warning = resolve_pod(cclient, deployment_id, pod) + except ShellError as exc: + raise click.ClickException(str(exc)) from exc + if warning: + click.echo(f"{warning} Use --pod to specify a different pod.", err=True) + + ws_url = build_ws_url(settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type) + token = auth.get_centml_token() + exit_code = asyncio.run(interactive_session(ws_url, token)) + sys.exit(exit_code) + + +@click.command(help="Execute a command in a deployment pod", context_settings={"ignore_unknown_options": True}) +@click.argument("deployment_id", type=int) +@click.argument("command", nargs=-1, required=True, type=click.UNPROCESSED) +@click.option("--pod", default=None, help="Specific pod name") +@click.option("--shell", "shell_type", default=None, type=click.Choice(["bash", "sh", "zsh"]), help="Shell type") +@handle_exception +def exec_cmd(deployment_id, command, pod, shell_type): + with get_centml_client() as cclient: + try: + pod_name, warning = resolve_pod(cclient, deployment_id, pod) + except ShellError as exc: + raise click.ClickException(str(exc)) from exc + if warning: + click.echo(f"{warning} Use --pod to specify a different pod.", err=True) + + ws_url = build_ws_url(settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type) + token = auth.get_centml_token() + cmd_str = " ".join(command) + exit_code = asyncio.run(exec_session(ws_url, token, cmd_str)) + sys.exit(exit_code) diff --git a/centml/sdk/api.py b/centml/sdk/api.py index e1e11d3..20dfa99 100644 --- a/centml/sdk/api.py +++ b/centml/sdk/api.py @@ -27,6 +27,9 @@ def get(self, depl_type): def get_status(self, id): return self._api.get_deployment_status_deployments_status_deployment_id_get(id) + def get_status_v3(self, deployment_id): + return self._api.get_deployment_status_v3_deployments_status_v3_deployment_id_get(deployment_id) + def get_inference(self, id): """Get Inference deployment details - automatically handles both V2 and V3 deployments""" # Try V3 first (recommended), fallback to V2 if deployment is V2 diff --git a/centml/sdk/shell/__init__.py b/centml/sdk/shell/__init__.py new file mode 100644 index 0000000..d1dc8ce --- /dev/null +++ b/centml/sdk/shell/__init__.py @@ -0,0 +1,34 @@ +"""SDK shell module -- reusable shell/exec session logic (no Click dependency).""" + +from centml.sdk.shell.exceptions import NoPodAvailableError, PodNotFoundError, ShellError +from centml.sdk.shell.renderer import char_to_sgr, color_sgr, pyte_extract_text, render_dirty +from centml.sdk.shell.session import ( + BEGIN_MARKER, + END_MARKER, + PRINTF_BEGIN, + PRINTF_END, + build_ws_url, + exec_session, + forward_io, + interactive_session, + resolve_pod, +) + +__all__ = [ + "ShellError", + "NoPodAvailableError", + "PodNotFoundError", + "color_sgr", + "char_to_sgr", + "render_dirty", + "pyte_extract_text", + "build_ws_url", + "resolve_pod", + "forward_io", + "interactive_session", + "exec_session", + "BEGIN_MARKER", + "END_MARKER", + "PRINTF_BEGIN", + "PRINTF_END", +] diff --git a/centml/sdk/shell/exceptions.py b/centml/sdk/shell/exceptions.py new file mode 100644 index 0000000..93286f0 --- /dev/null +++ b/centml/sdk/shell/exceptions.py @@ -0,0 +1,13 @@ +"""SDK exceptions for shell operations (no Click dependency).""" + + +class ShellError(Exception): + """Base exception for shell operations.""" + + +class NoPodAvailableError(ShellError): + """No running pods found for the deployment.""" + + +class PodNotFoundError(ShellError): + """Specified pod not found among running pods.""" diff --git a/centml/sdk/shell/renderer.py b/centml/sdk/shell/renderer.py new file mode 100644 index 0000000..4b630d6 --- /dev/null +++ b/centml/sdk/shell/renderer.py @@ -0,0 +1,130 @@ +"""Pyte terminal screen renderer -- converts pyte's in-memory buffer to ANSI.""" + +_PYTE_FG_TO_SGR = { + "default": "39", + "black": "30", + "red": "31", + "green": "32", + "brown": "33", + "blue": "34", + "magenta": "35", + "cyan": "36", + "white": "37", + "brightblack": "90", + "brightred": "91", + "brightgreen": "92", + "brightbrown": "93", + "brightblue": "94", + "brightmagenta": "95", + "brightcyan": "96", + "brightwhite": "97", +} + +_PYTE_BG_TO_SGR = { + "default": "49", + "black": "40", + "red": "41", + "green": "42", + "brown": "43", + "blue": "44", + "magenta": "45", + "cyan": "46", + "white": "47", + "brightblack": "100", + "brightred": "101", + "brightgreen": "102", + "brightbrown": "103", + "brightblue": "104", + "brightmagenta": "105", + "brightcyan": "106", + "brightwhite": "107", +} + + +def color_sgr(color, is_bg=False): + """Convert a pyte color value to an SGR parameter string.""" + table = _PYTE_BG_TO_SGR if is_bg else _PYTE_FG_TO_SGR + if color in table: + default_val = "49" if is_bg else "39" + code = table[color] + return code if code != default_val else "" + # 6-char hex -> truecolor + if len(color) == 6: + try: + r, g, b = int(color[:2], 16), int(color[2:4], 16), int(color[4:], 16) + prefix = "48" if is_bg else "38" + return f"{prefix};2;{r};{g};{b}" + except ValueError: + return "" + return "" + + +def char_to_sgr(char): + """Build the ANSI SGR parameter string for a pyte Char's attributes.""" + parts = [] + if char.bold: + parts.append("1") + if char.italics: + parts.append("3") + if char.underscore: + parts.append("4") + if char.blink: + parts.append("5") + if char.reverse: + parts.append("7") + if char.strikethrough: + parts.append("9") + fg = color_sgr(char.fg, is_bg=False) + if fg: + parts.append(fg) + bg = color_sgr(char.bg, is_bg=True) + if bg: + parts.append(bg) + return ";".join(parts) + + +def render_dirty(screen, output): + """Render only the dirty lines from the pyte Screen to the terminal. + + Args: + screen: pyte.Screen instance. + output: Writable binary stream (e.g. sys.stdout.buffer). + """ + parts = [] + for row in sorted(screen.dirty): + # Position cursor at row (1-based), column 1; clear line. + parts.append(f"\033[{row + 1};1H\033[2K") + prev_sgr = "" + line_chars = [] + for col in range(screen.columns): + char = screen.buffer[row][col] + if char.data == "": + continue + sgr = char_to_sgr(char) + if sgr != prev_sgr: + line_chars.append(f"\033[0m\033[{sgr}m" if sgr else "\033[0m") + prev_sgr = sgr + line_chars.append(char.data) + text = "".join(line_chars).rstrip() + parts.append(text) + # Reset attributes, position cursor. + parts.append("\033[0m") + parts.append(f"\033[{screen.cursor.y + 1};{screen.cursor.x + 1}H") + if screen.cursor.hidden: + parts.append("\033[?25l") + else: + parts.append("\033[?25h") + screen.dirty.clear() + output.write("".join(parts).encode("utf-8")) + output.flush() + + +def pyte_extract_text(line_stream, line_screen, text): + """Feed text through a single-row pyte screen and return visible characters. + + More robust than regex ANSI stripping: pyte interprets all VT100/VT220 + sequences including OSC, cursor repositioning, and truecolor escapes. + """ + line_screen.reset() + line_stream.feed(text) + return "".join(line_screen.buffer[0][col].data for col in range(line_screen.columns)).rstrip() diff --git a/centml/sdk/shell/session.py b/centml/sdk/shell/session.py new file mode 100644 index 0000000..51fb9da --- /dev/null +++ b/centml/sdk/shell/session.py @@ -0,0 +1,289 @@ +"""WebSocket session logic for shell and exec commands (no Click dependency).""" + +import asyncio +import json +import shutil +import signal +import sys +import termios +import tty +import urllib.parse +from typing import Optional, Tuple + +import pyte +import websockets + +from centml.sdk import PodStatus +from centml.sdk.shell.exceptions import NoPodAvailableError, PodNotFoundError +from centml.sdk.shell.renderer import pyte_extract_text, render_dirty + +BEGIN_MARKER = "__CENTML_BEGIN_5f3a__" +END_MARKER = "__CENTML_END_5f3a__" + +# printf octal \137 = underscore. The decoded output matches BEGIN/END_MARKER, +# but the literal command text does NOT, so shell echo won't trigger false matches. +PRINTF_BEGIN = r"\137\137CENTML_BEGIN_5f3a\137\137" +PRINTF_END = r"\137\137CENTML_END_5f3a\137\137" + + +def build_ws_url(api_url, deployment_id, pod_name, shell_type=None): + """Build the WebSocket URL for a terminal connection.""" + parsed = urllib.parse.urlparse(api_url) + ws_scheme = "wss" if parsed.scheme == "https" else "ws" + ws_base = parsed._replace(scheme=ws_scheme).geturl() + url = f"{ws_base}/deployments/{deployment_id}/terminal?pod={urllib.parse.quote(pod_name)}" + if shell_type: + url += f"&shell={urllib.parse.quote(shell_type)}" + return url + + +def resolve_pod(cclient, deployment_id, pod_name=None) -> Tuple[str, Optional[str]]: + """Resolve which pod to connect to. + + Args: + cclient: CentMLClient instance. + deployment_id: The deployment ID. + pod_name: Optional specific pod name to target. + + Returns: + Tuple of (pod_name, optional_warning_message). + + Raises: + NoPodAvailableError: If no running pods found. + PodNotFoundError: If specified pod not found among running pods. + """ + status = cclient.get_status_v3(deployment_id) + running_pods = [] + for revision in status.revision_pod_details_list or []: + for pod in revision.pod_details_list or []: + if pod.status == PodStatus.RUNNING and pod.name: + running_pods.append(pod.name) + + if not running_pods: + raise NoPodAvailableError(f"No running pods found for deployment {deployment_id}") + + if pod_name is not None: + if pod_name not in running_pods: + pods_list = ", ".join(running_pods) + raise PodNotFoundError(f"Pod '{pod_name}' not found. Available running pods: {pods_list}") + return pod_name, None + + warning = None + if len(running_pods) > 1: + warning = f"Multiple running pods found, connecting to {running_pods[0]}." + return running_pods[0], warning + + +async def forward_io(ws, screen, stream, shutdown): + """Bidirectional forwarding between local stdin/stdout and WebSocket. + + Output flows through a pyte terminal emulator so that cursor + addressing, line wrapping, and colors are rendered correctly + regardless of the remote PTY dimensions. + + The platform API proxy sends a close frame (code=1000) when the + remote shell exits, so _read_ws terminates via ConnectionClosed. + + Args: + ws: WebSocket connection. + screen: pyte.Screen instance sized to the local terminal. + stream: pyte.Stream attached to *screen*. + shutdown: asyncio.Event set by signal handlers to request exit. + + Returns: + The exit code (always 0 for interactive sessions). + """ + loop = asyncio.get_running_loop() + stdin_fd = sys.stdin.fileno() + stdin_closed = asyncio.Event() + + async def _read_ws(): + try: + while True: + raw_msg = await ws.recv() + msg = json.loads(raw_msg) + data = msg.get("data", "") + if data: + stream.feed(data.replace("\n", "\r\n")) + render_dirty(screen, sys.stdout.buffer) + elif msg.get("error"): + stream.feed(f"Error: {msg['error']}\r\n") + render_dirty(screen, sys.stdout.buffer) + except websockets.ConnectionClosed: + return + + async def _read_stdin(): + read_queue = asyncio.Queue() + + def _on_stdin_ready(): + data = sys.stdin.buffer.read1(4096) + if data: + read_queue.put_nowait(data) + else: + stdin_closed.set() + + loop.add_reader(stdin_fd, _on_stdin_ready) + try: + while not stdin_closed.is_set() and not shutdown.is_set(): + try: + data = await asyncio.wait_for(read_queue.get(), timeout=0.5) + except asyncio.TimeoutError: + continue + try: + await ws.send( + json.dumps( + { + "operation": "stdin", + "data": data.decode("utf-8", errors="replace"), + "rows": screen.lines, + "cols": screen.columns, + } + ) + ) + except websockets.ConnectionClosed: + return + finally: + loop.remove_reader(stdin_fd) + + async def _watch_shutdown(): + while not shutdown.is_set(): + await asyncio.sleep(0.2) + + task_ws = asyncio.create_task(_read_ws()) + task_stdin = asyncio.create_task(_read_stdin()) + task_shutdown = asyncio.create_task(_watch_shutdown()) + tasks = [task_ws, task_stdin, task_shutdown] + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + for t in pending: + t.cancel() + for t in pending: + try: + await t + except (asyncio.CancelledError, Exception): + pass + for t in done: + if t.exception() is not None: + raise t.exception() + return 0 + + +async def interactive_session(ws_url, token): + """Run an interactive terminal session over WebSocket. + + Enters raw mode, forwards I/O bidirectionally, and restores terminal + on exit. SIGTERM and SIGHUP are caught to ensure terminal settings + are always restored. + """ + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + cols, rows = shutil.get_terminal_size() + + screen = pyte.Screen(cols, rows) + stream = pyte.Stream(screen) + + # Switch to alternate screen buffer (disables scrollback) and clear. + sys.stdout.buffer.write(b"\033[?1049h\033[2J\033[H") + sys.stdout.buffer.flush() + + loop = asyncio.get_running_loop() + + shutdown = asyncio.Event() + loop.add_signal_handler(signal.SIGTERM, shutdown.set) + loop.add_signal_handler(signal.SIGHUP, shutdown.set) + + headers = {"Authorization": f"Bearer {token}"} + async with websockets.connect(ws_url, additional_headers=headers, close_timeout=2) as ws: + + def _send_resize(): + c, r = shutil.get_terminal_size() + screen.resize(r, c) + screen.dirty.update(range(r)) + asyncio.ensure_future(ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c}))) + + loop.add_signal_handler(signal.SIGWINCH, _send_resize) + + await ws.send(json.dumps({"operation": "resize", "rows": rows, "cols": cols})) + try: + exit_code = await forward_io(ws, screen, stream, shutdown) + finally: + loop.remove_signal_handler(signal.SIGWINCH) + + return exit_code + finally: + loop.remove_signal_handler(signal.SIGTERM) + loop.remove_signal_handler(signal.SIGHUP) + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + # Leave alternate screen buffer, restore cursor and attributes. + sys.stdout.buffer.write(b"\033[?1049l\033[?25h\033[0m") + sys.stdout.buffer.flush() + + +async def exec_session(ws_url, token, command): + """Execute a command in a pod and return its exit code. + + Does not enter raw mode -- output is pipe-friendly. + Suppresses shell echo and uses markers to capture only command output. + """ + cols, rows = shutil.get_terminal_size(fallback=(80, 24)) + # Single-row screen for interpreting escape sequences in marker detection. + line_screen = pyte.Screen(cols, 1) + line_stream = pyte.Stream(line_screen) + headers = {"Authorization": f"Bearer {token}"} + + async with websockets.connect(ws_url, additional_headers=headers, close_timeout=2) as ws: + await ws.send(json.dumps({"operation": "resize", "rows": rows, "cols": cols})) + + # Suppress echo/bracketed-paste, emit begin marker, run command, + # emit end marker with exit code, then exit. + # Markers use printf octal escapes so the literal marker string + # doesn't appear in the command echo. + wrapped = ( + f"stty -echo 2>/dev/null; printf '\\033[?2004l';" + f" printf '{PRINTF_BEGIN}\\n';" + f" {command};" + f" __ec=$?;" + f" printf '\\n{PRINTF_END}:%d\\n' \"$__ec\";" + f" exit $__ec\n" + ) + + await ws.send(json.dumps({"operation": "stdin", "data": wrapped})) + + exit_code = 0 + buffer = "" + is_capturing = False + is_done = False + try: + async for raw_msg in ws: + msg = json.loads(raw_msg) + if msg.get("data"): + buffer += msg["data"] + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + clean = pyte_extract_text(line_stream, line_screen, line.rstrip("\r")) + if BEGIN_MARKER in clean: + is_capturing = True + continue + if END_MARKER in clean: + parts = clean.split(END_MARKER + ":") + if len(parts) > 1: + try: + exit_code = int(parts[1].strip()) + except ValueError: + pass + is_done = True + break + if is_capturing: + sys.stdout.write(line + "\n") + sys.stdout.flush() + elif msg.get("error"): + sys.stderr.write(f"Error: {msg['error']}\n") + return 1 + if is_done: + break + except websockets.ConnectionClosed: + pass + return exit_code diff --git a/requirements.txt b/requirements.txt index c3b4961..9e79a4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ cryptography==44.0.1 prometheus-client>=0.20.0 scipy>=1.6.0 scikit-learn>=1.5.1 +websockets>=13.0 +pyte>=0.8.0 platform-api-python-client==4.6.0 diff --git a/tests/conftest.py b/tests/conftest.py index f3de342..1ab0b0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,11 +3,22 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" +# Tests that require PyTorch at import time -- skip during sanity runs +# where PyTorch is not installed. +_PYTORCH_TEST_FILES = ["test_backend.py", "test_helpers.py", "test_server.py"] + +collect_ignore = [] + def pytest_addoption(parser): parser.addoption("--sanity", action="store_true", help="Run sanity tests (exclude 'gpu' tests)") +def pytest_configure(config): + if config.getoption("--sanity", default=False): + collect_ignore.extend(_PYTORCH_TEST_FILES) + + def pytest_collection_modifyitems(config, items): if config.getoption("--sanity"): skip_gpu = pytest.mark.skip(reason="Skipping GPU tests for sanity run") diff --git a/tests/test_cli_shell.py b/tests/test_cli_shell.py new file mode 100644 index 0000000..54fcc6a --- /dev/null +++ b/tests/test_cli_shell.py @@ -0,0 +1,151 @@ +"""Tests for centml.cli.shell -- thin Click command wrappers.""" + +from unittest.mock import MagicMock, patch + +from centml.sdk.shell.exceptions import NoPodAvailableError, PodNotFoundError + +# =========================================================================== +# Click commands +# =========================================================================== + + +class TestShellCommand: + def test_rejects_non_tty(self): + from centml.cli.shell import shell + from click.testing import CliRunner + + runner = CliRunner() + result = runner.invoke(shell, ["123"]) + assert result.exit_code != 0 + assert "terminal" in result.output.lower() or "tty" in result.output.lower() + + def test_shell_option_forwarded(self): + from centml.cli.shell import shell + from click.testing import CliRunner + + with ( + patch("centml.cli.shell.resolve_pod", return_value=("pod-a", None)), + patch("centml.cli.shell.get_centml_client") as mock_ctx, + patch("centml.cli.shell.auth") as mock_auth, + patch("centml.cli.shell.settings") as mock_settings, + patch("centml.cli.shell.asyncio") as mock_asyncio, + patch("centml.cli.shell.sys") as mock_sys, + ): + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_sys.stdin.isatty.return_value = True + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + runner.invoke(shell, ["123", "--shell", "bash"]) + + mock_asyncio.run.assert_called_once() + + def test_pod_option_forwarded(self): + from centml.cli.shell import shell + from click.testing import CliRunner + + with ( + patch("centml.cli.shell.resolve_pod") as mock_resolve, + patch("centml.cli.shell.get_centml_client") as mock_ctx, + patch("centml.cli.shell.auth") as mock_auth, + patch("centml.cli.shell.settings") as mock_settings, + patch("centml.cli.shell.asyncio") as mock_asyncio, + patch("centml.cli.shell.sys") as mock_sys, + ): + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_resolve.return_value = ("my-pod", None) + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_sys.stdin.isatty.return_value = True + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + runner.invoke(shell, ["123", "--pod", "my-pod"]) + + mock_resolve.assert_called_once() + + def test_shell_error_converts_to_click_exception(self): + from centml.cli.shell import shell + from click.testing import CliRunner + + with ( + patch("centml.cli.shell.resolve_pod", side_effect=NoPodAvailableError("No running pods found")), + patch("centml.cli.shell.get_centml_client") as mock_ctx, + patch("centml.cli.shell.sys") as mock_sys, + ): + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_sys.stdin.isatty.return_value = True + + runner = CliRunner() + result = runner.invoke(shell, ["123"]) + + assert result.exit_code != 0 + assert "No running pods" in result.output + + +class TestExecCommand: + def test_passes_command(self): + from centml.cli.shell import exec_cmd + from click.testing import CliRunner + + with ( + patch("centml.cli.shell.resolve_pod", return_value=("pod-a", None)), + patch("centml.cli.shell.get_centml_client") as mock_ctx, + patch("centml.cli.shell.auth") as mock_auth, + patch("centml.cli.shell.settings") as mock_settings, + patch("centml.cli.shell.asyncio") as mock_asyncio, + ): + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + runner.invoke(exec_cmd, ["123", "--", "ls", "-la"]) + + mock_asyncio.run.assert_called_once() + + def test_shell_option_forwarded(self): + from centml.cli.shell import exec_cmd + from click.testing import CliRunner + + with ( + patch("centml.cli.shell.resolve_pod", return_value=("pod-a", None)), + patch("centml.cli.shell.get_centml_client") as mock_ctx, + patch("centml.cli.shell.auth") as mock_auth, + patch("centml.cli.shell.settings") as mock_settings, + patch("centml.cli.shell.asyncio") as mock_asyncio, + ): + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + runner.invoke(exec_cmd, ["123", "--shell", "zsh", "--", "echo", "hi"]) + + mock_asyncio.run.assert_called_once() + + def test_shell_error_converts_to_click_exception(self): + from centml.cli.shell import exec_cmd + from click.testing import CliRunner + + with ( + patch("centml.cli.shell.resolve_pod", side_effect=PodNotFoundError("Pod 'x' not found")), + patch("centml.cli.shell.get_centml_client") as mock_ctx, + ): + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + + runner = CliRunner() + result = runner.invoke(exec_cmd, ["123", "--", "ls"]) + + assert result.exit_code != 0 + assert "Pod 'x' not found" in result.output diff --git a/tests/test_sdk_shell_renderer.py b/tests/test_sdk_shell_renderer.py new file mode 100644 index 0000000..dee918d --- /dev/null +++ b/tests/test_sdk_shell_renderer.py @@ -0,0 +1,146 @@ +"""Tests for centml.sdk.shell.renderer -- pyte rendering utilities.""" + +import io + +import pyte + +from centml.sdk.shell.renderer import char_to_sgr, color_sgr, pyte_extract_text, render_dirty + +# =========================================================================== +# color_sgr +# =========================================================================== + + +class TestColorSgr: + def test_named_fg_color(self): + assert color_sgr("red", is_bg=False) == "31" + + def test_named_bg_color(self): + assert color_sgr("blue", is_bg=True) == "44" + + def test_default_fg_returns_empty(self): + assert color_sgr("default", is_bg=False) == "" + + def test_default_bg_returns_empty(self): + assert color_sgr("default", is_bg=True) == "" + + def test_hex_truecolor_fg(self): + assert color_sgr("ff0000", is_bg=False) == "38;2;255;0;0" + + def test_hex_truecolor_bg(self): + assert color_sgr("00ff00", is_bg=True) == "48;2;0;255;0" + + def test_invalid_hex_returns_empty(self): + assert color_sgr("zzzzzz", is_bg=False) == "" + + def test_unknown_name_returns_empty(self): + assert color_sgr("nope", is_bg=False) == "" + + +# =========================================================================== +# char_to_sgr +# =========================================================================== + + +class TestCharToSgr: + def test_default_attrs_returns_empty(self): + char = pyte.screens.Char(" ", "default", "default", False, False, False, False, False, False) + assert char_to_sgr(char) == "" + + def test_bold_red_fg(self): + char = pyte.screens.Char("x", "red", "default", True, False, False, False, False, False) + sgr = char_to_sgr(char) + assert "1" in sgr.split(";") + assert "31" in sgr.split(";") + + def test_bg_color(self): + char = pyte.screens.Char("x", "default", "blue", False, False, False, False, False, False) + sgr = char_to_sgr(char) + assert "44" in sgr.split(";") + + def test_256_color_fg(self): + char = pyte.screens.Char("x", "ff0000", "default", False, False, False, False, False, False) + sgr = char_to_sgr(char) + assert "38;2;255;0;0" in sgr + + def test_combined_attrs(self): + char = pyte.screens.Char("x", "green", "white", True, True, True, False, False, False) + sgr = char_to_sgr(char) + parts = sgr.split(";") + assert "1" in parts # bold + assert "3" in parts # italics + assert "4" in parts # underscore + assert "32" in parts # green fg + assert "47" in parts # white bg + + +# =========================================================================== +# render_dirty +# =========================================================================== + + +class TestRenderDirty: + def test_renders_simple_text(self): + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + screen.dirty.clear() + stream.feed("hello") + buf = io.BytesIO() + render_dirty(screen, buf) + output = buf.getvalue().decode("utf-8") + assert "hello" in output + assert len(screen.dirty) == 0 + + def test_clears_dirty_after_render(self): + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + screen.dirty.clear() + stream.feed("test") + assert len(screen.dirty) > 0 + render_dirty(screen, io.BytesIO()) + assert len(screen.dirty) == 0 + + def test_cursor_position_in_output(self): + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + stream.feed("abc") + buf = io.BytesIO() + render_dirty(screen, buf) + output = buf.getvalue().decode("utf-8") + # Cursor should be at row 1, col 4 (1-based: after "abc") + assert "\033[1;4H" in output + + def test_renders_only_dirty_lines(self): + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + stream.feed("line0\r\nline1\r\nline2") + # Render to clear dirty + render_dirty(screen, io.BytesIO()) + # Now modify only line 0 + stream.feed("\033[1;1Hchanged") + buf = io.BytesIO() + render_dirty(screen, buf) + output = buf.getvalue().decode("utf-8") + assert "changed" in output + # line1 and line2 should NOT be re-rendered + assert "line1" not in output + assert "line2" not in output + + +# =========================================================================== +# pyte_extract_text +# =========================================================================== + + +class TestPyteExtractText: + def test_strips_ansi(self): + screen = pyte.Screen(80, 1) + stream = pyte.Stream(screen) + result = pyte_extract_text(stream, screen, "\x1b[32mhello\x1b[0m") + assert result == "hello" + + def test_plain_text(self): + screen = pyte.Screen(80, 1) + stream = pyte.Stream(screen) + result = pyte_extract_text(stream, screen, "plain text") + assert result == "plain text" diff --git a/tests/test_sdk_shell_session.py b/tests/test_sdk_shell_session.py new file mode 100644 index 0000000..3fbe7c2 --- /dev/null +++ b/tests/test_sdk_shell_session.py @@ -0,0 +1,517 @@ +"""Tests for centml.sdk.shell.session -- WebSocket session logic.""" + +import asyncio +import io +import json +import os +import signal +import urllib.parse +from unittest.mock import AsyncMock, MagicMock, patch + +import pyte +import pytest + +from platform_api_python_client import PodStatus, PodDetails, RevisionPodDetails + +from centml.sdk.shell.exceptions import NoPodAvailableError, PodNotFoundError +from centml.sdk.shell.session import ( + BEGIN_MARKER, + END_MARKER, + build_ws_url, + exec_session, + forward_io, + interactive_session, + resolve_pod, +) + + +def _async_iter_from_list(items): + """Create an async iterator from a list of items.""" + + async def _aiter(): + for item in items: + yield item + + return _aiter() + + +# --------------------------------------------------------------------------- +# Helpers to build mock status responses +# --------------------------------------------------------------------------- + + +def _make_pod(name, status=PodStatus.RUNNING): + pod = MagicMock(spec=PodDetails) + pod.name = name + pod.status = status + return pod + + +def _make_revision(pods): + rev = MagicMock(spec=RevisionPodDetails) + rev.pod_details_list = pods + return rev + + +def _make_status_response(revisions): + resp = MagicMock() + resp.revision_pod_details_list = revisions + return resp + + +# =========================================================================== +# build_ws_url +# =========================================================================== + + +class TestBuildWsUrl: + def test_https_to_wss(self): + url = build_ws_url("https://api.centml.com", 123, "my-pod-abc") + parsed = urllib.parse.urlparse(url) + assert parsed.scheme == "wss" + assert parsed.netloc == "api.centml.com" + + def test_http_to_ws(self): + url = build_ws_url("http://localhost:16000", 42, "pod-1") + parsed = urllib.parse.urlparse(url) + assert parsed.scheme == "ws" + assert parsed.netloc == "localhost:16000" + + def test_contains_deployment_id_and_pod(self): + url = build_ws_url("https://api.centml.com", 99, "pod-xyz") + assert "/deployments/99/terminal" in url + assert "pod=pod-xyz" in url + + def test_with_shell(self): + url = build_ws_url("https://api.centml.com", 1, "p", shell_type="bash") + assert "shell=bash" in url + + def test_without_shell(self): + url = build_ws_url("https://api.centml.com", 1, "p") + assert "shell=" not in url + + def test_encodes_pod_name(self): + url = build_ws_url("https://api.centml.com", 1, "pod name/special") + assert "pod%20name" in url or "pod+name" in url + + +# =========================================================================== +# resolve_pod +# =========================================================================== + + +class TestResolvePod: + def test_selects_first_running(self): + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [_make_revision([_make_pod("pod-a", PodStatus.RUNNING), _make_pod("pod-b", PodStatus.RUNNING)])] + ) + pod_name, _ = resolve_pod(cclient, 1) + assert pod_name == "pod-a" + + def test_raises_no_running_pods(self): + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [_make_revision([_make_pod("pod-err", PodStatus.ERROR)])] + ) + with pytest.raises(NoPodAvailableError, match="No running pods"): + resolve_pod(cclient, 1) + + def test_raises_specified_pod_not_found(self): + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [_make_revision([_make_pod("pod-a", PodStatus.RUNNING)])] + ) + with pytest.raises(PodNotFoundError, match="pod-missing"): + resolve_pod(cclient, 1, pod_name="pod-missing") + + def test_returns_specified_pod(self): + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [_make_revision([_make_pod("pod-a", PodStatus.RUNNING), _make_pod("pod-b", PodStatus.RUNNING)])] + ) + pod_name, warning = resolve_pod(cclient, 1, pod_name="pod-b") + assert pod_name == "pod-b" + assert warning is None + + def test_empty_revision_list(self): + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response([]) + with pytest.raises(NoPodAvailableError, match="No running pods"): + resolve_pod(cclient, 1) + + def test_none_revision_list(self): + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response(None) + cclient.get_status_v3.return_value.revision_pod_details_list = None + with pytest.raises(NoPodAvailableError, match="No running pods"): + resolve_pod(cclient, 1) + + def test_skips_pods_without_name(self): + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [_make_revision([_make_pod(None, PodStatus.RUNNING), _make_pod("pod-real", PodStatus.RUNNING)])] + ) + pod_name, _ = resolve_pod(cclient, 1) + assert pod_name == "pod-real" + + def test_multiple_revisions(self): + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [ + _make_revision([_make_pod("pod-old", PodStatus.ERROR)]), + _make_revision([_make_pod("pod-new", PodStatus.RUNNING)]), + ] + ) + pod_name, _ = resolve_pod(cclient, 1) + assert pod_name == "pod-new" + + def test_multiple_running_pods_returns_warning(self): + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [_make_revision([_make_pod("pod-a", PodStatus.RUNNING), _make_pod("pod-b", PodStatus.RUNNING)])] + ) + pod_name, warning = resolve_pod(cclient, 1) + assert pod_name == "pod-a" + assert warning is not None + assert "Multiple running pods" in warning + assert "--pod" not in warning + + def test_single_running_pod_no_warning(self): + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [_make_revision([_make_pod("pod-a", PodStatus.RUNNING)])] + ) + pod_name, warning = resolve_pod(cclient, 1) + assert pod_name == "pod-a" + assert warning is None + + +# =========================================================================== +# exec_session +# =========================================================================== + + +class TestExecSession: + def test_sends_resize_and_wrapped_command(self): + ws = AsyncMock() + messages = [ + json.dumps({"data": f"noise\n{BEGIN_MARKER}\nhello world\n{END_MARKER}:0\n"}), + json.dumps({"Code": 0}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + with patch("centml.sdk.shell.session.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) + ) + + exit_code = asyncio.run(exec_session("wss://test/ws", "fake-token", "ls -la")) + + assert exit_code == 0 + assert ws.send.call_count == 2 + resize_msg = json.loads(ws.send.call_args_list[0][0][0]) + assert resize_msg["operation"] == "resize" + cmd_msg = json.loads(ws.send.call_args_list[1][0][0]) + assert cmd_msg["operation"] == "stdin" + assert "ls -la" in cmd_msg["data"] + assert "stty -echo" in cmd_msg["data"] + # Markers use printf octal escapes, so the literal marker + # should NOT appear in the command (prevents echo false-match). + assert BEGIN_MARKER not in cmd_msg["data"] + assert "CENTML_BEGIN" in cmd_msg["data"] + + def test_returns_nonzero_exit_code_from_marker(self): + ws = AsyncMock() + messages = [json.dumps({"data": f"{BEGIN_MARKER}\n{END_MARKER}:42\n"}), json.dumps({"Code": 42})] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + with patch("centml.sdk.shell.session.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) + ) + + exit_code = asyncio.run(exec_session("wss://test/ws", "fake-token", "false")) + + assert exit_code == 42 + + def test_error_message_returns_one(self): + ws = AsyncMock() + messages = [json.dumps({"error": "something went wrong"})] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + with patch("centml.sdk.shell.session.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) + ) + + exit_code = asyncio.run(exec_session("wss://test/ws", "fake-token", "bad")) + + assert exit_code == 1 + + def test_filters_noise_before_marker(self): + """Only output between BEGIN and END markers is written to stdout.""" + ws = AsyncMock() + messages = [ + json.dumps({"data": f"prompt$ command\n{BEGIN_MARKER}\nreal output\n{END_MARKER}:0\n"}), + json.dumps({"Code": 0}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + captured = [] + with ( + patch("centml.sdk.shell.session.websockets") as mock_ws_mod, + patch("centml.sdk.shell.session.sys") as mock_sys, + ): + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) + ) + mock_sys.stdout.write = lambda s: captured.append(s) + mock_sys.stdout.flush = MagicMock() + mock_sys.stderr.write = MagicMock() + + exit_code = asyncio.run(exec_session("wss://test/ws", "fake-token", "echo test")) + + assert exit_code == 0 + output = "".join(captured) + assert "real output" in output + assert "prompt$" not in output + + def test_connection_closed_returns_zero(self): + """Graceful exit when server closes connection without Code message.""" + import websockets as _ws_lib + + ws = AsyncMock() + + async def _raise_closed(): + yield json.dumps({"data": "partial\n"}) + raise _ws_lib.ConnectionClosed(None, None) + + ws.__aiter__ = MagicMock(return_value=_raise_closed()) + + with patch("centml.sdk.shell.session.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) + ) + mock_ws_mod.ConnectionClosed = _ws_lib.ConnectionClosed + + exit_code = asyncio.run(exec_session("wss://test/ws", "fake-token", "exit")) + + assert exit_code == 0 + + def test_handles_ansi_around_markers(self): + """Markers wrapped in ANSI codes are still detected via pyte.""" + ws = AsyncMock() + # Markers surrounded by ANSI color codes. + data = f"\x1b[32m{BEGIN_MARKER}\x1b[0m\noutput\n\x1b[32m{END_MARKER}:0\x1b[0m\n" + messages = [json.dumps({"data": data}), json.dumps({"Code": 0})] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + captured = [] + with ( + patch("centml.sdk.shell.session.websockets") as mock_ws_mod, + patch("centml.sdk.shell.session.sys") as mock_sys, + ): + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) + ) + mock_sys.stdout.write = lambda s: captured.append(s) + mock_sys.stdout.flush = MagicMock() + mock_sys.stderr.write = MagicMock() + + exit_code = asyncio.run(exec_session("wss://test/ws", "fake-token", "echo test")) + + assert exit_code == 0 + output = "".join(captured) + assert "output" in output + + +# =========================================================================== +# forward_io -- exit detection and shutdown +# =========================================================================== + + +class TestForwardIo: + """Tests for forward_io WebSocket forwarding. + + Uses a real pipe fd so ``loop.add_reader`` works without OS errors. + The server sends a close frame (code=1000) when the shell exits, + so forward_io relies on ConnectionClosed to terminate cleanly. + """ + + def _run_forward_io(self, ws, shutdown=None): + """Helper: run forward_io with a real pipe fd standing in for stdin.""" + import websockets as _ws_lib + + screen = pyte.Screen(80, 24) + stream = pyte.Stream(screen) + if shutdown is None: + shutdown = asyncio.Event() + + read_fd, write_fd = os.pipe() + os.close(write_fd) + try: + with ( + patch("centml.sdk.shell.session.sys") as mock_sys, + patch("centml.sdk.shell.session.websockets") as mock_ws_mod, + ): + mock_sys.stdin.fileno.return_value = read_fd + mock_sys.stdin.buffer.read1 = lambda n: b"" + mock_sys.stdout.buffer = io.BytesIO() + mock_ws_mod.ConnectionClosed = _ws_lib.ConnectionClosed + + return asyncio.run(forward_io(ws, screen, stream, shutdown)) + finally: + os.close(read_fd) + + def test_connection_closed_returns_zero(self): + """ConnectionClosed (server close frame) returns 0.""" + import websockets as _ws_lib + + ws = AsyncMock() + ws.recv = AsyncMock(side_effect=_ws_lib.ConnectionClosed(None, None)) + + assert self._run_forward_io(ws) == 0 + + def test_data_then_close_returns_zero(self): + """Normal data followed by server close frame returns 0.""" + import websockets as _ws_lib + + ws = AsyncMock() + ws.recv = AsyncMock(side_effect=[json.dumps({"data": "hello\r\n"}), _ws_lib.ConnectionClosed(None, None)]) + + assert self._run_forward_io(ws) == 0 + + def test_shutdown_event_exits(self): + """shutdown event causes forward_io to exit.""" + import websockets as _ws_lib + + ws = AsyncMock() + + # recv that blocks until cancelled (simulates open WS with no data) + async def _block_recv(): + await asyncio.sleep(999) + + ws.recv = _block_recv + + screen = pyte.Screen(80, 24) + stream = pyte.Stream(screen) + shutdown = asyncio.Event() + + read_fd, write_fd = os.pipe() + os.close(write_fd) + try: + with ( + patch("centml.sdk.shell.session.sys") as mock_sys, + patch("centml.sdk.shell.session.websockets") as mock_ws_mod, + ): + mock_sys.stdin.fileno.return_value = read_fd + mock_sys.stdin.buffer.read1 = lambda n: b"" + mock_sys.stdout.buffer = io.BytesIO() + mock_ws_mod.ConnectionClosed = _ws_lib.ConnectionClosed + + async def _run(): + async def _set_shutdown(): + await asyncio.sleep(0.1) + shutdown.set() + + asyncio.create_task(_set_shutdown()) + return await forward_io(ws, screen, stream, shutdown) + + assert asyncio.run(_run()) == 0 + finally: + os.close(read_fd) + + +# =========================================================================== +# interactive_session -- terminal restore +# =========================================================================== + + +class TestInteractiveSessionTerminalRestore: + def test_restores_terminal_on_exception(self): + with ( + patch("centml.sdk.shell.session.sys") as mock_sys, + patch("centml.sdk.shell.session.termios") as mock_termios, + patch("centml.sdk.shell.session.tty"), + patch("centml.sdk.shell.session.websockets") as mock_ws_mod, + ): + mock_sys.stdin.fileno.return_value = 0 + mock_sys.stdout.buffer = io.BytesIO() + mock_termios.tcgetattr.return_value = ["old_settings"] + + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(side_effect=ConnectionRefusedError("fail")), + __aexit__=AsyncMock(return_value=False), + ) + ) + + with pytest.raises(ConnectionRefusedError): + asyncio.run(interactive_session("wss://test/ws", "fake-token")) + + mock_termios.tcsetattr.assert_called_once() + restore_call = mock_termios.tcsetattr.call_args + assert restore_call[0][2] == ["old_settings"] + + +# =========================================================================== +# interactive_session -- signal handling +# =========================================================================== + + +class TestInteractiveSessionSignals: + """Tests for SIGTERM/SIGHUP restoring terminal settings.""" + + def test_sigterm_restores_terminal(self): + signal_handlers = {} + + def _fake_add_signal_handler(sig, handler): + signal_handlers[sig] = handler + + def _fake_remove_signal_handler(sig): + signal_handlers.pop(sig, None) + + async def _fake_forward_io(ws, screen, stream, shutdown): + if signal.SIGTERM in signal_handlers: + signal_handlers[signal.SIGTERM]() + return 0 + + with ( + patch("centml.sdk.shell.session.sys") as mock_sys, + patch("centml.sdk.shell.session.termios") as mock_termios, + patch("centml.sdk.shell.session.tty"), + patch("centml.sdk.shell.session.websockets") as mock_ws_mod, + patch("centml.sdk.shell.session.forward_io", side_effect=_fake_forward_io), + ): + mock_sys.stdin.fileno.return_value = 0 + mock_sys.stdout.buffer = io.BytesIO() + mock_termios.tcgetattr.return_value = ["old_settings"] + + mock_ws = AsyncMock() + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_ws), __aexit__=AsyncMock(return_value=False) + ) + ) + + def _patched_run(coro): + loop = asyncio.new_event_loop() + loop.add_signal_handler = _fake_add_signal_handler + loop.remove_signal_handler = _fake_remove_signal_handler + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + with patch("centml.sdk.shell.session.asyncio") as mock_asyncio_mod: + mock_asyncio_mod.get_running_loop.return_value = MagicMock( + add_signal_handler=_fake_add_signal_handler, remove_signal_handler=_fake_remove_signal_handler + ) + mock_asyncio_mod.Event = asyncio.Event + mock_asyncio_mod.create_task = asyncio.ensure_future + + _patched_run(interactive_session("wss://test/ws", "fake-token")) + + mock_termios.tcsetattr.assert_called_once() + assert mock_termios.tcsetattr.call_args[0][2] == ["old_settings"]