-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Expand file tree
/
Copy pathbase-openai-compatible-provider.ts
More file actions
289 lines (244 loc) · 8.72 KB
/
base-openai-compatible-provider.ts
File metadata and controls
289 lines (244 loc) · 8.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import type { ModelInfo } from "@roo-code/types"
import { type ApiHandlerOptions, getModelMaxOutputTokens } from "../../shared/api"
import { XmlMatcher } from "../../utils/xml-matcher"
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { DEFAULT_HEADERS } from "./constants"
import { BaseProvider } from "./base-provider"
import { handleOpenAIError } from "./utils/openai-error-handler"
import { calculateApiCostOpenAI } from "../../shared/cost"
import { getApiRequestTimeout } from "./utils/timeout-config"
type BaseOpenAiCompatibleProviderOptions<ModelName extends string> = ApiHandlerOptions & {
providerName: string
baseURL: string
defaultProviderModelId: ModelName
providerModels: Record<ModelName, ModelInfo>
defaultTemperature?: number
}
export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
extends BaseProvider
implements SingleCompletionHandler
{
protected readonly providerName: string
protected readonly baseURL: string
protected readonly defaultTemperature: number
protected readonly defaultProviderModelId: ModelName
protected readonly providerModels: Record<ModelName, ModelInfo>
protected readonly options: ApiHandlerOptions
protected client: OpenAI
// Abort controller for cancelling ongoing requests
private abortController?: AbortController
constructor({
providerName,
baseURL,
defaultProviderModelId,
providerModels,
defaultTemperature,
...options
}: BaseOpenAiCompatibleProviderOptions<ModelName>) {
super()
this.providerName = providerName
this.baseURL = baseURL
this.defaultProviderModelId = defaultProviderModelId
this.providerModels = providerModels
this.defaultTemperature = defaultTemperature ?? 0
this.options = options
if (!this.options.apiKey) {
throw new Error("API key is required")
}
this.client = new OpenAI({
baseURL,
apiKey: this.options.apiKey,
defaultHeaders: DEFAULT_HEADERS,
timeout: getApiRequestTimeout(),
})
}
protected createStream(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
requestOptions?: OpenAI.RequestOptions,
) {
const { id: model, info } = this.getModel()
// Centralized cap: clamp to 20% of the context window (unless provider-specific exceptions apply)
const max_tokens =
getModelMaxOutputTokens({
modelId: model,
model: info,
settings: this.options,
format: "openai",
}) ?? undefined
const temperature = this.options.modelTemperature ?? info.defaultTemperature ?? this.defaultTemperature
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model,
max_tokens,
temperature,
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
...(metadata?.toolProtocol === "native" && {
parallel_tool_calls: metadata.parallelToolCalls ?? false,
}),
}
// Add thinking parameter if reasoning is enabled and model supports it
if (this.options.enableReasoningEffort && info.supportsReasoningBinary) {
;(params as any).thinking = { type: "enabled" }
}
try {
// Merge abort signal with any existing request options
const mergedOptions: OpenAI.RequestOptions = {
...requestOptions,
signal: this.abortController?.signal,
}
return this.client.chat.completions.create(params, mergedOptions)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}
}
override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
// Create AbortController for cancellation
this.abortController = new AbortController()
try {
const stream = await this.createStream(systemPrompt, messages, metadata)
const matcher = new XmlMatcher(
"think",
(chunk) =>
({
type: chunk.matched ? "reasoning" : "text",
text: chunk.data,
}) as const,
)
let lastUsage: OpenAI.CompletionUsage | undefined
const activeToolCallIds = new Set<string>()
for await (const chunk of stream) {
// Check if request was aborted
if (this.abortController?.signal.aborted) {
break
}
// Check for provider-specific error responses (e.g., MiniMax base_resp)
const chunkAny = chunk as any
if (chunkAny.base_resp?.status_code && chunkAny.base_resp.status_code !== 0) {
throw new Error(
`${this.providerName} API Error (${chunkAny.base_resp.status_code}): ${chunkAny.base_resp.status_msg || "Unknown error"}`,
)
}
const delta = chunk.choices?.[0]?.delta
const finishReason = chunk.choices?.[0]?.finish_reason
if (delta?.content) {
for (const processedChunk of matcher.update(delta.content)) {
yield processedChunk
}
}
if (delta) {
for (const key of ["reasoning_content", "reasoning"] as const) {
if (key in delta) {
const reasoning_content = ((delta as any)[key] as string | undefined) || ""
if (reasoning_content?.trim()) {
yield { type: "reasoning", text: reasoning_content }
}
break
}
}
}
// Emit raw tool call chunks - NativeToolCallParser handles state management
if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
if (toolCall.id) {
activeToolCallIds.add(toolCall.id)
}
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
}
// Emit tool_call_end events when finish_reason is "tool_calls"
// This ensures tool calls are finalized even if the stream doesn't properly close
if (finishReason === "tool_calls" && activeToolCallIds.size > 0) {
for (const id of activeToolCallIds) {
yield { type: "tool_call_end", id }
}
activeToolCallIds.clear()
}
if (chunk.usage) {
lastUsage = chunk.usage
}
}
if (lastUsage) {
yield this.processUsageMetrics(lastUsage, this.getModel().info)
}
// Process any remaining content
for (const processedChunk of matcher.final()) {
yield processedChunk
}
} finally {
this.abortController = undefined
}
}
protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk {
const inputTokens = usage?.prompt_tokens || 0
const outputTokens = usage?.completion_tokens || 0
const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
const { totalCost } = modelInfo
? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
: { totalCost: 0 }
return {
type: "usage",
inputTokens,
outputTokens,
cacheWriteTokens: cacheWriteTokens || undefined,
cacheReadTokens: cacheReadTokens || undefined,
totalCost,
}
}
async completePrompt(prompt: string): Promise<string> {
// Create AbortController for cancellation
this.abortController = new AbortController()
try {
const { id: modelId, info: modelInfo } = this.getModel()
const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = {
model: modelId,
messages: [{ role: "user", content: prompt }],
}
// Add thinking parameter if reasoning is enabled and model supports it
if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) {
;(params as any).thinking = { type: "enabled" }
}
const response = await this.client.chat.completions.create(params, {
signal: this.abortController.signal,
})
// Check for provider-specific error responses (e.g., MiniMax base_resp)
const responseAny = response as any
if (responseAny.base_resp?.status_code && responseAny.base_resp.status_code !== 0) {
throw new Error(
`${this.providerName} API Error (${responseAny.base_resp.status_code}): ${responseAny.base_resp.status_msg || "Unknown error"}`,
)
}
return response.choices?.[0]?.message.content || ""
} catch (error) {
throw handleOpenAIError(error, this.providerName)
} finally {
this.abortController = undefined
}
}
override getModel() {
const id =
this.options.apiModelId && this.options.apiModelId in this.providerModels
? (this.options.apiModelId as ModelName)
: this.defaultProviderModelId
return { id, info: this.providerModels[id] }
}
}