Skip to content

Commit a38f534

Browse files
committed
feat: record user agent in AI Bridge interceptions
1 parent 9f6ce75 commit a38f534

10 files changed

Lines changed: 341 additions & 198 deletions

File tree

enterprise/aibridged/aibridged_integration_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package aibridged_test
33
import (
44
"bytes"
55
"context"
6+
"encoding/json"
67
"fmt"
78
"net/http"
89
"net/http/httptest"
@@ -226,10 +227,13 @@ func TestIntegration(t *testing.T) {
226227
}
227228
]
228229
}`))
230+
userAgent := "userAgent123"
229231
require.NoError(t, err, "make request to test server")
230232
req.Header.Add("Authorization", "Bearer "+apiKey.Key)
231233
req.Header.Add("Accept", "application/json")
234+
req.Header.Add("User-Agent", userAgent)
232235

236+
require.Equal(t, userAgent, req.UserAgent())
233237
// When: aibridged handles the request.
234238
rec := httptest.NewRecorder()
235239
srv.ServeHTTP(rec, req)
@@ -252,6 +256,13 @@ func TestIntegration(t *testing.T) {
252256
require.False(t, intc0.EndedAt.Time.Before(intc0.StartedAt), "EndedAt should not be before StartedAt")
253257
require.Less(t, intc0.EndedAt.Time.Sub(intc0.StartedAt), 5*time.Second)
254258

259+
require.True(t, intc0.Metadata.Valid)
260+
meta := map[string]any{}
261+
err = json.Unmarshal(intc0.Metadata.RawMessage, &meta)
262+
require.NoError(t, err)
263+
require.Contains(t, meta, "user-agent")
264+
require.Equal(t, userAgent, meta["user-agent"])
265+
255266
prompts, err := db.GetAIBridgeUserPromptsByInterceptionID(ctx, interceptions[0].ID)
256267
require.NoError(t, err)
257268
require.Len(t, prompts, 1)

enterprise/aibridged/http.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,10 @@ func (s *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
7878
}
7979

8080
handler, err := s.GetRequestHandler(ctx, Request{
81-
SessionKey: key,
8281
APIKeyID: resp.ApiKeyId,
8382
InitiatorID: id,
83+
SessionKey: key,
84+
UserAgent: r.UserAgent(),
8485
})
8586
if err != nil {
8687
logger.Warn(ctx, "failed to acquire request handler", slog.Error(err))

enterprise/aibridged/pool.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
146146
return nil, xerrors.Errorf("acquire client: %w", err)
147147
}
148148

149-
return &recorderTranslation{apiKeyID: req.APIKeyID, client: client}, nil
149+
return &recorderTranslation{
150+
apiKeyID: req.APIKeyID,
151+
client: client,
152+
}, nil
150153
})
151154

152155
// Slow path.

enterprise/aibridged/pool_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@ func TestPool(t *testing.T) {
5757

5858
// ...and it will return it when acquired again.
5959
instB, err := pool.Acquire(t.Context(), aibridged.Request{
60-
SessionKey: "key",
60+
SessionKey: "different key",
6161
InitiatorID: id,
6262
APIKeyID: apiKeyID1.String(),
63+
UserAgent: "some user-agent",
6364
}, clientFn, newMockMCPFactory(mcpProxy))
6465
require.NoError(t, err, "acquire pool instance")
6566
require.Same(t, inst, instB)

enterprise/aibridged/proto/aibridged.pb.go

Lines changed: 213 additions & 194 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

enterprise/aibridged/proto/aibridged.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ message RecordInterceptionRequest {
4343
map<string, google.protobuf.Any> metadata = 5;
4444
google.protobuf.Timestamp started_at = 6;
4545
string api_key_id = 7;
46+
string client = 8;
47+
string user_agent = 9;
4648
}
4749

4850
message RecordInterceptionResponse {}

enterprise/aibridged/request.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ package aibridged
33
import "github.com/google/uuid"
44

55
type Request struct {
6-
SessionKey string
76
APIKeyID string
87
InitiatorID uuid.UUID
8+
SessionKey string
9+
UserAgent string
910
}

enterprise/aibridged/translator.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ func (t *recorderTranslation) RecordInterception(ctx context.Context, req *aibri
3030
InitiatorId: req.InitiatorID,
3131
Provider: req.Provider,
3232
Model: req.Model,
33+
UserAgent: req.UserAgent,
3334
Metadata: marshalForProto(req.Metadata),
3435
StartedAt: timestamppb.New(req.StartedAt),
3536
})
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package aibridged //nolint:testpackage
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/google/uuid"
9+
"github.com/stretchr/testify/require"
10+
"google.golang.org/protobuf/proto"
11+
"google.golang.org/protobuf/types/known/anypb"
12+
"google.golang.org/protobuf/types/known/structpb"
13+
"google.golang.org/protobuf/types/known/timestamppb"
14+
15+
"github.com/coder/aibridge"
16+
abpb "github.com/coder/coder/v2/enterprise/aibridged/proto"
17+
"github.com/coder/coder/v2/testutil"
18+
)
19+
20+
const (
21+
MetaKeyUserAgent = "user-agent"
22+
)
23+
24+
type mockClient struct {
25+
abpb.DRPCRecorderClient
26+
got *abpb.RecordInterceptionRequest
27+
}
28+
29+
func (mc *mockClient) RecordInterception(ctx context.Context, in *abpb.RecordInterceptionRequest) (*abpb.RecordInterceptionResponse, error) {
30+
mc.got = in
31+
return &abpb.RecordInterceptionResponse{}, nil
32+
}
33+
34+
func mustAnypbNew(t *testing.T, src proto.Message) *anypb.Any {
35+
ret, err := anypb.New(src)
36+
require.NoError(t, err)
37+
return ret
38+
}
39+
40+
func TestRecordInterception(t *testing.T) {
41+
t.Parallel()
42+
43+
tests := []struct {
44+
name string
45+
apiKeyID string
46+
userAgent string
47+
in aibridge.InterceptionRecord
48+
expect *abpb.RecordInterceptionRequest
49+
}{
50+
{
51+
name: "ok",
52+
apiKeyID: "key",
53+
in: aibridge.InterceptionRecord{
54+
ID: uuid.UUID{1}.String(),
55+
InitiatorID: uuid.UUID{2}.String(),
56+
Provider: "prov",
57+
Model: "model",
58+
UserAgent: "user-agent",
59+
Metadata: map[string]any{"some": "data"},
60+
StartedAt: time.UnixMicro(123),
61+
},
62+
expect: &abpb.RecordInterceptionRequest{
63+
Id: uuid.UUID{1}.String(),
64+
ApiKeyId: "key",
65+
InitiatorId: uuid.UUID{2}.String(),
66+
Provider: "prov",
67+
Model: "model",
68+
UserAgent: "user-agent",
69+
Metadata: map[string]*anypb.Any{
70+
"some": mustAnypbNew(t, structpb.NewStringValue("data")),
71+
},
72+
StartedAt: timestamppb.New(time.UnixMicro(123)),
73+
},
74+
},
75+
}
76+
77+
for _, tc := range tests {
78+
t.Run(tc.name, func(t *testing.T) {
79+
t.Parallel()
80+
ctx := testutil.Context(t, testutil.WaitShort)
81+
82+
mc := &mockClient{}
83+
rt := &recorderTranslation{
84+
apiKeyID: tc.apiKeyID,
85+
client: mc,
86+
}
87+
rt.RecordInterception(ctx, &tc.in)
88+
require.Equal(t, tc.expect, mc.got)
89+
})
90+
}
91+
}

enterprise/aibridgedserver/aibridgedserver.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ import (
2727
"github.com/coder/coder/v2/enterprise/aibridged/proto"
2828
)
2929

30+
const (
31+
MetaKeyUserAgent = "user-agent"
32+
)
33+
3034
var (
3135
ErrExpiredOrInvalidOAuthToken = xerrors.New("expired or invalid OAuth2 token")
3236
ErrNoMCPConfigFound = xerrors.New("no MCP config found")
@@ -144,6 +148,15 @@ func (s *Server) RecordInterception(ctx context.Context, in *proto.RecordInterce
144148
)
145149
}
146150

151+
if in.UserAgent != "" {
152+
md, err := anypb.New(structpb.NewStringValue(in.UserAgent))
153+
if err != nil {
154+
s.logger.Warn(ctx, "failed to convert user agent to proto", slog.Error(err))
155+
return nil, xerrors.Errorf("invalid user agent")
156+
}
157+
metadata[MetaKeyUserAgent] = md
158+
}
159+
147160
out, err := json.Marshal(metadata)
148161
if err != nil {
149162
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))

0 commit comments

Comments
 (0)