From 78579bd52ed5db6358aa7aaa840e99e1112e003d Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 12 Feb 2026 18:31:09 +0000 Subject: [PATCH 01/17] Replace OnSnapshot callback with Emitter interface - Add Emitter interface to screentracker package - Remove OnSnapshot from PTYConversationConfig, accept Emitter in NewPTY - Rename EventEmitter methods: EmitMessages, EmitStatus, EmitScreen - Accept agentType at NewEventEmitter construction instead of per-call - Update server.go wiring, all tests pass --- lib/httpapi/events.go | 12 ++++++------ lib/httpapi/events_test.go | 19 +++++++++---------- lib/httpapi/server.go | 14 +++----------- lib/screentracker/conversation.go | 7 +++++++ lib/screentracker/pty_conversation.go | 15 +++++++-------- lib/screentracker/pty_conversation_test.go | 20 +++++++++++++------- 6 files changed, 45 insertions(+), 42 deletions(-) diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 73eff07b..b69c53a4 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -86,11 +86,12 @@ func convertStatus(status st.ConversationStatus) AgentStatus { // Listeners must actively drain the channel, so it's important to // set this to a value that is large enough to handle the expected // number of events. -func NewEventEmitter(subscriptionBufSize int) *EventEmitter { +func NewEventEmitter(subscriptionBufSize int, agentType mf.AgentType) *EventEmitter { return &EventEmitter{ mu: sync.Mutex{}, messages: make([]st.ConversationMessage, 0), status: AgentStatusRunning, + agentType: agentType, chans: make(map[int]chan Event), chanIdx: 0, subscriptionBufSize: subscriptionBufSize, @@ -122,7 +123,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) { // Assumes that only the last message can change or new messages can be added. // If a new message is injected between existing messages (identified by Id), the behavior is undefined. -func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.ConversationMessage) { +func (e *EventEmitter) EmitMessages(newMessages []st.ConversationMessage) { e.mu.Lock() defer e.mu.Unlock() @@ -149,7 +150,7 @@ func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.Conversatio e.messages = newMessages } -func (e *EventEmitter) UpdateStatusAndEmitChanges(newStatus st.ConversationStatus, agentType mf.AgentType) { +func (e *EventEmitter) EmitStatus(newStatus st.ConversationStatus) { e.mu.Lock() defer e.mu.Unlock() @@ -158,12 +159,11 @@ func (e *EventEmitter) UpdateStatusAndEmitChanges(newStatus st.ConversationStatu return } - e.notifyChannels(EventTypeStatusChange, StatusChangeBody{Status: newAgentStatus, AgentType: agentType}) + e.notifyChannels(EventTypeStatusChange, StatusChangeBody{Status: newAgentStatus, AgentType: e.agentType}) e.status = newAgentStatus - e.agentType = agentType } -func (e *EventEmitter) UpdateScreenAndEmitChanges(newScreen string) { +func (e *EventEmitter) EmitScreen(newScreen string) { e.mu.Lock() defer e.mu.Unlock() diff --git a/lib/httpapi/events_test.go b/lib/httpapi/events_test.go index 46ccea56..ac0d4548 100644 --- a/lib/httpapi/events_test.go +++ b/lib/httpapi/events_test.go @@ -5,14 +5,13 @@ import ( "testing" "time" - mf "github.com/coder/agentapi/lib/msgfmt" st "github.com/coder/agentapi/lib/screentracker" "github.com/stretchr/testify/assert" ) func TestEventEmitter(t *testing.T) { t.Run("single-subscription", func(t *testing.T) { - emitter := NewEventEmitter(10) + emitter := NewEventEmitter(10, "") _, ch, stateEvents := emitter.Subscribe() assert.Empty(t, ch) assert.Equal(t, []Event{ @@ -27,7 +26,7 @@ func TestEventEmitter(t *testing.T) { }, stateEvents) now := time.Now() - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now}, }) newEvent := <-ch @@ -36,7 +35,7 @@ func TestEventEmitter(t *testing.T) { Payload: MessageUpdateBody{Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now}, }, newEvent) - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: 1, Message: "Hello, world! (updated)", Role: st.ConversationRoleUser, Time: now}, {Id: 2, Message: "What's up?", Role: st.ConversationRoleAgent, Time: now}, }) @@ -52,16 +51,16 @@ func TestEventEmitter(t *testing.T) { Payload: MessageUpdateBody{Id: 2, Message: "What's up?", Role: st.ConversationRoleAgent, Time: now}, }, newEvent) - emitter.UpdateStatusAndEmitChanges(st.ConversationStatusStable, mf.AgentTypeAider) + emitter.EmitStatus(st.ConversationStatusStable) newEvent = <-ch assert.Equal(t, Event{ Type: EventTypeStatusChange, - Payload: StatusChangeBody{Status: AgentStatusStable, AgentType: mf.AgentTypeAider}, + Payload: StatusChangeBody{Status: AgentStatusStable, AgentType: ""}, }, newEvent) }) t.Run("multiple-subscriptions", func(t *testing.T) { - emitter := NewEventEmitter(10) + emitter := NewEventEmitter(10, "") channels := make([]<-chan Event, 0, 10) for i := 0; i < 10; i++ { _, ch, _ := emitter.Subscribe() @@ -69,7 +68,7 @@ func TestEventEmitter(t *testing.T) { } now := time.Now() - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now}, }) for _, ch := range channels { @@ -82,10 +81,10 @@ func TestEventEmitter(t *testing.T) { }) t.Run("close-channel", func(t *testing.T) { - emitter := NewEventEmitter(1) + emitter := NewEventEmitter(1, "") _, ch, _ := emitter.Subscribe() for i := range 5 { - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: i, Message: fmt.Sprintf("Hello, world! %d", i), Role: st.ConversationRoleUser, Time: time.Now()}, }) } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index e43315bf..775fdcf8 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -244,7 +244,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { return mf.FormatToolCall(config.AgentType, message) } - emitter := NewEventEmitter(1024) + emitter := NewEventEmitter(1024, config.AgentType) // Format initial prompt into message parts if provided var initialPrompt []st.MessagePart @@ -262,16 +262,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { ReadyForInitialPrompt: isAgentReadyForInitialPrompt, FormatToolCall: formatToolCall, InitialPrompt: initialPrompt, - // OnSnapshot uses a callback rather than passing the emitter directly - // to keep the screentracker package decoupled from httpapi concerns. - // This preserves clean package boundaries and avoids import cycles. - OnSnapshot: func(status st.ConversationStatus, messages []st.ConversationMessage, screen string) { - emitter.UpdateStatusAndEmitChanges(status, config.AgentType) - emitter.UpdateMessagesAndEmitChanges(messages) - emitter.UpdateScreenAndEmitChanges(screen) - }, - Logger: logger, - }) + Logger: logger, + }, emitter) // Create temporary directory for uploads tempDir, err := os.MkdirTemp("", "agentapi-uploads-") diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 9e6b856f..8299faa1 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -65,6 +65,13 @@ type Conversation interface { Text() string } +// Emitter receives conversation state updates. +type Emitter interface { + EmitMessages([]ConversationMessage) + EmitStatus(ConversationStatus) + EmitScreen(string) +} + type ConversationMessage struct { Id int Message string diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index ff0c7eed..086e1254 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -68,8 +68,6 @@ type PTYConversationConfig struct { FormatToolCall func(message string) (string, []string) // InitialPrompt is the initial prompt to send to the agent once ready InitialPrompt []MessagePart - // OnSnapshot is called after each snapshot with current status, messages, and screen content - OnSnapshot func(status ConversationStatus, messages []ConversationMessage, screen string) Logger *slog.Logger } @@ -86,7 +84,8 @@ func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { // PTYConversation is a conversation that uses a pseudo-terminal (PTY) for communication. // It uses a combination of polling and diffs to detect changes in the screen. type PTYConversation struct { - cfg PTYConversationConfig + cfg PTYConversationConfig + emitter Emitter // How many stable snapshots are required to consider the screen stable stableSnapshotsThreshold int snapshotBuffer *RingBuffer[screenSnapshot] @@ -115,13 +114,14 @@ type PTYConversation struct { var _ Conversation = &PTYConversation{} -func NewPTY(ctx context.Context, cfg PTYConversationConfig) *PTYConversation { +func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PTYConversation { if cfg.Clock == nil { cfg.Clock = quartz.NewReal() } threshold := cfg.getStableSnapshotsThreshold() c := &PTYConversation{ cfg: cfg, + emitter: emitter, stableSnapshotsThreshold: threshold, snapshotBuffer: NewRingBuffer[screenSnapshot](threshold), messages: []ConversationMessage{ @@ -139,9 +139,6 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig) *PTYConversation { if len(cfg.InitialPrompt) > 0 { c.outboundQueue <- outboundMessage{parts: cfg.InitialPrompt, errCh: nil} } - if c.cfg.OnSnapshot == nil { - c.cfg.OnSnapshot = func(ConversationStatus, []ConversationMessage, string) {} - } if c.cfg.ReadyForInitialPrompt == nil { c.cfg.ReadyForInitialPrompt = func(string) bool { return true } } @@ -173,7 +170,9 @@ func (c *PTYConversation) Start(ctx context.Context) { } c.lock.Unlock() - c.cfg.OnSnapshot(status, messages, screen) + c.emitter.EmitStatus(status) + c.emitter.EmitMessages(messages) + c.emitter.EmitScreen(screen) return nil }, "snapshot") diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index eaa4a69e..dc4fc464 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -49,6 +49,12 @@ func (a *testAgent) setScreen(s string) { a.screen = s } +type testEmitter struct{} + +func (testEmitter) EmitMessages([]st.ConversationMessage) {} +func (testEmitter) EmitStatus(st.ConversationStatus) {} +func (testEmitter) EmitScreen(string) {} + // advanceFor is a shorthand for advanceUntil with a time-based condition. func advanceFor(ctx context.Context, t *testing.T, mClock *quartz.Mock, total time.Duration) { t.Helper() @@ -125,7 +131,7 @@ func statusTest(t *testing.T, params statusTestParams) { params.cfg.AgentIO = agent params.cfg.Logger = slog.New(slog.NewTextHandler(io.Discard, nil)) - c := st.NewPTY(ctx, params.cfg) + c := st.NewPTY(ctx, params.cfg, &testEmitter{}) c.Start(ctx) assert.Equal(t, st.ConversationStatusInitializing, c.Status()) @@ -233,7 +239,7 @@ func TestMessages(t *testing.T) { agent = a } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) return c, agent, mClock @@ -460,7 +466,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Take a snapshot with "loading...". Threshold is 1 (stability 0 / interval 1s = 0 + 1 = 1). @@ -488,7 +494,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Agent not ready initially. @@ -524,7 +530,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Status is "changing" while waiting for readiness. @@ -564,7 +570,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) advanceFor(ctx, t, mClock, 1*time.Second) @@ -586,7 +592,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Fill buffer to reach stability with "ready" screen. From bc79eda0d227663c174bc9c15e609006e443aaed Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 12 Feb 2026 21:46:39 +0000 Subject: [PATCH 02/17] refactor: use functional options for NewEventEmitter --- lib/httpapi/events.go | 34 +++++++++++++++++++++++----------- lib/httpapi/events_test.go | 6 +++--- lib/httpapi/server.go | 2 +- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index b69c53a4..7b1e2d40 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -81,21 +81,33 @@ func convertStatus(status st.ConversationStatus) AgentStatus { } } -// subscriptionBufSize is the size of the buffer for each subscription. -// Once the buffer is full, the channel will be closed. -// Listeners must actively drain the channel, so it's important to -// set this to a value that is large enough to handle the expected -// number of events. -func NewEventEmitter(subscriptionBufSize int, agentType mf.AgentType) *EventEmitter { - return &EventEmitter{ - mu: sync.Mutex{}, +const defaultSubscriptionBufSize = 1024 + +type EventEmitterOption func(*EventEmitter) + +func WithSubscriptionBufSize(size int) EventEmitterOption { + return func(e *EventEmitter) { + e.subscriptionBufSize = size + } +} + +func WithAgentType(agentType mf.AgentType) EventEmitterOption { + return func(e *EventEmitter) { + e.agentType = agentType + } +} + +func NewEventEmitter(opts ...EventEmitterOption) *EventEmitter { + e := &EventEmitter{ messages: make([]st.ConversationMessage, 0), status: AgentStatusRunning, - agentType: agentType, chans: make(map[int]chan Event), - chanIdx: 0, - subscriptionBufSize: subscriptionBufSize, + subscriptionBufSize: defaultSubscriptionBufSize, + } + for _, opt := range opts { + opt(e) } + return e } // Assumes the caller holds the lock. diff --git a/lib/httpapi/events_test.go b/lib/httpapi/events_test.go index ac0d4548..a1d024c4 100644 --- a/lib/httpapi/events_test.go +++ b/lib/httpapi/events_test.go @@ -11,7 +11,7 @@ import ( func TestEventEmitter(t *testing.T) { t.Run("single-subscription", func(t *testing.T) { - emitter := NewEventEmitter(10, "") + emitter := NewEventEmitter(WithSubscriptionBufSize(10)) _, ch, stateEvents := emitter.Subscribe() assert.Empty(t, ch) assert.Equal(t, []Event{ @@ -60,7 +60,7 @@ func TestEventEmitter(t *testing.T) { }) t.Run("multiple-subscriptions", func(t *testing.T) { - emitter := NewEventEmitter(10, "") + emitter := NewEventEmitter(WithSubscriptionBufSize(10)) channels := make([]<-chan Event, 0, 10) for i := 0; i < 10; i++ { _, ch, _ := emitter.Subscribe() @@ -81,7 +81,7 @@ func TestEventEmitter(t *testing.T) { }) t.Run("close-channel", func(t *testing.T) { - emitter := NewEventEmitter(1, "") + emitter := NewEventEmitter(WithSubscriptionBufSize(1)) _, ch, _ := emitter.Subscribe() for i := range 5 { emitter.EmitMessages([]st.ConversationMessage{ diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 775fdcf8..956cfb8a 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -244,7 +244,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { return mf.FormatToolCall(config.AgentType, message) } - emitter := NewEventEmitter(1024, config.AgentType) + emitter := NewEventEmitter(WithAgentType(config.AgentType)) // Format initial prompt into message parts if provided var initialPrompt []st.MessagePart From 0fa554d044ee1a7e86645726a6c5755b8ba867cb Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 12 Feb 2026 18:31:09 +0000 Subject: [PATCH 03/17] Replace OnSnapshot callback with Emitter interface - Add Emitter interface to screentracker package - Remove OnSnapshot from PTYConversationConfig, accept Emitter in NewPTY - Rename EventEmitter methods: EmitMessages, EmitStatus, EmitScreen - Accept agentType at NewEventEmitter construction instead of per-call - Update server.go wiring, all tests pass --- lib/httpapi/events.go | 42 ++++++++++++++-------- lib/httpapi/events_test.go | 19 +++++----- lib/httpapi/server.go | 14 ++------ lib/screentracker/conversation.go | 7 ++++ lib/screentracker/pty_conversation.go | 15 ++++---- lib/screentracker/pty_conversation_test.go | 20 +++++++---- 6 files changed, 66 insertions(+), 51 deletions(-) diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 73eff07b..7b1e2d40 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -81,20 +81,33 @@ func convertStatus(status st.ConversationStatus) AgentStatus { } } -// subscriptionBufSize is the size of the buffer for each subscription. -// Once the buffer is full, the channel will be closed. -// Listeners must actively drain the channel, so it's important to -// set this to a value that is large enough to handle the expected -// number of events. -func NewEventEmitter(subscriptionBufSize int) *EventEmitter { - return &EventEmitter{ - mu: sync.Mutex{}, +const defaultSubscriptionBufSize = 1024 + +type EventEmitterOption func(*EventEmitter) + +func WithSubscriptionBufSize(size int) EventEmitterOption { + return func(e *EventEmitter) { + e.subscriptionBufSize = size + } +} + +func WithAgentType(agentType mf.AgentType) EventEmitterOption { + return func(e *EventEmitter) { + e.agentType = agentType + } +} + +func NewEventEmitter(opts ...EventEmitterOption) *EventEmitter { + e := &EventEmitter{ messages: make([]st.ConversationMessage, 0), status: AgentStatusRunning, chans: make(map[int]chan Event), - chanIdx: 0, - subscriptionBufSize: subscriptionBufSize, + subscriptionBufSize: defaultSubscriptionBufSize, + } + for _, opt := range opts { + opt(e) } + return e } // Assumes the caller holds the lock. @@ -122,7 +135,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) { // Assumes that only the last message can change or new messages can be added. // If a new message is injected between existing messages (identified by Id), the behavior is undefined. -func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.ConversationMessage) { +func (e *EventEmitter) EmitMessages(newMessages []st.ConversationMessage) { e.mu.Lock() defer e.mu.Unlock() @@ -149,7 +162,7 @@ func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.Conversatio e.messages = newMessages } -func (e *EventEmitter) UpdateStatusAndEmitChanges(newStatus st.ConversationStatus, agentType mf.AgentType) { +func (e *EventEmitter) EmitStatus(newStatus st.ConversationStatus) { e.mu.Lock() defer e.mu.Unlock() @@ -158,12 +171,11 @@ func (e *EventEmitter) UpdateStatusAndEmitChanges(newStatus st.ConversationStatu return } - e.notifyChannels(EventTypeStatusChange, StatusChangeBody{Status: newAgentStatus, AgentType: agentType}) + e.notifyChannels(EventTypeStatusChange, StatusChangeBody{Status: newAgentStatus, AgentType: e.agentType}) e.status = newAgentStatus - e.agentType = agentType } -func (e *EventEmitter) UpdateScreenAndEmitChanges(newScreen string) { +func (e *EventEmitter) EmitScreen(newScreen string) { e.mu.Lock() defer e.mu.Unlock() diff --git a/lib/httpapi/events_test.go b/lib/httpapi/events_test.go index 46ccea56..a1d024c4 100644 --- a/lib/httpapi/events_test.go +++ b/lib/httpapi/events_test.go @@ -5,14 +5,13 @@ import ( "testing" "time" - mf "github.com/coder/agentapi/lib/msgfmt" st "github.com/coder/agentapi/lib/screentracker" "github.com/stretchr/testify/assert" ) func TestEventEmitter(t *testing.T) { t.Run("single-subscription", func(t *testing.T) { - emitter := NewEventEmitter(10) + emitter := NewEventEmitter(WithSubscriptionBufSize(10)) _, ch, stateEvents := emitter.Subscribe() assert.Empty(t, ch) assert.Equal(t, []Event{ @@ -27,7 +26,7 @@ func TestEventEmitter(t *testing.T) { }, stateEvents) now := time.Now() - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now}, }) newEvent := <-ch @@ -36,7 +35,7 @@ func TestEventEmitter(t *testing.T) { Payload: MessageUpdateBody{Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now}, }, newEvent) - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: 1, Message: "Hello, world! (updated)", Role: st.ConversationRoleUser, Time: now}, {Id: 2, Message: "What's up?", Role: st.ConversationRoleAgent, Time: now}, }) @@ -52,16 +51,16 @@ func TestEventEmitter(t *testing.T) { Payload: MessageUpdateBody{Id: 2, Message: "What's up?", Role: st.ConversationRoleAgent, Time: now}, }, newEvent) - emitter.UpdateStatusAndEmitChanges(st.ConversationStatusStable, mf.AgentTypeAider) + emitter.EmitStatus(st.ConversationStatusStable) newEvent = <-ch assert.Equal(t, Event{ Type: EventTypeStatusChange, - Payload: StatusChangeBody{Status: AgentStatusStable, AgentType: mf.AgentTypeAider}, + Payload: StatusChangeBody{Status: AgentStatusStable, AgentType: ""}, }, newEvent) }) t.Run("multiple-subscriptions", func(t *testing.T) { - emitter := NewEventEmitter(10) + emitter := NewEventEmitter(WithSubscriptionBufSize(10)) channels := make([]<-chan Event, 0, 10) for i := 0; i < 10; i++ { _, ch, _ := emitter.Subscribe() @@ -69,7 +68,7 @@ func TestEventEmitter(t *testing.T) { } now := time.Now() - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now}, }) for _, ch := range channels { @@ -82,10 +81,10 @@ func TestEventEmitter(t *testing.T) { }) t.Run("close-channel", func(t *testing.T) { - emitter := NewEventEmitter(1) + emitter := NewEventEmitter(WithSubscriptionBufSize(1)) _, ch, _ := emitter.Subscribe() for i := range 5 { - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: i, Message: fmt.Sprintf("Hello, world! %d", i), Role: st.ConversationRoleUser, Time: time.Now()}, }) } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index e43315bf..956cfb8a 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -244,7 +244,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { return mf.FormatToolCall(config.AgentType, message) } - emitter := NewEventEmitter(1024) + emitter := NewEventEmitter(WithAgentType(config.AgentType)) // Format initial prompt into message parts if provided var initialPrompt []st.MessagePart @@ -262,16 +262,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { ReadyForInitialPrompt: isAgentReadyForInitialPrompt, FormatToolCall: formatToolCall, InitialPrompt: initialPrompt, - // OnSnapshot uses a callback rather than passing the emitter directly - // to keep the screentracker package decoupled from httpapi concerns. - // This preserves clean package boundaries and avoids import cycles. - OnSnapshot: func(status st.ConversationStatus, messages []st.ConversationMessage, screen string) { - emitter.UpdateStatusAndEmitChanges(status, config.AgentType) - emitter.UpdateMessagesAndEmitChanges(messages) - emitter.UpdateScreenAndEmitChanges(screen) - }, - Logger: logger, - }) + Logger: logger, + }, emitter) // Create temporary directory for uploads tempDir, err := os.MkdirTemp("", "agentapi-uploads-") diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 9e6b856f..8299faa1 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -65,6 +65,13 @@ type Conversation interface { Text() string } +// Emitter receives conversation state updates. +type Emitter interface { + EmitMessages([]ConversationMessage) + EmitStatus(ConversationStatus) + EmitScreen(string) +} + type ConversationMessage struct { Id int Message string diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index ff0c7eed..086e1254 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -68,8 +68,6 @@ type PTYConversationConfig struct { FormatToolCall func(message string) (string, []string) // InitialPrompt is the initial prompt to send to the agent once ready InitialPrompt []MessagePart - // OnSnapshot is called after each snapshot with current status, messages, and screen content - OnSnapshot func(status ConversationStatus, messages []ConversationMessage, screen string) Logger *slog.Logger } @@ -86,7 +84,8 @@ func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { // PTYConversation is a conversation that uses a pseudo-terminal (PTY) for communication. // It uses a combination of polling and diffs to detect changes in the screen. type PTYConversation struct { - cfg PTYConversationConfig + cfg PTYConversationConfig + emitter Emitter // How many stable snapshots are required to consider the screen stable stableSnapshotsThreshold int snapshotBuffer *RingBuffer[screenSnapshot] @@ -115,13 +114,14 @@ type PTYConversation struct { var _ Conversation = &PTYConversation{} -func NewPTY(ctx context.Context, cfg PTYConversationConfig) *PTYConversation { +func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PTYConversation { if cfg.Clock == nil { cfg.Clock = quartz.NewReal() } threshold := cfg.getStableSnapshotsThreshold() c := &PTYConversation{ cfg: cfg, + emitter: emitter, stableSnapshotsThreshold: threshold, snapshotBuffer: NewRingBuffer[screenSnapshot](threshold), messages: []ConversationMessage{ @@ -139,9 +139,6 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig) *PTYConversation { if len(cfg.InitialPrompt) > 0 { c.outboundQueue <- outboundMessage{parts: cfg.InitialPrompt, errCh: nil} } - if c.cfg.OnSnapshot == nil { - c.cfg.OnSnapshot = func(ConversationStatus, []ConversationMessage, string) {} - } if c.cfg.ReadyForInitialPrompt == nil { c.cfg.ReadyForInitialPrompt = func(string) bool { return true } } @@ -173,7 +170,9 @@ func (c *PTYConversation) Start(ctx context.Context) { } c.lock.Unlock() - c.cfg.OnSnapshot(status, messages, screen) + c.emitter.EmitStatus(status) + c.emitter.EmitMessages(messages) + c.emitter.EmitScreen(screen) return nil }, "snapshot") diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index eaa4a69e..dc4fc464 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -49,6 +49,12 @@ func (a *testAgent) setScreen(s string) { a.screen = s } +type testEmitter struct{} + +func (testEmitter) EmitMessages([]st.ConversationMessage) {} +func (testEmitter) EmitStatus(st.ConversationStatus) {} +func (testEmitter) EmitScreen(string) {} + // advanceFor is a shorthand for advanceUntil with a time-based condition. func advanceFor(ctx context.Context, t *testing.T, mClock *quartz.Mock, total time.Duration) { t.Helper() @@ -125,7 +131,7 @@ func statusTest(t *testing.T, params statusTestParams) { params.cfg.AgentIO = agent params.cfg.Logger = slog.New(slog.NewTextHandler(io.Discard, nil)) - c := st.NewPTY(ctx, params.cfg) + c := st.NewPTY(ctx, params.cfg, &testEmitter{}) c.Start(ctx) assert.Equal(t, st.ConversationStatusInitializing, c.Status()) @@ -233,7 +239,7 @@ func TestMessages(t *testing.T) { agent = a } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) return c, agent, mClock @@ -460,7 +466,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Take a snapshot with "loading...". Threshold is 1 (stability 0 / interval 1s = 0 + 1 = 1). @@ -488,7 +494,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Agent not ready initially. @@ -524,7 +530,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Status is "changing" while waiting for readiness. @@ -564,7 +570,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) advanceFor(ctx, t, mClock, 1*time.Second) @@ -586,7 +592,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Fill buffer to reach stability with "ready" screen. From 1cd4dca3e30ed308af84ebd52e4301c0d0ac31bc Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 12 Feb 2026 22:27:31 +0000 Subject: [PATCH 04/17] address PR review comments --- lib/httpapi/events.go | 9 ++++- lib/screentracker/pty_conversation.go | 11 ++++++- lib/screentracker/pty_conversation_test.go | 38 +++++++++++----------- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 7b1e2d40..9d64f72f 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -87,7 +87,11 @@ type EventEmitterOption func(*EventEmitter) func WithSubscriptionBufSize(size int) EventEmitterOption { return func(e *EventEmitter) { - e.subscriptionBufSize = size + if size <= 0 { + e.subscriptionBufSize = defaultSubscriptionBufSize + } else { + e.subscriptionBufSize = size + } } } @@ -150,6 +154,9 @@ func (e *EventEmitter) EmitMessages(newMessages []st.ConversationMessage) { newMsg = newMessages[i] } if oldMsg != newMsg { + if i >= len(newMessages) { + continue + } e.notifyChannels(EventTypeMessageUpdate, MessageUpdateBody{ Id: newMessages[i].Id, Role: newMessages[i].Role, diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 086e1254..27283775 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -68,7 +68,7 @@ type PTYConversationConfig struct { FormatToolCall func(message string) (string, []string) // InitialPrompt is the initial prompt to send to the agent once ready InitialPrompt []MessagePart - Logger *slog.Logger + Logger *slog.Logger } func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { @@ -114,10 +114,19 @@ type PTYConversation struct { var _ Conversation = &PTYConversation{} +type noopEmitter struct{} + +func (noopEmitter) EmitMessages([]ConversationMessage) {} +func (noopEmitter) EmitStatus(ConversationStatus) {} +func (noopEmitter) EmitScreen(string) {} + func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PTYConversation { if cfg.Clock == nil { cfg.Clock = quartz.NewReal() } + if emitter == nil { + emitter = noopEmitter{} + } threshold := cfg.getStableSnapshotsThreshold() c := &PTYConversation{ cfg: cfg, diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index dc4fc464..19b4511b 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -20,8 +20,8 @@ const testTimeout = 10 * time.Second // testAgent is a goroutine-safe mock implementation of AgentIO. type testAgent struct { - mu sync.Mutex - screen string + mu sync.Mutex + screen string // onWrite is called during Write to simulate the agent reacting to // terminal input (e.g., changing the screen), which unblocks // writeStabilize's polling loops. @@ -53,7 +53,7 @@ type testEmitter struct{} func (testEmitter) EmitMessages([]st.ConversationMessage) {} func (testEmitter) EmitStatus(st.ConversationStatus) {} -func (testEmitter) EmitScreen(string) {} +func (testEmitter) EmitScreen(string) {} // advanceFor is a shorthand for advanceUntil with a time-based condition. func advanceFor(ctx context.Context, t *testing.T, mClock *quartz.Mock, total time.Duration) { @@ -226,11 +226,11 @@ func TestMessages(t *testing.T) { mClock := quartz.NewMock(t) mClock.Set(now) cfg := st.PTYConversationConfig{ - Clock: mClock, - AgentIO: agent, - SnapshotInterval: 100 * time.Millisecond, - ScreenStabilityLength: 200 * time.Millisecond, - Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + Clock: mClock, + AgentIO: agent, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } for _, opt := range opts { opt(&cfg) @@ -519,15 +519,15 @@ func TestInitialPromptReadiness(t *testing.T) { agent.screen = fmt.Sprintf("__write_%d", writeCounter) } cfg := st.PTYConversationConfig{ - Clock: mClock, - SnapshotInterval: 1 * time.Second, - ScreenStabilityLength: 0, - AgentIO: agent, + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, ReadyForInitialPrompt: func(message string) bool { return message == "ready" }, - InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "initial prompt here"}}, - Logger: discardLogger, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "initial prompt here"}}, + Logger: discardLogger, } c := st.NewPTY(ctx, cfg, &testEmitter{}) @@ -585,11 +585,11 @@ func TestInitialPromptReadiness(t *testing.T) { mClock := quartz.NewMock(t) agent := &testAgent{screen: "ready"} cfg := st.PTYConversationConfig{ - Clock: mClock, - SnapshotInterval: 1 * time.Second, - ScreenStabilityLength: 2 * time.Second, // threshold = 3 - AgentIO: agent, - Logger: discardLogger, + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 2 * time.Second, // threshold = 3 + AgentIO: agent, + Logger: discardLogger, } c := st.NewPTY(ctx, cfg, &testEmitter{}) From a116d6311cd8c3c165c3be7fa6416279aa025e8b Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 12 Feb 2026 22:45:56 +0000 Subject: [PATCH 05/17] use uint for subscription buffer size --- lib/httpapi/events.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 9d64f72f..906a3a42 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -64,7 +64,7 @@ type EventEmitter struct { agentType mf.AgentType chans map[int]chan Event chanIdx int - subscriptionBufSize int + subscriptionBufSize uint screen string } @@ -81,13 +81,13 @@ func convertStatus(status st.ConversationStatus) AgentStatus { } } -const defaultSubscriptionBufSize = 1024 +const defaultSubscriptionBufSize uint = 1024 type EventEmitterOption func(*EventEmitter) -func WithSubscriptionBufSize(size int) EventEmitterOption { +func WithSubscriptionBufSize(size uint) EventEmitterOption { return func(e *EventEmitter) { - if size <= 0 { + if size == 0 { e.subscriptionBufSize = defaultSubscriptionBufSize } else { e.subscriptionBufSize = size From 8c8da69b20d7687f938e2d4e1e9f4e4012d2408d Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Wed, 4 Feb 2026 17:21:21 +0000 Subject: [PATCH 06/17] feat: add experimental ACP mode (--experimental-acp) Add support for Agent Control Protocol (ACP) as an alternative to terminal emulation. ACP uses JSON-RPC over stdin/stdout pipes. - Introduce AgentIO interface to abstract PTY vs ACP transports - Add ACPConversation implementing Conversation interface - Add --experimental-acp flag (mutually exclusive with --print-openapi) - Add e2e test with mock ACP agent - Block `attach` when using --experimental-acp (no terminal) - Update chat UI to show ACP tool calls Other changes: - chat: Fix redundant draft filtering from finally block Created using Mux (Opus 4.5) --- chat/src/components/chat-provider.tsx | 3 - cmd/attach/attach.go | 39 +++ cmd/server/server.go | 93 ++++-- e2e/acp_echo.go | 145 ++++++++++ e2e/echo_test.go | 28 ++ e2e/testdata/acp_basic.json | 6 + go.mod | 1 + go.sum | 2 + lib/acp/doc.go | 8 + lib/httpapi/server.go | 64 +++-- lib/httpapi/server_test.go | 18 +- lib/httpapi/setup.go | 55 ++++ openapi.json | 4 + x/acpio/acp_conversation.go | 230 +++++++++++++++ x/acpio/acp_conversation_test.go | 397 ++++++++++++++++++++++++++ x/acpio/acpio.go | 221 ++++++++++++++ 16 files changed, 1255 insertions(+), 59 deletions(-) create mode 100644 e2e/acp_echo.go create mode 100644 e2e/testdata/acp_basic.json create mode 100644 lib/acp/doc.go create mode 100644 x/acpio/acp_conversation.go create mode 100644 x/acpio/acp_conversation_test.go create mode 100644 x/acpio/acpio.go diff --git a/chat/src/components/chat-provider.tsx b/chat/src/components/chat-provider.tsx index 21a2ee3f..3cbef3ab 100644 --- a/chat/src/components/chat-provider.tsx +++ b/chat/src/components/chat-provider.tsx @@ -304,9 +304,6 @@ export function ChatProvider({ children }: PropsWithChildren) { }); } finally { if (type === "user") { - setMessages((prevMessages) => - prevMessages.filter((m) => !isDraftMessage(m)) - ); setLoading(false); } } diff --git a/cmd/attach/attach.go b/cmd/attach/attach.go index 17516398..512d70f4 100644 --- a/cmd/attach/attach.go +++ b/cmd/attach/attach.go @@ -129,7 +129,46 @@ func WriteRawInputOverHTTP(ctx context.Context, url string, msg string) error { return nil } + +// statusResponse is used to parse the /status endpoint response. +// The ACPMode field may not be present on older servers. +type statusResponse struct { + Status string `json:"status"` + AgentType string `json:"agent_type"` + ACPMode bool `json:"acp_mode"` +} + +func checkACPMode(remoteUrl string) error { + resp, err := http.Get(remoteUrl + "/status") + if err != nil { + return xerrors.Errorf("failed to check server status: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + // Server doesn't support /status or had an error, continue anyway + return nil + } + + var status statusResponse + if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { + // Can't parse response, continue anyway + return nil + } + + if status.ACPMode { + return xerrors.New("attach is not supported in ACP mode. The server is running with --experimental-acp which uses JSON-RPC instead of terminal emulation.") + } + + return nil +} + func runAttach(remoteUrl string) error { + // Check if server is running in ACP mode (attach not supported) + if err := checkACPMode(remoteUrl); err != nil { + return err + } + ctx, cancel := context.WithCancel(context.Background()) defer cancel() stdin := int(os.Stdin.Fd()) diff --git a/cmd/server/server.go b/cmd/server/server.go index 6d5cdec3..f34ac72a 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -19,6 +19,7 @@ import ( "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/logctx" "github.com/coder/agentapi/lib/msgfmt" + st "github.com/coder/agentapi/lib/screentracker" "github.com/coder/agentapi/lib/termexec" ) @@ -104,11 +105,33 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } printOpenAPI := viper.GetBool(FlagPrintOpenAPI) + experimentalACP := viper.GetBool(FlagExperimentalACP) + + if printOpenAPI && experimentalACP { + return xerrors.Errorf("flags --%s and --%s are mutually exclusive", FlagPrintOpenAPI, FlagExperimentalACP) + } + + var agentIO st.AgentIO + var transport = "pty" var process *termexec.Process + var acpWait func() error + if printOpenAPI { - process = nil + agentIO = nil + } else if experimentalACP { + acpResult, err := httpapi.SetupACP(ctx, httpapi.SetupACPConfig{ + Program: agent, + ProgramArgs: argsToPass[1:], + }) + if err != nil { + return xerrors.Errorf("failed to setup ACP: %w", err) + } + acpIO := acpResult.AgentIO + acpWait = acpResult.Wait + agentIO = acpIO + transport = "acp" } else { - process, err = httpapi.SetupProcess(ctx, httpapi.SetupProcessConfig{ + proc, err := httpapi.SetupProcess(ctx, httpapi.SetupProcessConfig{ Program: agent, ProgramArgs: argsToPass[1:], TerminalWidth: termWidth, @@ -118,11 +141,14 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er if err != nil { return xerrors.Errorf("failed to setup process: %w", err) } + process = proc + agentIO = proc } port := viper.GetInt(FlagPort) srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: agentType, - Process: process, + AgentIO: agentIO, + Transport: transport, Port: port, ChatBasePath: viper.GetString(FlagChatBasePath), AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), @@ -138,19 +164,34 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } logger.Info("Starting server on port", "port", port) processExitCh := make(chan error, 1) - go func() { - defer close(processExitCh) - if err := process.Wait(); err != nil { - if errors.Is(err, termexec.ErrNonZeroExitCode) { - processExitCh <- xerrors.Errorf("========\n%s\n========\n: %w", strings.TrimSpace(process.ReadScreen()), err) - } else { - processExitCh <- xerrors.Errorf("failed to wait for process: %w", err) + // Wait for process exit in PTY mode + if process != nil { + go func() { + defer close(processExitCh) + if err := process.Wait(); err != nil { + if errors.Is(err, termexec.ErrNonZeroExitCode) { + processExitCh <- xerrors.Errorf("========\n%s\n========\n: %w", strings.TrimSpace(process.ReadScreen()), err) + } else { + processExitCh <- xerrors.Errorf("failed to wait for process: %w", err) + } } - } - if err := srv.Stop(ctx); err != nil { - logger.Error("Failed to stop server", "error", err) - } - }() + if err := srv.Stop(ctx); err != nil { + logger.Error("Failed to stop server", "error", err) + } + }() + } + // Wait for process exit in ACP mode + if acpWait != nil { + go func() { + defer close(processExitCh) + if err := acpWait(); err != nil { + processExitCh <- xerrors.Errorf("ACP process exited: %w", err) + } + if err := srv.Stop(ctx); err != nil { + logger.Error("Failed to stop server", "error", err) + } + }() + } if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed { return xerrors.Errorf("failed to start server: %w", err) } @@ -180,16 +221,17 @@ type flagSpec struct { } const ( - FlagType = "type" - FlagPort = "port" - FlagPrintOpenAPI = "print-openapi" - FlagChatBasePath = "chat-base-path" - FlagTermWidth = "term-width" - FlagTermHeight = "term-height" - FlagAllowedHosts = "allowed-hosts" - FlagAllowedOrigins = "allowed-origins" - FlagExit = "exit" - FlagInitialPrompt = "initial-prompt" + FlagType = "type" + FlagPort = "port" + FlagPrintOpenAPI = "print-openapi" + FlagChatBasePath = "chat-base-path" + FlagTermWidth = "term-width" + FlagTermHeight = "term-height" + FlagAllowedHosts = "allowed-hosts" + FlagAllowedOrigins = "allowed-origins" + FlagExit = "exit" + FlagInitialPrompt = "initial-prompt" + FlagExperimentalACP = "experimental-acp" ) func CreateServerCmd() *cobra.Command { @@ -228,6 +270,7 @@ func CreateServerCmd() *cobra.Command { // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, {FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"}, + {FlagExperimentalACP, "", false, "Use experimental ACP transport instead of PTY", "bool"}, } for _, spec := range flagSpecs { diff --git a/e2e/acp_echo.go b/e2e/acp_echo.go new file mode 100644 index 00000000..a7986e75 --- /dev/null +++ b/e2e/acp_echo.go @@ -0,0 +1,145 @@ +//go:build ignore + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/signal" + "strings" + + acp "github.com/coder/acp-go-sdk" +) + +// ScriptEntry defines a single entry in the test script. +type ScriptEntry struct { + ExpectMessage string `json:"expectMessage"` + ThinkDurationMS int64 `json:"thinkDurationMS"` + ResponseMessage string `json:"responseMessage"` +} + +// acpEchoAgent implements the ACP Agent interface for testing. +type acpEchoAgent struct { + script []ScriptEntry + scriptIndex int + conn *acp.AgentSideConnection + sessionID acp.SessionId +} + +var _ acp.Agent = (*acpEchoAgent)(nil) + +func main() { + if len(os.Args) != 2 { + fmt.Fprintln(os.Stderr, "Usage: acp_echo ") + os.Exit(1) + } + + script, err := loadScript(os.Args[1]) + if err != nil { + fmt.Fprintf(os.Stderr, "Error loading script: %v\n", err) + os.Exit(1) + } + + if len(script) == 0 { + fmt.Fprintln(os.Stderr, "Script is empty") + os.Exit(1) + } + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt) + go func() { + <-sigCh + os.Exit(0) + }() + + agent := &acpEchoAgent{ + script: script, + } + + conn := acp.NewAgentSideConnection(agent, os.Stdout, os.Stdin) + agent.conn = conn + + <-conn.Done() +} + +func (a *acpEchoAgent) Initialize(_ context.Context, _ acp.InitializeRequest) (acp.InitializeResponse, error) { + return acp.InitializeResponse{ + ProtocolVersion: acp.ProtocolVersionNumber, + AgentCapabilities: acp.AgentCapabilities{}, + }, nil +} + +func (a *acpEchoAgent) Authenticate(_ context.Context, _ acp.AuthenticateRequest) (acp.AuthenticateResponse, error) { + return acp.AuthenticateResponse{}, nil +} + +func (a *acpEchoAgent) Cancel(_ context.Context, _ acp.CancelNotification) error { + return nil +} + +func (a *acpEchoAgent) NewSession(_ context.Context, _ acp.NewSessionRequest) (acp.NewSessionResponse, error) { + a.sessionID = "test-session" + return acp.NewSessionResponse{ + SessionId: a.sessionID, + }, nil +} + +func (a *acpEchoAgent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.PromptResponse, error) { + // Extract text from prompt + var promptText string + for _, block := range params.Prompt { + if block.Text != nil { + promptText = block.Text.Text + break + } + } + promptText = strings.TrimSpace(promptText) + + if a.scriptIndex >= len(a.script) { + return acp.PromptResponse{ + StopReason: acp.StopReasonEndTurn, + }, nil + } + + entry := a.script[a.scriptIndex] + expected := strings.TrimSpace(entry.ExpectMessage) + + // Empty ExpectMessage matches any prompt + if expected != "" && expected != promptText { + return acp.PromptResponse{}, fmt.Errorf("expected message %q but got %q", expected, promptText) + } + + a.scriptIndex++ + + // Send response via session update + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ + SessionId: params.SessionId, + Update: acp.UpdateAgentMessageText(entry.ResponseMessage), + }); err != nil { + return acp.PromptResponse{}, err + } + + return acp.PromptResponse{ + StopReason: acp.StopReasonEndTurn, + }, nil +} + +func (a *acpEchoAgent) SetSessionMode(_ context.Context, _ acp.SetSessionModeRequest) (acp.SetSessionModeResponse, error) { + return acp.SetSessionModeResponse{}, nil +} + +func loadScript(scriptPath string) ([]ScriptEntry, error) { + data, err := os.ReadFile(scriptPath) + if err != nil { + return nil, fmt.Errorf("failed to read script file: %w", err) + } + + var script []ScriptEntry + if err := json.Unmarshal(data, &script); err != nil { + return nil, fmt.Errorf("failed to parse script JSON: %w", err) + } + + return script, nil +} diff --git a/e2e/echo_test.go b/e2e/echo_test.go index 765521cf..fd44d32a 100644 --- a/e2e/echo_test.go +++ b/e2e/echo_test.go @@ -100,6 +100,34 @@ func TestE2E(t *testing.T) { require.Equal(t, script[0].ExpectMessage, strings.TrimSpace(msgResp.Messages[1].Content)) require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[2].Content)) }) + + t.Run("acp_basic", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + script, apiClient := setup(ctx, t, ¶ms{ + cmdFn: func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) { + return binaryPath, []string{ + "server", + fmt.Sprintf("--port=%d", serverPort), + "--experimental-acp", + "--", "go", "run", filepath.Join(cwd, "acp_echo.go"), scriptFilePath, + } + }, + }) + messageReq := agentapisdk.PostMessageParams{ + Content: "This is a test message.", + Type: agentapisdk.MessageTypeUser, + } + _, err := apiClient.PostMessage(ctx, messageReq) + require.NoError(t, err, "Failed to send message via SDK") + require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "post message")) + msgResp, err := apiClient.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages via SDK") + require.Len(t, msgResp.Messages, 2) + require.Equal(t, script[0].ExpectMessage, strings.TrimSpace(msgResp.Messages[0].Content)) + require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[1].Content)) + }) } type params struct { diff --git a/e2e/testdata/acp_basic.json b/e2e/testdata/acp_basic.json new file mode 100644 index 00000000..22dd8d98 --- /dev/null +++ b/e2e/testdata/acp_basic.json @@ -0,0 +1,6 @@ +[ + { + "expectMessage": "This is a test message.", + "responseMessage": "Echo: This is a test message." + } +] diff --git a/go.mod b/go.mod index 35aee152..5c1a486f 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/ActiveState/termtest/xpty v0.6.0 github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d github.com/charmbracelet/bubbletea v1.3.4 + github.com/coder/acp-go-sdk v0.6.3 github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225 github.com/coder/quartz v0.1.2 github.com/danielgtaylor/huma/v2 v2.32.0 diff --git a/go.sum b/go.sum index a98ca56d..1d70abb9 100644 --- a/go.sum +++ b/go.sum @@ -163,6 +163,8 @@ github.com/ckaznocha/intrange v0.3.1 h1:j1onQyXvHUsPWujDH6WIjhyH26gkRt/txNlV7Lsp github.com/ckaznocha/intrange v0.3.1/go.mod h1:QVepyz1AkUoFQkpEqksSYpNpUo3c5W7nWh/s6SHIJJk= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/coder/acp-go-sdk v0.6.3 h1:LsXQytehdjKIYJnoVWON/nf7mqbiarnyuyE3rrjBsXQ= +github.com/coder/acp-go-sdk v0.6.3/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko= github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225 h1:tRIViZ5JRmzdOEo5wUWngaGEFBG8OaE1o2GIHN5ujJ8= github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225/go.mod h1:rNLVpYgEVeu1Zk29K64z6Od8RBP9DwqCu9OfCzh8MR4= github.com/coder/paralleltestctx v0.0.1 h1:eauyehej1XYTGwgzGWMTjeRIVgOpU6XLPNVb2oi6kDs= diff --git a/lib/acp/doc.go b/lib/acp/doc.go new file mode 100644 index 00000000..4a29389e --- /dev/null +++ b/lib/acp/doc.go @@ -0,0 +1,8 @@ +// Package acp provides Agent Control Protocol (ACP) support for agentapi. +package acp + +import ( + // Import ACP SDK for go.mod dependency tracking. + // This will be used by actual implementations. + _ "github.com/coder/acp-go-sdk" +) diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 956cfb8a..d1c8cafe 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -24,6 +24,7 @@ import ( mf "github.com/coder/agentapi/lib/msgfmt" st "github.com/coder/agentapi/lib/screentracker" "github.com/coder/agentapi/lib/termexec" + "github.com/coder/agentapi/x/acpio" "github.com/coder/quartz" "github.com/danielgtaylor/huma/v2" "github.com/danielgtaylor/huma/v2/adapters/humachi" @@ -42,12 +43,13 @@ type Server struct { mu sync.RWMutex logger *slog.Logger conversation st.Conversation - agentio *termexec.Process + agentio st.AgentIO agentType mf.AgentType emitter *EventEmitter chatBasePath string tempDir string clock quartz.Clock + transport string } func (s *Server) NormalizeSchema(schema any) any { @@ -98,7 +100,8 @@ const snapshotInterval = 25 * time.Millisecond type ServerConfig struct { AgentType mf.AgentType - Process *termexec.Process + AgentIO st.AgentIO + Transport string Port int ChatBasePath string AllowedHosts []string @@ -252,18 +255,34 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { initialPrompt = FormatMessage(config.AgentType, config.InitialPrompt) } - conversation := st.NewPTY(ctx, st.PTYConversationConfig{ - AgentType: config.AgentType, - AgentIO: config.Process, - Clock: config.Clock, - SnapshotInterval: snapshotInterval, - ScreenStabilityLength: 2 * time.Second, - FormatMessage: formatMessage, - ReadyForInitialPrompt: isAgentReadyForInitialPrompt, - FormatToolCall: formatToolCall, - InitialPrompt: initialPrompt, - Logger: logger, - }, emitter) + // Create appropriate conversation based on transport type + var conversation st.Conversation + if config.Transport == "acp" { + // For ACP, cast AgentIO to *acpio.ACPAgentIO + acpIO, ok := config.AgentIO.(*acpio.ACPAgentIO) + if !ok { + return nil, fmt.Errorf("ACP transport requires ACPAgentIO") + } + conversation = acpio.NewACPConversation(acpIO, logger, initialPrompt, emitter, config.Clock) + } else { + // Default to PTY transport + proc, ok := config.AgentIO.(*termexec.Process) + if !ok && config.AgentIO != nil { + return nil, fmt.Errorf("PTY transport requires termexec.Process") + } + conversation = st.NewPTY(ctx, st.PTYConversationConfig{ + AgentType: config.AgentType, + AgentIO: proc, + Clock: config.Clock, + SnapshotInterval: snapshotInterval, + ScreenStabilityLength: 2 * time.Second, + FormatMessage: formatMessage, + ReadyForInitialPrompt: isAgentReadyForInitialPrompt, + FormatToolCall: formatToolCall, + InitialPrompt: initialPrompt, + Logger: logger, + }, emitter) + } // Create temporary directory for uploads tempDir, err := os.MkdirTemp("", "agentapi-uploads-") @@ -278,24 +297,25 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { port: config.Port, conversation: conversation, logger: logger, - agentio: config.Process, + agentio: config.AgentIO, agentType: config.AgentType, emitter: emitter, chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), tempDir: tempDir, clock: config.Clock, + transport: config.Transport, } // Register API routes s.registerRoutes() - // Start the conversation polling loop if we have a process. - // Process is nil only when --print-openapi is used (no agent runs). - // The process is already running at this point - termexec.StartProcess() - // blocks until the PTY is created and the process is active. Agent - // readiness (waiting for the prompt) is handled asynchronously inside - // conversation.Start() via ReadyForInitialPrompt. - if config.Process != nil { + // Start the conversation polling loop if we have an agent IO. + // AgentIO is nil only when --print-openapi is used (no agent runs). + // For PTY transport, the process is already running at this point - + // termexec.StartProcess() blocks until the PTY is created and the process + // is active. Agent readiness (waiting for the prompt) is handled + // asynchronously inside conversation.Start() via ReadyForInitialPrompt. + if config.AgentIO != nil { s.conversation.Start(ctx) } diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index c8e8b23c..712df95e 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -29,7 +29,7 @@ func TestOpenAPISchema(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, @@ -78,7 +78,7 @@ func TestServer_redirectToChat(t *testing.T) { tCtx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) s, err := httpapi.NewServer(tCtx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: tc.chatBasePath, AllowedHosts: []string{"*"}, @@ -242,7 +242,7 @@ func TestServer_AllowedHosts(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: tc.allowedHosts, @@ -325,7 +325,7 @@ func TestServer_CORSPreflightWithHosts(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: tc.allowedHosts, @@ -484,7 +484,7 @@ func TestServer_CORSOrigins(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, // Set wildcard to isolate CORS testing @@ -564,7 +564,7 @@ func TestServer_CORSPreflightOrigins(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, // Set wildcard to isolate CORS testing @@ -615,7 +615,7 @@ func TestServer_SSEMiddleware_Events(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, @@ -662,7 +662,7 @@ func TestServer_UploadFiles(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, @@ -817,7 +817,7 @@ func TestServer_UploadFiles_Errors(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 16203041..41565ba1 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "os/exec" "os/signal" "strings" "syscall" @@ -12,6 +13,7 @@ import ( "github.com/coder/agentapi/lib/logctx" mf "github.com/coder/agentapi/lib/msgfmt" "github.com/coder/agentapi/lib/termexec" + "github.com/coder/agentapi/x/acpio" ) type SetupProcessConfig struct { @@ -58,3 +60,56 @@ func SetupProcess(ctx context.Context, config SetupProcessConfig) (*termexec.Pro return process, nil } + +type SetupACPConfig struct { + Program string + ProgramArgs []string +} + +// SetupACPResult contains the result of setting up an ACP process. +type SetupACPResult struct { + AgentIO *acpio.ACPAgentIO + Wait func() error // Calls cmd.Wait() and returns exit error +} + +func SetupACP(ctx context.Context, config SetupACPConfig) (*SetupACPResult, error) { + logger := logctx.From(ctx) + + args := config.ProgramArgs + logger.Info(fmt.Sprintf("Running (ACP): %s %s", config.Program, strings.Join(args, " "))) + + cmd := exec.CommandContext(ctx, config.Program, args...) + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdin pipe: %w", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + } + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start process: %w", err) + } + + agentIO, err := acpio.NewWithPipes(ctx, stdin, stdout, logger) + if err != nil { + _ = cmd.Process.Kill() + return nil, fmt.Errorf("failed to initialize ACP connection: %w", err) + } + + go func() { + <-ctx.Done() + logger.Info("Context done, closing ACP agent") + _ = stdin.Close() + _ = stdout.Close() + _ = cmd.Process.Kill() + }() + + return &SetupACPResult{ + AgentIO: agentIO, + Wait: cmd.Wait, + }, nil +} + diff --git a/openapi.json b/openapi.json index dda817cc..f0554e43 100644 --- a/openapi.json +++ b/openapi.json @@ -265,6 +265,10 @@ "readOnly": true, "type": "string" }, + "acp_mode": { + "description": "Whether the server is running in ACP mode.", + "type": "boolean" + }, "agent_type": { "description": "Type of the agent being used by the server.", "type": "string" diff --git a/x/acpio/acp_conversation.go b/x/acpio/acp_conversation.go new file mode 100644 index 00000000..24c4baea --- /dev/null +++ b/x/acpio/acp_conversation.go @@ -0,0 +1,230 @@ +package acpio + +import ( + "context" + "log/slog" + "slices" + "strings" + "sync" + + st "github.com/coder/agentapi/lib/screentracker" + "github.com/coder/quartz" +) + +// Compile-time assertion that ACPConversation implements st.Conversation +var _ st.Conversation = (*ACPConversation)(nil) + +// ChunkableAgentIO extends AgentIO with chunk callback support for streaming responses. +// This interface is what ACPConversation needs from its AgentIO implementation. +type ChunkableAgentIO interface { + st.AgentIO + SetOnChunk(fn func(chunk string)) +} + +// ACPConversation tracks conversations with ACP-based agents. +// Unlike PTY-based Conversation, ACP has blocking writes where the +// response is complete when Write() returns. +type ACPConversation struct { + mu sync.Mutex + agentIO ChunkableAgentIO + messages []st.ConversationMessage + prompting bool // true while agent is processing + streamingResponse strings.Builder + logger *slog.Logger + emitter st.Emitter + initialPrompt []st.MessagePart + clock quartz.Clock +} + +// noopEmitter is a no-op implementation of Emitter for when no emitter is provided. +type noopEmitter struct{} + +func (noopEmitter) EmitMessages([]st.ConversationMessage) {} +func (noopEmitter) EmitStatus(st.ConversationStatus) {} +func (noopEmitter) EmitScreen(string) {} + +// NewACPConversation creates a new ACPConversation. +// If emitter is provided, it will receive events when messages/status/screen change. +// If clock is nil, a real clock will be used. +func NewACPConversation(agentIO ChunkableAgentIO, logger *slog.Logger, initialPrompt []st.MessagePart, emitter st.Emitter, clock quartz.Clock) *ACPConversation { + if logger == nil { + logger = slog.Default() + } + if clock == nil { + clock = quartz.NewReal() + } + if emitter == nil { + emitter = noopEmitter{} + } + c := &ACPConversation{ + agentIO: agentIO, + logger: logger, + initialPrompt: initialPrompt, + emitter: emitter, + clock: clock, + } + return c +} + +// Messages returns the conversation history. +func (c *ACPConversation) Messages() []st.ConversationMessage { + c.mu.Lock() + defer c.mu.Unlock() + return slices.Clone(c.messages) +} + +// Send sends a message to the agent asynchronously. +// It returns immediately after recording the user message and starts +// the agent request in a background goroutine. Returns an error if +// a message is already being processed. +func (c *ACPConversation) Send(messageParts ...st.MessagePart) error { + message := "" + for _, part := range messageParts { + message += part.String() + } + message = strings.TrimSpace(message) + + if message == "" { + return st.ErrMessageValidationEmpty + } + + // Check if already prompting and set state atomically + c.mu.Lock() + if c.prompting { + c.mu.Unlock() + return st.ErrMessageValidationChanging + } + c.messages = append(c.messages, st.ConversationMessage{ + Id: len(c.messages), + Role: st.ConversationRoleUser, + Message: message, + Time: c.clock.Now(), + }) + // Add placeholder for streaming agent response + c.messages = append(c.messages, st.ConversationMessage{ + Id: len(c.messages), + Role: st.ConversationRoleAgent, + Message: "", + Time: c.clock.Now(), + }) + c.streamingResponse.Reset() + c.prompting = true + status := c.statusLocked() + c.mu.Unlock() + + // Emit status change to "running" before starting the prompt + c.emitter.EmitStatus(status) + + c.logger.Debug("ACPConversation sending message", "message", message) + + // Run the blocking write in a goroutine so HTTP returns immediately + go c.executePrompt(messageParts) + + return nil +} + +// Start sets up chunk handling and sends the initial prompt if provided. +func (c *ACPConversation) Start(ctx context.Context) { + // Wire up the chunk callback for streaming + c.agentIO.SetOnChunk(c.handleChunk) + + // Send initial prompt if provided + if len(c.initialPrompt) > 0 { + err := c.Send(c.initialPrompt...) + if err != nil { + c.logger.Error("ACPConversation failed to send initial prompt", "error", err) + } + } else { + // No initial prompt means we start in stable state + c.emitter.EmitStatus(c.Status()) + } +} + +// Status returns the current conversation status. +func (c *ACPConversation) Status() st.ConversationStatus { + c.mu.Lock() + defer c.mu.Unlock() + return c.statusLocked() +} + +// statusLocked returns the status without acquiring the lock (caller must hold lock). +func (c *ACPConversation) statusLocked() st.ConversationStatus { + if c.prompting { + return st.ConversationStatusChanging // agent is processing + } + return st.ConversationStatusStable +} + +// Text returns the current streaming response text. +func (c *ACPConversation) Text() string { + c.mu.Lock() + defer c.mu.Unlock() + return c.streamingResponse.String() +} + +// handleChunk is called for each streaming chunk from the agent. +func (c *ACPConversation) handleChunk(chunk string) { + c.mu.Lock() + c.streamingResponse.WriteString(chunk) + // Update the last message (the streaming agent response) + if len(c.messages) > 0 { + c.messages[len(c.messages)-1].Message = c.streamingResponse.String() + } + messages := slices.Clone(c.messages) + status := c.statusLocked() + screen := c.streamingResponse.String() + c.mu.Unlock() + + c.emitter.EmitMessages(messages) + c.emitter.EmitStatus(status) + c.emitter.EmitScreen(screen) +} + +// executePrompt runs the actual agent request in background +func (c *ACPConversation) executePrompt(messageParts []st.MessagePart) { + var err error + for _, part := range messageParts { + if partErr := part.Do(c.agentIO); partErr != nil { + err = partErr + break + } + } + + c.mu.Lock() + c.prompting = false + + if err != nil { + c.logger.Error("ACPConversation message failed", "error", err) + // Remove the empty streaming message on error + if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == st.ConversationRoleAgent && + c.messages[len(c.messages)-1].Message == "" { + c.messages = c.messages[:len(c.messages)-1] + } + messages := slices.Clone(c.messages) + status := c.statusLocked() + screen := c.streamingResponse.String() + c.mu.Unlock() + + c.emitter.EmitMessages(messages) + c.emitter.EmitStatus(status) + c.emitter.EmitScreen(screen) + return + } + + // Final response should already be in the last message via streaming + // but ensure it's finalized + response := c.streamingResponse.String() + if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == st.ConversationRoleAgent { + c.messages[len(c.messages)-1].Message = strings.TrimSpace(response) + } + messages := slices.Clone(c.messages) + status := c.statusLocked() + screen := c.streamingResponse.String() + c.mu.Unlock() + + c.emitter.EmitMessages(messages) + c.emitter.EmitStatus(status) + c.emitter.EmitScreen(screen) + + c.logger.Debug("ACPConversation message complete", "responseLen", len(response)) +} diff --git a/x/acpio/acp_conversation_test.go b/x/acpio/acp_conversation_test.go new file mode 100644 index 00000000..c83b559f --- /dev/null +++ b/x/acpio/acp_conversation_test.go @@ -0,0 +1,397 @@ +package acpio_test + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/coder/quartz" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/agentapi/lib/screentracker" + "github.com/coder/agentapi/x/acpio" +) + +// mockAgentIO implements acpio.ChunkableAgentIO for testing. +// It provides a channel-based synchronization mechanism to test ACPConversation +// without relying on time.Sleep. +type mockAgentIO struct { + mu sync.Mutex + written []byte + screenContent string + onChunkFn func(chunk string) + + // Control behavior + writeErr error + // writeBlock is a channel that, if non-nil, will cause Write to block until closed. + // This allows tests to control when the write completes. + writeBlock chan struct{} + // writeStarted is closed when Write begins (before blocking on writeBlock). + // This allows tests to synchronize on the write starting. + writeStarted chan struct{} +} + +// mockEmitter implements screentracker.Emitter for testing. +type mockEmitter struct { + mu sync.Mutex + messagesCalls int + statusCalls int + screenCalls int + lastMessages []screentracker.ConversationMessage + lastStatus screentracker.ConversationStatus + lastScreen string +} + +func newMockEmitter() *mockEmitter { + return &mockEmitter{} +} + +func (m *mockEmitter) EmitMessages(messages []screentracker.ConversationMessage) { + m.mu.Lock() + defer m.mu.Unlock() + m.messagesCalls++ + m.lastMessages = messages +} + +func (m *mockEmitter) EmitStatus(status screentracker.ConversationStatus) { + m.mu.Lock() + defer m.mu.Unlock() + m.statusCalls++ + m.lastStatus = status +} + +func (m *mockEmitter) EmitScreen(screen string) { + m.mu.Lock() + defer m.mu.Unlock() + m.screenCalls++ + m.lastScreen = screen +} + +func (m *mockEmitter) TotalCalls() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.messagesCalls + m.statusCalls + m.screenCalls +} + +func newMockAgentIO() *mockAgentIO { + return &mockAgentIO{} +} + +func (m *mockAgentIO) Write(data []byte) (int, error) { + // Signal that write has started + m.mu.Lock() + started := m.writeStarted + block := m.writeBlock + m.mu.Unlock() + + if started != nil { + close(started) + } + + // Block if configured to do so (for testing concurrent sends) + if block != nil { + <-block + } + + m.mu.Lock() + defer m.mu.Unlock() + if m.writeErr != nil { + return 0, m.writeErr + } + m.written = append(m.written, data...) + return len(data), nil +} + +func (m *mockAgentIO) ReadScreen() string { + m.mu.Lock() + defer m.mu.Unlock() + return m.screenContent +} + +func (m *mockAgentIO) SetOnChunk(fn func(chunk string)) { + m.mu.Lock() + defer m.mu.Unlock() + m.onChunkFn = fn +} + +// SimulateChunks simulates the agent sending streaming chunks. +// This triggers the onChunk callback as if the agent was responding. +func (m *mockAgentIO) SimulateChunks(chunks ...string) { + m.mu.Lock() + fn := m.onChunkFn + m.mu.Unlock() + for _, chunk := range chunks { + if fn != nil { + fn(chunk) + } + } +} + +// GetWritten returns all data written to the agent. +func (m *mockAgentIO) GetWritten() []byte { + m.mu.Lock() + defer m.mu.Unlock() + return append([]byte(nil), m.written...) +} + +// BlockWrite sets up blocking for the next Write call and returns: +// - started: a channel that is closed when Write begins +// - done: a channel to close to unblock the Write +func (m *mockAgentIO) BlockWrite() (started chan struct{}, done chan struct{}) { + m.mu.Lock() + defer m.mu.Unlock() + m.writeStarted = make(chan struct{}) + m.writeBlock = make(chan struct{}) + return m.writeStarted, m.writeBlock +} + +func Test_NewACPConversation(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + + conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + + require.NotNil(t, conv) +} + +func Test_Messages_InitiallyEmpty(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + + messages := conv.Messages() + + assert.Empty(t, messages) +} + +func Test_Status_InitiallyStable(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + + status := conv.Status() + + assert.Equal(t, screentracker.ConversationStatusStable, status) +} + +func Test_Send_AddsUserMessage(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + // Set up blocking to synchronize with the goroutine + started, done := mock.BlockWrite() + + conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv.Start(context.Background()) + + err := conv.Send(screentracker.MessagePartText{Content: "hello"}) + require.NoError(t, err) + + // Wait for the write goroutine to start + <-started + + messages := conv.Messages() + require.Len(t, messages, 2) // user message + placeholder agent message + + assert.Equal(t, screentracker.ConversationRoleUser, messages[0].Role) + assert.Equal(t, "hello", messages[0].Message) + assert.Equal(t, screentracker.ConversationRoleAgent, messages[1].Role) + + // Unblock the write to let the test complete cleanly + close(done) +} + +func Test_Send_RejectsEmptyMessage(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + + err := conv.Send(screentracker.MessagePartText{Content: ""}) + + assert.ErrorIs(t, err, screentracker.ErrMessageValidationEmpty) +} + +func Test_Send_RejectsDuplicateSend(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + // Block the write so it doesn't complete immediately + started, done := mock.BlockWrite() + + conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv.Start(context.Background()) + + // First send should succeed + err := conv.Send(screentracker.MessagePartText{Content: "first"}) + require.NoError(t, err) + + // Wait for the write to start (ensuring we're in "prompting" state) + <-started + + // Second send while first is processing should fail + err = conv.Send(screentracker.MessagePartText{Content: "second"}) + assert.ErrorIs(t, err, screentracker.ErrMessageValidationChanging) + + // Unblock the write to let the test complete cleanly + close(done) +} + +func Test_Status_ChangesWhileProcessing(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + // Block the write so we can observe status changes + started, done := mock.BlockWrite() + + conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv.Start(context.Background()) + + // Initially stable + assert.Equal(t, screentracker.ConversationStatusStable, conv.Status()) + + // Send a message + err := conv.Send(screentracker.MessagePartText{Content: "test"}) + require.NoError(t, err) + + // Wait for write to start + <-started + + // Status should be changing while processing + assert.Equal(t, screentracker.ConversationStatusChanging, conv.Status()) + + // Unblock the write + close(done) + + // Give the goroutine a chance to complete (status update happens after Write returns) + require.Eventually(t, func() bool { + return conv.Status() == screentracker.ConversationStatusStable + }, 100*time.Millisecond, 5*time.Millisecond, "status should return to stable") +} + +func Test_Text_ReturnsStreamingContent(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + // Block the write so we can simulate streaming during processing + started, done := mock.BlockWrite() + + conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv.Start(context.Background()) + + // Initially empty + assert.Equal(t, "", conv.Text()) + + // Send a message + err := conv.Send(screentracker.MessagePartText{Content: "question"}) + require.NoError(t, err) + + // Wait for write to start + <-started + + // Simulate streaming chunks from agent + mock.SimulateChunks("Hello", " ", "world!") + + // Text should contain the streamed content + assert.Equal(t, "Hello world!", conv.Text()) + + // The last message should also be updated + messages := conv.Messages() + require.Len(t, messages, 2) + assert.Equal(t, "Hello world!", messages[1].Message) + + // Unblock the write to let the test complete cleanly + close(done) +} + +func Test_Emitter_CalledOnChanges(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + // Block the write so we can simulate chunks during processing + started, done := mock.BlockWrite() + + emitter := newMockEmitter() + + conv := acpio.NewACPConversation(mock, nil, nil, emitter, mClock) + conv.Start(context.Background()) + + // Send a message + err := conv.Send(screentracker.MessagePartText{Content: "test"}) + require.NoError(t, err) + + // Wait for write to start + <-started + + // Simulate chunks - each should trigger emitter calls + mock.SimulateChunks("chunk1") + mock.SimulateChunks("chunk2") + + emitter.mu.Lock() + messagesCallsBeforeComplete := emitter.messagesCalls + emitter.mu.Unlock() + + // Should have emit calls from chunks (each chunk emits messages, status, and screen) + assert.Equal(t, 2, messagesCallsBeforeComplete) + + // Unblock the write to complete processing + close(done) + + // Wait for completion emit + require.Eventually(t, func() bool { + emitter.mu.Lock() + c := emitter.messagesCalls + emitter.mu.Unlock() + return c >= 3 // 2 from chunks + 1 from completion + }, 100*time.Millisecond, 5*time.Millisecond, "should receive completion emit") +} + +func Test_InitialPrompt_SentOnStart(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + // Set up blocking to synchronize with the initial prompt send + started, done := mock.BlockWrite() + + initialPrompt := []screentracker.MessagePart{ + screentracker.MessagePartText{Content: "initial prompt"}, + } + + conv := acpio.NewACPConversation(mock, nil, initialPrompt, nil, mClock) + conv.Start(context.Background()) + + // Wait for write to start (initial prompt is being sent) + <-started + + // Should have user message from initial prompt + messages := conv.Messages() + require.GreaterOrEqual(t, len(messages), 1) + assert.Equal(t, screentracker.ConversationRoleUser, messages[0].Role) + assert.Equal(t, "initial prompt", messages[0].Message) + + // Unblock the write to let the test complete cleanly + close(done) +} + +func Test_Messages_AreCopied(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + // Set up blocking to synchronize + started, done := mock.BlockWrite() + + conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv.Start(context.Background()) + + err := conv.Send(screentracker.MessagePartText{Content: "test"}) + require.NoError(t, err) + + // Wait for write to start + <-started + + // Get messages and modify + messages := conv.Messages() + require.Len(t, messages, 2) + messages[0].Message = "modified" + + // Original should be unchanged + originalMessages := conv.Messages() + assert.Equal(t, "test", originalMessages[0].Message) + + // Unblock the write to let the test complete cleanly + close(done) +} diff --git a/x/acpio/acpio.go b/x/acpio/acpio.go new file mode 100644 index 00000000..6b8fc194 --- /dev/null +++ b/x/acpio/acpio.go @@ -0,0 +1,221 @@ +package acpio + +import ( + "context" + "fmt" + "io" + "log/slog" + "strings" + "sync" + + acp "github.com/coder/acp-go-sdk" + st "github.com/coder/agentapi/lib/screentracker" +) + +// Compile-time assertion that ACPAgentIO implements st.AgentIO +var _ st.AgentIO = (*ACPAgentIO)(nil) + +// ACPAgentIO implements screentracker.AgentIO using the ACP protocol +type ACPAgentIO struct { + ctx context.Context + conn *acp.ClientSideConnection + sessionID acp.SessionId + mu sync.RWMutex + response strings.Builder + logger *slog.Logger + onChunk func(chunk string) // called on each streaming chunk +} + +// acpClient implements acp.Client to handle callbacks from the agent +type acpClient struct { + agentIO *ACPAgentIO +} + +var _ acp.Client = (*acpClient)(nil) + +func (c *acpClient) SessionUpdate(ctx context.Context, params acp.SessionNotification) error { + c.agentIO.logger.Debug("SessionUpdate received", + "sessionId", params.SessionId, + "hasAgentMessageChunk", params.Update.AgentMessageChunk != nil) + + if params.Update.AgentMessageChunk != nil { + if text := params.Update.AgentMessageChunk.Content.Text; text != nil { + c.agentIO.logger.Debug("AgentMessageChunk text", + "text", text.Text, + "textLen", len(text.Text)) + c.agentIO.mu.Lock() + c.agentIO.response.WriteString(text.Text) + onChunk := c.agentIO.onChunk + c.agentIO.mu.Unlock() + if onChunk != nil { + onChunk(text.Text) + } + } + } + + // Handle tool calls - format as text and append to response + if params.Update.ToolCall != nil { + tc := params.Update.ToolCall + formatted := fmt.Sprintf("\n[Tool: %s] %s\n", tc.Kind, tc.Title) + c.agentIO.mu.Lock() + c.agentIO.response.WriteString(formatted) + onChunk := c.agentIO.onChunk + c.agentIO.mu.Unlock() + if onChunk != nil { + onChunk(formatted) + } + } + + if params.Update.ToolCallUpdate != nil { + tcu := params.Update.ToolCallUpdate + var formatted string + if tcu.Status != nil { + formatted = fmt.Sprintf("[Tool Status: %s]\n", *tcu.Status) + } + if formatted != "" { + c.agentIO.mu.Lock() + c.agentIO.response.WriteString(formatted) + onChunk := c.agentIO.onChunk + c.agentIO.mu.Unlock() + if onChunk != nil { + onChunk(formatted) + } + } + } + + return nil +} + +func (c *acpClient) RequestPermission(ctx context.Context, params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { + // Auto-approve all permissions for Phase 1 + return acp.RequestPermissionResponse{ + Outcome: acp.RequestPermissionOutcome{ + Selected: &acp.RequestPermissionOutcomeSelected{OptionId: "allow"}, + }, + }, nil +} + +func (c *acpClient) ReadTextFile(ctx context.Context, params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { + return acp.ReadTextFileResponse{}, nil +} + +func (c *acpClient) WriteTextFile(ctx context.Context, params acp.WriteTextFileRequest) (acp.WriteTextFileResponse, error) { + return acp.WriteTextFileResponse{}, nil +} + +func (c *acpClient) CreateTerminal(ctx context.Context, params acp.CreateTerminalRequest) (acp.CreateTerminalResponse, error) { + return acp.CreateTerminalResponse{}, nil +} + +func (c *acpClient) KillTerminalCommand(ctx context.Context, params acp.KillTerminalCommandRequest) (acp.KillTerminalCommandResponse, error) { + return acp.KillTerminalCommandResponse{}, nil +} + +func (c *acpClient) TerminalOutput(ctx context.Context, params acp.TerminalOutputRequest) (acp.TerminalOutputResponse, error) { + return acp.TerminalOutputResponse{}, nil +} + +func (c *acpClient) ReleaseTerminal(ctx context.Context, params acp.ReleaseTerminalRequest) (acp.ReleaseTerminalResponse, error) { + return acp.ReleaseTerminalResponse{}, nil +} + +func (c *acpClient) WaitForTerminalExit(ctx context.Context, params acp.WaitForTerminalExitRequest) (acp.WaitForTerminalExitResponse, error) { + return acp.WaitForTerminalExitResponse{}, nil +} + +// SetOnChunk sets a callback that will be called for each streaming chunk. +func (a *ACPAgentIO) SetOnChunk(fn func(chunk string)) { + a.mu.Lock() + defer a.mu.Unlock() + a.onChunk = fn +} + +// NewWithPipes creates an ACPAgentIO connected via the provided pipes +func NewWithPipes(ctx context.Context, toAgent io.Writer, fromAgent io.Reader, logger *slog.Logger) (*ACPAgentIO, error) { + if logger == nil { + logger = slog.Default() + } + agentIO := &ACPAgentIO{ctx: ctx, logger: logger} + client := &acpClient{agentIO: agentIO} + + conn := acp.NewClientSideConnection(client, toAgent, fromAgent) + agentIO.conn = conn + + logger.Debug("Initializing ACP connection") + + // Initialize the connection + initResp, err := conn.Initialize(ctx, acp.InitializeRequest{ + ProtocolVersion: acp.ProtocolVersionNumber, + ClientCapabilities: acp.ClientCapabilities{}, + }) + if err != nil { + logger.Error("Failed to initialize ACP connection", "error", err) + return nil, err + } + logger.Debug("ACP initialized", "protocolVersion", initResp.ProtocolVersion) + + // Create a session + sessResp, err := conn.NewSession(ctx, acp.NewSessionRequest{ + Cwd: "/tmp", + McpServers: []acp.McpServer{}, + }) + if err != nil { + logger.Error("Failed to create ACP session", "error", err) + return nil, err + } + agentIO.sessionID = sessResp.SessionId + logger.Debug("ACP session created", "sessionId", sessResp.SessionId) + + return agentIO, nil +} + +// Write sends a message to the agent via ACP prompt +func (a *ACPAgentIO) Write(data []byte) (int, error) { + text := string(data) + + // Strip bracketed paste escape sequences if present + text = strings.TrimPrefix(text, "\x1b[200~") + text = strings.TrimSuffix(text, "\x1b[201~") + + // Strip terminal hack sequences (x\b pattern used for Claude Code compatibility) + text = strings.TrimPrefix(text, "x\b") + + text = strings.TrimSpace(text) + + // Don't send empty prompts + if text == "" { + a.logger.Debug("Ignoring empty prompt", "rawDataLen", len(data)) + return len(data), nil + } + + // Clear previous response + a.mu.Lock() + a.response.Reset() + a.mu.Unlock() + + a.logger.Debug("Sending prompt", + "sessionId", a.sessionID, + "text", text, + "textLen", len(text), + "rawDataLen", len(data)) + + resp, err := a.conn.Prompt(a.ctx, acp.PromptRequest{ + SessionId: a.sessionID, + Prompt: []acp.ContentBlock{acp.TextBlock(text)}, + }) + if err != nil { + a.logger.Error("Prompt failed", "error", err) + return 0, err + } + + a.logger.Debug("Prompt completed", "stopReason", resp.StopReason) + + return len(data), nil +} + +// ReadScreen returns the accumulated agent response +func (a *ACPAgentIO) ReadScreen() string { + a.mu.RLock() + defer a.mu.RUnlock() + return a.response.String() +} From e9d8dd3ef0a984923434edfe1e4442b8f7a88f4f Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 13 Feb 2026 23:51:43 +0000 Subject: [PATCH 07/17] fixup! feat: add experimental ACP mode (--experimental-acp) --- cmd/attach/attach.go | 10 +++------- openapi.json | 4 ---- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/cmd/attach/attach.go b/cmd/attach/attach.go index 512d70f4..86fa69b6 100644 --- a/cmd/attach/attach.go +++ b/cmd/attach/attach.go @@ -129,9 +129,7 @@ func WriteRawInputOverHTTP(ctx context.Context, url string, msg string) error { return nil } - // statusResponse is used to parse the /status endpoint response. -// The ACPMode field may not be present on older servers. type statusResponse struct { Status string `json:"status"` AgentType string `json:"agent_type"` @@ -143,17 +141,15 @@ func checkACPMode(remoteUrl string) error { if err != nil { return xerrors.Errorf("failed to check server status: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - // Server doesn't support /status or had an error, continue anyway - return nil + return xerrors.Errorf("unexpected %d response from server: %s", resp.StatusCode, resp.Status) } var status statusResponse if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { - // Can't parse response, continue anyway - return nil + return xerrors.Errorf("failed to decode server status: %w", err) } if status.ACPMode { diff --git a/openapi.json b/openapi.json index f0554e43..dda817cc 100644 --- a/openapi.json +++ b/openapi.json @@ -265,10 +265,6 @@ "readOnly": true, "type": "string" }, - "acp_mode": { - "description": "Whether the server is running in ACP mode.", - "type": "boolean" - }, "agent_type": { "description": "Type of the agent being used by the server.", "type": "string" From 1c3c9aae83f054365b7d506d34c31da66825d32c Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 17 Feb 2026 16:26:28 +0000 Subject: [PATCH 08/17] self-review --- cmd/attach/attach.go | 4 ++-- lib/httpapi/models.go | 1 + lib/httpapi/server.go | 1 + openapi.json | 5 +++++ x/acpio/acp_conversation.go | 6 +++++- x/acpio/acp_conversation_test.go | 28 ++++++++++++++++++++++++++++ 6 files changed, 42 insertions(+), 3 deletions(-) diff --git a/cmd/attach/attach.go b/cmd/attach/attach.go index 86fa69b6..747b0a5c 100644 --- a/cmd/attach/attach.go +++ b/cmd/attach/attach.go @@ -133,7 +133,7 @@ func WriteRawInputOverHTTP(ctx context.Context, url string, msg string) error { type statusResponse struct { Status string `json:"status"` AgentType string `json:"agent_type"` - ACPMode bool `json:"acp_mode"` + Backend string `json:"backend"` } func checkACPMode(remoteUrl string) error { @@ -152,7 +152,7 @@ func checkACPMode(remoteUrl string) error { return xerrors.Errorf("failed to decode server status: %w", err) } - if status.ACPMode { + if status.Backend == "acp" { return xerrors.New("attach is not supported in ACP mode. The server is running with --experimental-acp which uses JSON-RPC instead of terminal emulation.") } diff --git a/lib/httpapi/models.go b/lib/httpapi/models.go index 7ed52c43..76ba824d 100644 --- a/lib/httpapi/models.go +++ b/lib/httpapi/models.go @@ -38,6 +38,7 @@ type StatusResponse struct { Body struct { Status AgentStatus `json:"status" doc:"Current agent status. 'running' means that the agent is processing a message, 'stable' means that the agent is idle and waiting for input."` AgentType mf.AgentType `json:"agent_type" doc:"Type of the agent being used by the server."` + Backend string `json:"backend" doc:"Backend transport being used ('acp' or 'pty')."` } } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index d1c8cafe..ebcf6e2c 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -437,6 +437,7 @@ func (s *Server) getStatus(ctx context.Context, input *struct{}) (*StatusRespons resp := &StatusResponse{} resp.Body.Status = agentStatus resp.Body.AgentType = s.agentType + resp.Body.Backend = s.transport return resp, nil } diff --git a/openapi.json b/openapi.json index dda817cc..12c4590b 100644 --- a/openapi.json +++ b/openapi.json @@ -269,6 +269,10 @@ "description": "Type of the agent being used by the server.", "type": "string" }, + "backend": { + "description": "Backend transport being used ('acp' or 'pty').", + "type": "string" + }, "status": { "$ref": "#/components/schemas/AgentStatus", "description": "Current agent status. 'running' means that the agent is processing a message, 'stable' means that the agent is idle and waiting for input." @@ -276,6 +280,7 @@ }, "required": [ "agent_type", + "backend", "status" ], "type": "object" diff --git a/x/acpio/acp_conversation.go b/x/acpio/acp_conversation.go index 24c4baea..1c052ddd 100644 --- a/x/acpio/acp_conversation.go +++ b/x/acpio/acp_conversation.go @@ -82,7 +82,11 @@ func (c *ACPConversation) Send(messageParts ...st.MessagePart) error { for _, part := range messageParts { message += part.String() } - message = strings.TrimSpace(message) + + // Validate whitespace BEFORE trimming (match PTY behavior) + if message != strings.TrimSpace(message) { + return st.ErrMessageValidationWhitespace + } if message == "" { return st.ErrMessageValidationEmpty diff --git a/x/acpio/acp_conversation_test.go b/x/acpio/acp_conversation_test.go index c83b559f..0ec3c220 100644 --- a/x/acpio/acp_conversation_test.go +++ b/x/acpio/acp_conversation_test.go @@ -212,6 +212,34 @@ func Test_Send_RejectsEmptyMessage(t *testing.T) { assert.ErrorIs(t, err, screentracker.ErrMessageValidationEmpty) } +func Test_Send_RejectsWhitespace(t *testing.T) { + tests := []struct { + name string + content string + }{ + {"leading space", " hello"}, + {"trailing space", "hello "}, + {"leading newline", "\nhello"}, + {"trailing newline", "hello\n"}, + {"both sides", " hello "}, + {"newlines both sides", "\nhello\n"}, + {"leading tab", "\thello"}, + {"trailing tab", "hello\t"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + + err := conv.Send(screentracker.MessagePartText{Content: tt.content}) + + assert.ErrorIs(t, err, screentracker.ErrMessageValidationWhitespace) + }) + } +} + func Test_Send_RejectsDuplicateSend(t *testing.T) { mClock := quartz.NewMock(t) mock := newMockAgentIO() From 35f4f5f6a892ba341af4a0a6614382e466ec6e62 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 17 Feb 2026 17:47:48 +0000 Subject: [PATCH 09/17] more self-review fixes --- cmd/server/server.go | 11 ++++++----- lib/httpapi/setup.go | 27 ++++++++++++++++++++------- x/acpio/acpio.go | 19 ++++++++++++++++--- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index f34ac72a..15c14af7 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -114,12 +114,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er var agentIO st.AgentIO var transport = "pty" var process *termexec.Process - var acpWait func() error + var acpResult *httpapi.SetupACPResult if printOpenAPI { agentIO = nil } else if experimentalACP { - acpResult, err := httpapi.SetupACP(ctx, httpapi.SetupACPConfig{ + var err error + acpResult, err = httpapi.SetupACP(ctx, httpapi.SetupACPConfig{ Program: agent, ProgramArgs: argsToPass[1:], }) @@ -127,7 +128,6 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er return xerrors.Errorf("failed to setup ACP: %w", err) } acpIO := acpResult.AgentIO - acpWait = acpResult.Wait agentIO = acpIO transport = "acp" } else { @@ -181,10 +181,11 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er }() } // Wait for process exit in ACP mode - if acpWait != nil { + if acpResult != nil { go func() { defer close(processExitCh) - if err := acpWait(); err != nil { + defer close(acpResult.Done) // Signal cleanup goroutine to exit + if err := acpResult.Wait(); err != nil { processExitCh <- xerrors.Errorf("ACP process exited: %w", err) } if err := srv.Stop(ctx); err != nil { diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 41565ba1..9e672760 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -69,7 +69,8 @@ type SetupACPConfig struct { // SetupACPResult contains the result of setting up an ACP process. type SetupACPResult struct { AgentIO *acpio.ACPAgentIO - Wait func() error // Calls cmd.Wait() and returns exit error + Wait func() error // Calls cmd.Wait() and returns exit error + Done chan struct{} // Close this when Wait() returns to clean up goroutine } func SetupACP(ctx context.Context, config SetupACPConfig) (*SetupACPResult, error) { @@ -93,23 +94,35 @@ func SetupACP(ctx context.Context, config SetupACPConfig) (*SetupACPResult, erro return nil, fmt.Errorf("failed to start process: %w", err) } - agentIO, err := acpio.NewWithPipes(ctx, stdin, stdout, logger) + agentIO, err := acpio.NewWithPipes(ctx, stdin, stdout, logger, os.Getwd) if err != nil { _ = cmd.Process.Kill() return nil, fmt.Errorf("failed to initialize ACP connection: %w", err) } + done := make(chan struct{}) go func() { - <-ctx.Done() - logger.Info("Context done, closing ACP agent") - _ = stdin.Close() - _ = stdout.Close() - _ = cmd.Process.Kill() + select { + case <-ctx.Done(): + logger.Info("Context done, closing ACP agent") + _ = stdin.Close() + _ = stdout.Close() + // Try graceful shutdown first + _ = cmd.Process.Signal(syscall.SIGTERM) + // Force kill after timeout + time.AfterFunc(5*time.Second, func() { + _ = cmd.Process.Kill() + }) + case <-done: + // Process exited normally, nothing to clean up + return + } }() return &SetupACPResult{ AgentIO: agentIO, Wait: cmd.Wait, + Done: done, }, nil } diff --git a/x/acpio/acpio.go b/x/acpio/acpio.go index 6b8fc194..bd6d6bfd 100644 --- a/x/acpio/acpio.go +++ b/x/acpio/acpio.go @@ -7,6 +7,7 @@ import ( "log/slog" "strings" "sync" + "time" acp "github.com/coder/acp-go-sdk" st "github.com/coder/agentapi/lib/screentracker" @@ -15,6 +16,9 @@ import ( // Compile-time assertion that ACPAgentIO implements st.AgentIO var _ st.AgentIO = (*ACPAgentIO)(nil) +// DefaultPromptTimeout is the maximum time to wait for an agent response. +const DefaultPromptTimeout = 5 * time.Minute + // ACPAgentIO implements screentracker.AgentIO using the ACP protocol type ACPAgentIO struct { ctx context.Context @@ -131,7 +135,7 @@ func (a *ACPAgentIO) SetOnChunk(fn func(chunk string)) { } // NewWithPipes creates an ACPAgentIO connected via the provided pipes -func NewWithPipes(ctx context.Context, toAgent io.Writer, fromAgent io.Reader, logger *slog.Logger) (*ACPAgentIO, error) { +func NewWithPipes(ctx context.Context, toAgent io.Writer, fromAgent io.Reader, logger *slog.Logger, getwd func() (string, error)) (*ACPAgentIO, error) { if logger == nil { logger = slog.Default() } @@ -155,8 +159,13 @@ func NewWithPipes(ctx context.Context, toAgent io.Writer, fromAgent io.Reader, l logger.Debug("ACP initialized", "protocolVersion", initResp.ProtocolVersion) // Create a session + cwd, err := getwd() + if err != nil { + logger.Error("Failed to get working directory", "error", err) + return nil, err + } sessResp, err := conn.NewSession(ctx, acp.NewSessionRequest{ - Cwd: "/tmp", + Cwd: cwd, McpServers: []acp.McpServer{}, }) if err != nil { @@ -199,7 +208,11 @@ func (a *ACPAgentIO) Write(data []byte) (int, error) { "textLen", len(text), "rawDataLen", len(data)) - resp, err := a.conn.Prompt(a.ctx, acp.PromptRequest{ + // Use a timeout to prevent hanging indefinitely + promptCtx, cancel := context.WithTimeout(a.ctx, DefaultPromptTimeout) + defer cancel() + + resp, err := a.conn.Prompt(promptCtx, acp.PromptRequest{ SessionId: a.sessionID, Prompt: []acp.ContentBlock{acp.TextBlock(text)}, }) From 0e5c2a0edef85fd74d75c83d69c5dbc98a29c6a0 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 17 Feb 2026 22:44:57 +0000 Subject: [PATCH 10/17] fix: remove partial agent messages on error in ACP conversation Previously, when Send() failed, only empty agent messages were removed. If the agent had streamed partial content before the error, that partial message would incorrectly remain in the conversation. Now we remove the agent message on error regardless of whether it has content, ensuring the conversation state stays consistent. --- x/acpio/acp_conversation.go | 5 ++-- x/acpio/acp_conversation_test.go | 49 ++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/x/acpio/acp_conversation.go b/x/acpio/acp_conversation.go index 1c052ddd..29d3198d 100644 --- a/x/acpio/acp_conversation.go +++ b/x/acpio/acp_conversation.go @@ -199,9 +199,8 @@ func (c *ACPConversation) executePrompt(messageParts []st.MessagePart) { if err != nil { c.logger.Error("ACPConversation message failed", "error", err) - // Remove the empty streaming message on error - if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == st.ConversationRoleAgent && - c.messages[len(c.messages)-1].Message == "" { + // Remove the agent's streaming message on error (may be empty or partial) + if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == st.ConversationRoleAgent { c.messages = c.messages[:len(c.messages)-1] } messages := slices.Clone(c.messages) diff --git a/x/acpio/acp_conversation_test.go b/x/acpio/acp_conversation_test.go index 0ec3c220..37e9b868 100644 --- a/x/acpio/acp_conversation_test.go +++ b/x/acpio/acp_conversation_test.go @@ -423,3 +423,52 @@ func Test_Messages_AreCopied(t *testing.T) { // Unblock the write to let the test complete cleanly close(done) } + +func Test_ErrorRemovesPartialMessage(t *testing.T) { + mClock := quartz.NewMock(t) + mock := newMockAgentIO() + // Block the write so we can simulate partial content before error + started, done := mock.BlockWrite() + + conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv.Start(context.Background()) + + // Send a message + err := conv.Send(screentracker.MessagePartText{Content: "test"}) + require.NoError(t, err) + + // Wait for write to start + <-started + + // Should have user message + placeholder agent message + messages := conv.Messages() + require.Len(t, messages, 2) + assert.Equal(t, screentracker.ConversationRoleUser, messages[0].Role) + assert.Equal(t, screentracker.ConversationRoleAgent, messages[1].Role) + + // Simulate the agent streaming partial content before the error + mock.SimulateChunks("partial ", "response ", "content") + + // Verify partial content is in the agent message + messages = conv.Messages() + require.Len(t, messages, 2) + assert.Equal(t, "partial response content", messages[1].Message) + + // Now configure the mock to return an error and unblock + mock.mu.Lock() + mock.writeErr = assert.AnError + mock.mu.Unlock() + close(done) + + // Wait for the conversation to stabilize after the error + require.Eventually(t, func() bool { + return conv.Status() == screentracker.ConversationStatusStable + }, 100*time.Millisecond, 5*time.Millisecond, "status should return to stable") + + // The partial agent message should be removed on error. + // Only the user message should remain. + messages = conv.Messages() + require.Len(t, messages, 1, "partial agent message should be removed on error") + assert.Equal(t, screentracker.ConversationRoleUser, messages[0].Role) + assert.Equal(t, "test", messages[0].Message) +} From eeea49ad9feb0e47be03b283980b108e6201807b Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 17 Feb 2026 22:44:48 +0000 Subject: [PATCH 11/17] Send SIGTERM before closing pipes in ACP shutdown Reorder the shutdown sequence to send SIGTERM first, giving the agent a chance to shutdown gracefully before losing I/O. --- lib/httpapi/setup.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 9e672760..0cbd761d 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -105,10 +105,11 @@ func SetupACP(ctx context.Context, config SetupACPConfig) (*SetupACPResult, erro select { case <-ctx.Done(): logger.Info("Context done, closing ACP agent") - _ = stdin.Close() - _ = stdout.Close() // Try graceful shutdown first _ = cmd.Process.Signal(syscall.SIGTERM) + // Then close pipes + _ = stdin.Close() + _ = stdout.Close() // Force kill after timeout time.AfterFunc(5*time.Second, func() { _ = cmd.Process.Kill() From 01c1b7aab4e7036ee6aaed382c523edc23cddd3b Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 17 Feb 2026 22:46:11 +0000 Subject: [PATCH 12/17] Clean up draft messages in sendMessage error paths --- chat/src/components/chat-provider.tsx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chat/src/components/chat-provider.tsx b/chat/src/components/chat-provider.tsx index 3cbef3ab..b84aa989 100644 --- a/chat/src/components/chat-provider.tsx +++ b/chat/src/components/chat-provider.tsx @@ -280,6 +280,7 @@ export function ChatProvider({ children }: PropsWithChildren) { }); if (!response.ok) { + setMessages((prev) => prev.filter((m) => !isDraftMessage(m))); const errorData = await response.json() as APIErrorModel; console.error("Failed to send message:", errorData); const detail = errorData.detail; @@ -296,6 +297,7 @@ export function ChatProvider({ children }: PropsWithChildren) { } } catch (error) { + setMessages((prev) => prev.filter((m) => !isDraftMessage(m))); console.error("Error sending message:", error); const message = getErrorMessage(error) From 8a44a8f21aa9a19c1399ede9a7588763136a35c9 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 17 Feb 2026 23:03:05 +0000 Subject: [PATCH 13/17] refactor: move draft cleanup to finally block in sendMessage --- chat/src/components/chat-provider.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chat/src/components/chat-provider.tsx b/chat/src/components/chat-provider.tsx index b84aa989..55db1e2b 100644 --- a/chat/src/components/chat-provider.tsx +++ b/chat/src/components/chat-provider.tsx @@ -280,7 +280,6 @@ export function ChatProvider({ children }: PropsWithChildren) { }); if (!response.ok) { - setMessages((prev) => prev.filter((m) => !isDraftMessage(m))); const errorData = await response.json() as APIErrorModel; console.error("Failed to send message:", errorData); const detail = errorData.detail; @@ -297,7 +296,6 @@ export function ChatProvider({ children }: PropsWithChildren) { } } catch (error) { - setMessages((prev) => prev.filter((m) => !isDraftMessage(m))); console.error("Error sending message:", error); const message = getErrorMessage(error) @@ -305,6 +303,8 @@ export function ChatProvider({ children }: PropsWithChildren) { description: message, }); } finally { + // Remove optimistic draft message if still present (may have been replaced by server response via SSE). + setMessages((prev) => prev.filter((m) => !isDraftMessage(m))); if (type === "user") { setLoading(false); } From fe02d076e3c53ac7a7c293ed04e2ffc5cc29a717 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 17 Feb 2026 23:04:56 +0000 Subject: [PATCH 14/17] Use quartz.Clock instead of time.AfterFunc in SetupACP - Add Clock field to SetupACPConfig - Default to quartz.NewReal() if Clock is nil - Replace time.AfterFunc with config.Clock.AfterFunc for testability --- lib/httpapi/setup.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 0cbd761d..0ee992f1 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -14,6 +14,7 @@ import ( mf "github.com/coder/agentapi/lib/msgfmt" "github.com/coder/agentapi/lib/termexec" "github.com/coder/agentapi/x/acpio" + "github.com/coder/quartz" ) type SetupProcessConfig struct { @@ -64,6 +65,7 @@ func SetupProcess(ctx context.Context, config SetupProcessConfig) (*termexec.Pro type SetupACPConfig struct { Program string ProgramArgs []string + Clock quartz.Clock } // SetupACPResult contains the result of setting up an ACP process. @@ -76,6 +78,10 @@ type SetupACPResult struct { func SetupACP(ctx context.Context, config SetupACPConfig) (*SetupACPResult, error) { logger := logctx.From(ctx) + if config.Clock == nil { + config.Clock = quartz.NewReal() + } + args := config.ProgramArgs logger.Info(fmt.Sprintf("Running (ACP): %s %s", config.Program, strings.Join(args, " "))) @@ -111,7 +117,7 @@ func SetupACP(ctx context.Context, config SetupACPConfig) (*SetupACPResult, erro _ = stdin.Close() _ = stdout.Close() // Force kill after timeout - time.AfterFunc(5*time.Second, func() { + config.Clock.AfterFunc(5*time.Second, func() { _ = cmd.Process.Kill() }) case <-done: From 97f7a175e2f6a3aea468a05de1e8c61dec5afc34 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Wed, 18 Feb 2026 08:47:03 +0000 Subject: [PATCH 15/17] acpio: check parent ctx and return early --- x/acpio/acpio.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/x/acpio/acpio.go b/x/acpio/acpio.go index bd6d6bfd..77db963e 100644 --- a/x/acpio/acpio.go +++ b/x/acpio/acpio.go @@ -208,6 +208,11 @@ func (a *ACPAgentIO) Write(data []byte) (int, error) { "textLen", len(text), "rawDataLen", len(data)) + // Ensure the context has not been cancelled before writing prompt + if err := a.ctx.Err(); err != nil { + a.logger.Debug("Aborting write", "error", err) + return 0, err + } // Use a timeout to prevent hanging indefinitely promptCtx, cancel := context.WithTimeout(a.ctx, DefaultPromptTimeout) defer cancel() From 9982505325d1c7c725ea0ce976c9078d61b3face Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Wed, 18 Feb 2026 12:43:04 +0000 Subject: [PATCH 16/17] fix: add missing return after context cancellation cleanup The cleanup goroutine in SetupACPProcess was leaking because it didn't return after handling ctx.Done(). After scheduling the kill timer, the goroutine would block forever on the done channel. --- lib/httpapi/setup.go | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 0ee992f1..447b10b7 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -120,6 +120,7 @@ func SetupACP(ctx context.Context, config SetupACPConfig) (*SetupACPResult, erro config.Clock.AfterFunc(5*time.Second, func() { _ = cmd.Process.Kill() }) + return case <-done: // Process exited normally, nothing to clean up return From 8f3d2783d634204a410881d16e92b5a77bd49467 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Wed, 18 Feb 2026 12:45:02 +0000 Subject: [PATCH 17/17] acpio: Thread context through ACPConversation for cancellation - Add ctx and cancel fields to ACPConversation struct - Update NewACPConversation to accept context.Context as first param - Create cancellable context in constructor using context.WithCancel - Add Stop() method that calls c.cancel() - Check c.ctx.Err() in executePrompt before each message part - Update server.go to pass ctx to NewACPConversation - Update all test calls to pass context.Background() --- lib/httpapi/server.go | 2 +- x/acpio/acp_conversation.go | 16 +++++++++++++++- x/acpio/acp_conversation_test.go | 26 +++++++++++++------------- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index ebcf6e2c..1f255101 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -263,7 +263,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { if !ok { return nil, fmt.Errorf("ACP transport requires ACPAgentIO") } - conversation = acpio.NewACPConversation(acpIO, logger, initialPrompt, emitter, config.Clock) + conversation = acpio.NewACPConversation(ctx, acpIO, logger, initialPrompt, emitter, config.Clock) } else { // Default to PTY transport proc, ok := config.AgentIO.(*termexec.Process) diff --git a/x/acpio/acp_conversation.go b/x/acpio/acp_conversation.go index 29d3198d..f58e7ff3 100644 --- a/x/acpio/acp_conversation.go +++ b/x/acpio/acp_conversation.go @@ -26,6 +26,8 @@ type ChunkableAgentIO interface { // response is complete when Write() returns. type ACPConversation struct { mu sync.Mutex + ctx context.Context + cancel context.CancelFunc agentIO ChunkableAgentIO messages []st.ConversationMessage prompting bool // true while agent is processing @@ -46,7 +48,7 @@ func (noopEmitter) EmitScreen(string) {} // NewACPConversation creates a new ACPConversation. // If emitter is provided, it will receive events when messages/status/screen change. // If clock is nil, a real clock will be used. -func NewACPConversation(agentIO ChunkableAgentIO, logger *slog.Logger, initialPrompt []st.MessagePart, emitter st.Emitter, clock quartz.Clock) *ACPConversation { +func NewACPConversation(ctx context.Context, agentIO ChunkableAgentIO, logger *slog.Logger, initialPrompt []st.MessagePart, emitter st.Emitter, clock quartz.Clock) *ACPConversation { if logger == nil { logger = slog.Default() } @@ -56,7 +58,10 @@ func NewACPConversation(agentIO ChunkableAgentIO, logger *slog.Logger, initialPr if emitter == nil { emitter = noopEmitter{} } + ctx, cancel := context.WithCancel(ctx) c := &ACPConversation{ + ctx: ctx, + cancel: cancel, agentIO: agentIO, logger: logger, initialPrompt: initialPrompt, @@ -159,6 +164,11 @@ func (c *ACPConversation) statusLocked() st.ConversationStatus { return st.ConversationStatusStable } +// Stop cancels any in-progress operations. +func (c *ACPConversation) Stop() { + c.cancel() +} + // Text returns the current streaming response text. func (c *ACPConversation) Text() string { c.mu.Lock() @@ -188,6 +198,10 @@ func (c *ACPConversation) handleChunk(chunk string) { func (c *ACPConversation) executePrompt(messageParts []st.MessagePart) { var err error for _, part := range messageParts { + if c.ctx.Err() != nil { + err = c.ctx.Err() + break + } if partErr := part.Do(c.agentIO); partErr != nil { err = partErr break diff --git a/x/acpio/acp_conversation_test.go b/x/acpio/acp_conversation_test.go index 37e9b868..0811bc3b 100644 --- a/x/acpio/acp_conversation_test.go +++ b/x/acpio/acp_conversation_test.go @@ -151,7 +151,7 @@ func Test_NewACPConversation(t *testing.T) { mClock := quartz.NewMock(t) mock := newMockAgentIO() - conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock) require.NotNil(t, conv) } @@ -159,7 +159,7 @@ func Test_NewACPConversation(t *testing.T) { func Test_Messages_InitiallyEmpty(t *testing.T) { mClock := quartz.NewMock(t) mock := newMockAgentIO() - conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock) messages := conv.Messages() @@ -169,7 +169,7 @@ func Test_Messages_InitiallyEmpty(t *testing.T) { func Test_Status_InitiallyStable(t *testing.T) { mClock := quartz.NewMock(t) mock := newMockAgentIO() - conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock) status := conv.Status() @@ -182,7 +182,7 @@ func Test_Send_AddsUserMessage(t *testing.T) { // Set up blocking to synchronize with the goroutine started, done := mock.BlockWrite() - conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock) conv.Start(context.Background()) err := conv.Send(screentracker.MessagePartText{Content: "hello"}) @@ -205,7 +205,7 @@ func Test_Send_AddsUserMessage(t *testing.T) { func Test_Send_RejectsEmptyMessage(t *testing.T) { mClock := quartz.NewMock(t) mock := newMockAgentIO() - conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock) err := conv.Send(screentracker.MessagePartText{Content: ""}) @@ -231,7 +231,7 @@ func Test_Send_RejectsWhitespace(t *testing.T) { t.Run(tt.name, func(t *testing.T) { mClock := quartz.NewMock(t) mock := newMockAgentIO() - conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock) err := conv.Send(screentracker.MessagePartText{Content: tt.content}) @@ -246,7 +246,7 @@ func Test_Send_RejectsDuplicateSend(t *testing.T) { // Block the write so it doesn't complete immediately started, done := mock.BlockWrite() - conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock) conv.Start(context.Background()) // First send should succeed @@ -270,7 +270,7 @@ func Test_Status_ChangesWhileProcessing(t *testing.T) { // Block the write so we can observe status changes started, done := mock.BlockWrite() - conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock) conv.Start(context.Background()) // Initially stable @@ -301,7 +301,7 @@ func Test_Text_ReturnsStreamingContent(t *testing.T) { // Block the write so we can simulate streaming during processing started, done := mock.BlockWrite() - conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock) conv.Start(context.Background()) // Initially empty @@ -337,7 +337,7 @@ func Test_Emitter_CalledOnChanges(t *testing.T) { emitter := newMockEmitter() - conv := acpio.NewACPConversation(mock, nil, nil, emitter, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, emitter, mClock) conv.Start(context.Background()) // Send a message @@ -380,7 +380,7 @@ func Test_InitialPrompt_SentOnStart(t *testing.T) { screentracker.MessagePartText{Content: "initial prompt"}, } - conv := acpio.NewACPConversation(mock, nil, initialPrompt, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, initialPrompt, nil, mClock) conv.Start(context.Background()) // Wait for write to start (initial prompt is being sent) @@ -402,7 +402,7 @@ func Test_Messages_AreCopied(t *testing.T) { // Set up blocking to synchronize started, done := mock.BlockWrite() - conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock) conv.Start(context.Background()) err := conv.Send(screentracker.MessagePartText{Content: "test"}) @@ -430,7 +430,7 @@ func Test_ErrorRemovesPartialMessage(t *testing.T) { // Block the write so we can simulate partial content before error started, done := mock.BlockWrite() - conv := acpio.NewACPConversation(mock, nil, nil, nil, mClock) + conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock) conv.Start(context.Background()) // Send a message