Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ import (
"log/slog"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"time"

"github.com/coder/agentapi/lib/screentracker"
"github.com/mattn/go-isatty"
"github.com/spf13/cobra"
"github.com/spf13/viper"
Expand Down Expand Up @@ -103,6 +106,44 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
}
}

// Get the variables related to state management
stateFile := viper.GetString(StateFile)
loadState := false
saveState := false

// Validate state file configuration
if stateFile != "" {
if !viper.IsSet(LoadState) {
loadState = true
} else {
loadState = viper.GetBool(LoadState)
}

if !viper.IsSet(SaveState) {
saveState = true
} else {
saveState = viper.GetBool(SaveState)
}
} else {
if viper.IsSet(LoadState) && viper.GetBool(LoadState) {
return xerrors.Errorf("--load-state requires --state-file to be set")
}
if viper.IsSet(SaveState) && viper.GetBool(SaveState) {
return xerrors.Errorf("--save-state requires --state-file to be set")
}
}

pidFile := viper.GetString(PidFile)

// Write PID file if configured
if pidFile != "" {
if err := writePIDFile(pidFile, logger); err != nil {
return xerrors.Errorf("failed to write PID file: %w", err)
}
// Ensure PID file is cleaned up on exit
defer cleanupPIDFile(pidFile, logger)
}

printOpenAPI := viper.GetBool(FlagPrintOpenAPI)
var process *termexec.Process
if printOpenAPI {
Expand All @@ -128,14 +169,21 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
InitialPrompt: initialPrompt,
StatePersistenceConfig: screentracker.StatePersistenceConfig{
StateFile: stateFile,
LoadState: loadState,
SaveState: saveState,
},
})

if err != nil {
return xerrors.Errorf("failed to create server: %w", err)
}
if printOpenAPI {
fmt.Println(srv.GetOpenAPI())
return nil
}
handleSignals(ctx, logger, srv, process)
logger.Info("Starting server on port", "port", port)
processExitCh := make(chan error, 1)
go func() {
Expand All @@ -147,11 +195,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
processExitCh <- xerrors.Errorf("failed to wait for process: %w", err)
}
}
if err := srv.Stop(ctx); err != nil {
logger.Error("Failed to stop server", "error", err)
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Stop(shutdownCtx); err != nil {
logger.Error("Failed to stop server after process exit", "error", err)
}
}()
if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed {
if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) {
return xerrors.Errorf("failed to start server: %w", err)
}
select {
Expand All @@ -171,6 +221,35 @@ var agentNames = (func() []string {
return names
})()

// writePIDFile writes the current process ID to the specified file
func writePIDFile(pidFile string, logger *slog.Logger) error {
pid := os.Getpid()
pidContent := fmt.Sprintf("%d\n", pid)

// Create directory if it doesn't exist
dir := filepath.Dir(pidFile)
if err := os.MkdirAll(dir, 0o755); err != nil {
return xerrors.Errorf("failed to create PID file directory: %w", err)
}

// Write PID file
if err := os.WriteFile(pidFile, []byte(pidContent), 0o644); err != nil {
return xerrors.Errorf("failed to write PID file: %w", err)
}

logger.Info("Wrote PID file", "pidFile", pidFile, "pid", pid)
return nil
}

// cleanupPIDFile removes the PID file if it exists
func cleanupPIDFile(pidFile string, logger *slog.Logger) {
if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) {
logger.Error("Failed to remove PID file", "pidFile", pidFile, "error", err)
} else if err == nil {
logger.Info("Removed PID file", "pidFile", pidFile)
}
}

type flagSpec struct {
name string
shorthand string
Expand All @@ -190,6 +269,10 @@ const (
FlagAllowedOrigins = "allowed-origins"
FlagExit = "exit"
FlagInitialPrompt = "initial-prompt"
StateFile = "state-file"
LoadState = "load-state"
SaveState = "save-state"
PidFile = "pid-file"
)

func CreateServerCmd() *cobra.Command {
Expand Down Expand Up @@ -228,6 +311,10 @@ func CreateServerCmd() *cobra.Command {
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
{FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"},
{StateFile, "s", "", "Path to file for saving/loading server state", "string"},
{LoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"},
{SaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"},
{PidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"},
}

for _, spec := range flagSpecs {
Expand Down
214 changes: 214 additions & 0 deletions cmd/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package server

import (
"fmt"
"io"
"log/slog"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -477,6 +479,218 @@ func TestServerCmd_AllowedHosts(t *testing.T) {
}
}

func TestServerCmd_StatePersistenceFlags(t *testing.T) {
// NOTE: These tests use --exit flag to test flag parsing and defaults.
// Runtime validation that happens in runServer (e.g., "--load-state requires --state-file")
// would call os.Exit(1) which terminates the test process, so those validations
// are tested through integration/E2E tests instead.

t.Run("state-file with defaults", func(t *testing.T) {
isolateViper(t)

serverCmd := CreateServerCmd()
setupCommandOutput(t, serverCmd)
serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--exit", "dummy-command"})
err := serverCmd.Execute()
require.NoError(t, err)

assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile))
// load-state and save-state default to true when state-file is set (validated in runServer)
})

t.Run("state-file with explicit load-state=false", func(t *testing.T) {
isolateViper(t)

serverCmd := CreateServerCmd()
setupCommandOutput(t, serverCmd)
serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--load-state=false", "--exit", "dummy-command"})
err := serverCmd.Execute()
require.NoError(t, err)

assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile))
assert.Equal(t, false, viper.GetBool(LoadState))
})

t.Run("state-file with explicit save-state=false", func(t *testing.T) {
isolateViper(t)

serverCmd := CreateServerCmd()
setupCommandOutput(t, serverCmd)
serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--save-state=false", "--exit", "dummy-command"})
err := serverCmd.Execute()
require.NoError(t, err)

assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile))
assert.Equal(t, false, viper.GetBool(SaveState))
})

t.Run("state-file with explicit load-state=true and save-state=true", func(t *testing.T) {
isolateViper(t)

serverCmd := CreateServerCmd()
setupCommandOutput(t, serverCmd)
serverCmd.SetArgs([]string{
"--state-file", "/tmp/state.json",
"--load-state=true",
"--save-state=true",
"--exit", "dummy-command",
})
err := serverCmd.Execute()
require.NoError(t, err)

assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile))
assert.Equal(t, true, viper.GetBool(LoadState))
assert.Equal(t, true, viper.GetBool(SaveState))
})

t.Run("load-state flag can be parsed", func(t *testing.T) {
isolateViper(t)

serverCmd := CreateServerCmd()
setupCommandOutput(t, serverCmd)
serverCmd.SetArgs([]string{"--load-state", "--exit", "dummy-command"})
err := serverCmd.Execute()
require.NoError(t, err)

// Flag is parsed correctly (validation happens in runServer)
assert.Equal(t, true, viper.GetBool(LoadState))
})

t.Run("save-state flag can be parsed", func(t *testing.T) {
isolateViper(t)

serverCmd := CreateServerCmd()
setupCommandOutput(t, serverCmd)
serverCmd.SetArgs([]string{"--save-state", "--exit", "dummy-command"})
err := serverCmd.Execute()
require.NoError(t, err)

// Flag is parsed correctly (validation happens in runServer)
assert.Equal(t, true, viper.GetBool(SaveState))
})

t.Run("pid-file can be set independently", func(t *testing.T) {
isolateViper(t)

serverCmd := CreateServerCmd()
setupCommandOutput(t, serverCmd)
serverCmd.SetArgs([]string{"--pid-file", "/tmp/server.pid", "--exit", "dummy-command"})
err := serverCmd.Execute()
require.NoError(t, err)

assert.Equal(t, "/tmp/server.pid", viper.GetString(PidFile))
})

t.Run("state-file and pid-file can be set together", func(t *testing.T) {
isolateViper(t)

serverCmd := CreateServerCmd()
setupCommandOutput(t, serverCmd)
serverCmd.SetArgs([]string{
"--state-file", "/tmp/state.json",
"--pid-file", "/tmp/server.pid",
"--exit", "dummy-command",
})
err := serverCmd.Execute()
require.NoError(t, err)

assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile))
assert.Equal(t, "/tmp/server.pid", viper.GetString(PidFile))
})
}

func TestPIDFileOperations(t *testing.T) {
discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil))

t.Run("writePIDFile creates file with process ID", func(t *testing.T) {
tmpDir := t.TempDir()
pidFile := tmpDir + "/test.pid"

err := writePIDFile(pidFile, discardLogger)
require.NoError(t, err)

// Verify file exists
_, err = os.Stat(pidFile)
require.NoError(t, err)

// Verify content contains current PID
data, err := os.ReadFile(pidFile)
require.NoError(t, err)

expectedPID := fmt.Sprintf("%d\n", os.Getpid())
assert.Equal(t, expectedPID, string(data))
})

t.Run("writePIDFile creates directory if not exists", func(t *testing.T) {
tmpDir := t.TempDir()
pidFile := tmpDir + "/nested/deep/test.pid"

err := writePIDFile(pidFile, discardLogger)
require.NoError(t, err)

// Verify file exists
_, err = os.Stat(pidFile)
require.NoError(t, err)

// Verify directory was created
_, err = os.Stat(tmpDir + "/nested/deep")
require.NoError(t, err)
})

t.Run("writePIDFile overwrites existing file", func(t *testing.T) {
tmpDir := t.TempDir()
pidFile := tmpDir + "/test.pid"

// Write initial PID file
err := os.WriteFile(pidFile, []byte("12345\n"), 0o644)
require.NoError(t, err)

// Overwrite with current PID
err = writePIDFile(pidFile, discardLogger)
require.NoError(t, err)

// Verify content is updated
data, err := os.ReadFile(pidFile)
require.NoError(t, err)

expectedPID := fmt.Sprintf("%d\n", os.Getpid())
assert.Equal(t, expectedPID, string(data))
})

t.Run("cleanupPIDFile removes file", func(t *testing.T) {
tmpDir := t.TempDir()
pidFile := tmpDir + "/test.pid"

// Create PID file
err := os.WriteFile(pidFile, []byte("12345\n"), 0o644)
require.NoError(t, err)

// Cleanup
cleanupPIDFile(pidFile, discardLogger)

// Verify file is removed
_, err = os.Stat(pidFile)
assert.True(t, os.IsNotExist(err))
})

t.Run("cleanupPIDFile handles non-existent file", func(t *testing.T) {
tmpDir := t.TempDir()
pidFile := tmpDir + "/nonexistent.pid"

// Should not panic or error
cleanupPIDFile(pidFile, discardLogger)
})

t.Run("cleanupPIDFile handles directory removal error gracefully", func(t *testing.T) {
// Create a file in a protected directory (this is system-dependent)
// Just verify it doesn't panic when it can't remove the file
pidFile := "/this/should/not/exist/test.pid"

// Should not panic
cleanupPIDFile(pidFile, discardLogger)
})
}

func TestServerCmd_AllowedOrigins(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading