Skip to content

Commit 561f10e

Browse files
committed
chore: review comments
Signed-off-by: Danny Kopping <danny@coder.com>
1 parent e4c2a82 commit 561f10e

3 files changed

Lines changed: 138 additions & 39 deletions

File tree

apidump_integration_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,15 @@ func TestAPIDump(t *testing.T) {
7272
},
7373
createRequestFunc: createOpenAIChatCompletionsReq,
7474
},
75+
{
76+
name: config.ProviderOpenAI,
77+
fixture: fixtures.OaiResponsesBlockingSimple,
78+
providerName: config.ProviderOpenAI,
79+
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
80+
return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))}
81+
},
82+
createRequestFunc: createOpenAIResponsesReq,
83+
},
7584
}
7685

7786
for _, tc := range cases {

intercept/apidump/apidump.go

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, lo
5454
logger.Named("apidump").Warn(context.Background(), "failed to dump request", slog.Error(err))
5555
}
5656

57+
// TODO: https://github.com/coder/aibridge/issues/129
5758
resp, err := next(req)
5859
if err != nil {
5960
return resp, err
@@ -98,29 +99,10 @@ func (d *dumper) dumpRequest(req *http.Request) error {
9899
// Build raw HTTP request format
99100
var buf bytes.Buffer
100101
fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto)
101-
fmt.Fprintf(&buf, "Host: %s\r\n", req.Host)
102-
fmt.Fprintf(&buf, "Content-Length: %d\r\n", len(prettyBody))
102+
d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{
103+
"Content-Length": fmt.Sprintf("%d", len(prettyBody)),
104+
})
103105

104-
// Sort header keys for deterministic output.
105-
headerKeys := make([]string, 0, len(req.Header))
106-
for key := range req.Header {
107-
headerKeys = append(headerKeys, key)
108-
}
109-
slices.Sort(headerKeys)
110-
111-
for _, key := range headerKeys {
112-
// Skip Content-Length since we write it explicitly above with the pretty-printed body length.
113-
if key == "Content-Length" {
114-
continue
115-
}
116-
_, sensitive := sensitiveRequestHeaders[key]
117-
for _, value := range req.Header[key] {
118-
if sensitive {
119-
value = redactHeaderValue(value)
120-
}
121-
fmt.Fprintf(&buf, "%s: %s\r\n", key, value)
122-
}
123-
}
124106
fmt.Fprintf(&buf, "\r\n")
125107
buf.Write(prettyBody)
126108

@@ -133,23 +115,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
133115
// Build raw HTTP response headers
134116
var headerBuf bytes.Buffer
135117
fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status)
136-
137-
// Sort header keys for deterministic output.
138-
headerKeys := make([]string, 0, len(resp.Header))
139-
for key := range resp.Header {
140-
headerKeys = append(headerKeys, key)
141-
}
142-
slices.Sort(headerKeys)
143-
144-
for _, key := range headerKeys {
145-
_, sensitive := sensitiveResponseHeaders[key]
146-
for _, value := range resp.Header[key] {
147-
if sensitive {
148-
value = redactHeaderValue(value)
149-
}
150-
fmt.Fprintf(&headerBuf, "%s: %s\r\n", key, value)
151-
}
152-
}
118+
d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil)
153119
fmt.Fprintf(&headerBuf, "\r\n")
154120

155121
// Wrap the response body to capture it as it streams
@@ -170,6 +136,50 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
170136
return nil
171137
}
172138

139+
// writeRedactedHeaders writes HTTP headers in wire format (Key: Value\r\n) to w,
140+
// redacting sensitive values and applying any overrides. Headers are sorted by key
141+
// for deterministic output.
142+
// `sensitive` and `overrides` must both supply keys in canoncialized form.
143+
// See [textproto.MIMEHeader].
144+
func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) {
145+
// Collect all header keys including overrides.
146+
headerKeys := make([]string, 0, len(headers)+len(overrides))
147+
seen := make(map[string]struct{}, len(headers)+len(overrides))
148+
for key := range headers {
149+
headerKeys = append(headerKeys, key)
150+
seen[key] = struct{}{}
151+
}
152+
// Add override keys that don't exist in headers.
153+
for key := range overrides {
154+
if _, ok := seen[key]; !ok {
155+
headerKeys = append(headerKeys, key)
156+
}
157+
}
158+
slices.Sort(headerKeys)
159+
160+
for _, key := range headerKeys {
161+
_, isSensitive := sensitive[key]
162+
values := headers[key]
163+
// If no values exist but we have an override, use that.
164+
if len(values) == 0 {
165+
if override, ok := overrides[key]; ok {
166+
fmt.Fprintf(w, "%s: %s\r\n", key, override)
167+
}
168+
continue
169+
}
170+
for _, value := range values {
171+
if override, ok := overrides[key]; ok {
172+
value = override
173+
}
174+
175+
if isSensitive {
176+
value = redactHeaderValue(value)
177+
}
178+
fmt.Fprintf(w, "%s: %s\r\n", key, value)
179+
}
180+
}
181+
}
182+
173183
// path returns the path to a request/response dump file for a given interception.
174184
// suffix should be SuffixRequest or SuffixResponse.
175185
func (d *dumper) path(suffix string) string {

intercept/apidump/headers_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
package apidump
22

33
import (
4+
"bytes"
5+
"net/http"
46
"testing"
57

8+
"cdr.dev/slog/v3"
9+
10+
"github.com/coder/quartz"
11+
"github.com/google/uuid"
612
"github.com/stretchr/testify/require"
713
)
814

@@ -90,3 +96,77 @@ func TestSensitiveHeaderLists(t *testing.T) {
9096
require.True(t, ok, "expected %q to be in sensitiveResponseHeaders", h)
9197
}
9298
}
99+
100+
func TestWriteRedactedHeaders(t *testing.T) {
101+
t.Parallel()
102+
103+
d := &dumper{
104+
baseDir: "/tmp",
105+
provider: "test",
106+
model: "test",
107+
interceptionID: uuid.New(),
108+
clk: quartz.NewMock(t),
109+
logger: slog.Make(),
110+
}
111+
112+
tests := []struct {
113+
name string
114+
headers http.Header
115+
sensitive map[string]struct{}
116+
overrides map[string]string
117+
expected string
118+
}{
119+
{
120+
name: "empty headers",
121+
headers: http.Header{},
122+
expected: "",
123+
},
124+
{
125+
name: "single header",
126+
headers: http.Header{"Content-Type": {"application/json"}},
127+
expected: "Content-Type: application/json\r\n",
128+
},
129+
{
130+
name: "sorted alphabetically",
131+
headers: http.Header{
132+
"Zebra": {"last"},
133+
"Alpha": {"first"},
134+
},
135+
expected: "Alpha: first\r\nZebra: last\r\n",
136+
},
137+
{
138+
name: "override applied",
139+
headers: http.Header{"Content-Length": {"100"}},
140+
overrides: map[string]string{"Content-Length": "200"},
141+
expected: "Content-Length: 200\r\n",
142+
},
143+
{
144+
name: "sensitive header redacted",
145+
headers: http.Header{"Set-Cookie": {"session=abcdefghij"}},
146+
sensitive: sensitiveResponseHeaders,
147+
expected: "Set-Cookie: sess...ghij\r\n",
148+
},
149+
{
150+
name: "multi-value header",
151+
headers: http.Header{
152+
"Accept": {"text/html", "application/json"},
153+
},
154+
expected: "Accept: text/html\r\nAccept: application/json\r\n",
155+
},
156+
{
157+
name: "override for non-existent header",
158+
headers: http.Header{},
159+
overrides: map[string]string{"Host": "example.com"},
160+
expected: "Host: example.com\r\n",
161+
},
162+
}
163+
164+
for _, tc := range tests {
165+
t.Run(tc.name, func(t *testing.T) {
166+
t.Parallel()
167+
var buf bytes.Buffer
168+
d.writeRedactedHeaders(&buf, tc.headers, tc.sensitive, tc.overrides)
169+
require.Equal(t, tc.expected, buf.String())
170+
})
171+
}
172+
}

0 commit comments

Comments
 (0)