diff --git a/typescript/src/chat.ts b/typescript/src/chat.ts index 78f24d0..fdc0824 100644 --- a/typescript/src/chat.ts +++ b/typescript/src/chat.ts @@ -16,15 +16,16 @@ const messageSchema = z.object({ }); export const chatRequestSchema = z.object({ - messages: z.array(messageSchema), - temperature: z.number().optional(), - max_tokens: z.number().optional(), + messages: z.array(messageSchema), + temperature: z.number().optional(), + max_tokens: z.number().optional(), }); type ChatRequest = z.infer; interface ChatResponse { model: string; content: string; + confidence?: number; } export interface Chunk { @@ -45,29 +46,36 @@ export async function getChatResponse( return { model: "gpt-4o", + confidence: response.confidence, content, }; } -export async function streamChatResponse( - chatRequest: ChatRequest -): Promise<{ model: string; stream: AsyncIterable }> { +export async function streamChatResponse(chatRequest: ChatRequest): Promise<{ + model: string; + confidence?: number; + stream: AsyncIterable; +}> { try { - const stream = await openAiChatCompletion({ + const { confidence, stream } = await openAiChatCompletion({ model: "gpt-4o", messages: chatRequest.messages, stream: true, temperature: chatRequest.temperature, max_tokens: chatRequest.max_tokens, }); - return { model: "gpt-4o", stream: chunksFromOpenAiStream(stream) }; + return { + model: "gpt-4o", + confidence, + stream: chunksFromOpenAiStream(stream), + }; } catch (e) { console.warn("Error streaming chat response from OpenAI", e); const systemMessage = chatRequest.messages.find( (message) => message.role === "system" ); - const iterator = await bedrockChatCompletion({ + const { confidence, stream } = await bedrockChatCompletion({ modelId: "anthropic.claude-3-sonnet-20240229-v1:0", system: systemMessage ? [ @@ -90,11 +98,12 @@ export async function streamChatResponse( inferenceConfig: { temperature: chatRequest.temperature, maxTokens: chatRequest.max_tokens, - } + }, }); return { model: "claude-3-sonnet", - stream: chunksFromBedrockStream(iterator), + confidence, + stream: chunksFromBedrockStream(stream), }; } } diff --git a/typescript/src/stubs/stub-bedrock-client.ts b/typescript/src/stubs/stub-bedrock-client.ts index c247f1d..a19a748 100644 --- a/typescript/src/stubs/stub-bedrock-client.ts +++ b/typescript/src/stubs/stub-bedrock-client.ts @@ -18,8 +18,12 @@ async function* stubGenerator(): AsyncGenerator { } } -export function bedrockChatCompletion( - _: ConverseStreamCommandInput -): Promise> { - return Promise.resolve(stubGenerator()); +export function bedrockChatCompletion(_: ConverseStreamCommandInput): { + confidence: number; + stream: AsyncIterable; +} { + return { + confidence: (Math.floor(Math.random() * 10) + 1) / 10, + stream: stubGenerator(), + }; } diff --git a/typescript/src/stubs/stub-openai-client.ts b/typescript/src/stubs/stub-openai-client.ts index 67d7bdd..1e49754 100644 --- a/typescript/src/stubs/stub-openai-client.ts +++ b/typescript/src/stubs/stub-openai-client.ts @@ -3,18 +3,22 @@ import { Stream } from "openai/streaming"; const chunks = ["Hello ", "from ", "OpenAI!"]; -async function* stubGenerator(model: string): AsyncGenerator { +async function* stubGenerator( + model: string +): AsyncGenerator { for (const [index, output] of chunks.entries()) { yield { id: index.toString(), - choices: [{ - index: 0, - delta: { - role: "assistant", - content: output, + choices: [ + { + index: 0, + delta: { + role: "assistant", + content: output, + }, + finish_reason: null, }, - finish_reason: null, - }], + ], model: model, object: "chat.completion.chunk", created: Date.now(), @@ -24,6 +28,12 @@ async function* stubGenerator(model: string): AsyncGenerator> { - return Promise.resolve(new Stream(() => stubGenerator(input.model), new AbortController())); +): { + confidence: number; + stream: Stream; +} { + return { + confidence: (Math.floor(Math.random() * 10) + 1) / 10, + stream: new Stream(() => stubGenerator(input.model), new AbortController()), + }; }