Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/shiny-pets-report.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@workflow/ai": patch
---

Fix `collectUIMessages` option failing in workflow context
29 changes: 28 additions & 1 deletion packages/ai/src/agent/do-stream-step.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type {
import {
type FinishReason,
gateway,
generateId,
type StepResult,
type StopCondition,
type ToolChoice,
Expand Down Expand Up @@ -60,6 +61,11 @@ export interface DoStreamStepOptions {
experimental_telemetry?: TelemetrySettings;
transforms?: Array<StreamTextTransform<ToolSet>>;
responseFormat?: LanguageModelV2CallOptions['responseFormat'];
/**
* If true, collects and returns all UIMessageChunks written to the stream.
* This is used by DurableAgent when collectUIMessages is enabled.
*/
collectUIChunks?: boolean;
}

/**
Expand Down Expand Up @@ -157,6 +163,8 @@ export async function doStreamStep(
const toolCalls: LanguageModelV2ToolCall[] = [];
const chunks: LanguageModelV2StreamPart[] = [];
const includeRawChunks = options?.includeRawChunks ?? false;
const collectUIChunks = options?.collectUIChunks ?? false;
const uiChunks: UIMessageChunk[] = [];

// Build the stream pipeline
let stream: ReadableStream<LanguageModelV2StreamPart> = result.stream;
Expand Down Expand Up @@ -203,6 +211,9 @@ export async function doStreamStep(
if (options?.sendStart) {
controller.enqueue({
type: 'start',
// Note that if useChat is used client-side, useChat will generate a different
// messageId. It's hard to work around this.
messageId: generateId(),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Client-side would do sendStart: false in that case, no?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the user can decide to do that, but would then have to handle the start themselves, which I think is fine. The ID seems unlikely to cause issues unless message persistence is mixed between client and server side

});
}
controller.enqueue({
Expand Down Expand Up @@ -443,10 +454,26 @@ export async function doStreamStep(
},
})
)
.pipeThrough(
// Optionally collect UIMessageChunks for later conversion to UIMessage[]
new TransformStream<UIMessageChunk, UIMessageChunk>({
transform: (chunk, controller) => {
if (collectUIChunks) {
uiChunks.push(chunk);
}
controller.enqueue(chunk);
},
})
)
.pipeTo(writable, { preventClose: true });

const step = chunksToStep(chunks, toolCalls, conversationPrompt, finish);
return { toolCalls, finish, step };
return {
toolCalls,
finish,
step,
uiChunks: collectUIChunks ? uiChunks : undefined,
};
}

/**
Expand Down
23 changes: 10 additions & 13 deletions packages/ai/src/agent/durable-agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,8 @@ describe('DurableAgent', () => {
timestamp: new Date(),
},
warnings: [],
};
// We're missing some properties that aren't relevant for the test
} as unknown as StepResult<any>;
const mockMessages: LanguageModelV2Prompt = [
{ role: 'user', content: [{ type: 'text', text: 'test' }] },
];
Expand Down Expand Up @@ -1296,7 +1297,8 @@ describe('DurableAgent', () => {
timestamp: new Date(),
},
warnings: [],
};
// We're missing some properties that aren't relevant for the test
} as unknown as StepResult<any>;
const finalMessages: LanguageModelV2Prompt = [
{ role: 'user', content: [{ type: 'text', text: 'test' }] },
{ role: 'assistant', content: [{ type: 'text', text: 'Hello' }] },
Expand Down Expand Up @@ -1584,29 +1586,26 @@ describe('DurableAgent', () => {
expect(result.uiMessages).toBeUndefined();
});

it('should use accumulator writable when collectUIMessages is true', async () => {
it('should pass collectUIChunks to streamTextIterator when collectUIMessages is true', async () => {
const mockModel = createMockModel();

const agent = new DurableAgent({
model: async () => mockModel,
tools: {},
});

const writtenChunks: unknown[] = [];
const mockWritable = new WritableStream({
write: (chunk) => {
writtenChunks.push(chunk);
},
write: vi.fn(),
close: vi.fn(),
});

const { streamTextIterator } = await import('./stream-text-iterator.js');
let capturedWritable: unknown;
let capturedCollectUIChunks: boolean | undefined;
const mockIterator = {
next: vi.fn().mockResolvedValueOnce({ done: true, value: [] }),
};
vi.mocked(streamTextIterator).mockImplementation((opts) => {
capturedWritable = opts.writable;
capturedCollectUIChunks = opts.collectUIChunks;
return mockIterator as unknown as MockIterator;
});

Expand All @@ -1616,10 +1615,8 @@ describe('DurableAgent', () => {
collectUIMessages: true,
});

// When collectUIMessages is true, the writable passed to streamTextIterator
// should be the accumulator's writable (not the original)
expect(capturedWritable).toBeDefined();
expect(capturedWritable).not.toBe(mockWritable);
// When collectUIMessages is true, collectUIChunks should be passed to streamTextIterator
expect(capturedCollectUIChunks).toBe(true);

// uiMessages should be defined (even if empty, since we're mocking)
expect(result.uiMessages).toBeDefined();
Expand Down
95 changes: 69 additions & 26 deletions packages/ai/src/agent/durable-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
type LanguageModelUsage,
type ModelMessage,
Output,
readUIMessageStream,
type StepResult,
type StopCondition,
type StreamTextOnStepFinishCallback,
Expand All @@ -25,7 +26,6 @@ import { convertToLanguageModelPrompt, standardizePrompt } from 'ai/internal';
import { FatalError } from 'workflow';
import { streamTextIterator } from './stream-text-iterator.js';
import type { CompatibleLanguageModel } from './types.js';
import { UIMessageAccumulator } from './ui-message-accumulator.js';

// Re-export for consumers
export type { CompatibleLanguageModel } from './types.js';
Expand Down Expand Up @@ -730,16 +730,14 @@ export class DurableAgent<TBaseTools extends ToolSet = ToolSet> {
};
}

// Set up UIMessage accumulator if requested
const accumulator = options.collectUIMessages
? new UIMessageAccumulator(options.writable)
: null;
const effectiveWritable = accumulator?.writable ?? options.writable;
// Track collected UI chunks if collectUIMessages is enabled
const collectUIChunks = options.collectUIMessages ?? false;
const allUIChunks: UIMessageChunk[] = [];

const iterator = streamTextIterator({
model: this.model,
tools: effectiveTools as ToolSet,
writable: effectiveWritable,
writable: options.writable,
prompt: modelPrompt,
stopConditions: options.stopWhen,
maxSteps: options.maxSteps,
Expand All @@ -756,6 +754,7 @@ export class DurableAgent<TBaseTools extends ToolSet = ToolSet> {
| StreamTextTransform<ToolSet>
| Array<StreamTextTransform<ToolSet>>,
responseFormat: options.experimental_output?.responseFormat,
collectUIChunks,
});

// Track the final conversation messages from the iterator
Expand All @@ -780,6 +779,7 @@ export class DurableAgent<TBaseTools extends ToolSet = ToolSet> {
messages: iterMessages,
step,
context,
uiChunks,
} = result.value;
if (step) {
// The step result is compatible with StepResult<TTools> since we're using the same tools
Expand All @@ -789,6 +789,10 @@ export class DurableAgent<TBaseTools extends ToolSet = ToolSet> {
if (context !== undefined) {
experimentalContext = context;
}
// Collect UI chunks if enabled
if (uiChunks && uiChunks.length > 0) {
allUIChunks.push(...uiChunks);
}

// Only execute tools if there are tool calls
if (toolCalls.length > 0) {
Expand Down Expand Up @@ -833,23 +837,8 @@ export class DurableAgent<TBaseTools extends ToolSet = ToolSet> {
const sendFinish = options.sendFinish ?? true;
const preventClose = options.preventClose ?? false;

// Handle stream closing with special care for accumulator
if (accumulator) {
// When using accumulator, we need to:
// 1. Write finish chunk through accumulator (if sendFinish is true) so it's captured
// 2. Close the accumulator's writable to signal completion
// 3. Handle the original stream's close separately (finish already forwarded through accumulator)
if (sendFinish) {
await writeFinishChunk(effectiveWritable);
}
// Always close the accumulator's writable so getMessages() can complete
await effectiveWritable.close();
// Now close the original stream (sendFinish=false since finish already written through accumulator)
if (!preventClose) {
await closeStream(options.writable, preventClose, false);
}
} else if (sendFinish || !preventClose) {
// No accumulator - use standard close logic
// Handle stream closing
if (sendFinish || !preventClose) {
await closeStream(options.writable, preventClose, sendFinish);
}

Expand Down Expand Up @@ -898,8 +887,9 @@ export class DurableAgent<TBaseTools extends ToolSet = ToolSet> {
}

// Collect accumulated UI messages if requested
const uiMessages = accumulator
? await accumulator.getMessages()
// This requires a step function since it performs stream operations
const uiMessages = collectUIChunks
? await convertChunksToUIMessages(allUIChunks)
: undefined;

return {
Expand Down Expand Up @@ -956,6 +946,59 @@ async function closeStream(
}
}

/**
* Convert UIMessageChunks to UIMessage[] using the AI SDK's readUIMessageStream.
* This must be a step function because it performs stream operations.
*
* @param chunks - The collected UIMessageChunks to convert
* @returns The accumulated UIMessage array
*/
async function convertChunksToUIMessages(
chunks: UIMessageChunk[]
): Promise<UIMessage[]> {
'use step';

if (chunks.length === 0) {
return [];
}

// Create a readable stream from the collected chunks.
// AI SDK only supports conversion from UIMessageChunk[] to UIMessage[]
// as a streaming operation, so we need to wrap the chunks in a stream.
const chunkStream = new ReadableStream<UIMessageChunk>({
start: (controller) => {
for (const chunk of chunks) {
controller.enqueue(chunk);
}
controller.close();
},
});

// Use the AI SDK's readUIMessageStream to convert chunks to messages
const messageStream = readUIMessageStream({
stream: chunkStream,
onError: (error) => {
console.error('Error processing UI message chunks:', error);
},
});

// Collect all message updates and return the final state
const messages: UIMessage[] = [];
for await (const message of messageStream) {
// readUIMessageStream yields updated versions of the message as it's built
// We want to collect the final state of each message
// Messages are identified by their id, so we update in place
const existingIndex = messages.findIndex((m) => m.id === message.id);
if (existingIndex >= 0) {
messages[existingIndex] = message;
} else {
messages.push(message);
}
}

return messages;
}

async function executeTool(
toolCall: LanguageModelV2ToolCall,
tools: ToolSet,
Expand Down
Loading
Loading