test: add comprehensive test harness across all packages (156 tests)

Characterization tests capturing current behavior before refactoring.
Covers auth, config, logging, proxy, ratelimit, server, and telemetry
packages with race-safe concurrent access tests.
This commit is contained in:
Alexander
2026-04-15 10:40:43 +02:00
parent d3fbfe8b42
commit 9150f466e5
13 changed files with 4325 additions and 0 deletions
+323
View File
@@ -0,0 +1,323 @@
package proxy
import (
"encoding/hex"
"strings"
"testing"
)
func TestFingerprintSaltConstant(t *testing.T) {
if fingerprintSalt != "59cf53e54c78" {
t.Errorf("fingerprintSalt = %q, want %q", fingerprintSalt, "59cf53e54c78")
}
}
func TestComputeFingerprint_Deterministic(t *testing.T) {
a := computeFingerprint("hello world test message", "1.0.0")
b := computeFingerprint("hello world test message", "1.0.0")
if a != b {
t.Errorf("fingerprint not deterministic: %q != %q", a, b)
}
}
func TestComputeFingerprint_Length(t *testing.T) {
fp := computeFingerprint("some message here", "2.0.0")
if len(fp) != 3 {
t.Errorf("fingerprint length = %d, want 3", len(fp))
}
// Must be valid hex
if _, err := hex.DecodeString(fp + "0"); err != nil { // pad to even length for decode
// Check each char is hex individually
for _, c := range fp {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("fingerprint %q contains non-hex char %c", fp, c)
}
}
}
}
func TestComputeFingerprint_DifferentVersions(t *testing.T) {
a := computeFingerprint("same message", "1.0.0")
b := computeFingerprint("same message", "2.0.0")
if a == b {
t.Errorf("different versions should (almost certainly) produce different fingerprints")
}
}
func TestComputeFingerprint_ShortMessage(t *testing.T) {
// "hi" has only 2 chars, indices [4,7,20] all out of range → chars = "000"
fp := computeFingerprint("hi", "1.0.0")
if len(fp) != 3 {
t.Errorf("short message fingerprint length = %d, want 3", len(fp))
}
}
func TestComputeFingerprint_EmptyMessage(t *testing.T) {
// Empty → all indices out of range → chars = "000"
fp := computeFingerprint("", "1.0.0")
if len(fp) != 3 {
t.Errorf("empty message fingerprint length = %d, want 3", len(fp))
}
// Empty and short message with same version should produce same fingerprint
// since both result in chars = "000"
fpShort := computeFingerprint("hi", "1.0.0")
if fp != fpShort {
t.Errorf("empty and 'hi' should produce same fingerprint (both use '000'), got %q vs %q", fp, fpShort)
}
}
func TestComputeFingerprint_Unicode(t *testing.T) {
// Emoji: 🎉 is U+1F389, encoded as UTF-16 surrogate pair [0xD83C, 0xDF89]
// So "abcd🎉fg" in UTF-16 is [a, b, c, d, 0xD83C, 0xDF89, f, g] = 8 uint16 values
// indices [4,7,20]: runes[4]=0xD83C, runes[7]='g', runes[20]=out of range
fp := computeFingerprint("abcd🎉fg", "1.0.0")
if len(fp) != 3 {
t.Errorf("unicode fingerprint length = %d, want 3", len(fp))
}
}
func TestComputeFingerprint_CharExtraction(t *testing.T) {
// "Hello, World!" UTF-16: [H,e,l,l,o,',', ,W,o,r,l,d,!]
// indices [4,7,20]: runes[4]='o', runes[7]='W', runes[20]=out of range → "0"
// So chars should be "oW0"
// Verify by comparing to a message where we know the expected extracted chars
// Two messages that extract same chars at indices should produce same fingerprint
// "xxxxoxxWxxxxxxxxxxxx" → index 4='o', 7='W', 20=out of range → "oW0" (20 chars, index 20 out of range)
fp1 := computeFingerprint("Hello, World!", "1.0.0")
fp2 := computeFingerprint("xxxxoxxWxxxxxxxxxxxx", "1.0.0")
if fp1 != fp2 {
t.Errorf("messages with same chars at indices [4,7,20] should produce same fingerprint, got %q vs %q", fp1, fp2)
}
}
func TestComputeFingerprint_IndexBoundary(t *testing.T) {
// Message with exactly 21 chars → index 20 is valid
msg21 := "abcdefghijklmnopqrstu" // 21 chars
fp21 := computeFingerprint(msg21, "1.0.0")
// Message with exactly 20 chars → index 20 is out of range → "0"
msg20 := "abcdefghijklmnopqrst" // 20 chars
fp20 := computeFingerprint(msg20, "1.0.0")
// They should differ because index 20 produces different chars
if fp21 == fp20 {
t.Errorf("boundary test: 21-char and 20-char messages should differ at index 20")
}
}
func TestExtractFirstUserMessage(t *testing.T) {
tests := []struct {
name string
body string
expected string
}{
{
name: "simple string content",
body: `{"messages":[{"role":"user","content":"hello world"}]}`,
expected: "hello world",
},
{
name: "array content with text block",
body: `{"messages":[{"role":"user","content":[{"type":"text","text":"from array"}]}]}`,
expected: "from array",
},
{
name: "no user messages",
body: `{"messages":[{"role":"assistant","content":"I am assistant"}]}`,
expected: "",
},
{
name: "assistant only messages",
body: `{"messages":[{"role":"assistant","content":"a1"},{"role":"assistant","content":"a2"}]}`,
expected: "",
},
{
name: "user with non-text block first then text",
body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"},{"type":"text","text":"the text"}]}]}`,
expected: "the text",
},
{
name: "user with only non-text blocks",
body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]}]}`,
expected: "",
},
{
name: "no messages field",
body: `{"model":"claude-sonnet-4-6"}`,
expected: "",
},
{
name: "messages not array",
body: `{"messages":"not array"}`,
expected: "",
},
{
name: "empty messages array",
body: `{"messages":[]}`,
expected: "",
},
{
name: "first user message used even if multiple exist",
body: `{"messages":[{"role":"user","content":"first"},{"role":"user","content":"second"}]}`,
expected: "first",
},
{
name: "assistant before user",
body: `{"messages":[{"role":"assistant","content":"assistant msg"},{"role":"user","content":"user msg"}]}`,
expected: "user msg",
},
{
name: "user with array content - first text block used",
body: `{"messages":[{"role":"user","content":[{"type":"text","text":"first text"},{"type":"text","text":"second text"}]}]}`,
expected: "first text",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractFirstUserMessage([]byte(tt.body))
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}
func TestExtractFirstUserMessage_BreaksAfterFirstUser(t *testing.T) {
// The function should break after finding the first user message,
// even if it didn't extract text (e.g. user with only image blocks)
body := `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]},{"role":"user","content":"second user"}]}`
result := extractFirstUserMessage([]byte(body))
// First user has no text blocks, function breaks, returns ""
if result != "" {
t.Errorf("should return empty when first user has no text, got %q", result)
}
}
func TestBuildBillingHeader(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"test message"}]}`)
version := "1.2.3"
header := buildBillingHeader(body, version)
// Check format
if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=1.2.3.") {
t.Errorf("header should start with 'x-anthropic-billing-header: cc_version=1.2.3.', got %q", header)
}
if !strings.Contains(header, "; cc_entrypoint=cli; cch=00000;") {
t.Errorf("header should contain '; cc_entrypoint=cli; cch=00000;', got %q", header)
}
// Verify the fingerprint part is 3 chars
// Format: "x-anthropic-billing-header: cc_version=1.2.3.XXX; cc_entrypoint=cli; cch=00000;"
parts := strings.Split(header, "cc_version=")
if len(parts) != 2 {
t.Fatalf("unexpected header format: %q", header)
}
versionFP := strings.Split(parts[1], ";")[0]
if !strings.HasPrefix(versionFP, "1.2.3.") {
t.Errorf("version+fingerprint should start with '1.2.3.', got %q", versionFP)
}
fp := strings.TrimPrefix(versionFP, "1.2.3.")
if len(fp) != 3 {
t.Errorf("fingerprint should be 3 chars, got %q (len %d)", fp, len(fp))
}
}
func TestBuildBillingHeader_EmptyMessages(t *testing.T) {
body := []byte(`{"messages":[]}`)
version := "1.0.0"
header := buildBillingHeader(body, version)
if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=") {
t.Errorf("header format wrong: %q", header)
}
}
func TestInjectBillingHeader_NoExistingSystem(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
result := injectBillingHeader(body, "1.0.0")
resultStr := string(result)
// Should have system field now
if !strings.Contains(resultStr, `"system"`) {
t.Errorf("should inject system field, got %s", resultStr)
}
// System should be an array with one billing block
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
t.Errorf("should contain billing header text, got %s", resultStr)
}
if !strings.Contains(resultStr, `"type":"text"`) {
t.Errorf("billing block should have type text, got %s", resultStr)
}
}
func TestInjectBillingHeader_ExistingSystemArray(t *testing.T) {
body := []byte(`{"system":[{"type":"text","text":"existing prompt"}],"messages":[{"role":"user","content":"hi"}]}`)
result := injectBillingHeader(body, "1.0.0")
resultStr := string(result)
// Should contain both billing header and existing prompt
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
t.Errorf("should contain billing header, got %s", resultStr)
}
if !strings.Contains(resultStr, "existing prompt") {
t.Errorf("should preserve existing prompt, got %s", resultStr)
}
// Billing block should be FIRST (prepended)
billingIdx := strings.Index(resultStr, "x-anthropic-billing-header")
existingIdx := strings.Index(resultStr, "existing prompt")
if billingIdx > existingIdx {
t.Errorf("billing block should come before existing prompt")
}
}
func TestInjectBillingHeader_ExistingSystemString(t *testing.T) {
body := []byte(`{"system":"You are a helpful assistant","messages":[{"role":"user","content":"hi"}]}`)
result := injectBillingHeader(body, "1.0.0")
resultStr := string(result)
// Should convert to array with billing block first, then original text
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
t.Errorf("should contain billing header, got %s", resultStr)
}
if !strings.Contains(resultStr, "You are a helpful assistant") {
t.Errorf("should preserve original system string, got %s", resultStr)
}
// Billing should come first
billingIdx := strings.Index(resultStr, "x-anthropic-billing-header")
origIdx := strings.Index(resultStr, "You are a helpful assistant")
if billingIdx > origIdx {
t.Errorf("billing block should come before original system text")
}
}
func TestInjectBillingHeader_PreservesOtherFields(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
result := injectBillingHeader(body, "1.0.0")
resultStr := string(result)
if !strings.Contains(resultStr, `"model":"claude-sonnet-4-6"`) {
t.Errorf("should preserve model field, got %s", resultStr)
}
if !strings.Contains(resultStr, `"max_tokens":1024`) {
t.Errorf("should preserve max_tokens field, got %s", resultStr)
}
}
func TestInjectBillingHeader_BillingBlockFormat(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
result := injectBillingHeader(body, "2.5.0")
resultStr := string(result)
// Verify the billing block contains the correct version
if !strings.Contains(resultStr, "cc_version=2.5.0.") {
t.Errorf("billing block should contain cc_version=2.5.0., got %s", resultStr)
}
if !strings.Contains(resultStr, "cc_entrypoint=cli") {
t.Errorf("billing block should contain cc_entrypoint=cli, got %s", resultStr)
}
if !strings.Contains(resultStr, "cch=00000") {
t.Errorf("billing block should contain cch=00000, got %s", resultStr)
}
}
+624
View File
@@ -0,0 +1,624 @@
package proxy
import (
"bytes"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/fujin/anthropic-proxy/internal/auth"
"github.com/fujin/anthropic-proxy/internal/config"
"github.com/fujin/anthropic-proxy/internal/ratelimit"
"github.com/fujin/anthropic-proxy/internal/telemetry"
"go.opentelemetry.io/otel/metric/noop"
)
func init() {
gin.SetMode(gin.TestMode)
// Initialize telemetry with noop meter to avoid nil pointer panics.
meter := noop.Meter{}
telemetry.InitMetrics(meter, nil)
}
// --- Request body reading and sanitization ---
func TestHandleMessages_ReadBodyError(t *testing.T) {
// A body that immediately fails on read shouldn't panic.
pool := auth.NewPool([]*auth.Credential{{ID: "c1", AccessToken: "tok", Email: "test@test.com"}})
san := NewSanitizer(config.SanitizeConfig{})
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", &errReader{})
handler(c)
if w.Code != http.StatusBadRequest {
t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest)
}
if !strings.Contains(w.Body.String(), "failed to read request body") {
t.Errorf("body = %q, expected error message about reading body", w.Body.String())
}
}
func TestHandleMessages_SanitizesRequestBody(t *testing.T) {
// We can't directly make HandleMessages use our mock server because
// UpstreamClient hardcodes messagesURL. Instead, we test sanitization
// by verifying the sanitizer is called on the body before any pool interaction.
san := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}},
Body: []config.ReplaceRule{{Match: "secret", Replace: "redacted"}},
})
// Create body with tool name and secret
body := `{"model":"claude-sonnet-4-6","tools":[{"name":"my_tool"}],"messages":[{"role":"user","content":"secret data"}]}`
sanitizedBody := san.SanitizeRequest([]byte(body))
// Verify sanitization happened correctly
if !strings.Contains(string(sanitizedBody), "renamed_tool") {
t.Error("expected tool to be renamed in sanitized body")
}
if strings.Contains(string(sanitizedBody), "my_tool") {
t.Error("original tool name should be gone after sanitization")
}
if !strings.Contains(string(sanitizedBody), "redacted") {
t.Error("expected 'secret' to be replaced with 'redacted'")
}
if strings.Contains(string(sanitizedBody), "secret") {
t.Error("'secret' should be gone after sanitization")
}
}
func TestHandleMessages_PoolPickError(t *testing.T) {
// Empty pool — Pick() will fail.
pool := auth.NewPool(nil)
san := NewSanitizer(config.SanitizeConfig{})
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
handler(c)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
}
if !strings.Contains(w.Body.String(), "no credentials available") {
t.Errorf("body = %q, expected pool error", w.Body.String())
}
}
func TestHandleMessages_PoolAllOnCooldown(t *testing.T) {
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "e"}
pool := auth.NewPool([]*auth.Credential{cred})
pool.MarkFailure(cred, 429) // puts on 30s cooldown
san := NewSanitizer(config.SanitizeConfig{})
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
handler(c)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
}
if !strings.Contains(w.Body.String(), "cooldown") {
t.Errorf("body = %q, expected cooldown message", w.Body.String())
}
}
// --- Stream vs non-stream routing ---
func TestHandleMessages_StreamField_Detection(t *testing.T) {
tests := []struct {
name string
body string
isStream bool
}{
{
name: "stream true",
body: `{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`,
isStream: true,
},
{
name: "stream false",
body: `{"model":"claude-sonnet-4-6","stream":false,"messages":[{"role":"user","content":"hi"}]}`,
isStream: false,
},
{
name: "no stream field defaults to false",
body: `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`,
isStream: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := gjson.Get(tt.body, "stream").Bool()
if got != tt.isStream {
t.Errorf("stream = %v, want %v", got, tt.isStream)
}
})
}
}
// --- Desanitization on response ---
func TestDesanitization_NonStreamResponse(t *testing.T) {
san := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
})
// Simulate upstream response with renamed tool
upstreamResponse := `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}]}`
desanitized := san.DesanitizeResponse([]byte(upstreamResponse))
if !strings.Contains(string(desanitized), "original_tool") {
t.Errorf("expected tool name to be desanitized back to 'original_tool', got %s", string(desanitized))
}
if strings.Contains(string(desanitized), `"name":"renamed_tool"`) {
t.Error("renamed_tool should have been replaced by original_tool")
}
}
func TestDesanitization_StreamEvent(t *testing.T) {
san := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
})
event := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}`
desanitized := san.DesanitizeStreamEvent(event)
if !strings.Contains(desanitized, "original_tool") {
t.Errorf("expected stream event to be desanitized, got %s", desanitized)
}
}
// --- handleNonStream behavior tests via direct function ---
func TestHandleNonStream_ConnectionError(t *testing.T) {
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
tracker := ratelimit.NewTracker(func() string { return "" })
uc := &UpstreamClient{
client: http.Client{Transport: &failingTransport{}},
sessionID: "test-sess",
profile: nil,
}
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleNonStream(c, uc, san, pool, cred, body, body, tracker)
if w.Code != http.StatusBadGateway {
t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway)
}
if !strings.Contains(w.Body.String(), "upstream request failed") {
t.Errorf("body = %q, expected upstream error message", w.Body.String())
}
}
func TestHandleNonStream_UpstreamSuccess(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Request-Id", "req-123")
w.WriteHeader(200)
w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hello"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
tracker := ratelimit.NewTracker(func() string { return "" })
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
// Override the messagesURL by constructing a custom Execute that uses the mock.
// Since we can't override the const, we test via a mock server approach:
// We create a custom http.Client with a transport that redirects to our mock.
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleNonStream(c, uc, san, pool, cred, body, body, tracker)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
}
if !strings.Contains(w.Body.String(), "hello") {
t.Errorf("response body missing expected content: %s", w.Body.String())
}
if got := w.Header().Get("Content-Type"); got != "application/json" {
t.Errorf("Content-Type = %q, want %q", got, "application/json")
}
if got := w.Header().Get("X-Request-Id"); got != "req-123" {
t.Errorf("X-Request-Id = %q, want %q", got, "req-123")
}
}
func TestHandleNonStream_UpstreamError_MarkFailure(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429)
w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"too many requests"}}`))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleNonStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != 429 {
t.Errorf("status = %d, want 429", w.Code)
}
// Verify MarkFailure was called — cred should now be on cooldown
if !cred.IsOnCooldown() {
t.Error("expected credential to be on cooldown after 429")
}
}
func TestHandleNonStream_UpstreamSuccess_DesanitizesResponse(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}],"model":"claude-sonnet-4-6","stop_reason":"tool_use","usage":{"input_tokens":10,"output_tokens":5}}`))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleNonStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
}
// Should be desanitized back to original_tool
if !strings.Contains(w.Body.String(), "original_tool") {
t.Errorf("response should contain desanitized tool name 'original_tool', got %s", w.Body.String())
}
}
func TestHandleNonStream_Upstream500_MarkFailure(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500)
w.Write([]byte(`{"error":{"type":"server_error","message":"internal error"}}`))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleNonStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != 500 {
t.Errorf("status = %d, want 500", w.Code)
}
if !cred.IsOnCooldown() {
t.Error("expected credential to be on cooldown after 500")
}
}
// --- handleStream behavior tests ---
func TestHandleStream_ConnectionError(t *testing.T) {
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
uc := &UpstreamClient{
client: http.Client{Transport: &failingTransport{}},
sessionID: "test-sess",
profile: nil,
}
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != http.StatusBadGateway {
t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway)
}
if !strings.Contains(w.Body.String(), "upstream stream request failed") {
t.Errorf("body = %q, expected upstream stream error", w.Body.String())
}
}
func TestHandleStream_UpstreamError(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429)
w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"rate limited"}}`))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != 429 {
t.Errorf("status = %d, want 429", w.Code)
}
if !cred.IsOnCooldown() {
t.Error("expected credential on cooldown after stream 429")
}
}
func TestHandleStream_SuccessForwardsEvents(t *testing.T) {
events := "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":0}}}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(200)
w.Write([]byte(events))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
respBody := w.Body.String()
if !strings.Contains(respBody, "message_start") {
t.Error("response missing message_start event")
}
if !strings.Contains(respBody, "hello") {
t.Error("response missing text content 'hello'")
}
if !strings.Contains(respBody, "message_stop") {
t.Error("response missing message_stop event")
}
// Verify SSE headers
if got := w.Header().Get("Content-Type"); got != "text/event-stream" {
t.Errorf("Content-Type = %q, want %q", got, "text/event-stream")
}
}
func TestHandleStream_DesanitizesEvents(t *testing.T) {
events := "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"name\":\"renamed_tool\",\"id\":\"t1\"}}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(200)
w.Write([]byte(events))
}))
defer mockUpstream.Close()
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
})
uc := &UpstreamClient{
client: *mockUpstream.Client(),
sessionID: "test-sess",
profile: nil,
}
uc.client.Transport = &rewriteTransport{
base: mockUpstream.Client().Transport,
destURL: mockUpstream.URL,
}
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
handleStream(c, uc, san, pool, cred, body, body, nil)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
}
if !strings.Contains(w.Body.String(), "original_tool") {
t.Errorf("stream response should contain desanitized 'original_tool', got %s", w.Body.String())
}
}
// --- HandleMessages full integration wiring test ---
func TestHandleMessages_WiresHandlerCorrectly(t *testing.T) {
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
// Verify the handler can be created without panic
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
if handler == nil {
t.Fatal("HandleMessages returned nil handler")
}
}
func TestHandleMessages_EmptyBody(t *testing.T) {
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
pool := auth.NewPool([]*auth.Credential{cred})
san := NewSanitizer(config.SanitizeConfig{})
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
// Empty body — handler should still try to pick cred and call upstream
// (which will fail with connection error to api.anthropic.com, not a panic)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(""))
handler(c)
// Should get a 502 because the upstream URL (api.anthropic.com) won't be reachable
// in test environment, or it might complete. The key thing is no panic.
// We mainly verify it doesn't panic.
if w.Code == 0 {
t.Error("expected non-zero status code")
}
}
// --- Test helpers ---
// errReader is an io.Reader that always returns an error.
type errReader struct{}
func (e *errReader) Read([]byte) (int, error) {
return 0, io.ErrUnexpectedEOF
}
// failingTransport is an http.RoundTripper that always returns an error.
type failingTransport struct{}
func (f *failingTransport) RoundTrip(*http.Request) (*http.Response, error) {
return nil, fmt.Errorf("connection refused")
}
// rewriteTransport intercepts HTTP requests and rewrites the URL to point
// at a local test server. This allows testing with UpstreamClient's hardcoded
// messagesURL by redirecting all requests to a mock server.
type rewriteTransport struct {
base http.RoundTripper
destURL string
}
func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL to point at our mock server
newReq := req.Clone(req.Context())
newReq.URL.Scheme = "http"
newReq.URL.Host = strings.TrimPrefix(t.destURL, "http://")
newReq.URL.Path = "/v1/messages"
newReq.URL.RawQuery = ""
if t.base == nil {
return http.DefaultTransport.RoundTrip(newReq)
}
return t.base.RoundTrip(newReq)
}
+476
View File
@@ -0,0 +1,476 @@
package proxy
import (
"strings"
"testing"
"github.com/fujin/anthropic-proxy/internal/config"
)
func TestNewSanitizer_Empty(t *testing.T) {
s := NewSanitizer(config.SanitizeConfig{})
if len(s.toolsForward) != 0 {
t.Errorf("expected empty toolsForward, got %d entries", len(s.toolsForward))
}
if len(s.toolsReverse) != 0 {
t.Errorf("expected empty toolsReverse, got %d entries", len(s.toolsReverse))
}
if s.systemRules != nil {
t.Errorf("expected nil systemRules")
}
if s.bodyRules != nil {
t.Errorf("expected nil bodyRules")
}
}
func TestNewSanitizer_WithTools(t *testing.T) {
cfg := config.SanitizeConfig{
Tools: []config.RenameRule{
{From: "old_tool", To: "new_tool"},
{From: "another", To: "replaced"},
},
}
s := NewSanitizer(cfg)
if got := s.toolsForward["old_tool"]; got != "new_tool" {
t.Errorf("toolsForward[old_tool] = %q, want %q", got, "new_tool")
}
if got := s.toolsReverse["new_tool"]; got != "old_tool" {
t.Errorf("toolsReverse[new_tool] = %q, want %q", got, "old_tool")
}
if got := s.toolsForward["another"]; got != "replaced" {
t.Errorf("toolsForward[another] = %q, want %q", got, "replaced")
}
if got := s.toolsReverse["replaced"]; got != "another" {
t.Errorf("toolsReverse[replaced] = %q, want %q", got, "another")
}
}
func TestNewSanitizer_WithSystemAndBodyRules(t *testing.T) {
cfg := config.SanitizeConfig{
System: []config.ReplaceRule{{Match: "foo", Replace: "bar"}},
Body: []config.ReplaceRule{{Match: "baz", Replace: "qux"}},
}
s := NewSanitizer(cfg)
if len(s.systemRules) != 1 || s.systemRules[0].Match != "foo" {
t.Errorf("systemRules not set correctly")
}
if len(s.bodyRules) != 1 || s.bodyRules[0].Match != "baz" {
t.Errorf("bodyRules not set correctly")
}
}
func TestRenameTools(t *testing.T) {
tests := []struct {
name string
forward map[string]string
body string
expected string
}{
{
name: "empty map returns body unchanged",
forward: map[string]string{},
body: `{"tools":[{"name":"my_tool"}]}`,
expected: `{"tools":[{"name":"my_tool"}]}`,
},
{
name: "no tools array returns body unchanged",
forward: map[string]string{"my_tool": "renamed"},
body: `{"messages":[]}`,
expected: `{"messages":[]}`,
},
{
name: "tools is not array returns body unchanged",
forward: map[string]string{"my_tool": "renamed"},
body: `{"tools":"not_array"}`,
expected: `{"tools":"not_array"}`,
},
{
name: "matching tool gets renamed",
forward: map[string]string{"my_tool": "renamed_tool"},
body: `{"tools":[{"name":"my_tool","description":"desc"}]}`,
expected: `renamed_tool`,
},
{
name: "non-matching tool unchanged",
forward: map[string]string{"other_tool": "renamed"},
body: `{"tools":[{"name":"my_tool"}]}`,
expected: `my_tool`,
},
{
name: "partial match - only exact match renames",
forward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"},
body: `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}`,
expected: `tool_x`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Sanitizer{
toolsForward: tt.forward,
toolsReverse: make(map[string]string),
}
result := string(s.renameTools([]byte(tt.body)))
if !strings.Contains(result, tt.expected) {
t.Errorf("result %q does not contain %q", result, tt.expected)
}
})
}
}
func TestRenameTools_MultipleTools(t *testing.T) {
s := &Sanitizer{
toolsForward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"},
toolsReverse: make(map[string]string),
}
body := `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}`
result := string(s.renameTools([]byte(body)))
if !strings.Contains(result, `"tool_x"`) {
t.Errorf("tool_a should be renamed to tool_x, got %s", result)
}
if !strings.Contains(result, `"tool_y"`) {
t.Errorf("tool_b should be renamed to tool_y, got %s", result)
}
if !strings.Contains(result, `"tool_c"`) {
t.Errorf("tool_c should remain unchanged, got %s", result)
}
}
func TestReplaceSystem(t *testing.T) {
tests := []struct {
name string
rules []config.ReplaceRule
body string
contains string
}{
{
name: "empty rules returns body unchanged",
rules: nil,
body: `{"system":[{"type":"text","text":"hello world"}]}`,
contains: "hello world",
},
{
name: "no system field returns body unchanged",
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
body: `{"messages":[]}`,
contains: `"messages":[]`,
},
{
name: "system not array returns body unchanged",
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
body: `{"system":"just a string"}`,
contains: "just a string",
},
{
name: "single block single rule",
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
body: `{"system":[{"type":"text","text":"hello world"}]}`,
contains: "goodbye world",
},
{
name: "multiple blocks",
rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}},
body: `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}`,
contains: "BBB first",
},
{
name: "multiple rules applied in order",
rules: []config.ReplaceRule{{Match: "cat", Replace: "dog"}, {Match: "dog", Replace: "fish"}},
body: `{"system":[{"type":"text","text":"I have a cat"}]}`,
contains: "I have a fish",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: make(map[string]string),
systemRules: tt.rules,
}
result := string(s.replaceSystem([]byte(tt.body)))
if !strings.Contains(result, tt.contains) {
t.Errorf("result %q does not contain %q", result, tt.contains)
}
})
}
}
func TestReplaceSystem_MultipleBlocks(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: make(map[string]string),
systemRules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}},
}
body := `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}`
result := string(s.replaceSystem([]byte(body)))
if !strings.Contains(result, "BBB first") {
t.Errorf("first block not replaced: %s", result)
}
if !strings.Contains(result, "BBB second") {
t.Errorf("second block not replaced: %s", result)
}
}
func TestReplaceBody(t *testing.T) {
tests := []struct {
name string
rules []config.ReplaceRule
body string
expected string
}{
{
name: "empty rules returns body unchanged",
rules: nil,
body: `{"foo":"bar"}`,
expected: `{"foo":"bar"}`,
},
{
name: "single replacement across entire body",
rules: []config.ReplaceRule{{Match: "SECRET", Replace: "REDACTED"}},
body: `{"data":"SECRET value SECRET"}`,
expected: `{"data":"REDACTED value REDACTED"}`,
},
{
name: "multiple rules applied sequentially",
rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}, {Match: "BBB", Replace: "CCC"}},
body: `{"text":"AAA"}`,
expected: `{"text":"CCC"}`,
},
{
name: "no match leaves body unchanged",
rules: []config.ReplaceRule{{Match: "NOMATCH", Replace: "X"}},
body: `{"text":"hello"}`,
expected: `{"text":"hello"}`,
},
{
name: "empty body",
rules: []config.ReplaceRule{{Match: "a", Replace: "b"}},
body: ``,
expected: ``,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: make(map[string]string),
bodyRules: tt.rules,
}
result := string(s.replaceBody([]byte(tt.body)))
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}
func TestSanitizeRequest(t *testing.T) {
cfg := config.SanitizeConfig{
Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}},
System: []config.ReplaceRule{{Match: "INTERNAL", Replace: "PUBLIC"}},
Body: []config.ReplaceRule{{Match: "secret_val", Replace: "safe_val"}},
}
s := NewSanitizer(cfg)
body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"INTERNAL info"}],"data":"secret_val here"}`
result := string(s.SanitizeRequest([]byte(body)))
if !strings.Contains(result, `"renamed_tool"`) {
t.Errorf("tool not renamed in result: %s", result)
}
if !strings.Contains(result, "PUBLIC info") {
t.Errorf("system not replaced in result: %s", result)
}
if !strings.Contains(result, "safe_val here") {
t.Errorf("body not replaced in result: %s", result)
}
if strings.Contains(result, "secret_val") {
t.Errorf("secret_val should have been replaced: %s", result)
}
}
func TestSanitizeRequest_EmptyConfig(t *testing.T) {
s := NewSanitizer(config.SanitizeConfig{})
body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"hello"}]}`
result := string(s.SanitizeRequest([]byte(body)))
if result != body {
t.Errorf("empty config should not modify body.\ngot: %s\nwant: %s", result, body)
}
}
func TestDesanitizeResponse(t *testing.T) {
tests := []struct {
name string
reverse map[string]string
body string
expected string
}{
{
name: "no content field returns unchanged",
reverse: map[string]string{"renamed": "original"},
body: `{"id":"msg_1","role":"assistant"}`,
expected: `{"id":"msg_1","role":"assistant"}`,
},
{
name: "content not array returns unchanged",
reverse: map[string]string{"renamed": "original"},
body: `{"content":"just text"}`,
expected: `{"content":"just text"}`,
},
{
name: "non-tool_use block left unchanged",
reverse: map[string]string{"renamed": "original"},
body: `{"content":[{"type":"text","text":"hello"}]}`,
expected: `{"content":[{"type":"text","text":"hello"}]}`,
},
{
name: "tool_use block with matching name gets reversed",
reverse: map[string]string{"renamed_tool": "original_tool"},
body: `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1"}]}`,
expected: `original_tool`,
},
{
name: "tool_use block with no match unchanged",
reverse: map[string]string{"other": "something"},
body: `{"content":[{"type":"tool_use","name":"my_tool","id":"t1"}]}`,
expected: `my_tool`,
},
{
name: "mixed blocks only tool_use reversed",
reverse: map[string]string{"renamed": "original"},
body: `{"content":[{"type":"text","text":"hi"},{"type":"tool_use","name":"renamed","id":"t1"}]}`,
expected: `original`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: tt.reverse,
}
result := string(s.DesanitizeResponse([]byte(tt.body)))
if !strings.Contains(result, tt.expected) {
t.Errorf("result %q does not contain %q", result, tt.expected)
}
})
}
}
func TestDesanitizeResponse_MultipleToolUse(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: map[string]string{"r1": "o1", "r2": "o2"},
}
body := `{"content":[{"type":"tool_use","name":"r1","id":"t1"},{"type":"text","text":"x"},{"type":"tool_use","name":"r2","id":"t2"}]}`
result := string(s.DesanitizeResponse([]byte(body)))
if !strings.Contains(result, `"o1"`) {
t.Errorf("r1 not reversed to o1: %s", result)
}
if !strings.Contains(result, `"o2"`) {
t.Errorf("r2 not reversed to o2: %s", result)
}
}
func TestDesanitizeStreamEvent(t *testing.T) {
tests := []struct {
name string
reverse map[string]string
line string
expected string
}{
{
name: "non-data line passed through",
reverse: map[string]string{"r": "o"},
line: "event: content_block_start",
expected: "event: content_block_start",
},
{
name: "data line without tool_use passed through",
reverse: map[string]string{"r": "o"},
line: `data: {"type":"text","text":"hello"}`,
expected: `data: {"type":"text","text":"hello"}`,
},
{
name: "data line with tool_use in content_block.name",
reverse: map[string]string{"renamed_tool": "original_tool"},
line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}`,
expected: `original_tool`,
},
{
name: "data line with tool_use in delta.name",
reverse: map[string]string{"renamed_tool": "original_tool"},
line: `data: {"type":"content_block_delta","delta":{"type":"tool_use","name":"renamed_tool"}}`,
expected: `original_tool`,
},
{
name: "data line with tool_use but no matching name",
reverse: map[string]string{"other": "something"},
line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"my_tool","id":"t1"}}`,
expected: `my_tool`,
},
{
name: "empty line passed through",
reverse: map[string]string{"r": "o"},
line: "",
expected: "",
},
{
name: "line contains tool_use but not data prefix - passed through",
reverse: map[string]string{"r": "o"},
line: "event: tool_use",
expected: "event: tool_use",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: tt.reverse,
}
result := s.DesanitizeStreamEvent(tt.line)
if !strings.Contains(result, tt.expected) {
t.Errorf("result %q does not contain %q", result, tt.expected)
}
})
}
}
func TestDesanitizeStreamEvent_DataPrefixPreserved(t *testing.T) {
s := &Sanitizer{
toolsForward: make(map[string]string),
toolsReverse: map[string]string{"renamed": "original"},
}
line := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed","id":"t1"}}`
result := s.DesanitizeStreamEvent(line)
if !strings.HasPrefix(result, "data: ") {
t.Errorf("result should start with 'data: ', got %q", result)
}
}
func TestSanitizeRequest_MalformedJSON(t *testing.T) {
s := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "a", To: "b"}},
System: []config.ReplaceRule{{Match: "x", Replace: "y"}},
})
// Malformed JSON - renameTools and replaceSystem should handle gracefully
body := `not valid json`
result := string(s.SanitizeRequest([]byte(body)))
// Should not panic; body rules still do string replacement
if result != "not valid json" {
t.Errorf("malformed JSON should pass through (no body rules match), got %q", result)
}
}
func TestSanitizeRequest_EmptyBody(t *testing.T) {
s := NewSanitizer(config.SanitizeConfig{
Tools: []config.RenameRule{{From: "a", To: "b"}},
})
result := s.SanitizeRequest([]byte{})
if len(result) != 0 {
t.Errorf("empty body should return empty, got %q", string(result))
}
}
+278
View File
@@ -0,0 +1,278 @@
package proxy
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func newRequest(t *testing.T, headers map[string][]string) *http.Request {
t.Helper()
r := httptest.NewRequest("POST", "/v1/messages", nil)
r.Header = http.Header{}
for k, vals := range headers {
for _, v := range vals {
r.Header.Add(k, v)
}
}
return r
}
func TestExtractProfile_BasicHeaders(t *testing.T) {
r := newRequest(t, map[string][]string{
"Content-Type": {"application/json"},
"X-Custom-Header": {"custom-value"},
"User-Agent": {"Claude/1.2.3 linux"},
})
body := []byte(`{"model":"claude-sonnet-4-6"}`)
p := extractProfile(r, body)
// Check version parsed
if p.Version != "1.2.3" {
t.Errorf("version = %q, want %q", p.Version, "1.2.3")
}
// Check body preserved
if string(p.Body) != string(body) {
t.Errorf("body not preserved")
}
// Check headers captured
found := map[string]bool{}
for _, h := range p.Headers {
found[strings.ToLower(h[0])] = true
}
if !found["content-type"] {
t.Error("Content-Type header should be captured")
}
if !found["x-custom-header"] {
t.Error("X-Custom-Header should be captured")
}
}
func TestExtractProfile_SkipHeaders(t *testing.T) {
r := newRequest(t, map[string][]string{
"Host": {"example.com"},
"Content-Length": {"42"},
"Authorization": {"Bearer token123"},
"X-Api-Key": {"key123"},
"Connection": {"keep-alive"},
"Content-Type": {"application/json"},
"X-Custom": {"keep-me"},
})
p := extractProfile(r, []byte(`{}`))
for _, h := range p.Headers {
lower := strings.ToLower(h[0])
if skipHeaders[lower] {
t.Errorf("header %q should have been skipped", h[0])
}
}
// Verify non-skipped headers are present
found := map[string]bool{}
for _, h := range p.Headers {
found[strings.ToLower(h[0])] = true
}
if !found["content-type"] {
t.Error("Content-Type should be kept")
}
if !found["x-custom"] {
t.Error("X-Custom should be kept")
}
}
func TestExtractProfile_HeaderDeduplication(t *testing.T) {
r := newRequest(t, map[string][]string{
"Content-Type": {"application/json"},
})
// Add duplicate with different casing - Go's http.Header normalizes to canonical form
// so we need to add the same canonical header with multiple values to test dedup
r.Header.Add("Content-Type", "text/plain")
p := extractProfile(r, []byte(`{}`))
// After deduplication by lowercase key, only one entry per key
seen := map[string]int{}
for _, h := range p.Headers {
seen[strings.ToLower(h[0])]++
}
for key, count := range seen {
if count > 1 {
t.Errorf("header %q appears %d times after dedup, want 1", key, count)
}
}
}
func TestExtractProfile_AnthropicBetaContextStripping(t *testing.T) {
r := newRequest(t, map[string][]string{
"Anthropic-Beta": {"prompt-caching-2024-07-31,context-1m-2024-09-01,some-other-beta"},
})
p := extractProfile(r, []byte(`{}`))
var betaValue string
for _, h := range p.Headers {
if strings.ToLower(h[0]) == "anthropic-beta" {
betaValue = h[1]
break
}
}
if strings.Contains(betaValue, "context-1m") {
t.Errorf("context-1m should be stripped from anthropic-beta, got %q", betaValue)
}
if !strings.Contains(betaValue, "prompt-caching-2024-07-31") {
t.Errorf("prompt-caching should be preserved, got %q", betaValue)
}
if !strings.Contains(betaValue, "some-other-beta") {
t.Errorf("some-other-beta should be preserved, got %q", betaValue)
}
}
func TestExtractProfile_AnthropicBetaAllContextRemoved(t *testing.T) {
r := newRequest(t, map[string][]string{
"Anthropic-Beta": {"context-1m-2024-09-01"},
})
p := extractProfile(r, []byte(`{}`))
for _, h := range p.Headers {
if strings.ToLower(h[0]) == "anthropic-beta" {
// All betas were context-1m, so after filtering the value should be empty
if h[1] != "" {
t.Errorf("all context-1m betas stripped should leave empty, got %q", h[1])
}
return
}
}
// It's also acceptable if the header is still present but empty
}
func TestExtractProfile_VersionParsing(t *testing.T) {
tests := []struct {
name string
userAgent string
expected string
}{
{
name: "standard Claude UA",
userAgent: "Claude/1.2.3 linux x86_64",
expected: "1.2.3",
},
{
name: "version with no space after",
userAgent: "Claude/4.5.6",
expected: "4.5.6",
},
{
name: "no slash in UA",
userAgent: "Mozilla 5.0",
expected: "",
},
{
name: "empty UA",
userAgent: "",
expected: "",
},
{
name: "slash at start",
userAgent: "/1.0.0 rest",
expected: "",
},
{
name: "multiple slashes",
userAgent: "App/1.0.0 (sub/2.0)",
expected: "1.0.0",
},
{
name: "version only after slash no space",
userAgent: "Tool/9.8.7",
expected: "9.8.7",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := newRequest(t, map[string][]string{
"User-Agent": {tt.userAgent},
})
p := extractProfile(r, []byte(`{}`))
if p.Version != tt.expected {
t.Errorf("version = %q, want %q", p.Version, tt.expected)
}
})
}
}
func TestExtractProfile_EmptyHeaders(t *testing.T) {
r := httptest.NewRequest("POST", "/v1/messages", nil)
r.Header = http.Header{}
p := extractProfile(r, []byte(`{"test":true}`))
if len(p.Headers) != 0 {
t.Errorf("expected no headers, got %d", len(p.Headers))
}
if p.Version != "" {
t.Errorf("expected empty version with no UA, got %q", p.Version)
}
if string(p.Body) != `{"test":true}` {
t.Errorf("body not preserved")
}
}
func TestExtractProfile_BodyPreserved(t *testing.T) {
r := newRequest(t, map[string][]string{
"User-Agent": {"Claude/1.0.0 test"},
})
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hello"}],"stream":true}`)
p := extractProfile(r, body)
if string(p.Body) != string(body) {
t.Errorf("body not preserved.\ngot: %s\nwant: %s", p.Body, body)
}
}
func TestSkipHeaders_Entries(t *testing.T) {
expected := map[string]bool{
"host": true,
"content-length": true,
"authorization": true,
"x-api-key": true,
"connection": true,
}
if len(skipHeaders) != len(expected) {
t.Errorf("skipHeaders has %d entries, want %d", len(skipHeaders), len(expected))
}
for k, v := range expected {
if skipHeaders[k] != v {
t.Errorf("skipHeaders[%q] = %v, want %v", k, skipHeaders[k], v)
}
}
}
func TestSniffedProfile_Fields(t *testing.T) {
// Verify the struct can hold all expected data
p := &SniffedProfile{
Headers: [][2]string{{"Content-Type", "application/json"}},
Body: []byte(`{}`),
Version: "1.0.0",
}
if len(p.Headers) != 1 {
t.Error("Headers should have 1 entry")
}
if p.Headers[0][0] != "Content-Type" || p.Headers[0][1] != "application/json" {
t.Error("Header not stored correctly")
}
if string(p.Body) != `{}` {
t.Error("Body not stored correctly")
}
if p.Version != "1.0.0" {
t.Error("Version not stored correctly")
}
}
+334
View File
@@ -0,0 +1,334 @@
package proxy
import (
"net/http"
"strings"
"testing"
)
// --- NewUpstreamClient ---
func TestNewUpstreamClient_NilProfile(t *testing.T) {
uc := NewUpstreamClient(nil)
if uc == nil {
t.Fatal("NewUpstreamClient returned nil")
}
if uc.sessionID == "" {
t.Error("expected non-empty sessionID")
}
if uc.profile != nil {
t.Error("expected nil profile")
}
}
func TestNewUpstreamClient_WithProfile(t *testing.T) {
profile := &SniffedProfile{
Version: "1.2.3",
Headers: [][2]string{{"User-Agent", "test/1.0"}},
}
uc := NewUpstreamClient(profile)
if uc.profile != profile {
t.Error("expected profile to be stored")
}
if uc.sessionID == "" {
t.Error("expected non-empty sessionID")
}
}
func TestNewUpstreamClient_UniqueSessionIDs(t *testing.T) {
uc1 := NewUpstreamClient(nil)
uc2 := NewUpstreamClient(nil)
if uc1.sessionID == uc2.sessionID {
t.Errorf("expected different session IDs, both got %q", uc1.sessionID)
}
}
// --- version() ---
func TestVersion_WithProfileVersion(t *testing.T) {
uc := &UpstreamClient{
profile: &SniffedProfile{Version: "3.5.7"},
}
if got := uc.version(); got != "3.5.7" {
t.Errorf("version() = %q, want %q", got, "3.5.7")
}
}
func TestVersion_NilProfile_Fallback(t *testing.T) {
uc := &UpstreamClient{profile: nil}
if got := uc.version(); got != "2.1.92" {
t.Errorf("version() = %q, want %q", got, "2.1.92")
}
}
func TestVersion_EmptyProfileVersion_Fallback(t *testing.T) {
uc := &UpstreamClient{
profile: &SniffedProfile{Version: ""},
}
if got := uc.version(); got != "2.1.92" {
t.Errorf("version() = %q, want %q", got, "2.1.92")
}
}
// --- applyHeaders ---
func TestApplyHeaders_NilProfile_NonOAuth_NonStream(t *testing.T) {
uc := &UpstreamClient{
sessionID: "test-session-id",
profile: nil,
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-api123", false)
// x-api-key for non-OAuth token
if got := req.Header.Get("x-api-key"); got != "sk-ant-api123" {
t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api123")
}
// Should NOT have Authorization
if got := req.Header.Get("Authorization"); got != "" {
t.Errorf("Authorization = %q, want empty", got)
}
// Session ID
if got := req.Header.Get("X-Claude-Code-Session-Id"); got != "test-session-id" {
t.Errorf("X-Claude-Code-Session-Id = %q, want %q", got, "test-session-id")
}
// Request ID should be a UUID
if got := req.Header.Get("x-client-request-id"); got == "" {
t.Error("expected non-empty x-client-request-id")
}
// Non-stream: application/json
if got := req.Header.Get("Accept"); got != "application/json" {
t.Errorf("Accept = %q, want %q", got, "application/json")
}
// Accept-Encoding always identity
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
t.Errorf("Accept-Encoding = %q, want %q", got, "identity")
}
}
func TestApplyHeaders_NilProfile_NonOAuth_Stream(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: nil,
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-api123", true)
if got := req.Header.Get("Accept"); got != "text/event-stream" {
t.Errorf("Accept = %q, want %q", got, "text/event-stream")
}
}
func TestApplyHeaders_OAuthToken_SetsBearerAndBetaFlag(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: nil,
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-oat-mytoken", false)
// OAuth: Authorization Bearer
if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-mytoken" {
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-mytoken")
}
// Should NOT have x-api-key
if got := req.Header.Get("x-api-key"); got != "" {
t.Errorf("x-api-key = %q, want empty for OAuth", got)
}
// anthropic-beta should include oauth-2025-04-20
if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20")
}
}
func TestApplyHeaders_OAuthToken_AppendsToExistingBeta(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: &SniffedProfile{
Headers: [][2]string{
{"anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15"},
},
},
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-oat-tok", false)
beta := req.Header.Get("anthropic-beta")
if !strings.Contains(beta, "max-tokens-3-5-sonnet-2024-07-15") {
t.Errorf("anthropic-beta %q should contain existing beta", beta)
}
if !strings.Contains(beta, "oauth-2025-04-20") {
t.Errorf("anthropic-beta %q should contain oauth flag", beta)
}
// Should be appended with comma
if beta != "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20" {
t.Errorf("anthropic-beta = %q, want %q", beta, "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20")
}
}
func TestApplyHeaders_OAuthToken_ExistingBetaAlreadyHasOAuth(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: &SniffedProfile{
Headers: [][2]string{
{"anthropic-beta", "oauth-2025-04-20,something-else"},
},
},
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-oat-tok", false)
beta := req.Header.Get("anthropic-beta")
// Should NOT duplicate oauth flag
count := strings.Count(beta, "oauth-2025-04-20")
if count != 1 {
t.Errorf("oauth flag appeared %d times in %q, want 1", count, beta)
}
}
func TestApplyHeaders_WithProfile_ReplaysHeaders(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: &SniffedProfile{
Headers: [][2]string{
{"User-Agent", "Claude/1.0"},
{"anthropic-version", "2023-06-01"},
{"Custom-Header", "custom-value"},
},
},
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-api123", false)
if got := req.Header.Get("User-Agent"); got != "Claude/1.0" {
t.Errorf("User-Agent = %q, want %q", got, "Claude/1.0")
}
if got := req.Header.Get("anthropic-version"); got != "2023-06-01" {
t.Errorf("anthropic-version = %q, want %q", got, "2023-06-01")
}
if got := req.Header.Get("Custom-Header"); got != "custom-value" {
t.Errorf("Custom-Header = %q, want %q", got, "custom-value")
}
}
func TestApplyHeaders_ProfileAuthHeadersRemoved(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: &SniffedProfile{
Headers: [][2]string{
{"Authorization", "Bearer old-token"},
{"x-api-key", "old-api-key"},
{"User-Agent", "test"},
},
},
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-api-new", false)
// Old auth headers from profile should be removed
if got := req.Header.Get("Authorization"); got != "" {
t.Errorf("Authorization should be empty for non-OAuth, got %q", got)
}
// New auth should be set via x-api-key
if got := req.Header.Get("x-api-key"); got != "sk-ant-api-new" {
t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api-new")
}
// User-Agent from profile should remain
if got := req.Header.Get("User-Agent"); got != "test" {
t.Errorf("User-Agent = %q, want %q", got, "test")
}
}
func TestApplyHeaders_ProfileAuthHeadersRemovedForOAuth(t *testing.T) {
uc := &UpstreamClient{
sessionID: "sess",
profile: &SniffedProfile{
Headers: [][2]string{
{"Authorization", "Bearer old-token"},
{"x-api-key", "old-api-key"},
},
},
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-oat-new", false)
// Old x-api-key removed
if got := req.Header.Get("x-api-key"); got != "" {
t.Errorf("x-api-key should be empty for OAuth, got %q", got)
}
// New auth set via Authorization
if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-new" {
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-new")
}
}
func TestApplyHeaders_AcceptEncoding_AlwaysIdentity(t *testing.T) {
tests := []struct {
name string
streaming bool
}{
{"non-stream", false},
{"stream", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
uc := &UpstreamClient{sessionID: "s", profile: nil}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "token", tt.streaming)
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
t.Errorf("Accept-Encoding = %q, want %q", got, "identity")
}
})
}
}
func TestApplyHeaders_UniqueRequestIDs(t *testing.T) {
uc := &UpstreamClient{sessionID: "s", profile: nil}
req1, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req1, "tok", false)
req2, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req2, "tok", false)
id1 := req1.Header.Get("x-client-request-id")
id2 := req2.Header.Get("x-client-request-id")
if id1 == "" || id2 == "" {
t.Fatal("expected non-empty request IDs")
}
if id1 == id2 {
t.Errorf("expected unique request IDs, both got %q", id1)
}
}
func TestApplyHeaders_NonOAuth_NoAnthroPicBetaSet(t *testing.T) {
// Non-OAuth tokens should NOT set anthropic-beta oauth flag
uc := &UpstreamClient{sessionID: "s", profile: nil}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-api123", false)
beta := req.Header.Get("anthropic-beta")
if strings.Contains(beta, "oauth-2025-04-20") {
t.Errorf("non-OAuth token should not have oauth beta flag, got %q", beta)
}
}
func TestApplyHeaders_OAuthToken_FreshBeta(t *testing.T) {
// No profile, no existing beta — should set fresh
uc := &UpstreamClient{sessionID: "s", profile: nil}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
uc.applyHeaders(req, "sk-ant-oat-tok", false)
if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20")
}
}