Skip to content

Commit 3077d0d

Browse files
authored
Merge pull request #3 from subpop/add-vertex
Add Vertex AI support for Google and Anthropic models
2 parents d899770 + 1402696 commit 3077d0d

6 files changed

Lines changed: 1058 additions & 0 deletions

File tree

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import Foundation
2+
3+
// MARK: - ADC Credential File
4+
5+
private struct ADCCredentials: Decodable {
6+
let type: String
7+
let clientId: String
8+
let clientSecret: String
9+
let refreshToken: String
10+
11+
enum CodingKeys: String, CodingKey {
12+
case type
13+
case clientId = "client_id"
14+
case clientSecret = "client_secret"
15+
case refreshToken = "refresh_token"
16+
}
17+
}
18+
19+
// MARK: - Token Response
20+
21+
private struct TokenResponse: Decodable {
22+
let accessToken: String
23+
let expiresIn: Int
24+
let tokenType: String
25+
26+
enum CodingKeys: String, CodingKey {
27+
case accessToken = "access_token"
28+
case expiresIn = "expires_in"
29+
case tokenType = "token_type"
30+
}
31+
}
32+
33+
/// Manages Google OAuth2 tokens from Application Default Credentials (ADC).
34+
///
35+
/// Reads `~/.config/gcloud/application_default_credentials.json` (created by
36+
/// `gcloud auth application-default login`) and transparently refreshes access
37+
/// tokens as needed.
38+
///
39+
/// Thread-safe via `actor` isolation — only one refresh request can be in
40+
/// flight at a time.
41+
public actor GoogleAuthService {
42+
// MARK: - Errors
43+
44+
public enum GoogleAuthError: Error, LocalizedError, Sendable {
45+
case credentialsFileNotFound(path: String)
46+
case unsupportedCredentialType(String)
47+
case refreshFailed(statusCode: Int, body: String)
48+
case decodingFailed(String)
49+
50+
public var errorDescription: String? {
51+
switch self {
52+
case let .credentialsFileNotFound(path):
53+
"Google ADC credentials not found at \(path). Run `gcloud auth application-default login`."
54+
case let .unsupportedCredentialType(type):
55+
"Unsupported ADC credential type: \(type). Only 'authorized_user' is supported."
56+
case let .refreshFailed(code, body):
57+
"Token refresh failed (HTTP \(code)): \(body)"
58+
case let .decodingFailed(message):
59+
"Failed to decode ADC credentials: \(message)"
60+
}
61+
}
62+
}
63+
64+
// MARK: - State
65+
66+
private let clientID: String
67+
private let clientSecret: String
68+
private let refreshToken: String
69+
private let session: URLSession
70+
71+
private var cachedAccessToken: String?
72+
private var tokenExpiry: Date?
73+
74+
/// Refresh the token when it has fewer than this many seconds remaining.
75+
private let refreshMargin: TimeInterval = 300 // 5 minutes
76+
77+
private static let tokenEndpoint = URL(string: "https://oauth2.googleapis.com/token")!
78+
79+
// MARK: - Init
80+
81+
/// Creates an auth service by reading the ADC file at the default path.
82+
public init(session: URLSession = .shared) throws {
83+
try self.init(credentialsPath: Self.defaultCredentialsPath(), session: session)
84+
}
85+
86+
/// Creates an auth service by reading the ADC file at a custom path.
87+
public init(credentialsPath: String, session: URLSession = .shared) throws {
88+
guard FileManager.default.fileExists(atPath: credentialsPath) else {
89+
throw GoogleAuthError.credentialsFileNotFound(path: credentialsPath)
90+
}
91+
let data: Data
92+
do {
93+
data = try Data(contentsOf: URL(fileURLWithPath: credentialsPath))
94+
} catch {
95+
throw GoogleAuthError.decodingFailed("Failed to read file: \(error.localizedDescription)")
96+
}
97+
let credentials: ADCCredentials
98+
do {
99+
credentials = try JSONDecoder().decode(ADCCredentials.self, from: data)
100+
} catch {
101+
throw GoogleAuthError.decodingFailed(error.localizedDescription)
102+
}
103+
guard credentials.type == "authorized_user" else {
104+
throw GoogleAuthError.unsupportedCredentialType(credentials.type)
105+
}
106+
clientID = credentials.clientId
107+
clientSecret = credentials.clientSecret
108+
refreshToken = credentials.refreshToken
109+
self.session = session
110+
}
111+
112+
// MARK: - Public API
113+
114+
/// Returns a valid access token, refreshing if necessary.
115+
public func accessToken() async throws -> String {
116+
if let token = cachedAccessToken,
117+
let expiry = tokenExpiry,
118+
Date() < expiry.addingTimeInterval(-refreshMargin) {
119+
return token
120+
}
121+
return try await refreshAccessToken()
122+
}
123+
124+
// MARK: - Private
125+
126+
private func refreshAccessToken() async throws -> String {
127+
var request = URLRequest(url: Self.tokenEndpoint)
128+
request.httpMethod = "POST"
129+
request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type")
130+
131+
let body = [
132+
"client_id=\(urlEncode(clientID))",
133+
"client_secret=\(urlEncode(clientSecret))",
134+
"refresh_token=\(urlEncode(refreshToken))",
135+
"grant_type=refresh_token",
136+
].joined(separator: "&")
137+
request.httpBody = Data(body.utf8)
138+
139+
let (data, response) = try await session.data(for: request)
140+
141+
guard let httpResponse = response as? HTTPURLResponse else {
142+
throw GoogleAuthError.refreshFailed(statusCode: 0, body: "Invalid response")
143+
}
144+
guard httpResponse.statusCode == 200 else {
145+
let responseBody = String(data: data, encoding: .utf8) ?? "<unreadable>"
146+
throw GoogleAuthError.refreshFailed(statusCode: httpResponse.statusCode, body: responseBody)
147+
}
148+
149+
let tokenResponse = try JSONDecoder().decode(TokenResponse.self, from: data)
150+
cachedAccessToken = tokenResponse.accessToken
151+
tokenExpiry = Date().addingTimeInterval(TimeInterval(tokenResponse.expiresIn))
152+
return tokenResponse.accessToken
153+
}
154+
155+
private func urlEncode(_ string: String) -> String {
156+
string.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed) ?? string
157+
}
158+
159+
/// The default path to the ADC credentials file.
160+
public static func defaultCredentialsPath() -> String {
161+
let home = FileManager.default.homeDirectoryForCurrentUser.path
162+
return "\(home)/.config/gcloud/application_default_credentials.json"
163+
}
164+
165+
/// Whether an ADC credentials file exists at the default path.
166+
public static func credentialsAvailable() -> Bool {
167+
FileManager.default.fileExists(atPath: defaultCredentialsPath())
168+
}
169+
}
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import Foundation
2+
3+
/// An LLM client for Anthropic Claude models served via Vertex AI.
4+
///
5+
/// Uses OAuth2 Bearer token authentication (via ``GoogleAuthService`` or a custom
6+
/// token provider closure) instead of Anthropic API key authentication.
7+
///
8+
/// The wire format is the standard Anthropic Messages API with a
9+
/// `"anthropic_version": "vertex-2023-10-16"` field injected into the request body.
10+
/// Response parsing and SSE streaming are delegated to an internal ``AnthropicClient``.
11+
///
12+
/// ```swift
13+
/// let auth = try GoogleAuthService()
14+
/// let client = VertexAnthropicClient(
15+
/// projectID: "my-project",
16+
/// location: "us-east5",
17+
/// model: "claude-sonnet-4-6",
18+
/// authService: auth
19+
/// )
20+
/// ```
21+
public struct VertexAnthropicClient: LLMClient, Sendable {
22+
public let contextWindowSize: Int?
23+
24+
let anthropic: AnthropicClient
25+
private let projectID: String
26+
private let location: String
27+
private let model: String
28+
private let tokenProvider: @Sendable () async throws -> String
29+
private let session: URLSession
30+
private let retryPolicy: RetryPolicy
31+
32+
public init(
33+
projectID: String,
34+
location: String,
35+
model: String,
36+
tokenProvider: @Sendable @escaping () async throws -> String,
37+
maxTokens: Int = 8192,
38+
contextWindowSize: Int? = nil,
39+
session: URLSession = .shared,
40+
retryPolicy: RetryPolicy = .default,
41+
reasoningConfig: ReasoningConfig? = nil,
42+
interleavedThinking: Bool = true,
43+
cachingEnabled: Bool = false
44+
) {
45+
self.projectID = projectID
46+
self.location = location
47+
self.model = model
48+
self.tokenProvider = tokenProvider
49+
self.session = session
50+
self.retryPolicy = retryPolicy
51+
self.contextWindowSize = contextWindowSize
52+
anthropic = AnthropicClient(
53+
apiKey: "",
54+
model: model,
55+
maxTokens: maxTokens,
56+
contextWindowSize: contextWindowSize,
57+
session: session,
58+
retryPolicy: retryPolicy,
59+
reasoningConfig: reasoningConfig,
60+
interleavedThinking: interleavedThinking,
61+
cachingEnabled: cachingEnabled
62+
)
63+
}
64+
65+
/// Convenience initializer that uses a ``GoogleAuthService`` for authentication.
66+
public init(
67+
projectID: String,
68+
location: String,
69+
model: String,
70+
authService: GoogleAuthService,
71+
maxTokens: Int = 8192,
72+
contextWindowSize: Int? = nil,
73+
session: URLSession = .shared,
74+
retryPolicy: RetryPolicy = .default,
75+
reasoningConfig: ReasoningConfig? = nil,
76+
interleavedThinking: Bool = true,
77+
cachingEnabled: Bool = false
78+
) {
79+
self.init(
80+
projectID: projectID,
81+
location: location,
82+
model: model,
83+
tokenProvider: { try await authService.accessToken() },
84+
maxTokens: maxTokens,
85+
contextWindowSize: contextWindowSize,
86+
session: session,
87+
retryPolicy: retryPolicy,
88+
reasoningConfig: reasoningConfig,
89+
interleavedThinking: interleavedThinking,
90+
cachingEnabled: cachingEnabled
91+
)
92+
}
93+
94+
// MARK: - LLMClient
95+
96+
public func generate(
97+
messages: [ChatMessage],
98+
tools: [ToolDefinition],
99+
responseFormat: ResponseFormat?,
100+
requestContext: RequestContext?
101+
) async throws -> AssistantMessage {
102+
if responseFormat != nil {
103+
throw AgentError.llmError(.other("VertexAnthropicClient does not support responseFormat"))
104+
}
105+
let request = try anthropic.buildRequest(
106+
messages: messages,
107+
tools: tools,
108+
extraFields: requestContext?.extraFields ?? [:]
109+
)
110+
let token = try await tokenProvider()
111+
let urlRequest = try buildVertexURLRequest(
112+
VertexAnthropicRequest(inner: request), stream: false, token: token
113+
)
114+
let (data, httpResponse) = try await HTTPRetry.performData(
115+
urlRequest: urlRequest, session: session, retryPolicy: retryPolicy
116+
)
117+
requestContext?.onResponse?(httpResponse)
118+
return try anthropic.parseResponse(data)
119+
}
120+
121+
public func stream(
122+
messages: [ChatMessage],
123+
tools: [ToolDefinition],
124+
requestContext: RequestContext?
125+
) -> AsyncThrowingStream<StreamDelta, Error> {
126+
AsyncThrowingStream { continuation in
127+
let task = Task {
128+
do {
129+
try await performStreamRequest(
130+
messages: messages,
131+
tools: tools,
132+
extraFields: requestContext?.extraFields ?? [:],
133+
onResponse: requestContext?.onResponse,
134+
continuation: continuation
135+
)
136+
} catch {
137+
continuation.finish(throwing: error)
138+
}
139+
}
140+
continuation.onTermination = { _ in task.cancel() }
141+
}
142+
}
143+
144+
// MARK: - Streaming
145+
146+
private func performStreamRequest(
147+
messages: [ChatMessage],
148+
tools: [ToolDefinition],
149+
extraFields: [String: JSONValue],
150+
onResponse: (@Sendable (HTTPURLResponse) -> Void)?,
151+
continuation: AsyncThrowingStream<StreamDelta, Error>.Continuation
152+
) async throws {
153+
let request = try anthropic.buildRequest(
154+
messages: messages, tools: tools,
155+
stream: true, extraFields: extraFields
156+
)
157+
let token = try await tokenProvider()
158+
let urlRequest = try buildVertexURLRequest(
159+
VertexAnthropicRequest(inner: request), stream: true, token: token
160+
)
161+
let (bytes, httpResponse) = try await HTTPRetry.performStream(
162+
urlRequest: urlRequest, session: session, retryPolicy: retryPolicy
163+
)
164+
onResponse?(httpResponse)
165+
166+
let state = AnthropicStreamState()
167+
168+
try await processSSEStream(
169+
bytes: bytes,
170+
stallTimeout: retryPolicy.streamStallTimeout
171+
) { line in
172+
try await anthropic.handleSSELine(
173+
line, state: state, continuation: continuation
174+
)
175+
}
176+
continuation.finish()
177+
}
178+
179+
// MARK: - URL Construction
180+
181+
func buildVertexURLRequest(
182+
_ request: VertexAnthropicRequest,
183+
stream: Bool,
184+
token: String
185+
) throws -> URLRequest {
186+
let action = stream ? "streamRawPredict" : "rawPredict"
187+
let basePath = "v1/projects/\(projectID)/locations/\(location)"
188+
+ "/publishers/anthropic/models/\(model):\(action)"
189+
let baseURL = URL(string: "https://\(location)-aiplatform.googleapis.com")!
190+
let url = baseURL.appendingPathComponent(basePath)
191+
192+
let headers = ["Authorization": "Bearer \(token)"]
193+
return try buildJSONPostRequest(url: url, body: request, headers: headers)
194+
}
195+
}
196+
197+
// MARK: - Vertex Anthropic Request Wrapper
198+
199+
/// Wraps an ``AnthropicRequest`` and injects `"anthropic_version": "vertex-2023-10-16"`
200+
/// into the encoded JSON body for Vertex AI compatibility.
201+
struct VertexAnthropicRequest: Encodable {
202+
static let vertexAnthropicVersion = "vertex-2023-10-16"
203+
204+
let inner: AnthropicRequest
205+
206+
func encode(to encoder: any Encoder) throws {
207+
try inner.encode(to: encoder)
208+
var container = encoder.container(keyedBy: DynamicCodingKey.self)
209+
try container.encode(
210+
Self.vertexAnthropicVersion,
211+
forKey: DynamicCodingKey("anthropic_version")
212+
)
213+
}
214+
}

0 commit comments

Comments
 (0)