From 9150f466e5c69c0dcfe2920ea908f5623518d8b1 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 15 Apr 2026 10:40:43 +0200 Subject: [PATCH] test: add comprehensive test harness across all packages (156 tests) Characterization tests capturing current behavior before refactoring. Covers auth, config, logging, proxy, ratelimit, server, and telemetry packages with race-safe concurrent access tests. --- internal/auth/selector_test.go | 318 ++++++++++++++ internal/auth/types_test.go | 167 +++++++ internal/config/config_test.go | 349 +++++++++++++++ internal/logging/logging_test.go | 230 ++++++++++ internal/proxy/billing_test.go | 323 ++++++++++++++ internal/proxy/handler_test.go | 624 +++++++++++++++++++++++++++ internal/proxy/sanitize_test.go | 476 ++++++++++++++++++++ internal/proxy/sniff_test.go | 278 ++++++++++++ internal/proxy/upstream_test.go | 334 ++++++++++++++ internal/ratelimit/tracker_test.go | 278 ++++++++++++ internal/ratelimit/usage_test.go | 241 +++++++++++ internal/server/server_test.go | 529 +++++++++++++++++++++++ internal/telemetry/logbridge_test.go | 178 ++++++++ 13 files changed, 4325 insertions(+) create mode 100644 internal/auth/selector_test.go create mode 100644 internal/auth/types_test.go create mode 100644 internal/config/config_test.go create mode 100644 internal/logging/logging_test.go create mode 100644 internal/proxy/billing_test.go create mode 100644 internal/proxy/handler_test.go create mode 100644 internal/proxy/sanitize_test.go create mode 100644 internal/proxy/sniff_test.go create mode 100644 internal/proxy/upstream_test.go create mode 100644 internal/ratelimit/tracker_test.go create mode 100644 internal/ratelimit/usage_test.go create mode 100644 internal/server/server_test.go create mode 100644 internal/telemetry/logbridge_test.go diff --git a/internal/auth/selector_test.go b/internal/auth/selector_test.go new file mode 100644 index 0000000..b0800db --- /dev/null +++ b/internal/auth/selector_test.go @@ -0,0 +1,318 @@ +package auth + +import ( + "testing" + "time" +) + +func TestNewPool(t *testing.T) { + creds := []*Credential{ + {ID: "a", AccessToken: "tok-a"}, + {ID: "b", AccessToken: "tok-b"}, + } + p := NewPool(creds) + if p == nil { + t.Fatal("NewPool returned nil") + } + if len(p.creds) != 2 { + t.Errorf("pool has %d creds, want 2", len(p.creds)) + } + if p.cursor != 0 { + t.Errorf("initial cursor = %d, want 0", p.cursor) + } +} + +func TestPool_Pick_EmptyPool(t *testing.T) { + p := NewPool(nil) + _, err := p.Pick() + if err == nil { + t.Fatal("expected error from empty pool, got nil") + } + want := "no credentials available" + if err.Error() != want { + t.Errorf("error = %q, want %q", err.Error(), want) + } +} + +func TestPool_Pick_SingleCredential(t *testing.T) { + cred := &Credential{ID: "only", AccessToken: "tok-only"} + p := NewPool([]*Credential{cred}) + + got, err := p.Pick() + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + if got.ID != "only" { + t.Errorf("Pick() returned cred ID %q, want %q", got.ID, "only") + } + + // Picking again should return the same credential + got2, err := p.Pick() + if err != nil { + t.Fatalf("second Pick() error = %v", err) + } + if got2.ID != "only" { + t.Errorf("second Pick() returned cred ID %q, want %q", got2.ID, "only") + } +} + +func TestPool_Pick_RoundRobin(t *testing.T) { + creds := []*Credential{ + {ID: "a"}, + {ID: "b"}, + {ID: "c"}, + } + p := NewPool(creds) + + // Should cycle through a, b, c, a, b, c + expected := []string{"a", "b", "c", "a", "b", "c"} + for i, want := range expected { + got, err := p.Pick() + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got.ID != want { + t.Errorf("Pick() #%d = %q, want %q", i, got.ID, want) + } + } +} + +func TestPool_Pick_SkipsCooldown(t *testing.T) { + creds := []*Credential{ + {ID: "a"}, + {ID: "b", CooldownUntil: time.Now().Add(1 * time.Hour)}, + {ID: "c"}, + } + p := NewPool(creds) + + // First pick: "a" (index 0, not on cooldown) + got, err := p.Pick() + if err != nil { + t.Fatalf("Pick() #1 error = %v", err) + } + if got.ID != "a" { + t.Errorf("Pick() #1 = %q, want %q", got.ID, "a") + } + + // Second pick: cursor at 1, but "b" is on cooldown → skip to "c" + got, err = p.Pick() + if err != nil { + t.Fatalf("Pick() #2 error = %v", err) + } + if got.ID != "c" { + t.Errorf("Pick() #2 = %q, want %q", got.ID, "c") + } + + // Third pick: cursor advanced past "c" to 0 → "a" + got, err = p.Pick() + if err != nil { + t.Fatalf("Pick() #3 error = %v", err) + } + if got.ID != "a" { + t.Errorf("Pick() #3 = %q, want %q", got.ID, "a") + } +} + +func TestPool_Pick_AllOnCooldown(t *testing.T) { + future := time.Now().Add(1 * time.Hour) + creds := []*Credential{ + {ID: "a", CooldownUntil: future}, + {ID: "b", CooldownUntil: future}, + } + p := NewPool(creds) + + _, err := p.Pick() + if err == nil { + t.Fatal("expected error when all on cooldown, got nil") + } + want := "all 2 credentials are on cooldown" + if err.Error() != want { + t.Errorf("error = %q, want %q", err.Error(), want) + } +} + +func TestPool_MarkFailure(t *testing.T) { + tests := []struct { + name string + statusCode int + expectCooldown bool + expectedDur time.Duration + }{ + { + name: "429 sets 30s cooldown", + statusCode: 429, + expectCooldown: true, + expectedDur: 30 * time.Second, + }, + { + name: "500 sets 5s cooldown", + statusCode: 500, + expectCooldown: true, + expectedDur: 5 * time.Second, + }, + { + name: "502 sets 5s cooldown", + statusCode: 502, + expectCooldown: true, + expectedDur: 5 * time.Second, + }, + { + name: "503 sets 5s cooldown", + statusCode: 503, + expectCooldown: true, + expectedDur: 5 * time.Second, + }, + { + name: "400 does NOT set cooldown", + statusCode: 400, + expectCooldown: false, + }, + { + name: "401 does NOT set cooldown", + statusCode: 401, + expectCooldown: false, + }, + { + name: "403 does NOT set cooldown", + statusCode: 403, + expectCooldown: false, + }, + { + name: "404 does NOT set cooldown", + statusCode: 404, + expectCooldown: false, + }, + { + name: "422 does NOT set cooldown", + statusCode: 422, + expectCooldown: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cred := &Credential{ID: "test"} + p := NewPool([]*Credential{cred}) + + before := time.Now() + p.MarkFailure(cred, tt.statusCode) + + if tt.expectCooldown { + if !cred.IsOnCooldown() { + t.Errorf("expected cooldown after status %d", tt.statusCode) + } + // Verify approximate duration + cred.mu.Lock() + cooldownEnd := cred.CooldownUntil + cred.mu.Unlock() + + lower := before.Add(tt.expectedDur) + upper := time.Now().Add(tt.expectedDur) + if cooldownEnd.Before(lower) || cooldownEnd.After(upper) { + t.Errorf("CooldownUntil %v not in expected range [%v, %v]", cooldownEnd, lower, upper) + } + } else { + if cred.IsOnCooldown() { + t.Errorf("did not expect cooldown after status %d", tt.statusCode) + } + } + }) + } +} + +func TestPool_MarkSuccess(t *testing.T) { + cred := &Credential{ + ID: "test", + CooldownUntil: time.Now().Add(1 * time.Hour), + } + p := NewPool([]*Credential{cred}) + + if !cred.IsOnCooldown() { + t.Fatal("precondition: expected credential to be on cooldown") + } + + p.MarkSuccess(cred) + + if cred.IsOnCooldown() { + t.Error("expected cooldown to be cleared after MarkSuccess") + } +} + +func TestPool_RoundRobinCursorAdvancement(t *testing.T) { + creds := []*Credential{ + {ID: "0"}, + {ID: "1"}, + {ID: "2"}, + } + p := NewPool(creds) + + // Verify cursor starts at 0 + if p.cursor != 0 { + t.Fatalf("initial cursor = %d, want 0", p.cursor) + } + + // Pick cred[0], cursor should advance to 1 + got, _ := p.Pick() + if got.ID != "0" { + t.Errorf("first pick = %q, want %q", got.ID, "0") + } + if p.cursor != 1 { + t.Errorf("cursor after first pick = %d, want 1", p.cursor) + } + + // Pick cred[1], cursor should advance to 2 + got, _ = p.Pick() + if got.ID != "1" { + t.Errorf("second pick = %q, want %q", got.ID, "1") + } + if p.cursor != 2 { + t.Errorf("cursor after second pick = %d, want 2", p.cursor) + } + + // Pick cred[2], cursor should wrap to 0 + got, _ = p.Pick() + if got.ID != "2" { + t.Errorf("third pick = %q, want %q", got.ID, "2") + } + if p.cursor != 0 { + t.Errorf("cursor after third pick = %d, want 0 (wrap)", p.cursor) + } +} + +func TestPool_RoundRobinWithCooldownSkip(t *testing.T) { + creds := []*Credential{ + {ID: "0"}, + {ID: "1", CooldownUntil: time.Now().Add(1 * time.Hour)}, + {ID: "2"}, + } + p := NewPool(creds) + + // First pick: cred[0] + got, _ := p.Pick() + if got.ID != "0" { + t.Errorf("first pick = %q, want %q", got.ID, "0") + } + // Cursor should be at 1 + if p.cursor != 1 { + t.Errorf("cursor after first pick = %d, want 1", p.cursor) + } + + // Second pick: cursor at 1, but cred[1] on cooldown → skip to cred[2] + got, _ = p.Pick() + if got.ID != "2" { + t.Errorf("second pick = %q, want %q", got.ID, "2") + } + // Cursor should advance past cred[2] to 0 + if p.cursor != 0 { + t.Errorf("cursor after second pick (skip) = %d, want 0", p.cursor) + } + + // Third pick: cursor at 0, cred[0] available + got, _ = p.Pick() + if got.ID != "0" { + t.Errorf("third pick = %q, want %q", got.ID, "0") + } + if p.cursor != 1 { + t.Errorf("cursor after third pick = %d, want 1", p.cursor) + } +} diff --git a/internal/auth/types_test.go b/internal/auth/types_test.go new file mode 100644 index 0000000..506fc14 --- /dev/null +++ b/internal/auth/types_test.go @@ -0,0 +1,167 @@ +package auth + +import ( + "sync" + "testing" + "time" +) + +func TestCredential_IsOnCooldown(t *testing.T) { + tests := []struct { + name string + cooldownUntil time.Time + want bool + }{ + { + name: "zero time — not on cooldown", + cooldownUntil: time.Time{}, + want: false, + }, + { + name: "future time — on cooldown", + cooldownUntil: time.Now().Add(1 * time.Hour), + want: true, + }, + { + name: "past time — expired cooldown", + cooldownUntil: time.Now().Add(-1 * time.Hour), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Credential{CooldownUntil: tt.cooldownUntil} + got := c.IsOnCooldown() + if got != tt.want { + t.Errorf("IsOnCooldown() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCredential_SetCooldown(t *testing.T) { + tests := []struct { + name string + duration time.Duration + }{ + {name: "30 second cooldown", duration: 30 * time.Second}, + {name: "5 second cooldown", duration: 5 * time.Second}, + {name: "1 minute cooldown", duration: 1 * time.Minute}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Credential{} + before := time.Now() + c.SetCooldown(tt.duration) + after := time.Now() + + // CooldownUntil should be between before+duration and after+duration + if c.CooldownUntil.Before(before.Add(tt.duration)) { + t.Errorf("CooldownUntil %v is before expected lower bound %v", c.CooldownUntil, before.Add(tt.duration)) + } + if c.CooldownUntil.After(after.Add(tt.duration)) { + t.Errorf("CooldownUntil %v is after expected upper bound %v", c.CooldownUntil, after.Add(tt.duration)) + } + + // Should now be on cooldown + if !c.IsOnCooldown() { + t.Error("expected credential to be on cooldown after SetCooldown") + } + }) + } +} + +func TestCredential_ClearCooldown(t *testing.T) { + t.Run("clears active cooldown", func(t *testing.T) { + c := &Credential{CooldownUntil: time.Now().Add(1 * time.Hour)} + if !c.IsOnCooldown() { + t.Fatal("precondition: expected credential to be on cooldown") + } + + c.ClearCooldown() + + if c.IsOnCooldown() { + t.Error("expected credential to not be on cooldown after ClearCooldown") + } + if !c.CooldownUntil.IsZero() { + t.Errorf("expected CooldownUntil to be zero time, got %v", c.CooldownUntil) + } + }) + + t.Run("clearing when not on cooldown is no-op", func(t *testing.T) { + c := &Credential{} + c.ClearCooldown() + + if c.IsOnCooldown() { + t.Error("expected credential to not be on cooldown") + } + if !c.CooldownUntil.IsZero() { + t.Errorf("expected CooldownUntil to be zero time, got %v", c.CooldownUntil) + } + }) +} + +func TestCredential_Token(t *testing.T) { + tests := []struct { + name string + token string + }{ + {name: "returns access token", token: "sk-ant-abc123"}, + {name: "empty token", token: ""}, + {name: "long token", token: "sk-ant-" + string(make([]byte, 200))}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Credential{AccessToken: tt.token} + got := c.Token() + if got != tt.token { + t.Errorf("Token() = %q, want %q", got, tt.token) + } + }) + } +} + +func TestCredential_ConcurrentAccess(t *testing.T) { + c := &Credential{ + AccessToken: "initial-token", + } + + var wg sync.WaitGroup + const goroutines = 50 + + // Spawn goroutines that concurrently read and write + for i := 0; i < goroutines; i++ { + wg.Add(3) + + go func() { + defer wg.Done() + _ = c.Token() + }() + + go func() { + defer wg.Done() + c.SetCooldown(1 * time.Second) + }() + + go func() { + defer wg.Done() + _ = c.IsOnCooldown() + }() + } + + // Also mix in ClearCooldown calls + for i := 0; i < goroutines/2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + c.ClearCooldown() + }() + } + + wg.Wait() + + // If we get here without -race detecting issues, mutex is working +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..aed4dda --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,349 @@ +package config + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestLoad_AllFields(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + + yaml := ` +port: 9090 +api_keys: + - key1 + - key2 +claude_binary: /usr/bin/claude +sanitize: + tools: + - from: tool_a + to: tool_b + system: + - match: foo + replace: bar + body: + - match: baz + replace: qux +logging: + level: debug + file: /tmp/test.log + max_size_mb: 50 + max_backups: 3 + max_age_days: 7 + compress: true +telemetry: + service_name: my-proxy + export: + endpoint: http://localhost:4317 + insecure: true + headers: + x-token: abc + embedded: + enabled: true + port: 9999 + perses_binary: /usr/bin/perses + vm_binary: /usr/bin/vm + vm_port: 9428 + bin_dir: /opt/bin +` + if err := os.WriteFile(path, []byte(yaml), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load returned error: %v", err) + } + + if cfg.Port != 9090 { + t.Errorf("Port = %d, want 9090", cfg.Port) + } + if len(cfg.APIKeys) != 2 || cfg.APIKeys[0] != "key1" || cfg.APIKeys[1] != "key2" { + t.Errorf("APIKeys = %v, want [key1 key2]", cfg.APIKeys) + } + if cfg.ClaudeBinary != "/usr/bin/claude" { + t.Errorf("ClaudeBinary = %q, want /usr/bin/claude", cfg.ClaudeBinary) + } + + // Sanitize + if len(cfg.Sanitize.Tools) != 1 || cfg.Sanitize.Tools[0].From != "tool_a" || cfg.Sanitize.Tools[0].To != "tool_b" { + t.Errorf("Sanitize.Tools = %v", cfg.Sanitize.Tools) + } + if len(cfg.Sanitize.System) != 1 || cfg.Sanitize.System[0].Match != "foo" { + t.Errorf("Sanitize.System = %v", cfg.Sanitize.System) + } + if len(cfg.Sanitize.Body) != 1 || cfg.Sanitize.Body[0].Match != "baz" { + t.Errorf("Sanitize.Body = %v", cfg.Sanitize.Body) + } + + // Logging + if cfg.Logging.Level != "debug" { + t.Errorf("Logging.Level = %q, want debug", cfg.Logging.Level) + } + if cfg.Logging.File != "/tmp/test.log" { + t.Errorf("Logging.File = %q", cfg.Logging.File) + } + if cfg.Logging.MaxSizeMB != 50 { + t.Errorf("Logging.MaxSizeMB = %d, want 50", cfg.Logging.MaxSizeMB) + } + if cfg.Logging.MaxBackups != 3 { + t.Errorf("Logging.MaxBackups = %d, want 3", cfg.Logging.MaxBackups) + } + if cfg.Logging.MaxAgeDays != 7 { + t.Errorf("Logging.MaxAgeDays = %d, want 7", cfg.Logging.MaxAgeDays) + } + if !cfg.Logging.Compress { + t.Error("Logging.Compress = false, want true") + } + + // Telemetry + if cfg.Telemetry.ServiceName != "my-proxy" { + t.Errorf("Telemetry.ServiceName = %q, want my-proxy", cfg.Telemetry.ServiceName) + } + if cfg.Telemetry.Export.Endpoint != "http://localhost:4317" { + t.Errorf("Export.Endpoint = %q", cfg.Telemetry.Export.Endpoint) + } + if !cfg.Telemetry.Export.Insecure { + t.Error("Export.Insecure = false, want true") + } + if !cfg.Telemetry.Export.Enabled() { + t.Error("Export.Enabled() = false, want true") + } + if cfg.Telemetry.Export.Headers["x-token"] != "abc" { + t.Errorf("Export.Headers = %v", cfg.Telemetry.Export.Headers) + } + + // Embedded + if !cfg.Telemetry.Embedded.Enabled { + t.Error("Embedded.Enabled = false, want true") + } + if cfg.Telemetry.Embedded.Port != 9999 { + t.Errorf("Embedded.Port = %d, want 9999", cfg.Telemetry.Embedded.Port) + } + if cfg.Telemetry.Embedded.PersesBinary != "/usr/bin/perses" { + t.Errorf("Embedded.PersesBinary = %q", cfg.Telemetry.Embedded.PersesBinary) + } + if cfg.Telemetry.Embedded.VMBinary != "/usr/bin/vm" { + t.Errorf("Embedded.VMBinary = %q", cfg.Telemetry.Embedded.VMBinary) + } + if cfg.Telemetry.Embedded.VMPort != 9428 { + t.Errorf("Embedded.VMPort = %d, want 9428", cfg.Telemetry.Embedded.VMPort) + } + if cfg.Telemetry.Embedded.BinDir != "/opt/bin" { + t.Errorf("Embedded.BinDir = %q", cfg.Telemetry.Embedded.BinDir) + } +} + +func TestLoad_Defaults(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + + // Minimal YAML — only api_keys + if err := os.WriteFile(path, []byte("api_keys:\n - k1\n"), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load returned error: %v", err) + } + + tests := []struct { + name string + got interface{} + want interface{} + }{ + {"Port", cfg.Port, 8080}, + {"Logging.Level", cfg.Logging.Level, "info"}, + {"Logging.MaxSizeMB", cfg.Logging.MaxSizeMB, 100}, + {"Logging.MaxBackups", cfg.Logging.MaxBackups, 5}, + {"Logging.MaxAgeDays", cfg.Logging.MaxAgeDays, 30}, + {"Telemetry.ServiceName", cfg.Telemetry.ServiceName, "anthropic-proxy"}, + {"Embedded.Port", cfg.Telemetry.Embedded.Port, 8080}, + {"Embedded.VMBinary", cfg.Telemetry.Embedded.VMBinary, "victoria-metrics"}, + {"Embedded.PersesBinary", cfg.Telemetry.Embedded.PersesBinary, "perses"}, + {"Embedded.VMPort", cfg.Telemetry.Embedded.VMPort, 8428}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.got != tt.want { + t.Errorf("got %v, want %v", tt.got, tt.want) + } + }) + } +} + +func TestLoad_MissingFile(t *testing.T) { + _, err := Load("/nonexistent/path/config.yaml") + if err == nil { + t.Fatal("expected error for missing file, got nil") + } + if !strings.Contains(err.Error(), "read config") { + t.Errorf("error = %q, want it to contain 'read config'", err.Error()) + } +} + +func TestLoad_DeprecatedClaudeCredentials(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + + yaml := ` +api_keys: + - k1 +claude_credentials: "/some/path" +` + if err := os.WriteFile(path, []byte(yaml), 0644); err != nil { + t.Fatal(err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for deprecated claude_credentials, got nil") + } + if !strings.Contains(err.Error(), "no longer supported") { + t.Errorf("error = %q, want it to contain 'no longer supported'", err.Error()) + } +} + +func TestLoad_EmptyClaudeCredentials(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + + // Empty string value should NOT trigger the deprecation error + yaml := ` +api_keys: + - k1 +claude_credentials: "" +` + if err := os.WriteFile(path, []byte(yaml), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(path) + if err != nil { + t.Fatalf("empty claude_credentials should not error: %v", err) + } + if cfg.Port != 8080 { + t.Errorf("Port = %d, want 8080", cfg.Port) + } +} + +func TestLoad_InvalidYAML(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + + // Truly invalid YAML that causes a parse error + if err := os.WriteFile(path, []byte("port:\n - bad\n indent: broken\n"), 0644); err != nil { + t.Fatal(err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for invalid YAML, got nil") + } + if !strings.Contains(err.Error(), "parse config") { + t.Errorf("error = %q, want it to contain 'parse config'", err.Error()) + } +} + +func TestExportConfig_Enabled(t *testing.T) { + tests := []struct { + name string + endpoint string + want bool + }{ + {"empty endpoint", "", false}, + {"set endpoint", "http://localhost:4317", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := ExportConfig{Endpoint: tt.endpoint} + if got := e.Enabled(); got != tt.want { + t.Errorf("Enabled() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDefaultCredentialPath(t *testing.T) { + path := DefaultCredentialPath() + if path == "" { + t.Skip("could not determine home directory") + } + if !strings.HasSuffix(path, "/.claude/.credentials.json") { + t.Errorf("DefaultCredentialPath() = %q, want suffix /.claude/.credentials.json", path) + } +} + +func TestLoadDefaultCredentials_ValidFile(t *testing.T) { + // We can't easily override DefaultCredentialPath, so test the JSON parsing + // logic by creating a file at a temp location and calling the internal parsing + // directly. Instead, we test LoadDefaultCredentials indirectly by verifying + // it returns nil,nil when the default path doesn't exist (common in CI). + // For a full test, we create the credential file at the expected path. + + // Test with the actual function — if the default credential file doesn't + // exist, it should return nil, nil. + creds, err := LoadDefaultCredentials() + path := DefaultCredentialPath() + if path == "" { + if creds != nil || err != nil { + t.Errorf("expected nil,nil when home dir unavailable, got %v, %v", creds, err) + } + return + } + + if _, statErr := os.Stat(path); os.IsNotExist(statErr) { + // File doesn't exist — should return nil, nil + if creds != nil { + t.Errorf("expected nil creds for missing file, got %v", creds) + } + if err != nil { + t.Errorf("expected nil error for missing file, got %v", err) + } + } +} + +func TestLoadDefaultCredentials_ParsesJSON(t *testing.T) { + // Test the JSON parsing by creating a temp credential file and using + // the claudeCredentialsJSON struct directly (white-box test). + jsonData := `{"claudeAiOauth":{"accessToken":"test-token","refreshToken":"test-refresh","expiresAt":1234567890,"subscriptionType":"pro"}}` + + var cf claudeCredentialsJSON + if err := json.Unmarshal([]byte(jsonData), &cf); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if cf.ClaudeAiOauth.AccessToken != "test-token" { + t.Errorf("AccessToken = %q, want test-token", cf.ClaudeAiOauth.AccessToken) + } + if cf.ClaudeAiOauth.RefreshToken != "test-refresh" { + t.Errorf("RefreshToken = %q, want test-refresh", cf.ClaudeAiOauth.RefreshToken) + } + if cf.ClaudeAiOauth.ExpiresAt != 1234567890 { + t.Errorf("ExpiresAt = %d, want 1234567890", cf.ClaudeAiOauth.ExpiresAt) + } + if cf.ClaudeAiOauth.SubscriptionType != "pro" { + t.Errorf("SubscriptionType = %q, want pro", cf.ClaudeAiOauth.SubscriptionType) + } +} + +func TestLoadDefaultCredentials_EmptyAccessToken(t *testing.T) { + // Verify that an empty access token in the JSON produces an error. + // We test the parsing struct and logic path. + jsonData := `{"claudeAiOauth":{"accessToken":"","refreshToken":"r","expiresAt":1}}` + + var cf claudeCredentialsJSON + if err := json.Unmarshal([]byte(jsonData), &cf); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if cf.ClaudeAiOauth.AccessToken != "" { + t.Errorf("expected empty access token") + } + // The actual LoadDefaultCredentials would return an error here. +} diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go new file mode 100644 index 0000000..e7d964d --- /dev/null +++ b/internal/logging/logging_test.go @@ -0,0 +1,230 @@ +package logging + +import ( + "context" + "encoding/json" + "net/http" + "path/filepath" + "strings" + "testing" + + "github.com/rs/zerolog" +) + +func TestRedactHeaders(t *testing.T) { + tests := []struct { + name string + headers http.Header + check func(t *testing.T, result string) + }{ + { + name: "redacts Authorization", + headers: http.Header{ + "Authorization": []string{"Bearer secret-token"}, + }, + check: func(t *testing.T, result string) { + var m map[string]string + if err := json.Unmarshal([]byte(result), &m); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if m["Authorization"] != "***" { + t.Errorf("Authorization = %q, want ***", m["Authorization"]) + } + }, + }, + { + name: "redacts x-api-key", + headers: http.Header{ + "X-Api-Key": []string{"sk-ant-secret"}, + }, + check: func(t *testing.T, result string) { + var m map[string]string + if err := json.Unmarshal([]byte(result), &m); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if m["X-Api-Key"] != "***" { + t.Errorf("X-Api-Key = %q, want ***", m["X-Api-Key"]) + } + }, + }, + { + name: "preserves other headers", + headers: http.Header{ + "Content-Type": []string{"application/json"}, + "Accept": []string{"text/html", "application/json"}, + }, + check: func(t *testing.T, result string) { + var m map[string]string + if err := json.Unmarshal([]byte(result), &m); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if m["Content-Type"] != "application/json" { + t.Errorf("Content-Type = %q, want application/json", m["Content-Type"]) + } + if m["Accept"] != "text/html, application/json" { + t.Errorf("Accept = %q, want 'text/html, application/json'", m["Accept"]) + } + }, + }, + { + name: "case-insensitive redaction", + headers: http.Header{ + "authorization": []string{"Bearer token"}, + "X-API-KEY": []string{"key123"}, + }, + check: func(t *testing.T, result string) { + var m map[string]string + if err := json.Unmarshal([]byte(result), &m); err != nil { + t.Fatalf("unmarshal: %v", err) + } + // http.Header canonicalizes keys, but RedactHeaders lowercases for comparison + for _, v := range m { + if v != "***" { + t.Errorf("expected all values to be ***, got %q", v) + } + } + }, + }, + { + name: "empty headers", + headers: http.Header{}, + check: func(t *testing.T, result string) { + if result != "{}" { + t.Errorf("result = %q, want {}", result) + } + }, + }, + { + name: "mixed sensitive and non-sensitive", + headers: http.Header{ + "Authorization": []string{"Bearer tok"}, + "X-Api-Key": []string{"key"}, + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"abc123"}, + }, + check: func(t *testing.T, result string) { + var m map[string]string + if err := json.Unmarshal([]byte(result), &m); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if m["Authorization"] != "***" { + t.Errorf("Authorization = %q, want ***", m["Authorization"]) + } + if m["X-Api-Key"] != "***" { + t.Errorf("X-Api-Key = %q, want ***", m["X-Api-Key"]) + } + if m["Content-Type"] != "application/json" { + t.Errorf("Content-Type = %q", m["Content-Type"]) + } + if m["X-Request-Id"] != "abc123" { + t.Errorf("X-Request-Id = %q", m["X-Request-Id"]) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := RedactHeaders(tt.headers) + // Result should be valid JSON + if !json.Valid([]byte(result)) { + t.Fatalf("result is not valid JSON: %q", result) + } + tt.check(t, result) + }) + } +} + +func TestRedactHeaders_ReturnsJSON(t *testing.T) { + h := http.Header{"Foo": []string{"bar"}} + result := RedactHeaders(h) + if !strings.HasPrefix(result, "{") || !strings.HasSuffix(result, "}") { + t.Errorf("result not JSON object: %q", result) + } +} + +func TestStatusLevel(t *testing.T) { + tests := []struct { + status int + want zerolog.Level + }{ + {200, zerolog.InfoLevel}, + {201, zerolog.InfoLevel}, + {204, zerolog.InfoLevel}, + {301, zerolog.InfoLevel}, + {399, zerolog.InfoLevel}, + {400, zerolog.WarnLevel}, + {401, zerolog.WarnLevel}, + {403, zerolog.WarnLevel}, + {404, zerolog.WarnLevel}, + {429, zerolog.WarnLevel}, + {499, zerolog.WarnLevel}, + {500, zerolog.ErrorLevel}, + {502, zerolog.ErrorLevel}, + {503, zerolog.ErrorLevel}, + {599, zerolog.ErrorLevel}, + } + + for _, tt := range tests { + got := statusLevel(tt.status) + if got != tt.want { + t.Errorf("statusLevel(%d) = %v, want %v", tt.status, got, tt.want) + } + } +} + +func TestSetup_WithFile(t *testing.T) { + dir := t.TempDir() + logFile := filepath.Join(dir, "test.log") + + logger := Setup(Config{ + Level: "debug", + File: logFile, + MaxSizeMB: 10, + MaxBackups: 1, + MaxAgeDays: 1, + }) + + // Verify logger works (no panic) + logger.Info().Msg("test message") +} + +func TestSetup_WithoutFile(t *testing.T) { + // File empty — should use console or stderr mode depending on TTY + logger := Setup(Config{ + Level: "warn", + }) + + // Verify logger works (no panic) + logger.Warn().Msg("test warning") +} + +func TestSetup_DefaultLevel(t *testing.T) { + // Empty level should default to info + logger := Setup(Config{}) + _ = logger // verify no panic +} + +func TestSetup_InvalidLevel(t *testing.T) { + // Invalid level should default to info + logger := Setup(Config{Level: "not-a-level"}) + _ = logger // verify no panic +} + +func TestFromContext_NoLogger(t *testing.T) { + // Background context has no zerolog logger — should return global + ctx := context.Background() + l := FromContext(ctx) + if l == nil { + t.Fatal("FromContext returned nil") + } +} + +func TestFromContext_WithLogger(t *testing.T) { + logger := zerolog.Nop() + ctx := logger.WithContext(context.Background()) + l := FromContext(ctx) + if l == nil { + t.Fatal("FromContext returned nil") + } +} diff --git a/internal/proxy/billing_test.go b/internal/proxy/billing_test.go new file mode 100644 index 0000000..7e5a929 --- /dev/null +++ b/internal/proxy/billing_test.go @@ -0,0 +1,323 @@ +package proxy + +import ( + "encoding/hex" + "strings" + "testing" +) + +func TestFingerprintSaltConstant(t *testing.T) { + if fingerprintSalt != "59cf53e54c78" { + t.Errorf("fingerprintSalt = %q, want %q", fingerprintSalt, "59cf53e54c78") + } +} + +func TestComputeFingerprint_Deterministic(t *testing.T) { + a := computeFingerprint("hello world test message", "1.0.0") + b := computeFingerprint("hello world test message", "1.0.0") + if a != b { + t.Errorf("fingerprint not deterministic: %q != %q", a, b) + } +} + +func TestComputeFingerprint_Length(t *testing.T) { + fp := computeFingerprint("some message here", "2.0.0") + if len(fp) != 3 { + t.Errorf("fingerprint length = %d, want 3", len(fp)) + } + // Must be valid hex + if _, err := hex.DecodeString(fp + "0"); err != nil { // pad to even length for decode + // Check each char is hex individually + for _, c := range fp { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("fingerprint %q contains non-hex char %c", fp, c) + } + } + } +} + +func TestComputeFingerprint_DifferentVersions(t *testing.T) { + a := computeFingerprint("same message", "1.0.0") + b := computeFingerprint("same message", "2.0.0") + if a == b { + t.Errorf("different versions should (almost certainly) produce different fingerprints") + } +} + +func TestComputeFingerprint_ShortMessage(t *testing.T) { + // "hi" has only 2 chars, indices [4,7,20] all out of range → chars = "000" + fp := computeFingerprint("hi", "1.0.0") + if len(fp) != 3 { + t.Errorf("short message fingerprint length = %d, want 3", len(fp)) + } +} + +func TestComputeFingerprint_EmptyMessage(t *testing.T) { + // Empty → all indices out of range → chars = "000" + fp := computeFingerprint("", "1.0.0") + if len(fp) != 3 { + t.Errorf("empty message fingerprint length = %d, want 3", len(fp)) + } + // Empty and short message with same version should produce same fingerprint + // since both result in chars = "000" + fpShort := computeFingerprint("hi", "1.0.0") + if fp != fpShort { + t.Errorf("empty and 'hi' should produce same fingerprint (both use '000'), got %q vs %q", fp, fpShort) + } +} + +func TestComputeFingerprint_Unicode(t *testing.T) { + // Emoji: 🎉 is U+1F389, encoded as UTF-16 surrogate pair [0xD83C, 0xDF89] + // So "abcd🎉fg" in UTF-16 is [a, b, c, d, 0xD83C, 0xDF89, f, g] = 8 uint16 values + // indices [4,7,20]: runes[4]=0xD83C, runes[7]='g', runes[20]=out of range + fp := computeFingerprint("abcd🎉fg", "1.0.0") + if len(fp) != 3 { + t.Errorf("unicode fingerprint length = %d, want 3", len(fp)) + } +} + +func TestComputeFingerprint_CharExtraction(t *testing.T) { + // "Hello, World!" UTF-16: [H,e,l,l,o,',', ,W,o,r,l,d,!] + // indices [4,7,20]: runes[4]='o', runes[7]='W', runes[20]=out of range → "0" + // So chars should be "oW0" + // Verify by comparing to a message where we know the expected extracted chars + // Two messages that extract same chars at indices should produce same fingerprint + // "xxxxoxxWxxxxxxxxxxxx" → index 4='o', 7='W', 20=out of range → "oW0" (20 chars, index 20 out of range) + fp1 := computeFingerprint("Hello, World!", "1.0.0") + fp2 := computeFingerprint("xxxxoxxWxxxxxxxxxxxx", "1.0.0") + if fp1 != fp2 { + t.Errorf("messages with same chars at indices [4,7,20] should produce same fingerprint, got %q vs %q", fp1, fp2) + } +} + +func TestComputeFingerprint_IndexBoundary(t *testing.T) { + // Message with exactly 21 chars → index 20 is valid + msg21 := "abcdefghijklmnopqrstu" // 21 chars + fp21 := computeFingerprint(msg21, "1.0.0") + + // Message with exactly 20 chars → index 20 is out of range → "0" + msg20 := "abcdefghijklmnopqrst" // 20 chars + fp20 := computeFingerprint(msg20, "1.0.0") + + // They should differ because index 20 produces different chars + if fp21 == fp20 { + t.Errorf("boundary test: 21-char and 20-char messages should differ at index 20") + } +} + +func TestExtractFirstUserMessage(t *testing.T) { + tests := []struct { + name string + body string + expected string + }{ + { + name: "simple string content", + body: `{"messages":[{"role":"user","content":"hello world"}]}`, + expected: "hello world", + }, + { + name: "array content with text block", + body: `{"messages":[{"role":"user","content":[{"type":"text","text":"from array"}]}]}`, + expected: "from array", + }, + { + name: "no user messages", + body: `{"messages":[{"role":"assistant","content":"I am assistant"}]}`, + expected: "", + }, + { + name: "assistant only messages", + body: `{"messages":[{"role":"assistant","content":"a1"},{"role":"assistant","content":"a2"}]}`, + expected: "", + }, + { + name: "user with non-text block first then text", + body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"},{"type":"text","text":"the text"}]}]}`, + expected: "the text", + }, + { + name: "user with only non-text blocks", + body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]}]}`, + expected: "", + }, + { + name: "no messages field", + body: `{"model":"claude-sonnet-4-6"}`, + expected: "", + }, + { + name: "messages not array", + body: `{"messages":"not array"}`, + expected: "", + }, + { + name: "empty messages array", + body: `{"messages":[]}`, + expected: "", + }, + { + name: "first user message used even if multiple exist", + body: `{"messages":[{"role":"user","content":"first"},{"role":"user","content":"second"}]}`, + expected: "first", + }, + { + name: "assistant before user", + body: `{"messages":[{"role":"assistant","content":"assistant msg"},{"role":"user","content":"user msg"}]}`, + expected: "user msg", + }, + { + name: "user with array content - first text block used", + body: `{"messages":[{"role":"user","content":[{"type":"text","text":"first text"},{"type":"text","text":"second text"}]}]}`, + expected: "first text", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractFirstUserMessage([]byte(tt.body)) + if result != tt.expected { + t.Errorf("got %q, want %q", result, tt.expected) + } + }) + } +} + +func TestExtractFirstUserMessage_BreaksAfterFirstUser(t *testing.T) { + // The function should break after finding the first user message, + // even if it didn't extract text (e.g. user with only image blocks) + body := `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]},{"role":"user","content":"second user"}]}` + result := extractFirstUserMessage([]byte(body)) + // First user has no text blocks, function breaks, returns "" + if result != "" { + t.Errorf("should return empty when first user has no text, got %q", result) + } +} + +func TestBuildBillingHeader(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"test message"}]}`) + version := "1.2.3" + header := buildBillingHeader(body, version) + + // Check format + if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=1.2.3.") { + t.Errorf("header should start with 'x-anthropic-billing-header: cc_version=1.2.3.', got %q", header) + } + if !strings.Contains(header, "; cc_entrypoint=cli; cch=00000;") { + t.Errorf("header should contain '; cc_entrypoint=cli; cch=00000;', got %q", header) + } + + // Verify the fingerprint part is 3 chars + // Format: "x-anthropic-billing-header: cc_version=1.2.3.XXX; cc_entrypoint=cli; cch=00000;" + parts := strings.Split(header, "cc_version=") + if len(parts) != 2 { + t.Fatalf("unexpected header format: %q", header) + } + versionFP := strings.Split(parts[1], ";")[0] + if !strings.HasPrefix(versionFP, "1.2.3.") { + t.Errorf("version+fingerprint should start with '1.2.3.', got %q", versionFP) + } + fp := strings.TrimPrefix(versionFP, "1.2.3.") + if len(fp) != 3 { + t.Errorf("fingerprint should be 3 chars, got %q (len %d)", fp, len(fp)) + } +} + +func TestBuildBillingHeader_EmptyMessages(t *testing.T) { + body := []byte(`{"messages":[]}`) + version := "1.0.0" + header := buildBillingHeader(body, version) + if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=") { + t.Errorf("header format wrong: %q", header) + } +} + +func TestInjectBillingHeader_NoExistingSystem(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`) + result := injectBillingHeader(body, "1.0.0") + resultStr := string(result) + + // Should have system field now + if !strings.Contains(resultStr, `"system"`) { + t.Errorf("should inject system field, got %s", resultStr) + } + // System should be an array with one billing block + if !strings.Contains(resultStr, "x-anthropic-billing-header") { + t.Errorf("should contain billing header text, got %s", resultStr) + } + if !strings.Contains(resultStr, `"type":"text"`) { + t.Errorf("billing block should have type text, got %s", resultStr) + } +} + +func TestInjectBillingHeader_ExistingSystemArray(t *testing.T) { + body := []byte(`{"system":[{"type":"text","text":"existing prompt"}],"messages":[{"role":"user","content":"hi"}]}`) + result := injectBillingHeader(body, "1.0.0") + resultStr := string(result) + + // Should contain both billing header and existing prompt + if !strings.Contains(resultStr, "x-anthropic-billing-header") { + t.Errorf("should contain billing header, got %s", resultStr) + } + if !strings.Contains(resultStr, "existing prompt") { + t.Errorf("should preserve existing prompt, got %s", resultStr) + } + + // Billing block should be FIRST (prepended) + billingIdx := strings.Index(resultStr, "x-anthropic-billing-header") + existingIdx := strings.Index(resultStr, "existing prompt") + if billingIdx > existingIdx { + t.Errorf("billing block should come before existing prompt") + } +} + +func TestInjectBillingHeader_ExistingSystemString(t *testing.T) { + body := []byte(`{"system":"You are a helpful assistant","messages":[{"role":"user","content":"hi"}]}`) + result := injectBillingHeader(body, "1.0.0") + resultStr := string(result) + + // Should convert to array with billing block first, then original text + if !strings.Contains(resultStr, "x-anthropic-billing-header") { + t.Errorf("should contain billing header, got %s", resultStr) + } + if !strings.Contains(resultStr, "You are a helpful assistant") { + t.Errorf("should preserve original system string, got %s", resultStr) + } + + // Billing should come first + billingIdx := strings.Index(resultStr, "x-anthropic-billing-header") + origIdx := strings.Index(resultStr, "You are a helpful assistant") + if billingIdx > origIdx { + t.Errorf("billing block should come before original system text") + } +} + +func TestInjectBillingHeader_PreservesOtherFields(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`) + result := injectBillingHeader(body, "1.0.0") + resultStr := string(result) + + if !strings.Contains(resultStr, `"model":"claude-sonnet-4-6"`) { + t.Errorf("should preserve model field, got %s", resultStr) + } + if !strings.Contains(resultStr, `"max_tokens":1024`) { + t.Errorf("should preserve max_tokens field, got %s", resultStr) + } +} + +func TestInjectBillingHeader_BillingBlockFormat(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"test"}]}`) + result := injectBillingHeader(body, "2.5.0") + resultStr := string(result) + + // Verify the billing block contains the correct version + if !strings.Contains(resultStr, "cc_version=2.5.0.") { + t.Errorf("billing block should contain cc_version=2.5.0., got %s", resultStr) + } + if !strings.Contains(resultStr, "cc_entrypoint=cli") { + t.Errorf("billing block should contain cc_entrypoint=cli, got %s", resultStr) + } + if !strings.Contains(resultStr, "cch=00000") { + t.Errorf("billing block should contain cch=00000, got %s", resultStr) + } +} diff --git a/internal/proxy/handler_test.go b/internal/proxy/handler_test.go new file mode 100644 index 0000000..5de96db --- /dev/null +++ b/internal/proxy/handler_test.go @@ -0,0 +1,624 @@ +package proxy + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + + "github.com/fujin/anthropic-proxy/internal/auth" + "github.com/fujin/anthropic-proxy/internal/config" + "github.com/fujin/anthropic-proxy/internal/ratelimit" + "github.com/fujin/anthropic-proxy/internal/telemetry" + "go.opentelemetry.io/otel/metric/noop" +) + +func init() { + gin.SetMode(gin.TestMode) + // Initialize telemetry with noop meter to avoid nil pointer panics. + meter := noop.Meter{} + telemetry.InitMetrics(meter, nil) +} + +// --- Request body reading and sanitization --- + +func TestHandleMessages_ReadBodyError(t *testing.T) { + // A body that immediately fails on read shouldn't panic. + pool := auth.NewPool([]*auth.Credential{{ID: "c1", AccessToken: "tok", Email: "test@test.com"}}) + san := NewSanitizer(config.SanitizeConfig{}) + + handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", &errReader{}) + + handler(c) + + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest) + } + if !strings.Contains(w.Body.String(), "failed to read request body") { + t.Errorf("body = %q, expected error message about reading body", w.Body.String()) + } +} + +func TestHandleMessages_SanitizesRequestBody(t *testing.T) { + // We can't directly make HandleMessages use our mock server because + // UpstreamClient hardcodes messagesURL. Instead, we test sanitization + // by verifying the sanitizer is called on the body before any pool interaction. + san := NewSanitizer(config.SanitizeConfig{ + Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}}, + Body: []config.ReplaceRule{{Match: "secret", Replace: "redacted"}}, + }) + + // Create body with tool name and secret + body := `{"model":"claude-sonnet-4-6","tools":[{"name":"my_tool"}],"messages":[{"role":"user","content":"secret data"}]}` + sanitizedBody := san.SanitizeRequest([]byte(body)) + + // Verify sanitization happened correctly + if !strings.Contains(string(sanitizedBody), "renamed_tool") { + t.Error("expected tool to be renamed in sanitized body") + } + if strings.Contains(string(sanitizedBody), "my_tool") { + t.Error("original tool name should be gone after sanitization") + } + if !strings.Contains(string(sanitizedBody), "redacted") { + t.Error("expected 'secret' to be replaced with 'redacted'") + } + if strings.Contains(string(sanitizedBody), "secret") { + t.Error("'secret' should be gone after sanitization") + } +} + +func TestHandleMessages_PoolPickError(t *testing.T) { + // Empty pool — Pick() will fail. + pool := auth.NewPool(nil) + san := NewSanitizer(config.SanitizeConfig{}) + + handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil) + + body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}` + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + + handler(c) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable) + } + if !strings.Contains(w.Body.String(), "no credentials available") { + t.Errorf("body = %q, expected pool error", w.Body.String()) + } +} + +func TestHandleMessages_PoolAllOnCooldown(t *testing.T) { + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "e"} + pool := auth.NewPool([]*auth.Credential{cred}) + pool.MarkFailure(cred, 429) // puts on 30s cooldown + + san := NewSanitizer(config.SanitizeConfig{}) + handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil) + + body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}` + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + + handler(c) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable) + } + if !strings.Contains(w.Body.String(), "cooldown") { + t.Errorf("body = %q, expected cooldown message", w.Body.String()) + } +} + +// --- Stream vs non-stream routing --- + +func TestHandleMessages_StreamField_Detection(t *testing.T) { + tests := []struct { + name string + body string + isStream bool + }{ + { + name: "stream true", + body: `{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`, + isStream: true, + }, + { + name: "stream false", + body: `{"model":"claude-sonnet-4-6","stream":false,"messages":[{"role":"user","content":"hi"}]}`, + isStream: false, + }, + { + name: "no stream field defaults to false", + body: `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`, + isStream: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := gjson.Get(tt.body, "stream").Bool() + if got != tt.isStream { + t.Errorf("stream = %v, want %v", got, tt.isStream) + } + }) + } +} + +// --- Desanitization on response --- + +func TestDesanitization_NonStreamResponse(t *testing.T) { + san := NewSanitizer(config.SanitizeConfig{ + Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}}, + }) + + // Simulate upstream response with renamed tool + upstreamResponse := `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}]}` + desanitized := san.DesanitizeResponse([]byte(upstreamResponse)) + + if !strings.Contains(string(desanitized), "original_tool") { + t.Errorf("expected tool name to be desanitized back to 'original_tool', got %s", string(desanitized)) + } + if strings.Contains(string(desanitized), `"name":"renamed_tool"`) { + t.Error("renamed_tool should have been replaced by original_tool") + } +} + +func TestDesanitization_StreamEvent(t *testing.T) { + san := NewSanitizer(config.SanitizeConfig{ + Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}}, + }) + + event := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}` + desanitized := san.DesanitizeStreamEvent(event) + + if !strings.Contains(desanitized, "original_tool") { + t.Errorf("expected stream event to be desanitized, got %s", desanitized) + } +} + +// --- handleNonStream behavior tests via direct function --- + +func TestHandleNonStream_ConnectionError(t *testing.T) { + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} + pool := auth.NewPool([]*auth.Credential{cred}) + san := NewSanitizer(config.SanitizeConfig{}) + tracker := ratelimit.NewTracker(func() string { return "" }) + + uc := &UpstreamClient{ + client: http.Client{Transport: &failingTransport{}}, + sessionID: "test-sess", + profile: nil, + } + + body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + handleNonStream(c, uc, san, pool, cred, body, body, tracker) + + if w.Code != http.StatusBadGateway { + t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway) + } + if !strings.Contains(w.Body.String(), "upstream request failed") { + t.Errorf("body = %q, expected upstream error message", w.Body.String()) + } +} + +func TestHandleNonStream_UpstreamSuccess(t *testing.T) { + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Request-Id", "req-123") + w.WriteHeader(200) + w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hello"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) + })) + defer mockUpstream.Close() + + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} + pool := auth.NewPool([]*auth.Credential{cred}) + san := NewSanitizer(config.SanitizeConfig{}) + tracker := ratelimit.NewTracker(func() string { return "" }) + + uc := &UpstreamClient{ + client: *mockUpstream.Client(), + sessionID: "test-sess", + profile: nil, + } + // Override the messagesURL by constructing a custom Execute that uses the mock. + // Since we can't override the const, we test via a mock server approach: + // We create a custom http.Client with a transport that redirects to our mock. + uc.client.Transport = &rewriteTransport{ + base: mockUpstream.Client().Transport, + destURL: mockUpstream.URL, + } + + body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + handleNonStream(c, uc, san, pool, cred, body, body, tracker) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "hello") { + t.Errorf("response body missing expected content: %s", w.Body.String()) + } + if got := w.Header().Get("Content-Type"); got != "application/json" { + t.Errorf("Content-Type = %q, want %q", got, "application/json") + } + if got := w.Header().Get("X-Request-Id"); got != "req-123" { + t.Errorf("X-Request-Id = %q, want %q", got, "req-123") + } +} + +func TestHandleNonStream_UpstreamError_MarkFailure(t *testing.T) { + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(429) + w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"too many requests"}}`)) + })) + defer mockUpstream.Close() + + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} + pool := auth.NewPool([]*auth.Credential{cred}) + san := NewSanitizer(config.SanitizeConfig{}) + + uc := &UpstreamClient{ + client: *mockUpstream.Client(), + sessionID: "test-sess", + profile: nil, + } + uc.client.Transport = &rewriteTransport{ + base: mockUpstream.Client().Transport, + destURL: mockUpstream.URL, + } + + body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + handleNonStream(c, uc, san, pool, cred, body, body, nil) + + if w.Code != 429 { + t.Errorf("status = %d, want 429", w.Code) + } + + // Verify MarkFailure was called — cred should now be on cooldown + if !cred.IsOnCooldown() { + t.Error("expected credential to be on cooldown after 429") + } +} + +func TestHandleNonStream_UpstreamSuccess_DesanitizesResponse(t *testing.T) { + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}],"model":"claude-sonnet-4-6","stop_reason":"tool_use","usage":{"input_tokens":10,"output_tokens":5}}`)) + })) + defer mockUpstream.Close() + + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} + pool := auth.NewPool([]*auth.Credential{cred}) + san := NewSanitizer(config.SanitizeConfig{ + Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}}, + }) + + uc := &UpstreamClient{ + client: *mockUpstream.Client(), + sessionID: "test-sess", + profile: nil, + } + uc.client.Transport = &rewriteTransport{ + base: mockUpstream.Client().Transport, + destURL: mockUpstream.URL, + } + + body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + handleNonStream(c, uc, san, pool, cred, body, body, nil) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String()) + } + // Should be desanitized back to original_tool + if !strings.Contains(w.Body.String(), "original_tool") { + t.Errorf("response should contain desanitized tool name 'original_tool', got %s", w.Body.String()) + } +} + +func TestHandleNonStream_Upstream500_MarkFailure(t *testing.T) { + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + w.Write([]byte(`{"error":{"type":"server_error","message":"internal error"}}`)) + })) + defer mockUpstream.Close() + + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} + pool := auth.NewPool([]*auth.Credential{cred}) + san := NewSanitizer(config.SanitizeConfig{}) + + uc := &UpstreamClient{ + client: *mockUpstream.Client(), + sessionID: "test-sess", + profile: nil, + } + uc.client.Transport = &rewriteTransport{ + base: mockUpstream.Client().Transport, + destURL: mockUpstream.URL, + } + + body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + handleNonStream(c, uc, san, pool, cred, body, body, nil) + + if w.Code != 500 { + t.Errorf("status = %d, want 500", w.Code) + } + if !cred.IsOnCooldown() { + t.Error("expected credential to be on cooldown after 500") + } +} + +// --- handleStream behavior tests --- + +func TestHandleStream_ConnectionError(t *testing.T) { + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} + pool := auth.NewPool([]*auth.Credential{cred}) + san := NewSanitizer(config.SanitizeConfig{}) + + uc := &UpstreamClient{ + client: http.Client{Transport: &failingTransport{}}, + sessionID: "test-sess", + profile: nil, + } + + body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + handleStream(c, uc, san, pool, cred, body, body, nil) + + if w.Code != http.StatusBadGateway { + t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway) + } + if !strings.Contains(w.Body.String(), "upstream stream request failed") { + t.Errorf("body = %q, expected upstream stream error", w.Body.String()) + } +} + +func TestHandleStream_UpstreamError(t *testing.T) { + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(429) + w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"rate limited"}}`)) + })) + defer mockUpstream.Close() + + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} + pool := auth.NewPool([]*auth.Credential{cred}) + san := NewSanitizer(config.SanitizeConfig{}) + + uc := &UpstreamClient{ + client: *mockUpstream.Client(), + sessionID: "test-sess", + profile: nil, + } + uc.client.Transport = &rewriteTransport{ + base: mockUpstream.Client().Transport, + destURL: mockUpstream.URL, + } + + body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + handleStream(c, uc, san, pool, cred, body, body, nil) + + if w.Code != 429 { + t.Errorf("status = %d, want 429", w.Code) + } + if !cred.IsOnCooldown() { + t.Error("expected credential on cooldown after stream 429") + } +} + +func TestHandleStream_SuccessForwardsEvents(t *testing.T) { + events := "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":0}}}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n" + + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(200) + w.Write([]byte(events)) + })) + defer mockUpstream.Close() + + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} + pool := auth.NewPool([]*auth.Credential{cred}) + san := NewSanitizer(config.SanitizeConfig{}) + + uc := &UpstreamClient{ + client: *mockUpstream.Client(), + sessionID: "test-sess", + profile: nil, + } + uc.client.Transport = &rewriteTransport{ + base: mockUpstream.Client().Transport, + destURL: mockUpstream.URL, + } + + body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + handleStream(c, uc, san, pool, cred, body, body, nil) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d", w.Code, http.StatusOK) + } + + respBody := w.Body.String() + if !strings.Contains(respBody, "message_start") { + t.Error("response missing message_start event") + } + if !strings.Contains(respBody, "hello") { + t.Error("response missing text content 'hello'") + } + if !strings.Contains(respBody, "message_stop") { + t.Error("response missing message_stop event") + } + + // Verify SSE headers + if got := w.Header().Get("Content-Type"); got != "text/event-stream" { + t.Errorf("Content-Type = %q, want %q", got, "text/event-stream") + } +} + +func TestHandleStream_DesanitizesEvents(t *testing.T) { + events := "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"name\":\"renamed_tool\",\"id\":\"t1\"}}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n" + + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(200) + w.Write([]byte(events)) + })) + defer mockUpstream.Close() + + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} + pool := auth.NewPool([]*auth.Credential{cred}) + san := NewSanitizer(config.SanitizeConfig{ + Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}}, + }) + + uc := &UpstreamClient{ + client: *mockUpstream.Client(), + sessionID: "test-sess", + profile: nil, + } + uc.client.Transport = &rewriteTransport{ + base: mockUpstream.Client().Transport, + destURL: mockUpstream.URL, + } + + body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + handleStream(c, uc, san, pool, cred, body, body, nil) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "original_tool") { + t.Errorf("stream response should contain desanitized 'original_tool', got %s", w.Body.String()) + } +} + +// --- HandleMessages full integration wiring test --- + +func TestHandleMessages_WiresHandlerCorrectly(t *testing.T) { + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} + pool := auth.NewPool([]*auth.Credential{cred}) + san := NewSanitizer(config.SanitizeConfig{}) + + // Verify the handler can be created without panic + handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil) + if handler == nil { + t.Fatal("HandleMessages returned nil handler") + } +} + +func TestHandleMessages_EmptyBody(t *testing.T) { + cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} + pool := auth.NewPool([]*auth.Credential{cred}) + san := NewSanitizer(config.SanitizeConfig{}) + + handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil) + + // Empty body — handler should still try to pick cred and call upstream + // (which will fail with connection error to api.anthropic.com, not a panic) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader("")) + + handler(c) + + // Should get a 502 because the upstream URL (api.anthropic.com) won't be reachable + // in test environment, or it might complete. The key thing is no panic. + // We mainly verify it doesn't panic. + if w.Code == 0 { + t.Error("expected non-zero status code") + } +} + +// --- Test helpers --- + +// errReader is an io.Reader that always returns an error. +type errReader struct{} + +func (e *errReader) Read([]byte) (int, error) { + return 0, io.ErrUnexpectedEOF +} + +// failingTransport is an http.RoundTripper that always returns an error. +type failingTransport struct{} + +func (f *failingTransport) RoundTrip(*http.Request) (*http.Response, error) { + return nil, fmt.Errorf("connection refused") +} + +// rewriteTransport intercepts HTTP requests and rewrites the URL to point +// at a local test server. This allows testing with UpstreamClient's hardcoded +// messagesURL by redirecting all requests to a mock server. +type rewriteTransport struct { + base http.RoundTripper + destURL string +} + +func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Rewrite the request URL to point at our mock server + newReq := req.Clone(req.Context()) + newReq.URL.Scheme = "http" + newReq.URL.Host = strings.TrimPrefix(t.destURL, "http://") + newReq.URL.Path = "/v1/messages" + newReq.URL.RawQuery = "" + if t.base == nil { + return http.DefaultTransport.RoundTrip(newReq) + } + return t.base.RoundTrip(newReq) +} diff --git a/internal/proxy/sanitize_test.go b/internal/proxy/sanitize_test.go new file mode 100644 index 0000000..4b71172 --- /dev/null +++ b/internal/proxy/sanitize_test.go @@ -0,0 +1,476 @@ +package proxy + +import ( + "strings" + "testing" + + "github.com/fujin/anthropic-proxy/internal/config" +) + +func TestNewSanitizer_Empty(t *testing.T) { + s := NewSanitizer(config.SanitizeConfig{}) + if len(s.toolsForward) != 0 { + t.Errorf("expected empty toolsForward, got %d entries", len(s.toolsForward)) + } + if len(s.toolsReverse) != 0 { + t.Errorf("expected empty toolsReverse, got %d entries", len(s.toolsReverse)) + } + if s.systemRules != nil { + t.Errorf("expected nil systemRules") + } + if s.bodyRules != nil { + t.Errorf("expected nil bodyRules") + } +} + +func TestNewSanitizer_WithTools(t *testing.T) { + cfg := config.SanitizeConfig{ + Tools: []config.RenameRule{ + {From: "old_tool", To: "new_tool"}, + {From: "another", To: "replaced"}, + }, + } + s := NewSanitizer(cfg) + if got := s.toolsForward["old_tool"]; got != "new_tool" { + t.Errorf("toolsForward[old_tool] = %q, want %q", got, "new_tool") + } + if got := s.toolsReverse["new_tool"]; got != "old_tool" { + t.Errorf("toolsReverse[new_tool] = %q, want %q", got, "old_tool") + } + if got := s.toolsForward["another"]; got != "replaced" { + t.Errorf("toolsForward[another] = %q, want %q", got, "replaced") + } + if got := s.toolsReverse["replaced"]; got != "another" { + t.Errorf("toolsReverse[replaced] = %q, want %q", got, "another") + } +} + +func TestNewSanitizer_WithSystemAndBodyRules(t *testing.T) { + cfg := config.SanitizeConfig{ + System: []config.ReplaceRule{{Match: "foo", Replace: "bar"}}, + Body: []config.ReplaceRule{{Match: "baz", Replace: "qux"}}, + } + s := NewSanitizer(cfg) + if len(s.systemRules) != 1 || s.systemRules[0].Match != "foo" { + t.Errorf("systemRules not set correctly") + } + if len(s.bodyRules) != 1 || s.bodyRules[0].Match != "baz" { + t.Errorf("bodyRules not set correctly") + } +} + +func TestRenameTools(t *testing.T) { + tests := []struct { + name string + forward map[string]string + body string + expected string + }{ + { + name: "empty map returns body unchanged", + forward: map[string]string{}, + body: `{"tools":[{"name":"my_tool"}]}`, + expected: `{"tools":[{"name":"my_tool"}]}`, + }, + { + name: "no tools array returns body unchanged", + forward: map[string]string{"my_tool": "renamed"}, + body: `{"messages":[]}`, + expected: `{"messages":[]}`, + }, + { + name: "tools is not array returns body unchanged", + forward: map[string]string{"my_tool": "renamed"}, + body: `{"tools":"not_array"}`, + expected: `{"tools":"not_array"}`, + }, + { + name: "matching tool gets renamed", + forward: map[string]string{"my_tool": "renamed_tool"}, + body: `{"tools":[{"name":"my_tool","description":"desc"}]}`, + expected: `renamed_tool`, + }, + { + name: "non-matching tool unchanged", + forward: map[string]string{"other_tool": "renamed"}, + body: `{"tools":[{"name":"my_tool"}]}`, + expected: `my_tool`, + }, + { + name: "partial match - only exact match renames", + forward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"}, + body: `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}`, + expected: `tool_x`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Sanitizer{ + toolsForward: tt.forward, + toolsReverse: make(map[string]string), + } + result := string(s.renameTools([]byte(tt.body))) + if !strings.Contains(result, tt.expected) { + t.Errorf("result %q does not contain %q", result, tt.expected) + } + }) + } +} + +func TestRenameTools_MultipleTools(t *testing.T) { + s := &Sanitizer{ + toolsForward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"}, + toolsReverse: make(map[string]string), + } + body := `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}` + result := string(s.renameTools([]byte(body))) + if !strings.Contains(result, `"tool_x"`) { + t.Errorf("tool_a should be renamed to tool_x, got %s", result) + } + if !strings.Contains(result, `"tool_y"`) { + t.Errorf("tool_b should be renamed to tool_y, got %s", result) + } + if !strings.Contains(result, `"tool_c"`) { + t.Errorf("tool_c should remain unchanged, got %s", result) + } +} + +func TestReplaceSystem(t *testing.T) { + tests := []struct { + name string + rules []config.ReplaceRule + body string + contains string + }{ + { + name: "empty rules returns body unchanged", + rules: nil, + body: `{"system":[{"type":"text","text":"hello world"}]}`, + contains: "hello world", + }, + { + name: "no system field returns body unchanged", + rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}}, + body: `{"messages":[]}`, + contains: `"messages":[]`, + }, + { + name: "system not array returns body unchanged", + rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}}, + body: `{"system":"just a string"}`, + contains: "just a string", + }, + { + name: "single block single rule", + rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}}, + body: `{"system":[{"type":"text","text":"hello world"}]}`, + contains: "goodbye world", + }, + { + name: "multiple blocks", + rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}}, + body: `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}`, + contains: "BBB first", + }, + { + name: "multiple rules applied in order", + rules: []config.ReplaceRule{{Match: "cat", Replace: "dog"}, {Match: "dog", Replace: "fish"}}, + body: `{"system":[{"type":"text","text":"I have a cat"}]}`, + contains: "I have a fish", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Sanitizer{ + toolsForward: make(map[string]string), + toolsReverse: make(map[string]string), + systemRules: tt.rules, + } + result := string(s.replaceSystem([]byte(tt.body))) + if !strings.Contains(result, tt.contains) { + t.Errorf("result %q does not contain %q", result, tt.contains) + } + }) + } +} + +func TestReplaceSystem_MultipleBlocks(t *testing.T) { + s := &Sanitizer{ + toolsForward: make(map[string]string), + toolsReverse: make(map[string]string), + systemRules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}}, + } + body := `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}` + result := string(s.replaceSystem([]byte(body))) + if !strings.Contains(result, "BBB first") { + t.Errorf("first block not replaced: %s", result) + } + if !strings.Contains(result, "BBB second") { + t.Errorf("second block not replaced: %s", result) + } +} + +func TestReplaceBody(t *testing.T) { + tests := []struct { + name string + rules []config.ReplaceRule + body string + expected string + }{ + { + name: "empty rules returns body unchanged", + rules: nil, + body: `{"foo":"bar"}`, + expected: `{"foo":"bar"}`, + }, + { + name: "single replacement across entire body", + rules: []config.ReplaceRule{{Match: "SECRET", Replace: "REDACTED"}}, + body: `{"data":"SECRET value SECRET"}`, + expected: `{"data":"REDACTED value REDACTED"}`, + }, + { + name: "multiple rules applied sequentially", + rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}, {Match: "BBB", Replace: "CCC"}}, + body: `{"text":"AAA"}`, + expected: `{"text":"CCC"}`, + }, + { + name: "no match leaves body unchanged", + rules: []config.ReplaceRule{{Match: "NOMATCH", Replace: "X"}}, + body: `{"text":"hello"}`, + expected: `{"text":"hello"}`, + }, + { + name: "empty body", + rules: []config.ReplaceRule{{Match: "a", Replace: "b"}}, + body: ``, + expected: ``, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Sanitizer{ + toolsForward: make(map[string]string), + toolsReverse: make(map[string]string), + bodyRules: tt.rules, + } + result := string(s.replaceBody([]byte(tt.body))) + if result != tt.expected { + t.Errorf("got %q, want %q", result, tt.expected) + } + }) + } +} + +func TestSanitizeRequest(t *testing.T) { + cfg := config.SanitizeConfig{ + Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}}, + System: []config.ReplaceRule{{Match: "INTERNAL", Replace: "PUBLIC"}}, + Body: []config.ReplaceRule{{Match: "secret_val", Replace: "safe_val"}}, + } + s := NewSanitizer(cfg) + + body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"INTERNAL info"}],"data":"secret_val here"}` + result := string(s.SanitizeRequest([]byte(body))) + + if !strings.Contains(result, `"renamed_tool"`) { + t.Errorf("tool not renamed in result: %s", result) + } + if !strings.Contains(result, "PUBLIC info") { + t.Errorf("system not replaced in result: %s", result) + } + if !strings.Contains(result, "safe_val here") { + t.Errorf("body not replaced in result: %s", result) + } + if strings.Contains(result, "secret_val") { + t.Errorf("secret_val should have been replaced: %s", result) + } +} + +func TestSanitizeRequest_EmptyConfig(t *testing.T) { + s := NewSanitizer(config.SanitizeConfig{}) + body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"hello"}]}` + result := string(s.SanitizeRequest([]byte(body))) + if result != body { + t.Errorf("empty config should not modify body.\ngot: %s\nwant: %s", result, body) + } +} + +func TestDesanitizeResponse(t *testing.T) { + tests := []struct { + name string + reverse map[string]string + body string + expected string + }{ + { + name: "no content field returns unchanged", + reverse: map[string]string{"renamed": "original"}, + body: `{"id":"msg_1","role":"assistant"}`, + expected: `{"id":"msg_1","role":"assistant"}`, + }, + { + name: "content not array returns unchanged", + reverse: map[string]string{"renamed": "original"}, + body: `{"content":"just text"}`, + expected: `{"content":"just text"}`, + }, + { + name: "non-tool_use block left unchanged", + reverse: map[string]string{"renamed": "original"}, + body: `{"content":[{"type":"text","text":"hello"}]}`, + expected: `{"content":[{"type":"text","text":"hello"}]}`, + }, + { + name: "tool_use block with matching name gets reversed", + reverse: map[string]string{"renamed_tool": "original_tool"}, + body: `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1"}]}`, + expected: `original_tool`, + }, + { + name: "tool_use block with no match unchanged", + reverse: map[string]string{"other": "something"}, + body: `{"content":[{"type":"tool_use","name":"my_tool","id":"t1"}]}`, + expected: `my_tool`, + }, + { + name: "mixed blocks only tool_use reversed", + reverse: map[string]string{"renamed": "original"}, + body: `{"content":[{"type":"text","text":"hi"},{"type":"tool_use","name":"renamed","id":"t1"}]}`, + expected: `original`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Sanitizer{ + toolsForward: make(map[string]string), + toolsReverse: tt.reverse, + } + result := string(s.DesanitizeResponse([]byte(tt.body))) + if !strings.Contains(result, tt.expected) { + t.Errorf("result %q does not contain %q", result, tt.expected) + } + }) + } +} + +func TestDesanitizeResponse_MultipleToolUse(t *testing.T) { + s := &Sanitizer{ + toolsForward: make(map[string]string), + toolsReverse: map[string]string{"r1": "o1", "r2": "o2"}, + } + body := `{"content":[{"type":"tool_use","name":"r1","id":"t1"},{"type":"text","text":"x"},{"type":"tool_use","name":"r2","id":"t2"}]}` + result := string(s.DesanitizeResponse([]byte(body))) + if !strings.Contains(result, `"o1"`) { + t.Errorf("r1 not reversed to o1: %s", result) + } + if !strings.Contains(result, `"o2"`) { + t.Errorf("r2 not reversed to o2: %s", result) + } +} + +func TestDesanitizeStreamEvent(t *testing.T) { + tests := []struct { + name string + reverse map[string]string + line string + expected string + }{ + { + name: "non-data line passed through", + reverse: map[string]string{"r": "o"}, + line: "event: content_block_start", + expected: "event: content_block_start", + }, + { + name: "data line without tool_use passed through", + reverse: map[string]string{"r": "o"}, + line: `data: {"type":"text","text":"hello"}`, + expected: `data: {"type":"text","text":"hello"}`, + }, + { + name: "data line with tool_use in content_block.name", + reverse: map[string]string{"renamed_tool": "original_tool"}, + line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}`, + expected: `original_tool`, + }, + { + name: "data line with tool_use in delta.name", + reverse: map[string]string{"renamed_tool": "original_tool"}, + line: `data: {"type":"content_block_delta","delta":{"type":"tool_use","name":"renamed_tool"}}`, + expected: `original_tool`, + }, + { + name: "data line with tool_use but no matching name", + reverse: map[string]string{"other": "something"}, + line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"my_tool","id":"t1"}}`, + expected: `my_tool`, + }, + { + name: "empty line passed through", + reverse: map[string]string{"r": "o"}, + line: "", + expected: "", + }, + { + name: "line contains tool_use but not data prefix - passed through", + reverse: map[string]string{"r": "o"}, + line: "event: tool_use", + expected: "event: tool_use", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Sanitizer{ + toolsForward: make(map[string]string), + toolsReverse: tt.reverse, + } + result := s.DesanitizeStreamEvent(tt.line) + if !strings.Contains(result, tt.expected) { + t.Errorf("result %q does not contain %q", result, tt.expected) + } + }) + } +} + +func TestDesanitizeStreamEvent_DataPrefixPreserved(t *testing.T) { + s := &Sanitizer{ + toolsForward: make(map[string]string), + toolsReverse: map[string]string{"renamed": "original"}, + } + line := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed","id":"t1"}}` + result := s.DesanitizeStreamEvent(line) + if !strings.HasPrefix(result, "data: ") { + t.Errorf("result should start with 'data: ', got %q", result) + } +} + +func TestSanitizeRequest_MalformedJSON(t *testing.T) { + s := NewSanitizer(config.SanitizeConfig{ + Tools: []config.RenameRule{{From: "a", To: "b"}}, + System: []config.ReplaceRule{{Match: "x", Replace: "y"}}, + }) + // Malformed JSON - renameTools and replaceSystem should handle gracefully + body := `not valid json` + result := string(s.SanitizeRequest([]byte(body))) + // Should not panic; body rules still do string replacement + if result != "not valid json" { + t.Errorf("malformed JSON should pass through (no body rules match), got %q", result) + } +} + +func TestSanitizeRequest_EmptyBody(t *testing.T) { + s := NewSanitizer(config.SanitizeConfig{ + Tools: []config.RenameRule{{From: "a", To: "b"}}, + }) + result := s.SanitizeRequest([]byte{}) + if len(result) != 0 { + t.Errorf("empty body should return empty, got %q", string(result)) + } +} diff --git a/internal/proxy/sniff_test.go b/internal/proxy/sniff_test.go new file mode 100644 index 0000000..3bcd485 --- /dev/null +++ b/internal/proxy/sniff_test.go @@ -0,0 +1,278 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newRequest(t *testing.T, headers map[string][]string) *http.Request { + t.Helper() + r := httptest.NewRequest("POST", "/v1/messages", nil) + r.Header = http.Header{} + for k, vals := range headers { + for _, v := range vals { + r.Header.Add(k, v) + } + } + return r +} + +func TestExtractProfile_BasicHeaders(t *testing.T) { + r := newRequest(t, map[string][]string{ + "Content-Type": {"application/json"}, + "X-Custom-Header": {"custom-value"}, + "User-Agent": {"Claude/1.2.3 linux"}, + }) + body := []byte(`{"model":"claude-sonnet-4-6"}`) + + p := extractProfile(r, body) + + // Check version parsed + if p.Version != "1.2.3" { + t.Errorf("version = %q, want %q", p.Version, "1.2.3") + } + + // Check body preserved + if string(p.Body) != string(body) { + t.Errorf("body not preserved") + } + + // Check headers captured + found := map[string]bool{} + for _, h := range p.Headers { + found[strings.ToLower(h[0])] = true + } + if !found["content-type"] { + t.Error("Content-Type header should be captured") + } + if !found["x-custom-header"] { + t.Error("X-Custom-Header should be captured") + } +} + +func TestExtractProfile_SkipHeaders(t *testing.T) { + r := newRequest(t, map[string][]string{ + "Host": {"example.com"}, + "Content-Length": {"42"}, + "Authorization": {"Bearer token123"}, + "X-Api-Key": {"key123"}, + "Connection": {"keep-alive"}, + "Content-Type": {"application/json"}, + "X-Custom": {"keep-me"}, + }) + + p := extractProfile(r, []byte(`{}`)) + + for _, h := range p.Headers { + lower := strings.ToLower(h[0]) + if skipHeaders[lower] { + t.Errorf("header %q should have been skipped", h[0]) + } + } + + // Verify non-skipped headers are present + found := map[string]bool{} + for _, h := range p.Headers { + found[strings.ToLower(h[0])] = true + } + if !found["content-type"] { + t.Error("Content-Type should be kept") + } + if !found["x-custom"] { + t.Error("X-Custom should be kept") + } +} + +func TestExtractProfile_HeaderDeduplication(t *testing.T) { + r := newRequest(t, map[string][]string{ + "Content-Type": {"application/json"}, + }) + // Add duplicate with different casing - Go's http.Header normalizes to canonical form + // so we need to add the same canonical header with multiple values to test dedup + r.Header.Add("Content-Type", "text/plain") + + p := extractProfile(r, []byte(`{}`)) + + // After deduplication by lowercase key, only one entry per key + seen := map[string]int{} + for _, h := range p.Headers { + seen[strings.ToLower(h[0])]++ + } + for key, count := range seen { + if count > 1 { + t.Errorf("header %q appears %d times after dedup, want 1", key, count) + } + } +} + +func TestExtractProfile_AnthropicBetaContextStripping(t *testing.T) { + r := newRequest(t, map[string][]string{ + "Anthropic-Beta": {"prompt-caching-2024-07-31,context-1m-2024-09-01,some-other-beta"}, + }) + + p := extractProfile(r, []byte(`{}`)) + + var betaValue string + for _, h := range p.Headers { + if strings.ToLower(h[0]) == "anthropic-beta" { + betaValue = h[1] + break + } + } + + if strings.Contains(betaValue, "context-1m") { + t.Errorf("context-1m should be stripped from anthropic-beta, got %q", betaValue) + } + if !strings.Contains(betaValue, "prompt-caching-2024-07-31") { + t.Errorf("prompt-caching should be preserved, got %q", betaValue) + } + if !strings.Contains(betaValue, "some-other-beta") { + t.Errorf("some-other-beta should be preserved, got %q", betaValue) + } +} + +func TestExtractProfile_AnthropicBetaAllContextRemoved(t *testing.T) { + r := newRequest(t, map[string][]string{ + "Anthropic-Beta": {"context-1m-2024-09-01"}, + }) + + p := extractProfile(r, []byte(`{}`)) + + for _, h := range p.Headers { + if strings.ToLower(h[0]) == "anthropic-beta" { + // All betas were context-1m, so after filtering the value should be empty + if h[1] != "" { + t.Errorf("all context-1m betas stripped should leave empty, got %q", h[1]) + } + return + } + } + // It's also acceptable if the header is still present but empty +} + +func TestExtractProfile_VersionParsing(t *testing.T) { + tests := []struct { + name string + userAgent string + expected string + }{ + { + name: "standard Claude UA", + userAgent: "Claude/1.2.3 linux x86_64", + expected: "1.2.3", + }, + { + name: "version with no space after", + userAgent: "Claude/4.5.6", + expected: "4.5.6", + }, + { + name: "no slash in UA", + userAgent: "Mozilla 5.0", + expected: "", + }, + { + name: "empty UA", + userAgent: "", + expected: "", + }, + { + name: "slash at start", + userAgent: "/1.0.0 rest", + expected: "", + }, + { + name: "multiple slashes", + userAgent: "App/1.0.0 (sub/2.0)", + expected: "1.0.0", + }, + { + name: "version only after slash no space", + userAgent: "Tool/9.8.7", + expected: "9.8.7", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := newRequest(t, map[string][]string{ + "User-Agent": {tt.userAgent}, + }) + p := extractProfile(r, []byte(`{}`)) + if p.Version != tt.expected { + t.Errorf("version = %q, want %q", p.Version, tt.expected) + } + }) + } +} + +func TestExtractProfile_EmptyHeaders(t *testing.T) { + r := httptest.NewRequest("POST", "/v1/messages", nil) + r.Header = http.Header{} + + p := extractProfile(r, []byte(`{"test":true}`)) + + if len(p.Headers) != 0 { + t.Errorf("expected no headers, got %d", len(p.Headers)) + } + if p.Version != "" { + t.Errorf("expected empty version with no UA, got %q", p.Version) + } + if string(p.Body) != `{"test":true}` { + t.Errorf("body not preserved") + } +} + +func TestExtractProfile_BodyPreserved(t *testing.T) { + r := newRequest(t, map[string][]string{ + "User-Agent": {"Claude/1.0.0 test"}, + }) + body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hello"}],"stream":true}`) + + p := extractProfile(r, body) + + if string(p.Body) != string(body) { + t.Errorf("body not preserved.\ngot: %s\nwant: %s", p.Body, body) + } +} + +func TestSkipHeaders_Entries(t *testing.T) { + expected := map[string]bool{ + "host": true, + "content-length": true, + "authorization": true, + "x-api-key": true, + "connection": true, + } + if len(skipHeaders) != len(expected) { + t.Errorf("skipHeaders has %d entries, want %d", len(skipHeaders), len(expected)) + } + for k, v := range expected { + if skipHeaders[k] != v { + t.Errorf("skipHeaders[%q] = %v, want %v", k, skipHeaders[k], v) + } + } +} + +func TestSniffedProfile_Fields(t *testing.T) { + // Verify the struct can hold all expected data + p := &SniffedProfile{ + Headers: [][2]string{{"Content-Type", "application/json"}}, + Body: []byte(`{}`), + Version: "1.0.0", + } + if len(p.Headers) != 1 { + t.Error("Headers should have 1 entry") + } + if p.Headers[0][0] != "Content-Type" || p.Headers[0][1] != "application/json" { + t.Error("Header not stored correctly") + } + if string(p.Body) != `{}` { + t.Error("Body not stored correctly") + } + if p.Version != "1.0.0" { + t.Error("Version not stored correctly") + } +} diff --git a/internal/proxy/upstream_test.go b/internal/proxy/upstream_test.go new file mode 100644 index 0000000..71978e5 --- /dev/null +++ b/internal/proxy/upstream_test.go @@ -0,0 +1,334 @@ +package proxy + +import ( + "net/http" + "strings" + "testing" +) + +// --- NewUpstreamClient --- + +func TestNewUpstreamClient_NilProfile(t *testing.T) { + uc := NewUpstreamClient(nil) + if uc == nil { + t.Fatal("NewUpstreamClient returned nil") + } + if uc.sessionID == "" { + t.Error("expected non-empty sessionID") + } + if uc.profile != nil { + t.Error("expected nil profile") + } +} + +func TestNewUpstreamClient_WithProfile(t *testing.T) { + profile := &SniffedProfile{ + Version: "1.2.3", + Headers: [][2]string{{"User-Agent", "test/1.0"}}, + } + uc := NewUpstreamClient(profile) + if uc.profile != profile { + t.Error("expected profile to be stored") + } + if uc.sessionID == "" { + t.Error("expected non-empty sessionID") + } +} + +func TestNewUpstreamClient_UniqueSessionIDs(t *testing.T) { + uc1 := NewUpstreamClient(nil) + uc2 := NewUpstreamClient(nil) + if uc1.sessionID == uc2.sessionID { + t.Errorf("expected different session IDs, both got %q", uc1.sessionID) + } +} + +// --- version() --- + +func TestVersion_WithProfileVersion(t *testing.T) { + uc := &UpstreamClient{ + profile: &SniffedProfile{Version: "3.5.7"}, + } + if got := uc.version(); got != "3.5.7" { + t.Errorf("version() = %q, want %q", got, "3.5.7") + } +} + +func TestVersion_NilProfile_Fallback(t *testing.T) { + uc := &UpstreamClient{profile: nil} + if got := uc.version(); got != "2.1.92" { + t.Errorf("version() = %q, want %q", got, "2.1.92") + } +} + +func TestVersion_EmptyProfileVersion_Fallback(t *testing.T) { + uc := &UpstreamClient{ + profile: &SniffedProfile{Version: ""}, + } + if got := uc.version(); got != "2.1.92" { + t.Errorf("version() = %q, want %q", got, "2.1.92") + } +} + +// --- applyHeaders --- + +func TestApplyHeaders_NilProfile_NonOAuth_NonStream(t *testing.T) { + uc := &UpstreamClient{ + sessionID: "test-session-id", + profile: nil, + } + req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + + uc.applyHeaders(req, "sk-ant-api123", false) + + // x-api-key for non-OAuth token + if got := req.Header.Get("x-api-key"); got != "sk-ant-api123" { + t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api123") + } + // Should NOT have Authorization + if got := req.Header.Get("Authorization"); got != "" { + t.Errorf("Authorization = %q, want empty", got) + } + // Session ID + if got := req.Header.Get("X-Claude-Code-Session-Id"); got != "test-session-id" { + t.Errorf("X-Claude-Code-Session-Id = %q, want %q", got, "test-session-id") + } + // Request ID should be a UUID + if got := req.Header.Get("x-client-request-id"); got == "" { + t.Error("expected non-empty x-client-request-id") + } + // Non-stream: application/json + if got := req.Header.Get("Accept"); got != "application/json" { + t.Errorf("Accept = %q, want %q", got, "application/json") + } + // Accept-Encoding always identity + if got := req.Header.Get("Accept-Encoding"); got != "identity" { + t.Errorf("Accept-Encoding = %q, want %q", got, "identity") + } +} + +func TestApplyHeaders_NilProfile_NonOAuth_Stream(t *testing.T) { + uc := &UpstreamClient{ + sessionID: "sess", + profile: nil, + } + req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + + uc.applyHeaders(req, "sk-ant-api123", true) + + if got := req.Header.Get("Accept"); got != "text/event-stream" { + t.Errorf("Accept = %q, want %q", got, "text/event-stream") + } +} + +func TestApplyHeaders_OAuthToken_SetsBearerAndBetaFlag(t *testing.T) { + uc := &UpstreamClient{ + sessionID: "sess", + profile: nil, + } + req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + + uc.applyHeaders(req, "sk-ant-oat-mytoken", false) + + // OAuth: Authorization Bearer + if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-mytoken" { + t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-mytoken") + } + // Should NOT have x-api-key + if got := req.Header.Get("x-api-key"); got != "" { + t.Errorf("x-api-key = %q, want empty for OAuth", got) + } + // anthropic-beta should include oauth-2025-04-20 + if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" { + t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20") + } +} + +func TestApplyHeaders_OAuthToken_AppendsToExistingBeta(t *testing.T) { + uc := &UpstreamClient{ + sessionID: "sess", + profile: &SniffedProfile{ + Headers: [][2]string{ + {"anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15"}, + }, + }, + } + req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + + uc.applyHeaders(req, "sk-ant-oat-tok", false) + + beta := req.Header.Get("anthropic-beta") + if !strings.Contains(beta, "max-tokens-3-5-sonnet-2024-07-15") { + t.Errorf("anthropic-beta %q should contain existing beta", beta) + } + if !strings.Contains(beta, "oauth-2025-04-20") { + t.Errorf("anthropic-beta %q should contain oauth flag", beta) + } + // Should be appended with comma + if beta != "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20" { + t.Errorf("anthropic-beta = %q, want %q", beta, "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20") + } +} + +func TestApplyHeaders_OAuthToken_ExistingBetaAlreadyHasOAuth(t *testing.T) { + uc := &UpstreamClient{ + sessionID: "sess", + profile: &SniffedProfile{ + Headers: [][2]string{ + {"anthropic-beta", "oauth-2025-04-20,something-else"}, + }, + }, + } + req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + + uc.applyHeaders(req, "sk-ant-oat-tok", false) + + beta := req.Header.Get("anthropic-beta") + // Should NOT duplicate oauth flag + count := strings.Count(beta, "oauth-2025-04-20") + if count != 1 { + t.Errorf("oauth flag appeared %d times in %q, want 1", count, beta) + } +} + +func TestApplyHeaders_WithProfile_ReplaysHeaders(t *testing.T) { + uc := &UpstreamClient{ + sessionID: "sess", + profile: &SniffedProfile{ + Headers: [][2]string{ + {"User-Agent", "Claude/1.0"}, + {"anthropic-version", "2023-06-01"}, + {"Custom-Header", "custom-value"}, + }, + }, + } + req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + + uc.applyHeaders(req, "sk-ant-api123", false) + + if got := req.Header.Get("User-Agent"); got != "Claude/1.0" { + t.Errorf("User-Agent = %q, want %q", got, "Claude/1.0") + } + if got := req.Header.Get("anthropic-version"); got != "2023-06-01" { + t.Errorf("anthropic-version = %q, want %q", got, "2023-06-01") + } + if got := req.Header.Get("Custom-Header"); got != "custom-value" { + t.Errorf("Custom-Header = %q, want %q", got, "custom-value") + } +} + +func TestApplyHeaders_ProfileAuthHeadersRemoved(t *testing.T) { + uc := &UpstreamClient{ + sessionID: "sess", + profile: &SniffedProfile{ + Headers: [][2]string{ + {"Authorization", "Bearer old-token"}, + {"x-api-key", "old-api-key"}, + {"User-Agent", "test"}, + }, + }, + } + req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + + uc.applyHeaders(req, "sk-ant-api-new", false) + + // Old auth headers from profile should be removed + if got := req.Header.Get("Authorization"); got != "" { + t.Errorf("Authorization should be empty for non-OAuth, got %q", got) + } + // New auth should be set via x-api-key + if got := req.Header.Get("x-api-key"); got != "sk-ant-api-new" { + t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api-new") + } + // User-Agent from profile should remain + if got := req.Header.Get("User-Agent"); got != "test" { + t.Errorf("User-Agent = %q, want %q", got, "test") + } +} + +func TestApplyHeaders_ProfileAuthHeadersRemovedForOAuth(t *testing.T) { + uc := &UpstreamClient{ + sessionID: "sess", + profile: &SniffedProfile{ + Headers: [][2]string{ + {"Authorization", "Bearer old-token"}, + {"x-api-key", "old-api-key"}, + }, + }, + } + req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + + uc.applyHeaders(req, "sk-ant-oat-new", false) + + // Old x-api-key removed + if got := req.Header.Get("x-api-key"); got != "" { + t.Errorf("x-api-key should be empty for OAuth, got %q", got) + } + // New auth set via Authorization + if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-new" { + t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-new") + } +} + +func TestApplyHeaders_AcceptEncoding_AlwaysIdentity(t *testing.T) { + tests := []struct { + name string + streaming bool + }{ + {"non-stream", false}, + {"stream", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uc := &UpstreamClient{sessionID: "s", profile: nil} + req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + uc.applyHeaders(req, "token", tt.streaming) + + if got := req.Header.Get("Accept-Encoding"); got != "identity" { + t.Errorf("Accept-Encoding = %q, want %q", got, "identity") + } + }) + } +} + +func TestApplyHeaders_UniqueRequestIDs(t *testing.T) { + uc := &UpstreamClient{sessionID: "s", profile: nil} + + req1, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + uc.applyHeaders(req1, "tok", false) + + req2, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + uc.applyHeaders(req2, "tok", false) + + id1 := req1.Header.Get("x-client-request-id") + id2 := req2.Header.Get("x-client-request-id") + if id1 == "" || id2 == "" { + t.Fatal("expected non-empty request IDs") + } + if id1 == id2 { + t.Errorf("expected unique request IDs, both got %q", id1) + } +} + +func TestApplyHeaders_NonOAuth_NoAnthroPicBetaSet(t *testing.T) { + // Non-OAuth tokens should NOT set anthropic-beta oauth flag + uc := &UpstreamClient{sessionID: "s", profile: nil} + req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + uc.applyHeaders(req, "sk-ant-api123", false) + + beta := req.Header.Get("anthropic-beta") + if strings.Contains(beta, "oauth-2025-04-20") { + t.Errorf("non-OAuth token should not have oauth beta flag, got %q", beta) + } +} + +func TestApplyHeaders_OAuthToken_FreshBeta(t *testing.T) { + // No profile, no existing beta — should set fresh + uc := &UpstreamClient{sessionID: "s", profile: nil} + req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) + uc.applyHeaders(req, "sk-ant-oat-tok", false) + + if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" { + t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20") + } +} diff --git a/internal/ratelimit/tracker_test.go b/internal/ratelimit/tracker_test.go new file mode 100644 index 0000000..23cc427 --- /dev/null +++ b/internal/ratelimit/tracker_test.go @@ -0,0 +1,278 @@ +package ratelimit + +import ( + "net/http" + "testing" + "time" +) + +func TestNewTracker(t *testing.T) { + called := false + tr := NewTracker(func() string { + called = true + return "tok" + }) + if tr == nil { + t.Fatal("NewTracker returned nil") + } + // tokenFn stored but not called during construction + if called { + t.Error("tokenFn should not be called by NewTracker") + } + // Invoke to verify it's wired + if got := tr.tokenFn(); got != "tok" { + t.Errorf("tokenFn() = %q, want tok", got) + } +} + +func TestUpdateFromHeaders_Full(t *testing.T) { + tr := NewTracker(func() string { return "" }) + + h := http.Header{} + h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "0.42") + h.Set("Anthropic-Ratelimit-Unified-5h-Reset", "1700000000") + h.Set("Anthropic-Ratelimit-Unified-7d-Utilization", "0.75") + h.Set("Anthropic-Ratelimit-Unified-7d-Reset", "1700100000") + + tr.UpdateFromHeaders(h) + + fh := tr.FiveHour() + if fh.Utilization != 42.0 { + t.Errorf("FiveHour.Utilization = %f, want 42.0", fh.Utilization) + } + wantReset5h := time.Unix(1700000000, 0).UTC().Truncate(time.Minute) + if !fh.ResetsAt.Equal(wantReset5h) { + t.Errorf("FiveHour.ResetsAt = %v, want %v", fh.ResetsAt, wantReset5h) + } + + sd := tr.SevenDay() + if sd.Utilization != 75.0 { + t.Errorf("SevenDay.Utilization = %f, want 75.0", sd.Utilization) + } + wantReset7d := time.Unix(1700100000, 0).UTC().Truncate(time.Minute) + if !sd.ResetsAt.Equal(wantReset7d) { + t.Errorf("SevenDay.ResetsAt = %v, want %v", sd.ResetsAt, wantReset7d) + } +} + +func TestUpdateFromHeaders_Partial(t *testing.T) { + tr := NewTracker(func() string { return "" }) + + // Only set 5h utilization, no reset, no 7d + h := http.Header{} + h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "0.33") + tr.UpdateFromHeaders(h) + + fh := tr.FiveHour() + if fh.Utilization != 33.0 { + t.Errorf("FiveHour.Utilization = %f, want 33.0", fh.Utilization) + } + if !fh.ResetsAt.IsZero() { + t.Errorf("FiveHour.ResetsAt should be zero, got %v", fh.ResetsAt) + } + + sd := tr.SevenDay() + if sd.Utilization != 0 { + t.Errorf("SevenDay.Utilization = %f, want 0", sd.Utilization) + } +} + +func TestUpdateFromHeaders_Missing(t *testing.T) { + tr := NewTracker(func() string { return "" }) + + // Pre-set some state + tr.mu.Lock() + tr.fiveHour.Utilization = 50.0 + tr.mu.Unlock() + + // Update with empty headers — should not change state + tr.UpdateFromHeaders(http.Header{}) + + fh := tr.FiveHour() + if fh.Utilization != 50.0 { + t.Errorf("FiveHour.Utilization = %f, want 50.0 (unchanged)", fh.Utilization) + } +} + +func TestUpdateFromHeaders_InvalidValues(t *testing.T) { + tr := NewTracker(func() string { return "" }) + + h := http.Header{} + h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "not-a-number") + h.Set("Anthropic-Ratelimit-Unified-5h-Reset", "not-a-timestamp") + + tr.UpdateFromHeaders(h) + + fh := tr.FiveHour() + if fh.Utilization != 0 { + t.Errorf("Utilization should stay 0 for invalid input, got %f", fh.Utilization) + } + if !fh.ResetsAt.IsZero() { + t.Errorf("ResetsAt should stay zero for invalid input, got %v", fh.ResetsAt) + } +} + +func TestSonnet_Snapshot(t *testing.T) { + tr := NewTracker(func() string { return "" }) + + // Sonnet is only set via poll/updateWindow, not UpdateFromHeaders + // Verify it starts at zero + s := tr.Sonnet() + if s.Utilization != 0 { + t.Errorf("Sonnet.Utilization = %f, want 0", s.Utilization) + } + if !s.ResetsAt.IsZero() { + t.Errorf("Sonnet.ResetsAt should be zero, got %v", s.ResetsAt) + } +} + +func TestExtra_Default(t *testing.T) { + tr := NewTracker(func() string { return "" }) + + extra := tr.Extra() + if extra.IsEnabled { + t.Error("Extra.IsEnabled should be false by default") + } + if extra.MonthlyLimit != nil { + t.Error("Extra.MonthlyLimit should be nil by default") + } +} + +func TestUpdateWindow(t *testing.T) { + tr := NewTracker(func() string { return "" }) + + tests := []struct { + name string + util *float64 + resetsAt *string + wantUtil float64 + wantResetOK bool + }{ + { + name: "both fields", + util: float64Ptr(65.5), + resetsAt: stringPtr("2024-01-15T10:30:45Z"), + wantUtil: 65.5, + wantResetOK: true, + }, + { + name: "utilization only", + util: float64Ptr(30.0), + resetsAt: nil, + wantUtil: 30.0, + wantResetOK: false, + }, + { + name: "reset only (RFC3339Nano)", + util: nil, + resetsAt: stringPtr("2024-06-01T12:00:00.123456789Z"), + wantUtil: 0, + wantResetOK: true, + }, + { + name: "nil both", + util: nil, + resetsAt: nil, + wantUtil: 0, + wantResetOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &Window{} + rl := &RateLimit{ + Utilization: tt.util, + ResetsAt: tt.resetsAt, + } + tr.updateWindow(w, rl) + + if w.Utilization != tt.wantUtil { + t.Errorf("Utilization = %f, want %f", w.Utilization, tt.wantUtil) + } + if tt.wantResetOK { + if w.ResetsAt.IsZero() { + t.Error("ResetsAt should be set") + } + // Verify truncation to minute + if w.ResetsAt.Second() != 0 || w.ResetsAt.Nanosecond() != 0 { + t.Errorf("ResetsAt not truncated to minute: %v", w.ResetsAt) + } + if w.ResetsAt.Location() != time.UTC { + t.Errorf("ResetsAt not in UTC: %v", w.ResetsAt.Location()) + } + } else if tt.resetsAt == nil { + if !w.ResetsAt.IsZero() { + t.Errorf("ResetsAt should be zero when input is nil, got %v", w.ResetsAt) + } + } + }) + } +} + +func TestUpdateWindow_InvalidTime(t *testing.T) { + tr := NewTracker(func() string { return "" }) + w := &Window{} + bad := "not-a-time" + rl := &RateLimit{ResetsAt: &bad} + tr.updateWindow(w, rl) + if !w.ResetsAt.IsZero() { + t.Errorf("ResetsAt should stay zero for invalid time, got %v", w.ResetsAt) + } +} + +func TestPoll_SetsStateFromUsageResponse(t *testing.T) { + // White-box: directly set fields that poll would set after fetchUsage + tr := NewTracker(func() string { return "" }) + + // Simulate what poll does after fetching usage + tr.mu.Lock() + usage := &UsageResponse{ + FiveHour: &RateLimit{Utilization: float64Ptr(55.5), ResetsAt: stringPtr("2024-03-01T08:00:00Z")}, + SevenDay: &RateLimit{Utilization: float64Ptr(22.3), ResetsAt: stringPtr("2024-03-07T00:00:00Z")}, + SevenDaySonnet: &RateLimit{Utilization: float64Ptr(10.0), ResetsAt: stringPtr("2024-03-07T00:00:00Z")}, + ExtraUsage: &ExtraUsage{IsEnabled: true, MonthlyLimit: float64Ptr(100.0), UsedCredits: float64Ptr(42.5)}, + } + if usage.FiveHour != nil { + tr.updateWindow(&tr.fiveHour, usage.FiveHour) + } + if usage.SevenDay != nil { + tr.updateWindow(&tr.sevenDay, usage.SevenDay) + } + if usage.SevenDaySonnet != nil { + tr.updateWindow(&tr.sonnet, usage.SevenDaySonnet) + } + if usage.ExtraUsage != nil { + tr.extra = *usage.ExtraUsage + } + tr.mu.Unlock() + + fh := tr.FiveHour() + if fh.Utilization != 55.5 { + t.Errorf("FiveHour.Utilization = %f, want 55.5", fh.Utilization) + } + + sd := tr.SevenDay() + if sd.Utilization != 22.3 { + t.Errorf("SevenDay.Utilization = %f, want 22.3", sd.Utilization) + } + + sn := tr.Sonnet() + if sn.Utilization != 10.0 { + t.Errorf("Sonnet.Utilization = %f, want 10.0", sn.Utilization) + } + + extra := tr.Extra() + if !extra.IsEnabled { + t.Error("Extra.IsEnabled = false, want true") + } + if extra.MonthlyLimit == nil || *extra.MonthlyLimit != 100.0 { + t.Errorf("Extra.MonthlyLimit = %v, want 100.0", extra.MonthlyLimit) + } + if extra.UsedCredits == nil || *extra.UsedCredits != 42.5 { + t.Errorf("Extra.UsedCredits = %v, want 42.5", extra.UsedCredits) + } +} + +func float64Ptr(f float64) *float64 { return &f } +func stringPtr(s string) *string { return &s } diff --git a/internal/ratelimit/usage_test.go b/internal/ratelimit/usage_test.go new file mode 100644 index 0000000..b07eebe --- /dev/null +++ b/internal/ratelimit/usage_test.go @@ -0,0 +1,241 @@ +package ratelimit + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestUsageResponse_FullJSON(t *testing.T) { + raw := `{ + "five_hour": {"utilization": 42.5, "resets_at": "2024-01-15T10:30:00Z"}, + "seven_day": {"utilization": 75.0, "resets_at": "2024-01-20T00:00:00Z"}, + "seven_day_sonnet": {"utilization": 10.0, "resets_at": "2024-01-20T00:00:00Z"}, + "extra_usage": { + "is_enabled": true, + "monthly_limit": 100.0, + "used_credits": 42.5, + "utilization": 42.5 + } + }` + + var resp UsageResponse + if err := json.Unmarshal([]byte(raw), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if resp.FiveHour == nil { + t.Fatal("FiveHour is nil") + } + if resp.FiveHour.Utilization == nil || *resp.FiveHour.Utilization != 42.5 { + t.Errorf("FiveHour.Utilization = %v, want 42.5", resp.FiveHour.Utilization) + } + if resp.FiveHour.ResetsAt == nil || *resp.FiveHour.ResetsAt != "2024-01-15T10:30:00Z" { + t.Errorf("FiveHour.ResetsAt = %v", resp.FiveHour.ResetsAt) + } + + if resp.SevenDay == nil { + t.Fatal("SevenDay is nil") + } + if resp.SevenDay.Utilization == nil || *resp.SevenDay.Utilization != 75.0 { + t.Errorf("SevenDay.Utilization = %v, want 75.0", resp.SevenDay.Utilization) + } + + if resp.SevenDaySonnet == nil { + t.Fatal("SevenDaySonnet is nil") + } + if resp.SevenDaySonnet.Utilization == nil || *resp.SevenDaySonnet.Utilization != 10.0 { + t.Errorf("SevenDaySonnet.Utilization = %v", resp.SevenDaySonnet.Utilization) + } + + if resp.ExtraUsage == nil { + t.Fatal("ExtraUsage is nil") + } + if !resp.ExtraUsage.IsEnabled { + t.Error("ExtraUsage.IsEnabled = false, want true") + } + if resp.ExtraUsage.MonthlyLimit == nil || *resp.ExtraUsage.MonthlyLimit != 100.0 { + t.Errorf("ExtraUsage.MonthlyLimit = %v, want 100.0", resp.ExtraUsage.MonthlyLimit) + } + if resp.ExtraUsage.UsedCredits == nil || *resp.ExtraUsage.UsedCredits != 42.5 { + t.Errorf("ExtraUsage.UsedCredits = %v, want 42.5", resp.ExtraUsage.UsedCredits) + } + if resp.ExtraUsage.Utilization == nil || *resp.ExtraUsage.Utilization != 42.5 { + t.Errorf("ExtraUsage.Utilization = %v, want 42.5", resp.ExtraUsage.Utilization) + } +} + +func TestUsageResponse_PartialJSON(t *testing.T) { + raw := `{"five_hour": {"utilization": 10.0}}` + + var resp UsageResponse + if err := json.Unmarshal([]byte(raw), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if resp.FiveHour == nil { + t.Fatal("FiveHour is nil") + } + if resp.FiveHour.Utilization == nil || *resp.FiveHour.Utilization != 10.0 { + t.Errorf("FiveHour.Utilization = %v, want 10.0", resp.FiveHour.Utilization) + } + if resp.FiveHour.ResetsAt != nil { + t.Errorf("FiveHour.ResetsAt should be nil, got %v", resp.FiveHour.ResetsAt) + } + if resp.SevenDay != nil { + t.Errorf("SevenDay should be nil, got %v", resp.SevenDay) + } + if resp.SevenDaySonnet != nil { + t.Errorf("SevenDaySonnet should be nil, got %v", resp.SevenDaySonnet) + } + if resp.ExtraUsage != nil { + t.Errorf("ExtraUsage should be nil, got %v", resp.ExtraUsage) + } +} + +func TestUsageResponse_EmptyJSON(t *testing.T) { + var resp UsageResponse + if err := json.Unmarshal([]byte(`{}`), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if resp.FiveHour != nil || resp.SevenDay != nil || resp.SevenDaySonnet != nil || resp.ExtraUsage != nil { + t.Error("all fields should be nil for empty JSON") + } +} + +func TestFetchUsage_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request headers + if got := r.Header.Get("Authorization"); got != "Bearer test-token" { + t.Errorf("Authorization = %q, want 'Bearer test-token'", got) + } + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("Content-Type = %q, want application/json", got) + } + if got := r.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" { + t.Errorf("anthropic-beta = %q, want oauth-2025-04-20", got) + } + if got := r.Header.Get("User-Agent"); got != "claude-cli/2.1.92" { + t.Errorf("User-Agent = %q, want claude-cli/2.1.92", got) + } + if r.Method != http.MethodGet { + t.Errorf("Method = %q, want GET", r.Method) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "five_hour": {"utilization": 50.0, "resets_at": "2024-01-15T10:00:00Z"}, + "seven_day": {"utilization": 25.0, "resets_at": "2024-01-20T00:00:00Z"} + }`)) + })) + defer srv.Close() + + // fetchUsage hardcodes usageURL, but we can test via the mock by temporarily + // using http.DefaultClient's transport. Instead, we test the handler directly. + // The httptest server validates our request expectations above. + + // Make a real request to the test server to verify handler behavior + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("anthropic-beta", "oauth-2025-04-20") + req.Header.Set("User-Agent", "claude-cli/2.1.92") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + var usage UsageResponse + if err := json.NewDecoder(resp.Body).Decode(&usage); err != nil { + t.Fatalf("decode: %v", err) + } + + if usage.FiveHour == nil || *usage.FiveHour.Utilization != 50.0 { + t.Errorf("FiveHour.Utilization = %v, want 50.0", usage.FiveHour) + } + if usage.SevenDay == nil || *usage.SevenDay.Utilization != 25.0 { + t.Errorf("SevenDay.Utilization = %v, want 25.0", usage.SevenDay) + } +} + +func TestFetchUsage_Non200(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"forbidden"}`)) + })) + defer srv.Close() + + // Simulate the error path: non-200 returns error with status and body + resp, err := http.Get(srv.URL) + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + t.Fatal("expected non-200 status") + } + + // This matches the fetchUsage error format + body := make([]byte, 1024) + n, _ := resp.Body.Read(body) + bodyStr := string(body[:n]) + if !strings.Contains(bodyStr, "forbidden") { + t.Errorf("body = %q, want it to contain 'forbidden'", bodyStr) + } +} + +func TestFetchUsage_MalformedJSON(t *testing.T) { + raw := `{not valid json` + var resp UsageResponse + err := json.Unmarshal([]byte(raw), &resp) + if err == nil { + t.Fatal("expected decode error for malformed JSON") + } +} + +func TestRateLimit_NilFields(t *testing.T) { + raw := `{}` + var rl RateLimit + if err := json.Unmarshal([]byte(raw), &rl); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if rl.Utilization != nil { + t.Errorf("Utilization should be nil, got %v", rl.Utilization) + } + if rl.ResetsAt != nil { + t.Errorf("ResetsAt should be nil, got %v", rl.ResetsAt) + } +} + +func TestExtraUsage_JSON(t *testing.T) { + raw := `{"is_enabled":false,"monthly_limit":null,"used_credits":null,"utilization":null}` + var eu ExtraUsage + if err := json.Unmarshal([]byte(raw), &eu); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if eu.IsEnabled { + t.Error("IsEnabled should be false") + } + if eu.MonthlyLimit != nil { + t.Error("MonthlyLimit should be nil") + } + if eu.UsedCredits != nil { + t.Error("UsedCredits should be nil") + } + if eu.Utilization != nil { + t.Error("Utilization should be nil") + } +} + +func TestUsageURL_Constant(t *testing.T) { + if usageURL != "https://api.anthropic.com/api/oauth/usage" { + t.Errorf("usageURL = %q, want https://api.anthropic.com/api/oauth/usage", usageURL) + } +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..ef352b6 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,529 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "sync/atomic" + "testing" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// --- makeKeySet --- + +func TestMakeKeySet(t *testing.T) { + tests := []struct { + name string + keys []string + wantN int + lookup string + found bool + }{ + { + name: "nil slice returns empty map", + keys: nil, + wantN: 0, + }, + { + name: "empty slice returns empty map", + keys: []string{}, + wantN: 0, + }, + { + name: "single key", + keys: []string{"key1"}, + wantN: 1, + lookup: "key1", + found: true, + }, + { + name: "multiple keys", + keys: []string{"a", "b", "c"}, + wantN: 3, + lookup: "b", + found: true, + }, + { + name: "missing key not found", + keys: []string{"a", "b"}, + wantN: 2, + lookup: "c", + found: false, + }, + { + name: "duplicate keys deduped", + keys: []string{"x", "x", "x"}, + wantN: 1, + lookup: "x", + found: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := makeKeySet(tt.keys) + if len(got) != tt.wantN { + t.Errorf("len(makeKeySet) = %d, want %d", len(got), tt.wantN) + } + if tt.lookup != "" { + _, ok := got[tt.lookup] + if ok != tt.found { + t.Errorf("keySet[%q] found=%v, want %v", tt.lookup, ok, tt.found) + } + } + }) + } +} + +// --- corsMiddleware --- + +func TestCorsMiddleware_SetsHeaders(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + handler := corsMiddleware() + handler(c) + + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" { + t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "*") + } + if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT, DELETE, OPTIONS" { + t.Errorf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT, DELETE, OPTIONS") + } + allowHeaders := w.Header().Get("Access-Control-Allow-Headers") + for _, h := range []string{"x-api-key", "anthropic-version", "anthropic-beta", "Authorization", "Content-Type", "Origin"} { + if !containsSubstring(allowHeaders, h) { + t.Errorf("Access-Control-Allow-Headers %q missing %q", allowHeaders, h) + } + } +} + +func TestCorsMiddleware_OptionsReturns204(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodOptions, "/v1/messages", nil) + + handler := corsMiddleware() + handler(c) + + if w.Code != http.StatusNoContent { + t.Errorf("OPTIONS status = %d, want %d", w.Code, http.StatusNoContent) + } + if !c.IsAborted() { + t.Error("expected context to be aborted on OPTIONS") + } +} + +func TestCorsMiddleware_NonOptionsDoesNotAbort(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + handler := corsMiddleware() + handler(c) + + if c.IsAborted() { + t.Error("POST request should not be aborted") + } +} + +// --- authMiddleware --- + +func newServerWithKeys(keys []string) *Server { + s := &Server{} + keySet := makeKeySet(keys) + s.apiKeys.Store(&keySet) + return s +} + +func TestAuthMiddleware_BypassPaths(t *testing.T) { + paths := []string{"/healthz", "/reload", "/metrics"} + s := newServerWithKeys(nil) // no keys — would reject if auth checked + + for _, path := range paths { + t.Run(path, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, path, nil) + + handler := s.authMiddleware() + handler(c) + + if c.IsAborted() { + t.Errorf("path %q should bypass auth but was aborted", path) + } + }) + } +} + +func TestAuthMiddleware_MissingToken_401(t *testing.T) { + s := newServerWithKeys([]string{"valid-key"}) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + handler := s.authMiddleware() + handler(c) + + if w.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized) + } + if !c.IsAborted() { + t.Error("expected aborted on missing token") + } + + var body map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + if body["error"] != "missing authentication" { + t.Errorf("error = %q, want %q", body["error"], "missing authentication") + } +} + +func TestAuthMiddleware_InvalidKey_403(t *testing.T) { + s := newServerWithKeys([]string{"valid-key"}) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("x-api-key", "wrong-key") + + handler := s.authMiddleware() + handler(c) + + if w.Code != http.StatusForbidden { + t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden) + } + if !c.IsAborted() { + t.Error("expected aborted on invalid key") + } + + var body map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + if body["error"] != "invalid api key" { + t.Errorf("error = %q, want %q", body["error"], "invalid api key") + } +} + +func TestAuthMiddleware_ValidKey_XApiKey(t *testing.T) { + s := newServerWithKeys([]string{"valid-key"}) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("x-api-key", "valid-key") + + handler := s.authMiddleware() + handler(c) + + if c.IsAborted() { + t.Error("valid key should not abort") + } + if w.Code == http.StatusUnauthorized || w.Code == http.StatusForbidden { + t.Errorf("unexpected status %d for valid key", w.Code) + } +} + +func TestAuthMiddleware_ValidKey_BearerAuth(t *testing.T) { + s := newServerWithKeys([]string{"my-token"}) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("Authorization", "Bearer my-token") + + handler := s.authMiddleware() + handler(c) + + if c.IsAborted() { + t.Error("valid Bearer token should not abort") + } +} + +func TestAuthMiddleware_BearerPrefix_Stripped(t *testing.T) { + // The token is "my-token", sent as "Bearer my-token". The middleware should + // strip "Bearer " and compare "my-token" against the key set. + s := newServerWithKeys([]string{"my-token"}) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("Authorization", "Bearer my-token") + + handler := s.authMiddleware() + handler(c) + + if c.IsAborted() { + t.Error("expected auth to pass with Bearer-prefixed valid key") + } +} + +func TestAuthMiddleware_AuthorizationWithoutBearer(t *testing.T) { + // If Authorization header doesn't have Bearer prefix, TrimPrefix is a no-op, + // so the full header value is used as the token. + s := newServerWithKeys([]string{"raw-token-value"}) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("Authorization", "raw-token-value") + + handler := s.authMiddleware() + handler(c) + + if c.IsAborted() { + t.Error("raw Authorization value matching a key should pass") + } +} + +func TestAuthMiddleware_XApiKey_FallbackWhenNoAuthHeader(t *testing.T) { + // If Authorization is empty, x-api-key is checked. + s := newServerWithKeys([]string{"fallback-key"}) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("x-api-key", "fallback-key") + + handler := s.authMiddleware() + handler(c) + + if c.IsAborted() { + t.Error("x-api-key fallback should pass") + } +} + +func TestAuthMiddleware_AuthorizationPreferredOverXApiKey(t *testing.T) { + // Both headers set; Authorization takes precedence. + s := newServerWithKeys([]string{"auth-key"}) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("Authorization", "Bearer auth-key") + c.Request.Header.Set("x-api-key", "wrong-key") + + handler := s.authMiddleware() + handler(c) + + if c.IsAborted() { + t.Error("Authorization should take precedence over x-api-key") + } +} + +// --- handleReload --- + +func TestHandleReload_Success(t *testing.T) { + // Create a temp config file + tmpFile, err := os.CreateTemp("", "config-*.yaml") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + + configContent := ` +port: 9999 +api_keys: + - reloaded-key-1 + - reloaded-key-2 +sanitize: + tools: + - from: old_tool + to: new_tool + system: + - match: foo + replace: bar + body: + - match: baz + replace: qux +` + if _, err := tmpFile.WriteString(configContent); err != nil { + t.Fatalf("failed to write config: %v", err) + } + tmpFile.Close() + + s := &Server{configPath: tmpFile.Name()} + // Initialize with empty values + emptyKeys := makeKeySet(nil) + s.apiKeys.Store(&emptyKeys) + + emptySan := &atomic.Pointer[interface{}]{} + _ = emptySan // just to show we're aware + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil) + + handler := s.handleReload() + handler(c) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["status"] != "reloaded" { + t.Errorf("status = %v, want %q", resp["status"], "reloaded") + } + + // Verify api keys were updated + keys := s.apiKeys.Load() + if _, ok := (*keys)["reloaded-key-1"]; !ok { + t.Error("expected reloaded-key-1 in api keys after reload") + } + if _, ok := (*keys)["reloaded-key-2"]; !ok { + t.Error("expected reloaded-key-2 in api keys after reload") + } + if len(*keys) != 2 { + t.Errorf("expected 2 api keys, got %d", len(*keys)) + } + + // Verify sanitizer was updated + san := s.sanitizer.Load() + if san == nil { + t.Fatal("sanitizer is nil after reload") + } + + // Check tool_renames in response + if toolRenames, ok := resp["tool_renames"].(float64); !ok || int(toolRenames) != 1 { + t.Errorf("tool_renames = %v, want 1", resp["tool_renames"]) + } + if apiKeys, ok := resp["api_keys"].(float64); !ok || int(apiKeys) != 2 { + t.Errorf("api_keys = %v, want 2", resp["api_keys"]) + } +} + +func TestHandleReload_InvalidConfig(t *testing.T) { + s := &Server{configPath: "/nonexistent/path/config.yaml"} + emptyKeys := makeKeySet(nil) + s.apiKeys.Store(&emptyKeys) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil) + + handler := s.handleReload() + handler(c) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["error"] == "" { + t.Error("expected non-empty error message") + } +} + +// --- Full route tests using httptest --- + +func TestHealthzEndpoint(t *testing.T) { + engine := gin.New() + engine.Use(corsMiddleware()) + + s := newServerWithKeys(nil) + engine.Use(s.authMiddleware()) + + engine.GET("/healthz", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) + + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + w := httptest.NewRecorder() + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d", w.Code, http.StatusOK) + } + + var body map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if body["status"] != "ok" { + t.Errorf("status = %q, want %q", body["status"], "ok") + } +} + +func TestAuthMiddleware_FullRoute_Rejected(t *testing.T) { + engine := gin.New() + s := newServerWithKeys([]string{"correct-key"}) + engine.Use(s.authMiddleware()) + engine.POST("/v1/messages", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + // No auth header + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + w := httptest.NewRecorder() + engine.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized) + } +} + +func TestAuthMiddleware_FullRoute_Accepted(t *testing.T) { + engine := gin.New() + s := newServerWithKeys([]string{"correct-key"}) + engine.Use(s.authMiddleware()) + engine.POST("/v1/messages", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + req.Header.Set("x-api-key", "correct-key") + w := httptest.NewRecorder() + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d", w.Code, http.StatusOK) + } +} + +func TestCorsMiddleware_FullRoute_OptionsRequest(t *testing.T) { + engine := gin.New() + engine.Use(corsMiddleware()) + + s := newServerWithKeys([]string{"key"}) + engine.Use(s.authMiddleware()) + + engine.POST("/v1/messages", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + req := httptest.NewRequest(http.MethodOptions, "/v1/messages", nil) + w := httptest.NewRecorder() + engine.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent { + t.Errorf("status = %d, want %d", w.Code, http.StatusNoContent) + } + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" { + t.Errorf("ACAO = %q, want %q", got, "*") + } +} + +// helper +func containsSubstring(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub)) +} + +func containsStr(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/telemetry/logbridge_test.go b/internal/telemetry/logbridge_test.go new file mode 100644 index 0000000..73e1a13 --- /dev/null +++ b/internal/telemetry/logbridge_test.go @@ -0,0 +1,178 @@ +package telemetry + +import ( + "encoding/json" + "testing" + + otellog "go.opentelemetry.io/otel/log" + sdklog "go.opentelemetry.io/otel/sdk/log" +) + +func TestMapSeverity(t *testing.T) { + tests := []struct { + input string + want otellog.Severity + }{ + {"trace", otellog.SeverityTrace}, + {"debug", otellog.SeverityDebug}, + {"info", otellog.SeverityInfo}, + {"warn", otellog.SeverityWarn}, + {"warning", otellog.SeverityWarn}, + {"error", otellog.SeverityError}, + {"fatal", otellog.SeverityFatal}, + {"panic", otellog.SeverityFatal2}, + {"unknown", otellog.SeverityInfo}, + {"", otellog.SeverityInfo}, + {"INFO", otellog.SeverityInfo}, // uppercase falls to default + {"PANIC", otellog.SeverityInfo}, // uppercase falls to default + {"gibberish", otellog.SeverityInfo}, + } + + for _, tc := range tests { + t.Run("level_"+tc.input, func(t *testing.T) { + got := mapSeverity(tc.input) + if got != tc.want { + t.Errorf("mapSeverity(%q) = %v, want %v", tc.input, got, tc.want) + } + }) + } +} + +func newTestBridge(t *testing.T) *LogBridge { + t.Helper() + provider := sdklog.NewLoggerProvider() + t.Cleanup(func() { + _ = provider.Shutdown(t.Context()) + }) + return &LogBridge{provider: provider} +} + +func TestLogBridgeWrite(t *testing.T) { + tests := []struct { + name string + input interface{} // will be marshaled to JSON; use string for raw input + raw string // if non-empty, use this directly instead of marshaling input + }{ + { + name: "valid_json_with_message_level_and_extras", + input: map[string]interface{}{ + "message": "request handled", + "level": "info", + "method": "GET", + "status": float64(200), + }, + }, + { + name: "message_only_no_level", + input: map[string]interface{}{ + "message": "hello world", + }, + }, + { + name: "level_only_no_message", + input: map[string]interface{}{ + "level": "error", + }, + }, + { + name: "empty_json_object", + input: map[string]interface{}{}, + }, + { + name: "string_float64_bool_attributes", + input: map[string]interface{}{ + "message": "test", + "level": "debug", + "str_val": "hello", + "num_val": float64(3.14), + "bool_val": true, + }, + }, + { + name: "complex_nested_object_attribute", + input: map[string]interface{}{ + "message": "nested", + "level": "warn", + "nested": map[string]interface{}{"foo": "bar", "n": float64(1)}, + }, + }, + { + name: "time_field_skipped_in_attributes", + input: map[string]interface{}{ + "message": "with time", + "level": "info", + "time": "2025-01-01T00:00:00Z", + "extra": "kept", + }, + }, + { + name: "malformed_json", + raw: "this is not json at all", + }, + { + name: "malformed_json_partial", + raw: `{"broken":`, + }, + { + name: "array_attribute_marshaled_as_string", + input: map[string]interface{}{ + "message": "arrays", + "tags": []interface{}{"a", "b"}, + }, + }, + { + name: "null_value_attribute", + input: map[string]interface{}{ + "message": "nulls", + "val": nil, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + bridge := newTestBridge(t) + + var p []byte + if tc.raw != "" { + p = []byte(tc.raw) + } else { + var err error + p, err = json.Marshal(tc.input) + if err != nil { + t.Fatalf("failed to marshal test input: %v", err) + } + } + + n, err := bridge.Write(p) + if n != len(p) { + t.Errorf("Write() returned n=%d, want %d", n, len(p)) + } + if err != nil { + t.Errorf("Write() returned err=%v, want nil", err) + } + }) + } +} + +func TestLogBridgeWriteAlwaysReturnsLenAndNil(t *testing.T) { + bridge := newTestBridge(t) + + inputs := [][]byte{ + []byte(`{"message":"ok","level":"info"}`), + []byte(`not json`), + []byte(`{}`), + []byte(``), + []byte(`[]`), + } + + for _, p := range inputs { + n, err := bridge.Write(p) + if n != len(p) { + t.Errorf("Write(%q) n=%d, want %d", string(p), n, len(p)) + } + if err != nil { + t.Errorf("Write(%q) err=%v, want nil", string(p), err) + } + } +}