-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Expand file tree
/
Copy pathlm-studio.ts
More file actions
284 lines (246 loc) · 9.01 KB
/
lm-studio.ts
File metadata and controls
284 lines (246 loc) · 9.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import axios from "axios"
import { type ModelInfo, openAiModelInfoSaneDefaults, LMSTUDIO_DEFAULT_TEMPERATURE } from "@roo-code/types"
import type { ApiHandlerOptions } from "../../shared/api"
import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser"
import { TagMatcher } from "../../utils/tag-matcher"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { convertToZAiFormat } from "../transform/zai-format"
import { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { getModelsFromCache } from "./fetchers/modelCache"
import { getApiRequestTimeout } from "./utils/timeout-config"
import { handleOpenAIError } from "./utils/openai-error-handler"
import { detectGlmModel, logGlmDetection, type GlmModelConfig } from "./utils/glm-model-detection"
export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: OpenAI
private readonly providerName = "LM Studio"
private glmConfig: GlmModelConfig | null = null
constructor(options: ApiHandlerOptions) {
super()
this.options = options
// LM Studio uses "noop" as a placeholder API key
const apiKey = "noop"
this.client = new OpenAI({
baseURL: (this.options.lmStudioBaseUrl || "http://localhost:1234") + "/v1",
apiKey: apiKey,
timeout: getApiRequestTimeout(),
})
// Detect GLM model on construction if model ID is available
const modelId = this.options.lmStudioModelId || ""
if (modelId) {
this.glmConfig = detectGlmModel(modelId)
logGlmDetection(this.providerName, modelId, this.glmConfig)
}
}
override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const model = this.getModel()
// Re-detect GLM model if not already done or if model ID changed
if (!this.glmConfig || this.glmConfig.originalModelId !== model.id) {
this.glmConfig = detectGlmModel(model.id)
logGlmDetection(this.providerName, model.id, this.glmConfig)
}
// Convert messages based on whether this is a GLM model
// GLM models benefit from mergeToolResultText to prevent reasoning_content loss
const convertedMessages = this.glmConfig.isGlmModel
? convertToZAiFormat(messages, { mergeToolResultText: this.glmConfig.mergeToolResultText })
: convertToOpenAiMessages(messages)
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertedMessages,
]
// -------------------------
// Track token usage
// -------------------------
const toContentBlocks = (
blocks: Anthropic.Messages.MessageParam[] | string,
): Anthropic.Messages.ContentBlockParam[] => {
if (typeof blocks === "string") {
return [{ type: "text", text: blocks }]
}
const result: Anthropic.Messages.ContentBlockParam[] = []
for (const msg of blocks) {
if (typeof msg.content === "string") {
result.push({ type: "text", text: msg.content })
} else if (Array.isArray(msg.content)) {
for (const part of msg.content) {
if (part.type === "text") {
result.push({ type: "text", text: part.text })
}
}
}
}
return result
}
let inputTokens = 0
try {
inputTokens = await this.countTokens([{ type: "text", text: systemPrompt }, ...toContentBlocks(messages)])
} catch (err) {
console.error("[LmStudio] Failed to count input tokens:", err)
inputTokens = 0
}
let assistantText = ""
try {
// Determine parallel_tool_calls setting
// Disable for GLM models as they may not support it properly
let parallelToolCalls: boolean
if (this.glmConfig.isGlmModel && this.glmConfig.disableParallelToolCalls) {
parallelToolCalls = false
console.log(`[${this.providerName}] parallel_tool_calls disabled for GLM model`)
} else {
parallelToolCalls = metadata?.parallelToolCalls ?? true
}
const params: OpenAI.Chat.ChatCompletionCreateParamsStreaming & { draft_model?: string } = {
model: model.id,
messages: openAiMessages,
temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE,
stream: true,
tools: this.convertToolsForOpenAI(metadata?.tools),
tool_choice: metadata?.tool_choice,
parallel_tool_calls: parallelToolCalls,
}
if (this.options.lmStudioSpeculativeDecodingEnabled && this.options.lmStudioDraftModelId) {
params.draft_model = this.options.lmStudioDraftModelId
}
// For GLM-4.7 models with thinking support, add thinking parameter
if (this.glmConfig.isGlmModel && this.glmConfig.supportsThinking) {
const useReasoning = this.options.enableReasoningEffort !== false // Default to enabled for GLM-4.7
;(params as any).thinking = useReasoning ? { type: "enabled" } : { type: "disabled" }
console.log(`[${this.providerName}] GLM-4.7 thinking mode: ${useReasoning ? "enabled" : "disabled"}`)
}
let results
try {
results = await this.client.chat.completions.create(params)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}
const matcher = new TagMatcher(
"think",
(chunk) =>
({
type: chunk.matched ? "reasoning" : "text",
text: chunk.data,
}) as const,
)
for await (const chunk of results) {
const delta = chunk.choices[0]?.delta
const finishReason = chunk.choices[0]?.finish_reason
if (delta?.content) {
assistantText += delta.content
for (const processedChunk of matcher.update(delta.content)) {
yield processedChunk
}
}
// Handle reasoning_content for GLM models (similar to Z.ai)
if (delta) {
for (const key of ["reasoning_content", "reasoning"] as const) {
if (key in delta) {
const reasoning_content = ((delta as any)[key] as string | undefined) || ""
if (reasoning_content?.trim()) {
yield { type: "reasoning", text: reasoning_content }
}
break
}
}
}
// Handle tool calls in stream - emit partial chunks for NativeToolCallParser
if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
}
// Process finish_reason to emit tool_call_end events
if (finishReason) {
const endEvents = NativeToolCallParser.processFinishReason(finishReason)
for (const event of endEvents) {
yield event
}
}
}
for (const processedChunk of matcher.final()) {
yield processedChunk
}
let outputTokens = 0
try {
outputTokens = await this.countTokens([{ type: "text", text: assistantText }])
} catch (err) {
console.error("[LmStudio] Failed to count output tokens:", err)
outputTokens = 0
}
yield {
type: "usage",
inputTokens,
outputTokens,
} as const
} catch (error) {
throw new Error(
"Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Roo Code's prompts.",
)
}
}
override getModel(): { id: string; info: ModelInfo } {
const models = getModelsFromCache("lmstudio")
if (models && this.options.lmStudioModelId && models[this.options.lmStudioModelId]) {
return {
id: this.options.lmStudioModelId,
info: models[this.options.lmStudioModelId],
}
} else {
return {
id: this.options.lmStudioModelId || "",
info: openAiModelInfoSaneDefaults,
}
}
}
async completePrompt(prompt: string): Promise<string> {
try {
// Create params object with optional draft model
const params: any = {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE,
stream: false,
}
// Add draft model if speculative decoding is enabled and a draft model is specified
if (this.options.lmStudioSpeculativeDecodingEnabled && this.options.lmStudioDraftModelId) {
params.draft_model = this.options.lmStudioDraftModelId
}
let response
try {
response = await this.client.chat.completions.create(params)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}
return response.choices[0]?.message.content || ""
} catch (error) {
throw new Error(
"Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Roo Code's prompts.",
)
}
}
}
export async function getLmStudioModels(baseUrl = "http://localhost:1234") {
try {
if (!URL.canParse(baseUrl)) {
return []
}
const response = await axios.get(`${baseUrl}/v1/models`)
const modelsArray = response.data?.data?.map((model: any) => model.id) || []
return [...new Set<string>(modelsArray)]
} catch (error) {
return []
}
}