diff --git a/cmd/server/server.go b/cmd/server/server.go index 6d5cdec..c20a833 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -8,9 +8,12 @@ import ( "log/slog" "net/http" "os" + "path/filepath" "sort" "strings" + "time" + "github.com/coder/agentapi/lib/screentracker" "github.com/mattn/go-isatty" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -103,6 +106,44 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } } + // Get the variables related to state management + stateFile := viper.GetString(StateFile) + loadState := false + saveState := false + + // Validate state file configuration + if stateFile != "" { + if !viper.IsSet(LoadState) { + loadState = true + } else { + loadState = viper.GetBool(LoadState) + } + + if !viper.IsSet(SaveState) { + saveState = true + } else { + saveState = viper.GetBool(SaveState) + } + } else { + if viper.IsSet(LoadState) && viper.GetBool(LoadState) { + return xerrors.Errorf("--load-state requires --state-file to be set") + } + if viper.IsSet(SaveState) && viper.GetBool(SaveState) { + return xerrors.Errorf("--save-state requires --state-file to be set") + } + } + + pidFile := viper.GetString(PidFile) + + // Write PID file if configured + if pidFile != "" { + if err := writePIDFile(pidFile, logger); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + // Ensure PID file is cleaned up on exit + defer cleanupPIDFile(pidFile, logger) + } + printOpenAPI := viper.GetBool(FlagPrintOpenAPI) var process *termexec.Process if printOpenAPI { @@ -128,7 +169,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), InitialPrompt: initialPrompt, + StatePersistenceConfig: screentracker.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: loadState, + SaveState: saveState, + }, }) + if err != nil { return xerrors.Errorf("failed to create server: %w", err) } @@ -136,6 +183,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er fmt.Println(srv.GetOpenAPI()) return nil } + handleSignals(ctx, logger, srv, process) logger.Info("Starting server on port", "port", port) processExitCh := make(chan error, 1) go func() { @@ -147,11 +195,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er processExitCh <- xerrors.Errorf("failed to wait for process: %w", err) } } - if err := srv.Stop(ctx); err != nil { - logger.Error("Failed to stop server", "error", err) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Stop(shutdownCtx); err != nil { + logger.Error("Failed to stop server after process exit", "error", err) } }() - if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed { + if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { return xerrors.Errorf("failed to start server: %w", err) } select { @@ -171,6 +221,35 @@ var agentNames = (func() []string { return names })() +// writePIDFile writes the current process ID to the specified file +func writePIDFile(pidFile string, logger *slog.Logger) error { + pid := os.Getpid() + pidContent := fmt.Sprintf("%d\n", pid) + + // Create directory if it doesn't exist + dir := filepath.Dir(pidFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + return xerrors.Errorf("failed to create PID file directory: %w", err) + } + + // Write PID file + if err := os.WriteFile(pidFile, []byte(pidContent), 0o644); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + + logger.Info("Wrote PID file", "pidFile", pidFile, "pid", pid) + return nil +} + +// cleanupPIDFile removes the PID file if it exists +func cleanupPIDFile(pidFile string, logger *slog.Logger) { + if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { + logger.Error("Failed to remove PID file", "pidFile", pidFile, "error", err) + } else if err == nil { + logger.Info("Removed PID file", "pidFile", pidFile) + } +} + type flagSpec struct { name string shorthand string @@ -190,6 +269,10 @@ const ( FlagAllowedOrigins = "allowed-origins" FlagExit = "exit" FlagInitialPrompt = "initial-prompt" + StateFile = "state-file" + LoadState = "load-state" + SaveState = "save-state" + PidFile = "pid-file" ) func CreateServerCmd() *cobra.Command { @@ -228,6 +311,10 @@ func CreateServerCmd() *cobra.Command { // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, {FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"}, + {StateFile, "s", "", "Path to file for saving/loading server state", "string"}, + {LoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"}, + {SaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"}, + {PidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"}, } for _, spec := range flagSpecs { diff --git a/cmd/server/server_test.go b/cmd/server/server_test.go index bd07fc6..4affad0 100644 --- a/cmd/server/server_test.go +++ b/cmd/server/server_test.go @@ -2,6 +2,8 @@ package server import ( "fmt" + "io" + "log/slog" "os" "strings" "testing" @@ -477,6 +479,218 @@ func TestServerCmd_AllowedHosts(t *testing.T) { } } +func TestServerCmd_StatePersistenceFlags(t *testing.T) { + // NOTE: These tests use --exit flag to test flag parsing and defaults. + // Runtime validation that happens in runServer (e.g., "--load-state requires --state-file") + // would call os.Exit(1) which terminates the test process, so those validations + // are tested through integration/E2E tests instead. + + t.Run("state-file with defaults", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + // load-state and save-state default to true when state-file is set (validated in runServer) + }) + + t.Run("state-file with explicit load-state=false", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--load-state=false", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, false, viper.GetBool(LoadState)) + }) + + t.Run("state-file with explicit save-state=false", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--save-state=false", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, false, viper.GetBool(SaveState)) + }) + + t.Run("state-file with explicit load-state=true and save-state=true", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{ + "--state-file", "/tmp/state.json", + "--load-state=true", + "--save-state=true", + "--exit", "dummy-command", + }) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, true, viper.GetBool(LoadState)) + assert.Equal(t, true, viper.GetBool(SaveState)) + }) + + t.Run("load-state flag can be parsed", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--load-state", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + // Flag is parsed correctly (validation happens in runServer) + assert.Equal(t, true, viper.GetBool(LoadState)) + }) + + t.Run("save-state flag can be parsed", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--save-state", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + // Flag is parsed correctly (validation happens in runServer) + assert.Equal(t, true, viper.GetBool(SaveState)) + }) + + t.Run("pid-file can be set independently", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--pid-file", "/tmp/server.pid", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/server.pid", viper.GetString(PidFile)) + }) + + t.Run("state-file and pid-file can be set together", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{ + "--state-file", "/tmp/state.json", + "--pid-file", "/tmp/server.pid", + "--exit", "dummy-command", + }) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, "/tmp/server.pid", viper.GetString(PidFile)) + }) +} + +func TestPIDFileOperations(t *testing.T) { + discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + t.Run("writePIDFile creates file with process ID", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + err := writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(pidFile) + require.NoError(t, err) + + // Verify content contains current PID + data, err := os.ReadFile(pidFile) + require.NoError(t, err) + + expectedPID := fmt.Sprintf("%d\n", os.Getpid()) + assert.Equal(t, expectedPID, string(data)) + }) + + t.Run("writePIDFile creates directory if not exists", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/nested/deep/test.pid" + + err := writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(pidFile) + require.NoError(t, err) + + // Verify directory was created + _, err = os.Stat(tmpDir + "/nested/deep") + require.NoError(t, err) + }) + + t.Run("writePIDFile overwrites existing file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + // Write initial PID file + err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + require.NoError(t, err) + + // Overwrite with current PID + err = writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify content is updated + data, err := os.ReadFile(pidFile) + require.NoError(t, err) + + expectedPID := fmt.Sprintf("%d\n", os.Getpid()) + assert.Equal(t, expectedPID, string(data)) + }) + + t.Run("cleanupPIDFile removes file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + // Create PID file + err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + require.NoError(t, err) + + // Cleanup + cleanupPIDFile(pidFile, discardLogger) + + // Verify file is removed + _, err = os.Stat(pidFile) + assert.True(t, os.IsNotExist(err)) + }) + + t.Run("cleanupPIDFile handles non-existent file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/nonexistent.pid" + + // Should not panic or error + cleanupPIDFile(pidFile, discardLogger) + }) + + t.Run("cleanupPIDFile handles directory removal error gracefully", func(t *testing.T) { + // Create a file in a protected directory (this is system-dependent) + // Just verify it doesn't panic when it can't remove the file + pidFile := "/this/should/not/exist/test.pid" + + // Should not panic + cleanupPIDFile(pidFile, discardLogger) + }) +} + func TestServerCmd_AllowedOrigins(t *testing.T) { tests := []struct { name string diff --git a/cmd/server/signals.go b/cmd/server/signals.go new file mode 100644 index 0000000..e66554e --- /dev/null +++ b/cmd/server/signals.go @@ -0,0 +1,36 @@ +package server + +import ( + "context" + "log/slog" + "os" + "time" + + "github.com/coder/agentapi/lib/httpapi" + "github.com/coder/agentapi/lib/termexec" +) + +// performGracefulShutdown handles the common shutdown logic for all platforms. +// It saves state, stops the HTTP server, closes the process, and exits. +func performGracefulShutdown(sig os.Signal, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { + logger.Info("Received shutdown signal, initiating graceful shutdown", "signal", sig) + + // Save state + if err := srv.SaveState(sig.String()); err != nil { + logger.Error("Failed to save state during shutdown", "signal", sig, "error", err) + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Stop(shutdownCtx); err != nil { + logger.Error("Failed to stop HTTP server", "signal", sig, "error", err) + } + + // Close the process + if err := process.Close(logger, 5*time.Second); err != nil { + logger.Error("Failed to close process cleanly", "signal", sig, "error", err) + } + + // Exit cleanly + os.Exit(0) +} diff --git a/cmd/server/signals_unix.go b/cmd/server/signals_unix.go new file mode 100644 index 0000000..fe6b469 --- /dev/null +++ b/cmd/server/signals_unix.go @@ -0,0 +1,46 @@ +//go:build unix + +package server + +import ( + "context" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/coder/agentapi/lib/httpapi" + "github.com/coder/agentapi/lib/termexec" +) + +// handleSignals sets up signal handlers for: +// - SIGTERM, SIGINT, SIGHUP: save conversation state, stop server, then close the process +// - SIGUSR1: save conversation state without exiting +func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { + // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT) + go func() { + defer signal.Stop(shutdownCh) + sig := <-shutdownCh + performGracefulShutdown(sig, logger, srv, process) + }() + + // Handle SIGUSR1 for save without exit + saveOnlyCh := make(chan os.Signal, 1) + signal.Notify(saveOnlyCh, syscall.SIGUSR1) + go func() { + defer signal.Stop(saveOnlyCh) + for { + select { + case <-saveOnlyCh: + logger.Info("Received SIGUSR1, saving state without exiting") + if err := srv.SaveState("SIGUSR1"); err != nil { + logger.Error("Failed to save state on SIGUSR1", "error", err) + } + case <-ctx.Done(): + return + } + } + }() +} diff --git a/cmd/server/signals_windows.go b/cmd/server/signals_windows.go new file mode 100644 index 0000000..52d9061 --- /dev/null +++ b/cmd/server/signals_windows.go @@ -0,0 +1,27 @@ +//go:build windows + +package server + +import ( + "context" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/coder/agentapi/lib/httpapi" + "github.com/coder/agentapi/lib/termexec" +) + +// handleSignals sets up signal handlers for Windows. +// Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows). +func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { + // Handle shutdown signals (SIGTERM, SIGINT only on Windows) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) + go func() { + defer signal.Stop(shutdownCh) + sig := <-shutdownCh + performGracefulShutdown(sig, logger, srv, process) + }() +} diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 906a3a4..dac1549 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -137,7 +137,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) { } } -// Assumes that only the last message can change or new messages can be added. +// UpdateMessagesAndEmitChanges assumes that only the last message can change or new messages can be added. // If a new message is injected between existing messages (identified by Id), the behavior is undefined. func (e *EventEmitter) EmitMessages(newMessages []st.ConversationMessage) { e.mu.Lock() diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 956cfb8..29f0dce 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -40,6 +40,7 @@ type Server struct { port int srv *http.Server mu sync.RWMutex + stopOnce sync.Once logger *slog.Logger conversation st.Conversation agentio *termexec.Process @@ -97,14 +98,15 @@ func (s *Server) GetOpenAPI() string { const snapshotInterval = 25 * time.Millisecond type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string - AllowedHosts []string - AllowedOrigins []string - InitialPrompt string - Clock quartz.Clock + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string + InitialPrompt string + Clock quartz.Clock + StatePersistenceConfig st.StatePersistenceConfig } // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. @@ -253,16 +255,17 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { } conversation := st.NewPTY(ctx, st.PTYConversationConfig{ - AgentType: config.AgentType, - AgentIO: config.Process, - Clock: config.Clock, - SnapshotInterval: snapshotInterval, - ScreenStabilityLength: 2 * time.Second, - FormatMessage: formatMessage, - ReadyForInitialPrompt: isAgentReadyForInitialPrompt, - FormatToolCall: formatToolCall, - InitialPrompt: initialPrompt, - Logger: logger, + AgentType: config.AgentType, + AgentIO: config.Process, + Clock: config.Clock, + SnapshotInterval: snapshotInterval, + ScreenStabilityLength: 2 * time.Second, + FormatMessage: formatMessage, + ReadyForInitialPrompt: isAgentReadyForInitialPrompt, + FormatToolCall: formatToolCall, + InitialPrompt: initialPrompt, + Logger: logger, + StatePersistenceConfig: config.StatePersistenceConfig, }, emitter) // Create temporary directory for uploads @@ -588,15 +591,18 @@ func (s *Server) Start() error { return s.srv.ListenAndServe() } -// Stop gracefully stops the HTTP server +// Stop gracefully stops the HTTP server. It is safe to call multiple times. func (s *Server) Stop(ctx context.Context) error { - // Clean up temporary directory - s.cleanupTempDir() + var err error + s.stopOnce.Do(func() { + // Clean up temporary directory + s.cleanupTempDir() - if s.srv != nil { - return s.srv.Shutdown(ctx) - } - return nil + if s.srv != nil { + err = s.srv.Shutdown(ctx) + } + }) + return err } // cleanupTempDir removes the temporary directory and all its contents @@ -608,6 +614,14 @@ func (s *Server) cleanupTempDir() { } } +func (s *Server) SaveState(source string) error { + if err := s.conversation.SaveState(); err != nil { + s.logger.Error("Failed to save conversation state", "source", source, "error", err) + return err + } + return nil +} + // registerStaticFileRoutes sets up routes for serving static files func (s *Server) registerStaticFileRoutes() { chatHandler := FileServerWithIndexFallback(s.chatBasePath) diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index c8e8b23..82fc671 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -13,6 +13,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/logctx" @@ -956,3 +957,36 @@ func TestServer_UploadFiles_Errors(t *testing.T) { require.Contains(t, string(body), "file size exceeds 10MB limit") }) } + +func TestServer_Stop_Idempotency(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + + srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, + AllowedOrigins: []string{"*"}, + }) + require.NoError(t, err) + + // First call to Stop should succeed + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err = srv.Stop(stopCtx) + require.NoError(t, err) + + // Second call to Stop should also succeed (no-op) + stopCtx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + err = srv.Stop(stopCtx2) + require.NoError(t, err) + + // Third call to Stop should also succeed (no-op) + stopCtx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel3() + err = srv.Stop(stopCtx3) + require.NoError(t, err) +} diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 1620304..c8d95b6 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -4,10 +4,7 @@ import ( "context" "fmt" "os" - "os/signal" "strings" - "syscall" - "time" "github.com/coder/agentapi/lib/logctx" mf "github.com/coder/agentapi/lib/msgfmt" @@ -45,16 +42,5 @@ func SetupProcess(ctx context.Context, config SetupProcessConfig) (*termexec.Pro return nil, err } } - - // Handle SIGINT (Ctrl+C) and send it to the process - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) - go func() { - <-signalCh - if err := process.Close(logger, 5*time.Second); err != nil { - logger.Error("Error closing process", "error", err) - } - }() - return process, nil } diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 8299faa..f921d16 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -63,6 +63,7 @@ type Conversation interface { Start(context.Context) Status() ConversationStatus Text() string + SaveState() error } // Emitter receives conversation state updates. @@ -78,3 +79,9 @@ type ConversationMessage struct { Role ConversationRole Time time.Time } + +type StatePersistenceConfig struct { + StateFile string + LoadState bool + SaveState bool +} diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 2728377..e5b6feb 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -2,8 +2,11 @@ package screentracker import ( "context" + "encoding/json" "fmt" "log/slog" + "os" + "path/filepath" "strings" "sync" "time" @@ -26,6 +29,12 @@ type MessagePartText struct { Hidden bool } +type AgentState struct { + Version int `json:"version"` + Messages []ConversationMessage `json:"messages"` + InitialPrompt string `json:"initial_prompt"` +} + var _ MessagePart = &MessagePartText{} func (p MessagePartText) Do(writer AgentIO) error { @@ -67,8 +76,9 @@ type PTYConversationConfig struct { // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls FormatToolCall func(message string) (string, []string) // InitialPrompt is the initial prompt to send to the agent once ready - InitialPrompt []MessagePart - Logger *slog.Logger + InitialPrompt []MessagePart + Logger *slog.Logger + StatePersistenceConfig StatePersistenceConfig } func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { @@ -107,6 +117,14 @@ type PTYConversation struct { stableSignal chan struct{} // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message toolCallMessageSet map[string]bool + // dirty tracks whether the conversation state has changed since the last save + dirty bool + // firstStableSnapshot is the conversation history rolled out by the agent in case of a resume (given that the agent supports it) + firstStableSnapshot string + // userSentMessageAfterLoadState tracks if the user has sent their first message after we load the state + userSentMessageAfterLoadState bool + // loadStateSuccessful indicates whether conversation state was successfully restored from file. + loadStateSuccessful bool // initialPromptReady is set to true when ReadyForInitialPrompt returns true. // Checked inline in the snapshot loop on each tick. initialPromptReady bool @@ -140,9 +158,13 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PT Time: cfg.Clock.Now(), }, }, - outboundQueue: make(chan outboundMessage, 1), - stableSignal: make(chan struct{}, 1), - toolCallMessageSet: make(map[string]bool), + outboundQueue: make(chan outboundMessage, 1), + stableSignal: make(chan struct{}, 1), + toolCallMessageSet: make(map[string]bool), + dirty: false, + firstStableSnapshot: "", + userSentMessageAfterLoadState: false, + loadStateSuccessful: false, } // If we have an initial prompt, enqueue it if len(cfg.InitialPrompt) > 0 { @@ -169,6 +191,12 @@ func (c *PTYConversation) Start(ctx context.Context) { if !c.initialPromptReady && c.cfg.ReadyForInitialPrompt(screen) { c.initialPromptReady = true } + + if c.initialPromptReady && !c.loadStateSuccessful && c.cfg.StatePersistenceConfig.LoadState { + _ = c.loadState() + c.loadStateSuccessful = true + } + if c.initialPromptReady && len(c.outboundQueue) > 0 && c.isScreenStableLocked() { select { case c.stableSignal <- struct{}{}: @@ -245,6 +273,9 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp if c.cfg.FormatMessage != nil { agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) } + if c.loadStateSuccessful { + agentMessage = c.adjustScreenAfterStateLoad(agentMessage) + } if c.cfg.FormatToolCall != nil { agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) } @@ -274,6 +305,8 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp c.messages[len(c.messages)-1] = conversationMessage } c.messages[len(c.messages)-1].Id = len(c.messages) - 1 + + c.dirty = true } // caller MUST hold c.lock @@ -350,6 +383,8 @@ func (c *PTYConversation) sendMessage(ctx context.Context, messageParts ...Messa Role: ConversationRoleUser, Time: now, }) + c.userSentMessageAfterLoadState = true + c.lock.Unlock() return nil } @@ -497,3 +532,140 @@ func (c *PTYConversation) Text() string { } return snapshots[len(snapshots)-1].screen } + +func (c *PTYConversation) SaveState() error { + conversation := c.Messages() + + c.lock.Lock() + defer c.lock.Unlock() + + stateFile := c.cfg.StatePersistenceConfig.StateFile + saveState := c.cfg.StatePersistenceConfig.SaveState + + if !saveState { + c.cfg.Logger.Info("") + return nil + } + + // Skip if not dirty + if !c.dirty { + c.cfg.Logger.Info("Skipping state save: no changes since last save") + return nil + } + + // Serialize initial prompt from message parts + var initialPromptStr string + if len(c.cfg.InitialPrompt) > 0 { + var sb strings.Builder + for _, part := range c.cfg.InitialPrompt { + sb.WriteString(part.String()) + } + initialPromptStr = sb.String() + } + + // Use atomic write: write to temp file, then rename to target path + data, err := json.MarshalIndent(AgentState{ + Version: 1, + Messages: conversation, + InitialPrompt: initialPromptStr, + }, "", " ") + if err != nil { + return xerrors.Errorf("failed to marshal state: %w", err) + } + + // Create directory if it doesn't exist + dir := filepath.Dir(stateFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + return xerrors.Errorf("failed to create state directory: %w", err) + } + + // Write to temp file + tempFile := stateFile + ".tmp" + if err := os.WriteFile(tempFile, data, 0o644); err != nil { + return xerrors.Errorf("failed to write temp state file: %w", err) + } + + // Atomic rename + if err := os.Rename(tempFile, stateFile); err != nil { + return xerrors.Errorf("failed to rename state file: %w", err) + } + + // Clear dirty flag after successful save + c.dirty = false + + c.cfg.Logger.Info(fmt.Sprintf("State saved successfully to: %s", stateFile)) + + return nil +} + +// LoadState loads the state, this method assumes that caller holds the Lock +func (c *PTYConversation) loadState() error { + stateFile := c.cfg.StatePersistenceConfig.StateFile + loadState := c.cfg.StatePersistenceConfig.LoadState + + if !loadState { + return nil + } + + // Check if file exists + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + c.cfg.Logger.Info("No previous state to load (file does not exist)", "path", stateFile) + return nil + } + + // Read state file + data, err := os.ReadFile(stateFile) + if err != nil { + c.cfg.Logger.Warn("Failed to load state file", "path", stateFile, "err", err) + return xerrors.Errorf("failed to read state file: %w", err) + } + + if len(data) == 0 { + c.cfg.Logger.Info("No previous state to load (file is empty)", "path", stateFile) + return nil + } + + var agentState AgentState + if err := json.Unmarshal(data, &agentState); err != nil { + c.cfg.Logger.Warn("Failed to load state file (corrupted or invalid JSON)", "path", stateFile, "err", err) + return xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) + } + + //c.cfg.initialPromptSent = agentState.InitialPromptSent + c.cfg.InitialPrompt = []MessagePart{MessagePartText{ + Content: agentState.InitialPrompt, + Alias: "", + Hidden: false, + }} + c.messages = agentState.Messages + + // Store the first stable snapshot for filtering later + snapshots := c.snapshotBuffer.GetAll() + if len(snapshots) > 0 { + c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") + } + + c.loadStateSuccessful = true + c.dirty = false + + c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) + return nil +} + +func (c *PTYConversation) adjustScreenAfterStateLoad(screen string) string { + + if c.firstStableSnapshot == "" { + return screen + } + + newScreen := strings.Replace(screen, c.firstStableSnapshot, "", 1) + + // Before the first user message after loading state, return the last message from the loaded state. + // This prevents computing incorrect diffs from the restored screen, as the agent's message should + // remain stable until the user continues the conversation. + if c.userSentMessageAfterLoadState == false { + newScreen = "\n" + c.messages[len(c.messages)-1].Message + } + + return newScreen +} diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index 19b4511..67ff139 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -2,9 +2,11 @@ package screentracker_test import ( "context" + "encoding/json" "fmt" "io" "log/slog" + "os" "sync" "testing" "time" @@ -446,6 +448,357 @@ func TestMessages(t *testing.T) { }) } +func TestStatePersistence(t *testing.T) { + t.Run("SaveState creates file with correct structure", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + // Create temp directory for state file + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Generate some conversation + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Save state + err := c.SaveState() + require.NoError(t, err) + + // Read and verify the saved file + data, err := os.ReadFile(stateFile) + require.NoError(t, err) + + var agentState st.AgentState + err = json.Unmarshal(data, &agentState) + require.NoError(t, err) + + assert.Equal(t, 1, agentState.Version) + assert.Equal(t, "test prompt", agentState.InitialPrompt) + assert.NotEmpty(t, agentState.Messages) + }) + + t.Run("SaveState skips when not configured", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + err := c.SaveState() + require.NoError(t, err) + + // File should not be created + _, err = os.Stat(stateFile) + assert.True(t, os.IsNotExist(err)) + }) + + t.Run("SaveState honors dirty flag", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Generate conversation and save + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + err := c.SaveState() + require.NoError(t, err) + + // Get file modification time + info1, err := os.Stat(stateFile) + require.NoError(t, err) + modTime1 := info1.ModTime() + + // Save again without changes - file should not be modified + err = c.SaveState() + require.NoError(t, err) + + info2, err := os.Stat(stateFile) + require.NoError(t, err) + modTime2 := info2.ModTime() + + // File modification time should be the same (dirty flag prevents save) + assert.Equal(t, modTime1, modTime2) + }) + + t.Run("SaveState creates directory if not exists", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/nested/deep/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + err := c.SaveState() + require.NoError(t, err) + + // Verify file and directory were created + _, err = os.Stat(stateFile) + assert.NoError(t, err) + }) + + t.Run("LoadState restores conversation from file", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with test data + testState := st.AgentState{ + Version: 1, + InitialPrompt: "restored prompt", + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message 1", Role: st.ConversationRoleAgent, Time: time.Now()}, + {Id: 1, Message: "user message 1", Role: st.ConversationRoleUser, Time: time.Now()}, + {Id: 2, Message: "agent message 2", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation with LoadState enabled + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Advance until agent is ready and state is loaded + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Verify messages were restored + messages := c.Messages() + assert.Len(t, messages, 3) + assert.Equal(t, "agent message 1", messages[0].Message) + assert.Equal(t, "user message 1", messages[1].Message) + // The last agent message may have adjustments from adjustScreenAfterStateLoad + assert.Contains(t, messages[2].Message, "agent message 2") + }) + + t.Run("LoadState handles missing file gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/nonexistent.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic or error + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) + + t.Run("LoadState handles empty file gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/empty.json" + + // Create empty file + err := os.WriteFile(stateFile, []byte(""), 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic or error + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) + + t.Run("LoadState handles corrupted JSON gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/corrupted.json" + + // Create corrupted JSON file + err := os.WriteFile(stateFile, []byte("{invalid json}"), 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic - logs warning and continues + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) +} + func TestInitialPromptReadiness(t *testing.T) { discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil))