test: add comprehensive test harness across all packages (156 tests)
Characterization tests capturing current behavior before refactoring. Covers auth, config, logging, proxy, ratelimit, server, and telemetry packages with race-safe concurrent access tests.
This commit is contained in:
@@ -0,0 +1,318 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewPool(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "a", AccessToken: "tok-a"},
|
||||
{ID: "b", AccessToken: "tok-b"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
if p == nil {
|
||||
t.Fatal("NewPool returned nil")
|
||||
}
|
||||
if len(p.creds) != 2 {
|
||||
t.Errorf("pool has %d creds, want 2", len(p.creds))
|
||||
}
|
||||
if p.cursor != 0 {
|
||||
t.Errorf("initial cursor = %d, want 0", p.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_EmptyPool(t *testing.T) {
|
||||
p := NewPool(nil)
|
||||
_, err := p.Pick()
|
||||
if err == nil {
|
||||
t.Fatal("expected error from empty pool, got nil")
|
||||
}
|
||||
want := "no credentials available"
|
||||
if err.Error() != want {
|
||||
t.Errorf("error = %q, want %q", err.Error(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_SingleCredential(t *testing.T) {
|
||||
cred := &Credential{ID: "only", AccessToken: "tok-only"}
|
||||
p := NewPool([]*Credential{cred})
|
||||
|
||||
got, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() error = %v", err)
|
||||
}
|
||||
if got.ID != "only" {
|
||||
t.Errorf("Pick() returned cred ID %q, want %q", got.ID, "only")
|
||||
}
|
||||
|
||||
// Picking again should return the same credential
|
||||
got2, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("second Pick() error = %v", err)
|
||||
}
|
||||
if got2.ID != "only" {
|
||||
t.Errorf("second Pick() returned cred ID %q, want %q", got2.ID, "only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_RoundRobin(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "a"},
|
||||
{ID: "b"},
|
||||
{ID: "c"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// Should cycle through a, b, c, a, b, c
|
||||
expected := []string{"a", "b", "c", "a", "b", "c"}
|
||||
for i, want := range expected {
|
||||
got, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got.ID != want {
|
||||
t.Errorf("Pick() #%d = %q, want %q", i, got.ID, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_SkipsCooldown(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "a"},
|
||||
{ID: "b", CooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||
{ID: "c"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// First pick: "a" (index 0, not on cooldown)
|
||||
got, err := p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #1 error = %v", err)
|
||||
}
|
||||
if got.ID != "a" {
|
||||
t.Errorf("Pick() #1 = %q, want %q", got.ID, "a")
|
||||
}
|
||||
|
||||
// Second pick: cursor at 1, but "b" is on cooldown → skip to "c"
|
||||
got, err = p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #2 error = %v", err)
|
||||
}
|
||||
if got.ID != "c" {
|
||||
t.Errorf("Pick() #2 = %q, want %q", got.ID, "c")
|
||||
}
|
||||
|
||||
// Third pick: cursor advanced past "c" to 0 → "a"
|
||||
got, err = p.Pick()
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #3 error = %v", err)
|
||||
}
|
||||
if got.ID != "a" {
|
||||
t.Errorf("Pick() #3 = %q, want %q", got.ID, "a")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_Pick_AllOnCooldown(t *testing.T) {
|
||||
future := time.Now().Add(1 * time.Hour)
|
||||
creds := []*Credential{
|
||||
{ID: "a", CooldownUntil: future},
|
||||
{ID: "b", CooldownUntil: future},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
_, err := p.Pick()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all on cooldown, got nil")
|
||||
}
|
||||
want := "all 2 credentials are on cooldown"
|
||||
if err.Error() != want {
|
||||
t.Errorf("error = %q, want %q", err.Error(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_MarkFailure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
expectCooldown bool
|
||||
expectedDur time.Duration
|
||||
}{
|
||||
{
|
||||
name: "429 sets 30s cooldown",
|
||||
statusCode: 429,
|
||||
expectCooldown: true,
|
||||
expectedDur: 30 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "500 sets 5s cooldown",
|
||||
statusCode: 500,
|
||||
expectCooldown: true,
|
||||
expectedDur: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "502 sets 5s cooldown",
|
||||
statusCode: 502,
|
||||
expectCooldown: true,
|
||||
expectedDur: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "503 sets 5s cooldown",
|
||||
statusCode: 503,
|
||||
expectCooldown: true,
|
||||
expectedDur: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "400 does NOT set cooldown",
|
||||
statusCode: 400,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "401 does NOT set cooldown",
|
||||
statusCode: 401,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "403 does NOT set cooldown",
|
||||
statusCode: 403,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "404 does NOT set cooldown",
|
||||
statusCode: 404,
|
||||
expectCooldown: false,
|
||||
},
|
||||
{
|
||||
name: "422 does NOT set cooldown",
|
||||
statusCode: 422,
|
||||
expectCooldown: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cred := &Credential{ID: "test"}
|
||||
p := NewPool([]*Credential{cred})
|
||||
|
||||
before := time.Now()
|
||||
p.MarkFailure(cred, tt.statusCode)
|
||||
|
||||
if tt.expectCooldown {
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Errorf("expected cooldown after status %d", tt.statusCode)
|
||||
}
|
||||
// Verify approximate duration
|
||||
cred.mu.Lock()
|
||||
cooldownEnd := cred.CooldownUntil
|
||||
cred.mu.Unlock()
|
||||
|
||||
lower := before.Add(tt.expectedDur)
|
||||
upper := time.Now().Add(tt.expectedDur)
|
||||
if cooldownEnd.Before(lower) || cooldownEnd.After(upper) {
|
||||
t.Errorf("CooldownUntil %v not in expected range [%v, %v]", cooldownEnd, lower, upper)
|
||||
}
|
||||
} else {
|
||||
if cred.IsOnCooldown() {
|
||||
t.Errorf("did not expect cooldown after status %d", tt.statusCode)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_MarkSuccess(t *testing.T) {
|
||||
cred := &Credential{
|
||||
ID: "test",
|
||||
CooldownUntil: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
p := NewPool([]*Credential{cred})
|
||||
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Fatal("precondition: expected credential to be on cooldown")
|
||||
}
|
||||
|
||||
p.MarkSuccess(cred)
|
||||
|
||||
if cred.IsOnCooldown() {
|
||||
t.Error("expected cooldown to be cleared after MarkSuccess")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_RoundRobinCursorAdvancement(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "0"},
|
||||
{ID: "1"},
|
||||
{ID: "2"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// Verify cursor starts at 0
|
||||
if p.cursor != 0 {
|
||||
t.Fatalf("initial cursor = %d, want 0", p.cursor)
|
||||
}
|
||||
|
||||
// Pick cred[0], cursor should advance to 1
|
||||
got, _ := p.Pick()
|
||||
if got.ID != "0" {
|
||||
t.Errorf("first pick = %q, want %q", got.ID, "0")
|
||||
}
|
||||
if p.cursor != 1 {
|
||||
t.Errorf("cursor after first pick = %d, want 1", p.cursor)
|
||||
}
|
||||
|
||||
// Pick cred[1], cursor should advance to 2
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "1" {
|
||||
t.Errorf("second pick = %q, want %q", got.ID, "1")
|
||||
}
|
||||
if p.cursor != 2 {
|
||||
t.Errorf("cursor after second pick = %d, want 2", p.cursor)
|
||||
}
|
||||
|
||||
// Pick cred[2], cursor should wrap to 0
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "2" {
|
||||
t.Errorf("third pick = %q, want %q", got.ID, "2")
|
||||
}
|
||||
if p.cursor != 0 {
|
||||
t.Errorf("cursor after third pick = %d, want 0 (wrap)", p.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_RoundRobinWithCooldownSkip(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "0"},
|
||||
{ID: "1", CooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||
{ID: "2"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
// First pick: cred[0]
|
||||
got, _ := p.Pick()
|
||||
if got.ID != "0" {
|
||||
t.Errorf("first pick = %q, want %q", got.ID, "0")
|
||||
}
|
||||
// Cursor should be at 1
|
||||
if p.cursor != 1 {
|
||||
t.Errorf("cursor after first pick = %d, want 1", p.cursor)
|
||||
}
|
||||
|
||||
// Second pick: cursor at 1, but cred[1] on cooldown → skip to cred[2]
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "2" {
|
||||
t.Errorf("second pick = %q, want %q", got.ID, "2")
|
||||
}
|
||||
// Cursor should advance past cred[2] to 0
|
||||
if p.cursor != 0 {
|
||||
t.Errorf("cursor after second pick (skip) = %d, want 0", p.cursor)
|
||||
}
|
||||
|
||||
// Third pick: cursor at 0, cred[0] available
|
||||
got, _ = p.Pick()
|
||||
if got.ID != "0" {
|
||||
t.Errorf("third pick = %q, want %q", got.ID, "0")
|
||||
}
|
||||
if p.cursor != 1 {
|
||||
t.Errorf("cursor after third pick = %d, want 1", p.cursor)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCredential_IsOnCooldown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cooldownUntil time.Time
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "zero time — not on cooldown",
|
||||
cooldownUntil: time.Time{},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "future time — on cooldown",
|
||||
cooldownUntil: time.Now().Add(1 * time.Hour),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "past time — expired cooldown",
|
||||
cooldownUntil: time.Now().Add(-1 * time.Hour),
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Credential{CooldownUntil: tt.cooldownUntil}
|
||||
got := c.IsOnCooldown()
|
||||
if got != tt.want {
|
||||
t.Errorf("IsOnCooldown() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredential_SetCooldown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
}{
|
||||
{name: "30 second cooldown", duration: 30 * time.Second},
|
||||
{name: "5 second cooldown", duration: 5 * time.Second},
|
||||
{name: "1 minute cooldown", duration: 1 * time.Minute},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Credential{}
|
||||
before := time.Now()
|
||||
c.SetCooldown(tt.duration)
|
||||
after := time.Now()
|
||||
|
||||
// CooldownUntil should be between before+duration and after+duration
|
||||
if c.CooldownUntil.Before(before.Add(tt.duration)) {
|
||||
t.Errorf("CooldownUntil %v is before expected lower bound %v", c.CooldownUntil, before.Add(tt.duration))
|
||||
}
|
||||
if c.CooldownUntil.After(after.Add(tt.duration)) {
|
||||
t.Errorf("CooldownUntil %v is after expected upper bound %v", c.CooldownUntil, after.Add(tt.duration))
|
||||
}
|
||||
|
||||
// Should now be on cooldown
|
||||
if !c.IsOnCooldown() {
|
||||
t.Error("expected credential to be on cooldown after SetCooldown")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredential_ClearCooldown(t *testing.T) {
|
||||
t.Run("clears active cooldown", func(t *testing.T) {
|
||||
c := &Credential{CooldownUntil: time.Now().Add(1 * time.Hour)}
|
||||
if !c.IsOnCooldown() {
|
||||
t.Fatal("precondition: expected credential to be on cooldown")
|
||||
}
|
||||
|
||||
c.ClearCooldown()
|
||||
|
||||
if c.IsOnCooldown() {
|
||||
t.Error("expected credential to not be on cooldown after ClearCooldown")
|
||||
}
|
||||
if !c.CooldownUntil.IsZero() {
|
||||
t.Errorf("expected CooldownUntil to be zero time, got %v", c.CooldownUntil)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("clearing when not on cooldown is no-op", func(t *testing.T) {
|
||||
c := &Credential{}
|
||||
c.ClearCooldown()
|
||||
|
||||
if c.IsOnCooldown() {
|
||||
t.Error("expected credential to not be on cooldown")
|
||||
}
|
||||
if !c.CooldownUntil.IsZero() {
|
||||
t.Errorf("expected CooldownUntil to be zero time, got %v", c.CooldownUntil)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredential_Token(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{name: "returns access token", token: "sk-ant-abc123"},
|
||||
{name: "empty token", token: ""},
|
||||
{name: "long token", token: "sk-ant-" + string(make([]byte, 200))},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Credential{AccessToken: tt.token}
|
||||
got := c.Token()
|
||||
if got != tt.token {
|
||||
t.Errorf("Token() = %q, want %q", got, tt.token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredential_ConcurrentAccess(t *testing.T) {
|
||||
c := &Credential{
|
||||
AccessToken: "initial-token",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 50
|
||||
|
||||
// Spawn goroutines that concurrently read and write
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(3)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = c.Token()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c.SetCooldown(1 * time.Second)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = c.IsOnCooldown()
|
||||
}()
|
||||
}
|
||||
|
||||
// Also mix in ClearCooldown calls
|
||||
for i := 0; i < goroutines/2; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c.ClearCooldown()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// If we get here without -race detecting issues, mutex is working
|
||||
}
|
||||
@@ -0,0 +1,349 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoad_AllFields(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
yaml := `
|
||||
port: 9090
|
||||
api_keys:
|
||||
- key1
|
||||
- key2
|
||||
claude_binary: /usr/bin/claude
|
||||
sanitize:
|
||||
tools:
|
||||
- from: tool_a
|
||||
to: tool_b
|
||||
system:
|
||||
- match: foo
|
||||
replace: bar
|
||||
body:
|
||||
- match: baz
|
||||
replace: qux
|
||||
logging:
|
||||
level: debug
|
||||
file: /tmp/test.log
|
||||
max_size_mb: 50
|
||||
max_backups: 3
|
||||
max_age_days: 7
|
||||
compress: true
|
||||
telemetry:
|
||||
service_name: my-proxy
|
||||
export:
|
||||
endpoint: http://localhost:4317
|
||||
insecure: true
|
||||
headers:
|
||||
x-token: abc
|
||||
embedded:
|
||||
enabled: true
|
||||
port: 9999
|
||||
perses_binary: /usr/bin/perses
|
||||
vm_binary: /usr/bin/vm
|
||||
vm_port: 9428
|
||||
bin_dir: /opt/bin
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load returned error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Port != 9090 {
|
||||
t.Errorf("Port = %d, want 9090", cfg.Port)
|
||||
}
|
||||
if len(cfg.APIKeys) != 2 || cfg.APIKeys[0] != "key1" || cfg.APIKeys[1] != "key2" {
|
||||
t.Errorf("APIKeys = %v, want [key1 key2]", cfg.APIKeys)
|
||||
}
|
||||
if cfg.ClaudeBinary != "/usr/bin/claude" {
|
||||
t.Errorf("ClaudeBinary = %q, want /usr/bin/claude", cfg.ClaudeBinary)
|
||||
}
|
||||
|
||||
// Sanitize
|
||||
if len(cfg.Sanitize.Tools) != 1 || cfg.Sanitize.Tools[0].From != "tool_a" || cfg.Sanitize.Tools[0].To != "tool_b" {
|
||||
t.Errorf("Sanitize.Tools = %v", cfg.Sanitize.Tools)
|
||||
}
|
||||
if len(cfg.Sanitize.System) != 1 || cfg.Sanitize.System[0].Match != "foo" {
|
||||
t.Errorf("Sanitize.System = %v", cfg.Sanitize.System)
|
||||
}
|
||||
if len(cfg.Sanitize.Body) != 1 || cfg.Sanitize.Body[0].Match != "baz" {
|
||||
t.Errorf("Sanitize.Body = %v", cfg.Sanitize.Body)
|
||||
}
|
||||
|
||||
// Logging
|
||||
if cfg.Logging.Level != "debug" {
|
||||
t.Errorf("Logging.Level = %q, want debug", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Logging.File != "/tmp/test.log" {
|
||||
t.Errorf("Logging.File = %q", cfg.Logging.File)
|
||||
}
|
||||
if cfg.Logging.MaxSizeMB != 50 {
|
||||
t.Errorf("Logging.MaxSizeMB = %d, want 50", cfg.Logging.MaxSizeMB)
|
||||
}
|
||||
if cfg.Logging.MaxBackups != 3 {
|
||||
t.Errorf("Logging.MaxBackups = %d, want 3", cfg.Logging.MaxBackups)
|
||||
}
|
||||
if cfg.Logging.MaxAgeDays != 7 {
|
||||
t.Errorf("Logging.MaxAgeDays = %d, want 7", cfg.Logging.MaxAgeDays)
|
||||
}
|
||||
if !cfg.Logging.Compress {
|
||||
t.Error("Logging.Compress = false, want true")
|
||||
}
|
||||
|
||||
// Telemetry
|
||||
if cfg.Telemetry.ServiceName != "my-proxy" {
|
||||
t.Errorf("Telemetry.ServiceName = %q, want my-proxy", cfg.Telemetry.ServiceName)
|
||||
}
|
||||
if cfg.Telemetry.Export.Endpoint != "http://localhost:4317" {
|
||||
t.Errorf("Export.Endpoint = %q", cfg.Telemetry.Export.Endpoint)
|
||||
}
|
||||
if !cfg.Telemetry.Export.Insecure {
|
||||
t.Error("Export.Insecure = false, want true")
|
||||
}
|
||||
if !cfg.Telemetry.Export.Enabled() {
|
||||
t.Error("Export.Enabled() = false, want true")
|
||||
}
|
||||
if cfg.Telemetry.Export.Headers["x-token"] != "abc" {
|
||||
t.Errorf("Export.Headers = %v", cfg.Telemetry.Export.Headers)
|
||||
}
|
||||
|
||||
// Embedded
|
||||
if !cfg.Telemetry.Embedded.Enabled {
|
||||
t.Error("Embedded.Enabled = false, want true")
|
||||
}
|
||||
if cfg.Telemetry.Embedded.Port != 9999 {
|
||||
t.Errorf("Embedded.Port = %d, want 9999", cfg.Telemetry.Embedded.Port)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.PersesBinary != "/usr/bin/perses" {
|
||||
t.Errorf("Embedded.PersesBinary = %q", cfg.Telemetry.Embedded.PersesBinary)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.VMBinary != "/usr/bin/vm" {
|
||||
t.Errorf("Embedded.VMBinary = %q", cfg.Telemetry.Embedded.VMBinary)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.VMPort != 9428 {
|
||||
t.Errorf("Embedded.VMPort = %d, want 9428", cfg.Telemetry.Embedded.VMPort)
|
||||
}
|
||||
if cfg.Telemetry.Embedded.BinDir != "/opt/bin" {
|
||||
t.Errorf("Embedded.BinDir = %q", cfg.Telemetry.Embedded.BinDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_Defaults(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Minimal YAML — only api_keys
|
||||
if err := os.WriteFile(path, []byte("api_keys:\n - k1\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load returned error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
want interface{}
|
||||
}{
|
||||
{"Port", cfg.Port, 8080},
|
||||
{"Logging.Level", cfg.Logging.Level, "info"},
|
||||
{"Logging.MaxSizeMB", cfg.Logging.MaxSizeMB, 100},
|
||||
{"Logging.MaxBackups", cfg.Logging.MaxBackups, 5},
|
||||
{"Logging.MaxAgeDays", cfg.Logging.MaxAgeDays, 30},
|
||||
{"Telemetry.ServiceName", cfg.Telemetry.ServiceName, "anthropic-proxy"},
|
||||
{"Embedded.Port", cfg.Telemetry.Embedded.Port, 8080},
|
||||
{"Embedded.VMBinary", cfg.Telemetry.Embedded.VMBinary, "victoria-metrics"},
|
||||
{"Embedded.PersesBinary", cfg.Telemetry.Embedded.PersesBinary, "perses"},
|
||||
{"Embedded.VMPort", cfg.Telemetry.Embedded.VMPort, 8428},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.want {
|
||||
t.Errorf("got %v, want %v", tt.got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_MissingFile(t *testing.T) {
|
||||
_, err := Load("/nonexistent/path/config.yaml")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "read config") {
|
||||
t.Errorf("error = %q, want it to contain 'read config'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_DeprecatedClaudeCredentials(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
yaml := `
|
||||
api_keys:
|
||||
- k1
|
||||
claude_credentials: "/some/path"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for deprecated claude_credentials, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no longer supported") {
|
||||
t.Errorf("error = %q, want it to contain 'no longer supported'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_EmptyClaudeCredentials(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Empty string value should NOT trigger the deprecation error
|
||||
yaml := `
|
||||
api_keys:
|
||||
- k1
|
||||
claude_credentials: ""
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("empty claude_credentials should not error: %v", err)
|
||||
}
|
||||
if cfg.Port != 8080 {
|
||||
t.Errorf("Port = %d, want 8080", cfg.Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_InvalidYAML(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Truly invalid YAML that causes a parse error
|
||||
if err := os.WriteFile(path, []byte("port:\n - bad\n indent: broken\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid YAML, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "parse config") {
|
||||
t.Errorf("error = %q, want it to contain 'parse config'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportConfig_Enabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
want bool
|
||||
}{
|
||||
{"empty endpoint", "", false},
|
||||
{"set endpoint", "http://localhost:4317", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := ExportConfig{Endpoint: tt.endpoint}
|
||||
if got := e.Enabled(); got != tt.want {
|
||||
t.Errorf("Enabled() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCredentialPath(t *testing.T) {
|
||||
path := DefaultCredentialPath()
|
||||
if path == "" {
|
||||
t.Skip("could not determine home directory")
|
||||
}
|
||||
if !strings.HasSuffix(path, "/.claude/.credentials.json") {
|
||||
t.Errorf("DefaultCredentialPath() = %q, want suffix /.claude/.credentials.json", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultCredentials_ValidFile(t *testing.T) {
|
||||
// We can't easily override DefaultCredentialPath, so test the JSON parsing
|
||||
// logic by creating a file at a temp location and calling the internal parsing
|
||||
// directly. Instead, we test LoadDefaultCredentials indirectly by verifying
|
||||
// it returns nil,nil when the default path doesn't exist (common in CI).
|
||||
// For a full test, we create the credential file at the expected path.
|
||||
|
||||
// Test with the actual function — if the default credential file doesn't
|
||||
// exist, it should return nil, nil.
|
||||
creds, err := LoadDefaultCredentials()
|
||||
path := DefaultCredentialPath()
|
||||
if path == "" {
|
||||
if creds != nil || err != nil {
|
||||
t.Errorf("expected nil,nil when home dir unavailable, got %v, %v", creds, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if _, statErr := os.Stat(path); os.IsNotExist(statErr) {
|
||||
// File doesn't exist — should return nil, nil
|
||||
if creds != nil {
|
||||
t.Errorf("expected nil creds for missing file, got %v", creds)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error for missing file, got %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultCredentials_ParsesJSON(t *testing.T) {
|
||||
// Test the JSON parsing by creating a temp credential file and using
|
||||
// the claudeCredentialsJSON struct directly (white-box test).
|
||||
jsonData := `{"claudeAiOauth":{"accessToken":"test-token","refreshToken":"test-refresh","expiresAt":1234567890,"subscriptionType":"pro"}}`
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if cf.ClaudeAiOauth.AccessToken != "test-token" {
|
||||
t.Errorf("AccessToken = %q, want test-token", cf.ClaudeAiOauth.AccessToken)
|
||||
}
|
||||
if cf.ClaudeAiOauth.RefreshToken != "test-refresh" {
|
||||
t.Errorf("RefreshToken = %q, want test-refresh", cf.ClaudeAiOauth.RefreshToken)
|
||||
}
|
||||
if cf.ClaudeAiOauth.ExpiresAt != 1234567890 {
|
||||
t.Errorf("ExpiresAt = %d, want 1234567890", cf.ClaudeAiOauth.ExpiresAt)
|
||||
}
|
||||
if cf.ClaudeAiOauth.SubscriptionType != "pro" {
|
||||
t.Errorf("SubscriptionType = %q, want pro", cf.ClaudeAiOauth.SubscriptionType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultCredentials_EmptyAccessToken(t *testing.T) {
|
||||
// Verify that an empty access token in the JSON produces an error.
|
||||
// We test the parsing struct and logic path.
|
||||
jsonData := `{"claudeAiOauth":{"accessToken":"","refreshToken":"r","expiresAt":1}}`
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if cf.ClaudeAiOauth.AccessToken != "" {
|
||||
t.Errorf("expected empty access token")
|
||||
}
|
||||
// The actual LoadDefaultCredentials would return an error here.
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestRedactHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers http.Header
|
||||
check func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "redacts Authorization",
|
||||
headers: http.Header{
|
||||
"Authorization": []string{"Bearer secret-token"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["Authorization"] != "***" {
|
||||
t.Errorf("Authorization = %q, want ***", m["Authorization"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "redacts x-api-key",
|
||||
headers: http.Header{
|
||||
"X-Api-Key": []string{"sk-ant-secret"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["X-Api-Key"] != "***" {
|
||||
t.Errorf("X-Api-Key = %q, want ***", m["X-Api-Key"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "preserves other headers",
|
||||
headers: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"Accept": []string{"text/html", "application/json"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["Content-Type"] != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want application/json", m["Content-Type"])
|
||||
}
|
||||
if m["Accept"] != "text/html, application/json" {
|
||||
t.Errorf("Accept = %q, want 'text/html, application/json'", m["Accept"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "case-insensitive redaction",
|
||||
headers: http.Header{
|
||||
"authorization": []string{"Bearer token"},
|
||||
"X-API-KEY": []string{"key123"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
// http.Header canonicalizes keys, but RedactHeaders lowercases for comparison
|
||||
for _, v := range m {
|
||||
if v != "***" {
|
||||
t.Errorf("expected all values to be ***, got %q", v)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty headers",
|
||||
headers: http.Header{},
|
||||
check: func(t *testing.T, result string) {
|
||||
if result != "{}" {
|
||||
t.Errorf("result = %q, want {}", result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed sensitive and non-sensitive",
|
||||
headers: http.Header{
|
||||
"Authorization": []string{"Bearer tok"},
|
||||
"X-Api-Key": []string{"key"},
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"abc123"},
|
||||
},
|
||||
check: func(t *testing.T, result string) {
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(result), &m); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if m["Authorization"] != "***" {
|
||||
t.Errorf("Authorization = %q, want ***", m["Authorization"])
|
||||
}
|
||||
if m["X-Api-Key"] != "***" {
|
||||
t.Errorf("X-Api-Key = %q, want ***", m["X-Api-Key"])
|
||||
}
|
||||
if m["Content-Type"] != "application/json" {
|
||||
t.Errorf("Content-Type = %q", m["Content-Type"])
|
||||
}
|
||||
if m["X-Request-Id"] != "abc123" {
|
||||
t.Errorf("X-Request-Id = %q", m["X-Request-Id"])
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := RedactHeaders(tt.headers)
|
||||
// Result should be valid JSON
|
||||
if !json.Valid([]byte(result)) {
|
||||
t.Fatalf("result is not valid JSON: %q", result)
|
||||
}
|
||||
tt.check(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactHeaders_ReturnsJSON(t *testing.T) {
|
||||
h := http.Header{"Foo": []string{"bar"}}
|
||||
result := RedactHeaders(h)
|
||||
if !strings.HasPrefix(result, "{") || !strings.HasSuffix(result, "}") {
|
||||
t.Errorf("result not JSON object: %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
status int
|
||||
want zerolog.Level
|
||||
}{
|
||||
{200, zerolog.InfoLevel},
|
||||
{201, zerolog.InfoLevel},
|
||||
{204, zerolog.InfoLevel},
|
||||
{301, zerolog.InfoLevel},
|
||||
{399, zerolog.InfoLevel},
|
||||
{400, zerolog.WarnLevel},
|
||||
{401, zerolog.WarnLevel},
|
||||
{403, zerolog.WarnLevel},
|
||||
{404, zerolog.WarnLevel},
|
||||
{429, zerolog.WarnLevel},
|
||||
{499, zerolog.WarnLevel},
|
||||
{500, zerolog.ErrorLevel},
|
||||
{502, zerolog.ErrorLevel},
|
||||
{503, zerolog.ErrorLevel},
|
||||
{599, zerolog.ErrorLevel},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := statusLevel(tt.status)
|
||||
if got != tt.want {
|
||||
t.Errorf("statusLevel(%d) = %v, want %v", tt.status, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetup_WithFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
logFile := filepath.Join(dir, "test.log")
|
||||
|
||||
logger := Setup(Config{
|
||||
Level: "debug",
|
||||
File: logFile,
|
||||
MaxSizeMB: 10,
|
||||
MaxBackups: 1,
|
||||
MaxAgeDays: 1,
|
||||
})
|
||||
|
||||
// Verify logger works (no panic)
|
||||
logger.Info().Msg("test message")
|
||||
}
|
||||
|
||||
func TestSetup_WithoutFile(t *testing.T) {
|
||||
// File empty — should use console or stderr mode depending on TTY
|
||||
logger := Setup(Config{
|
||||
Level: "warn",
|
||||
})
|
||||
|
||||
// Verify logger works (no panic)
|
||||
logger.Warn().Msg("test warning")
|
||||
}
|
||||
|
||||
func TestSetup_DefaultLevel(t *testing.T) {
|
||||
// Empty level should default to info
|
||||
logger := Setup(Config{})
|
||||
_ = logger // verify no panic
|
||||
}
|
||||
|
||||
func TestSetup_InvalidLevel(t *testing.T) {
|
||||
// Invalid level should default to info
|
||||
logger := Setup(Config{Level: "not-a-level"})
|
||||
_ = logger // verify no panic
|
||||
}
|
||||
|
||||
func TestFromContext_NoLogger(t *testing.T) {
|
||||
// Background context has no zerolog logger — should return global
|
||||
ctx := context.Background()
|
||||
l := FromContext(ctx)
|
||||
if l == nil {
|
||||
t.Fatal("FromContext returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromContext_WithLogger(t *testing.T) {
|
||||
logger := zerolog.Nop()
|
||||
ctx := logger.WithContext(context.Background())
|
||||
l := FromContext(ctx)
|
||||
if l == nil {
|
||||
t.Fatal("FromContext returned nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFingerprintSaltConstant(t *testing.T) {
|
||||
if fingerprintSalt != "59cf53e54c78" {
|
||||
t.Errorf("fingerprintSalt = %q, want %q", fingerprintSalt, "59cf53e54c78")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_Deterministic(t *testing.T) {
|
||||
a := computeFingerprint("hello world test message", "1.0.0")
|
||||
b := computeFingerprint("hello world test message", "1.0.0")
|
||||
if a != b {
|
||||
t.Errorf("fingerprint not deterministic: %q != %q", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_Length(t *testing.T) {
|
||||
fp := computeFingerprint("some message here", "2.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
// Must be valid hex
|
||||
if _, err := hex.DecodeString(fp + "0"); err != nil { // pad to even length for decode
|
||||
// Check each char is hex individually
|
||||
for _, c := range fp {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("fingerprint %q contains non-hex char %c", fp, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_DifferentVersions(t *testing.T) {
|
||||
a := computeFingerprint("same message", "1.0.0")
|
||||
b := computeFingerprint("same message", "2.0.0")
|
||||
if a == b {
|
||||
t.Errorf("different versions should (almost certainly) produce different fingerprints")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_ShortMessage(t *testing.T) {
|
||||
// "hi" has only 2 chars, indices [4,7,20] all out of range → chars = "000"
|
||||
fp := computeFingerprint("hi", "1.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("short message fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_EmptyMessage(t *testing.T) {
|
||||
// Empty → all indices out of range → chars = "000"
|
||||
fp := computeFingerprint("", "1.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("empty message fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
// Empty and short message with same version should produce same fingerprint
|
||||
// since both result in chars = "000"
|
||||
fpShort := computeFingerprint("hi", "1.0.0")
|
||||
if fp != fpShort {
|
||||
t.Errorf("empty and 'hi' should produce same fingerprint (both use '000'), got %q vs %q", fp, fpShort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_Unicode(t *testing.T) {
|
||||
// Emoji: 🎉 is U+1F389, encoded as UTF-16 surrogate pair [0xD83C, 0xDF89]
|
||||
// So "abcd🎉fg" in UTF-16 is [a, b, c, d, 0xD83C, 0xDF89, f, g] = 8 uint16 values
|
||||
// indices [4,7,20]: runes[4]=0xD83C, runes[7]='g', runes[20]=out of range
|
||||
fp := computeFingerprint("abcd🎉fg", "1.0.0")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("unicode fingerprint length = %d, want 3", len(fp))
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_CharExtraction(t *testing.T) {
|
||||
// "Hello, World!" UTF-16: [H,e,l,l,o,',', ,W,o,r,l,d,!]
|
||||
// indices [4,7,20]: runes[4]='o', runes[7]='W', runes[20]=out of range → "0"
|
||||
// So chars should be "oW0"
|
||||
// Verify by comparing to a message where we know the expected extracted chars
|
||||
// Two messages that extract same chars at indices should produce same fingerprint
|
||||
// "xxxxoxxWxxxxxxxxxxxx" → index 4='o', 7='W', 20=out of range → "oW0" (20 chars, index 20 out of range)
|
||||
fp1 := computeFingerprint("Hello, World!", "1.0.0")
|
||||
fp2 := computeFingerprint("xxxxoxxWxxxxxxxxxxxx", "1.0.0")
|
||||
if fp1 != fp2 {
|
||||
t.Errorf("messages with same chars at indices [4,7,20] should produce same fingerprint, got %q vs %q", fp1, fp2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeFingerprint_IndexBoundary(t *testing.T) {
|
||||
// Message with exactly 21 chars → index 20 is valid
|
||||
msg21 := "abcdefghijklmnopqrstu" // 21 chars
|
||||
fp21 := computeFingerprint(msg21, "1.0.0")
|
||||
|
||||
// Message with exactly 20 chars → index 20 is out of range → "0"
|
||||
msg20 := "abcdefghijklmnopqrst" // 20 chars
|
||||
fp20 := computeFingerprint(msg20, "1.0.0")
|
||||
|
||||
// They should differ because index 20 produces different chars
|
||||
if fp21 == fp20 {
|
||||
t.Errorf("boundary test: 21-char and 20-char messages should differ at index 20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFirstUserMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple string content",
|
||||
body: `{"messages":[{"role":"user","content":"hello world"}]}`,
|
||||
expected: "hello world",
|
||||
},
|
||||
{
|
||||
name: "array content with text block",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"text","text":"from array"}]}]}`,
|
||||
expected: "from array",
|
||||
},
|
||||
{
|
||||
name: "no user messages",
|
||||
body: `{"messages":[{"role":"assistant","content":"I am assistant"}]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "assistant only messages",
|
||||
body: `{"messages":[{"role":"assistant","content":"a1"},{"role":"assistant","content":"a2"}]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "user with non-text block first then text",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"},{"type":"text","text":"the text"}]}]}`,
|
||||
expected: "the text",
|
||||
},
|
||||
{
|
||||
name: "user with only non-text blocks",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]}]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "no messages field",
|
||||
body: `{"model":"claude-sonnet-4-6"}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "messages not array",
|
||||
body: `{"messages":"not array"}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty messages array",
|
||||
body: `{"messages":[]}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "first user message used even if multiple exist",
|
||||
body: `{"messages":[{"role":"user","content":"first"},{"role":"user","content":"second"}]}`,
|
||||
expected: "first",
|
||||
},
|
||||
{
|
||||
name: "assistant before user",
|
||||
body: `{"messages":[{"role":"assistant","content":"assistant msg"},{"role":"user","content":"user msg"}]}`,
|
||||
expected: "user msg",
|
||||
},
|
||||
{
|
||||
name: "user with array content - first text block used",
|
||||
body: `{"messages":[{"role":"user","content":[{"type":"text","text":"first text"},{"type":"text","text":"second text"}]}]}`,
|
||||
expected: "first text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractFirstUserMessage([]byte(tt.body))
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFirstUserMessage_BreaksAfterFirstUser(t *testing.T) {
|
||||
// The function should break after finding the first user message,
|
||||
// even if it didn't extract text (e.g. user with only image blocks)
|
||||
body := `{"messages":[{"role":"user","content":[{"type":"image","source":"x"}]},{"role":"user","content":"second user"}]}`
|
||||
result := extractFirstUserMessage([]byte(body))
|
||||
// First user has no text blocks, function breaks, returns ""
|
||||
if result != "" {
|
||||
t.Errorf("should return empty when first user has no text, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBillingHeader(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"test message"}]}`)
|
||||
version := "1.2.3"
|
||||
header := buildBillingHeader(body, version)
|
||||
|
||||
// Check format
|
||||
if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=1.2.3.") {
|
||||
t.Errorf("header should start with 'x-anthropic-billing-header: cc_version=1.2.3.', got %q", header)
|
||||
}
|
||||
if !strings.Contains(header, "; cc_entrypoint=cli; cch=00000;") {
|
||||
t.Errorf("header should contain '; cc_entrypoint=cli; cch=00000;', got %q", header)
|
||||
}
|
||||
|
||||
// Verify the fingerprint part is 3 chars
|
||||
// Format: "x-anthropic-billing-header: cc_version=1.2.3.XXX; cc_entrypoint=cli; cch=00000;"
|
||||
parts := strings.Split(header, "cc_version=")
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("unexpected header format: %q", header)
|
||||
}
|
||||
versionFP := strings.Split(parts[1], ";")[0]
|
||||
if !strings.HasPrefix(versionFP, "1.2.3.") {
|
||||
t.Errorf("version+fingerprint should start with '1.2.3.', got %q", versionFP)
|
||||
}
|
||||
fp := strings.TrimPrefix(versionFP, "1.2.3.")
|
||||
if len(fp) != 3 {
|
||||
t.Errorf("fingerprint should be 3 chars, got %q (len %d)", fp, len(fp))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBillingHeader_EmptyMessages(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
version := "1.0.0"
|
||||
header := buildBillingHeader(body, version)
|
||||
if !strings.HasPrefix(header, "x-anthropic-billing-header: cc_version=") {
|
||||
t.Errorf("header format wrong: %q", header)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_NoExistingSystem(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Should have system field now
|
||||
if !strings.Contains(resultStr, `"system"`) {
|
||||
t.Errorf("should inject system field, got %s", resultStr)
|
||||
}
|
||||
// System should be an array with one billing block
|
||||
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
|
||||
t.Errorf("should contain billing header text, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, `"type":"text"`) {
|
||||
t.Errorf("billing block should have type text, got %s", resultStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_ExistingSystemArray(t *testing.T) {
|
||||
body := []byte(`{"system":[{"type":"text","text":"existing prompt"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Should contain both billing header and existing prompt
|
||||
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
|
||||
t.Errorf("should contain billing header, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "existing prompt") {
|
||||
t.Errorf("should preserve existing prompt, got %s", resultStr)
|
||||
}
|
||||
|
||||
// Billing block should be FIRST (prepended)
|
||||
billingIdx := strings.Index(resultStr, "x-anthropic-billing-header")
|
||||
existingIdx := strings.Index(resultStr, "existing prompt")
|
||||
if billingIdx > existingIdx {
|
||||
t.Errorf("billing block should come before existing prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_ExistingSystemString(t *testing.T) {
|
||||
body := []byte(`{"system":"You are a helpful assistant","messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Should convert to array with billing block first, then original text
|
||||
if !strings.Contains(resultStr, "x-anthropic-billing-header") {
|
||||
t.Errorf("should contain billing header, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "You are a helpful assistant") {
|
||||
t.Errorf("should preserve original system string, got %s", resultStr)
|
||||
}
|
||||
|
||||
// Billing should come first
|
||||
billingIdx := strings.Index(resultStr, "x-anthropic-billing-header")
|
||||
origIdx := strings.Index(resultStr, "You are a helpful assistant")
|
||||
if billingIdx > origIdx {
|
||||
t.Errorf("billing block should come before original system text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_PreservesOtherFields(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
|
||||
result := injectBillingHeader(body, "1.0.0")
|
||||
resultStr := string(result)
|
||||
|
||||
if !strings.Contains(resultStr, `"model":"claude-sonnet-4-6"`) {
|
||||
t.Errorf("should preserve model field, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, `"max_tokens":1024`) {
|
||||
t.Errorf("should preserve max_tokens field, got %s", resultStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectBillingHeader_BillingBlockFormat(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
|
||||
result := injectBillingHeader(body, "2.5.0")
|
||||
resultStr := string(result)
|
||||
|
||||
// Verify the billing block contains the correct version
|
||||
if !strings.Contains(resultStr, "cc_version=2.5.0.") {
|
||||
t.Errorf("billing block should contain cc_version=2.5.0., got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "cc_entrypoint=cli") {
|
||||
t.Errorf("billing block should contain cc_entrypoint=cli, got %s", resultStr)
|
||||
}
|
||||
if !strings.Contains(resultStr, "cch=00000") {
|
||||
t.Errorf("billing block should contain cch=00000, got %s", resultStr)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,624 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/auth"
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
"github.com/fujin/anthropic-proxy/internal/ratelimit"
|
||||
"github.com/fujin/anthropic-proxy/internal/telemetry"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Initialize telemetry with noop meter to avoid nil pointer panics.
|
||||
meter := noop.Meter{}
|
||||
telemetry.InitMetrics(meter, nil)
|
||||
}
|
||||
|
||||
// --- Request body reading and sanitization ---
|
||||
|
||||
func TestHandleMessages_ReadBodyError(t *testing.T) {
|
||||
// A body that immediately fails on read shouldn't panic.
|
||||
pool := auth.NewPool([]*auth.Credential{{ID: "c1", AccessToken: "tok", Email: "test@test.com"}})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", &errReader{})
|
||||
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "failed to read request body") {
|
||||
t.Errorf("body = %q, expected error message about reading body", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_SanitizesRequestBody(t *testing.T) {
|
||||
// We can't directly make HandleMessages use our mock server because
|
||||
// UpstreamClient hardcodes messagesURL. Instead, we test sanitization
|
||||
// by verifying the sanitizer is called on the body before any pool interaction.
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}},
|
||||
Body: []config.ReplaceRule{{Match: "secret", Replace: "redacted"}},
|
||||
})
|
||||
|
||||
// Create body with tool name and secret
|
||||
body := `{"model":"claude-sonnet-4-6","tools":[{"name":"my_tool"}],"messages":[{"role":"user","content":"secret data"}]}`
|
||||
sanitizedBody := san.SanitizeRequest([]byte(body))
|
||||
|
||||
// Verify sanitization happened correctly
|
||||
if !strings.Contains(string(sanitizedBody), "renamed_tool") {
|
||||
t.Error("expected tool to be renamed in sanitized body")
|
||||
}
|
||||
if strings.Contains(string(sanitizedBody), "my_tool") {
|
||||
t.Error("original tool name should be gone after sanitization")
|
||||
}
|
||||
if !strings.Contains(string(sanitizedBody), "redacted") {
|
||||
t.Error("expected 'secret' to be replaced with 'redacted'")
|
||||
}
|
||||
if strings.Contains(string(sanitizedBody), "secret") {
|
||||
t.Error("'secret' should be gone after sanitization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_PoolPickError(t *testing.T) {
|
||||
// Empty pool — Pick() will fail.
|
||||
pool := auth.NewPool(nil)
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "no credentials available") {
|
||||
t.Errorf("body = %q, expected pool error", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_PoolAllOnCooldown(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "e"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
pool.MarkFailure(cred, 429) // puts on 30s cooldown
|
||||
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "cooldown") {
|
||||
t.Errorf("body = %q, expected cooldown message", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Stream vs non-stream routing ---
|
||||
|
||||
func TestHandleMessages_StreamField_Detection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
isStream bool
|
||||
}{
|
||||
{
|
||||
name: "stream true",
|
||||
body: `{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`,
|
||||
isStream: true,
|
||||
},
|
||||
{
|
||||
name: "stream false",
|
||||
body: `{"model":"claude-sonnet-4-6","stream":false,"messages":[{"role":"user","content":"hi"}]}`,
|
||||
isStream: false,
|
||||
},
|
||||
{
|
||||
name: "no stream field defaults to false",
|
||||
body: `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`,
|
||||
isStream: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := gjson.Get(tt.body, "stream").Bool()
|
||||
if got != tt.isStream {
|
||||
t.Errorf("stream = %v, want %v", got, tt.isStream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Desanitization on response ---
|
||||
|
||||
func TestDesanitization_NonStreamResponse(t *testing.T) {
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
// Simulate upstream response with renamed tool
|
||||
upstreamResponse := `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}]}`
|
||||
desanitized := san.DesanitizeResponse([]byte(upstreamResponse))
|
||||
|
||||
if !strings.Contains(string(desanitized), "original_tool") {
|
||||
t.Errorf("expected tool name to be desanitized back to 'original_tool', got %s", string(desanitized))
|
||||
}
|
||||
if strings.Contains(string(desanitized), `"name":"renamed_tool"`) {
|
||||
t.Error("renamed_tool should have been replaced by original_tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitization_StreamEvent(t *testing.T) {
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
event := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}`
|
||||
desanitized := san.DesanitizeStreamEvent(event)
|
||||
|
||||
if !strings.Contains(desanitized, "original_tool") {
|
||||
t.Errorf("expected stream event to be desanitized, got %s", desanitized)
|
||||
}
|
||||
}
|
||||
|
||||
// --- handleNonStream behavior tests via direct function ---
|
||||
|
||||
func TestHandleNonStream_ConnectionError(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
tracker := ratelimit.NewTracker(func() string { return "" })
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: http.Client{Transport: &failingTransport{}},
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, tracker)
|
||||
|
||||
if w.Code != http.StatusBadGateway {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "upstream request failed") {
|
||||
t.Errorf("body = %q, expected upstream error message", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_UpstreamSuccess(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Request-Id", "req-123")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hello"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
tracker := ratelimit.NewTracker(func() string { return "" })
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
// Override the messagesURL by constructing a custom Execute that uses the mock.
|
||||
// Since we can't override the const, we test via a mock server approach:
|
||||
// We create a custom http.Client with a transport that redirects to our mock.
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, tracker)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "hello") {
|
||||
t.Errorf("response body missing expected content: %s", w.Body.String())
|
||||
}
|
||||
if got := w.Header().Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want %q", got, "application/json")
|
||||
}
|
||||
if got := w.Header().Get("X-Request-Id"); got != "req-123" {
|
||||
t.Errorf("X-Request-Id = %q, want %q", got, "req-123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_UpstreamError_MarkFailure(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(429)
|
||||
w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"too many requests"}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != 429 {
|
||||
t.Errorf("status = %d, want 429", w.Code)
|
||||
}
|
||||
|
||||
// Verify MarkFailure was called — cred should now be on cooldown
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Error("expected credential to be on cooldown after 429")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_UpstreamSuccess_DesanitizesResponse(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}],"model":"claude-sonnet-4-6","stop_reason":"tool_use","usage":{"input_tokens":10,"output_tokens":5}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
// Should be desanitized back to original_tool
|
||||
if !strings.Contains(w.Body.String(), "original_tool") {
|
||||
t.Errorf("response should contain desanitized tool name 'original_tool', got %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStream_Upstream500_MarkFailure(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(`{"error":{"type":"server_error","message":"internal error"}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleNonStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != 500 {
|
||||
t.Errorf("status = %d, want 500", w.Code)
|
||||
}
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Error("expected credential to be on cooldown after 500")
|
||||
}
|
||||
}
|
||||
|
||||
// --- handleStream behavior tests ---
|
||||
|
||||
func TestHandleStream_ConnectionError(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: http.Client{Transport: &failingTransport{}},
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusBadGateway {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "upstream stream request failed") {
|
||||
t.Errorf("body = %q, expected upstream stream error", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStream_UpstreamError(t *testing.T) {
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(429)
|
||||
w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"rate limited"}}`))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != 429 {
|
||||
t.Errorf("status = %d, want 429", w.Code)
|
||||
}
|
||||
if !cred.IsOnCooldown() {
|
||||
t.Error("expected credential on cooldown after stream 429")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStream_SuccessForwardsEvents(t *testing.T) {
|
||||
events := "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":0}}}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
|
||||
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(events))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
respBody := w.Body.String()
|
||||
if !strings.Contains(respBody, "message_start") {
|
||||
t.Error("response missing message_start event")
|
||||
}
|
||||
if !strings.Contains(respBody, "hello") {
|
||||
t.Error("response missing text content 'hello'")
|
||||
}
|
||||
if !strings.Contains(respBody, "message_stop") {
|
||||
t.Error("response missing message_stop event")
|
||||
}
|
||||
|
||||
// Verify SSE headers
|
||||
if got := w.Header().Get("Content-Type"); got != "text/event-stream" {
|
||||
t.Errorf("Content-Type = %q, want %q", got, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStream_DesanitizesEvents(t *testing.T) {
|
||||
events := "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"name\":\"renamed_tool\",\"id\":\"t1\"}}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
|
||||
|
||||
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(events))
|
||||
}))
|
||||
defer mockUpstream.Close()
|
||||
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}},
|
||||
})
|
||||
|
||||
uc := &UpstreamClient{
|
||||
client: *mockUpstream.Client(),
|
||||
sessionID: "test-sess",
|
||||
profile: nil,
|
||||
}
|
||||
uc.client.Transport = &rewriteTransport{
|
||||
base: mockUpstream.Client().Transport,
|
||||
destURL: mockUpstream.URL,
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
|
||||
handleStream(c, uc, san, pool, cred, body, body, nil)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "original_tool") {
|
||||
t.Errorf("stream response should contain desanitized 'original_tool', got %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleMessages full integration wiring test ---
|
||||
|
||||
func TestHandleMessages_WiresHandlerCorrectly(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
// Verify the handler can be created without panic
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
if handler == nil {
|
||||
t.Fatal("HandleMessages returned nil handler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMessages_EmptyBody(t *testing.T) {
|
||||
cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"}
|
||||
pool := auth.NewPool([]*auth.Credential{cred})
|
||||
san := NewSanitizer(config.SanitizeConfig{})
|
||||
|
||||
handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil)
|
||||
|
||||
// Empty body — handler should still try to pick cred and call upstream
|
||||
// (which will fail with connection error to api.anthropic.com, not a panic)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(""))
|
||||
|
||||
handler(c)
|
||||
|
||||
// Should get a 502 because the upstream URL (api.anthropic.com) won't be reachable
|
||||
// in test environment, or it might complete. The key thing is no panic.
|
||||
// We mainly verify it doesn't panic.
|
||||
if w.Code == 0 {
|
||||
t.Error("expected non-zero status code")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test helpers ---
|
||||
|
||||
// errReader is an io.Reader that always returns an error.
|
||||
type errReader struct{}
|
||||
|
||||
func (e *errReader) Read([]byte) (int, error) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
// failingTransport is an http.RoundTripper that always returns an error.
|
||||
type failingTransport struct{}
|
||||
|
||||
func (f *failingTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("connection refused")
|
||||
}
|
||||
|
||||
// rewriteTransport intercepts HTTP requests and rewrites the URL to point
|
||||
// at a local test server. This allows testing with UpstreamClient's hardcoded
|
||||
// messagesURL by redirecting all requests to a mock server.
|
||||
type rewriteTransport struct {
|
||||
base http.RoundTripper
|
||||
destURL string
|
||||
}
|
||||
|
||||
func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL to point at our mock server
|
||||
newReq := req.Clone(req.Context())
|
||||
newReq.URL.Scheme = "http"
|
||||
newReq.URL.Host = strings.TrimPrefix(t.destURL, "http://")
|
||||
newReq.URL.Path = "/v1/messages"
|
||||
newReq.URL.RawQuery = ""
|
||||
if t.base == nil {
|
||||
return http.DefaultTransport.RoundTrip(newReq)
|
||||
}
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
@@ -0,0 +1,476 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/config"
|
||||
)
|
||||
|
||||
func TestNewSanitizer_Empty(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{})
|
||||
if len(s.toolsForward) != 0 {
|
||||
t.Errorf("expected empty toolsForward, got %d entries", len(s.toolsForward))
|
||||
}
|
||||
if len(s.toolsReverse) != 0 {
|
||||
t.Errorf("expected empty toolsReverse, got %d entries", len(s.toolsReverse))
|
||||
}
|
||||
if s.systemRules != nil {
|
||||
t.Errorf("expected nil systemRules")
|
||||
}
|
||||
if s.bodyRules != nil {
|
||||
t.Errorf("expected nil bodyRules")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSanitizer_WithTools(t *testing.T) {
|
||||
cfg := config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{
|
||||
{From: "old_tool", To: "new_tool"},
|
||||
{From: "another", To: "replaced"},
|
||||
},
|
||||
}
|
||||
s := NewSanitizer(cfg)
|
||||
if got := s.toolsForward["old_tool"]; got != "new_tool" {
|
||||
t.Errorf("toolsForward[old_tool] = %q, want %q", got, "new_tool")
|
||||
}
|
||||
if got := s.toolsReverse["new_tool"]; got != "old_tool" {
|
||||
t.Errorf("toolsReverse[new_tool] = %q, want %q", got, "old_tool")
|
||||
}
|
||||
if got := s.toolsForward["another"]; got != "replaced" {
|
||||
t.Errorf("toolsForward[another] = %q, want %q", got, "replaced")
|
||||
}
|
||||
if got := s.toolsReverse["replaced"]; got != "another" {
|
||||
t.Errorf("toolsReverse[replaced] = %q, want %q", got, "another")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSanitizer_WithSystemAndBodyRules(t *testing.T) {
|
||||
cfg := config.SanitizeConfig{
|
||||
System: []config.ReplaceRule{{Match: "foo", Replace: "bar"}},
|
||||
Body: []config.ReplaceRule{{Match: "baz", Replace: "qux"}},
|
||||
}
|
||||
s := NewSanitizer(cfg)
|
||||
if len(s.systemRules) != 1 || s.systemRules[0].Match != "foo" {
|
||||
t.Errorf("systemRules not set correctly")
|
||||
}
|
||||
if len(s.bodyRules) != 1 || s.bodyRules[0].Match != "baz" {
|
||||
t.Errorf("bodyRules not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenameTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
forward map[string]string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty map returns body unchanged",
|
||||
forward: map[string]string{},
|
||||
body: `{"tools":[{"name":"my_tool"}]}`,
|
||||
expected: `{"tools":[{"name":"my_tool"}]}`,
|
||||
},
|
||||
{
|
||||
name: "no tools array returns body unchanged",
|
||||
forward: map[string]string{"my_tool": "renamed"},
|
||||
body: `{"messages":[]}`,
|
||||
expected: `{"messages":[]}`,
|
||||
},
|
||||
{
|
||||
name: "tools is not array returns body unchanged",
|
||||
forward: map[string]string{"my_tool": "renamed"},
|
||||
body: `{"tools":"not_array"}`,
|
||||
expected: `{"tools":"not_array"}`,
|
||||
},
|
||||
{
|
||||
name: "matching tool gets renamed",
|
||||
forward: map[string]string{"my_tool": "renamed_tool"},
|
||||
body: `{"tools":[{"name":"my_tool","description":"desc"}]}`,
|
||||
expected: `renamed_tool`,
|
||||
},
|
||||
{
|
||||
name: "non-matching tool unchanged",
|
||||
forward: map[string]string{"other_tool": "renamed"},
|
||||
body: `{"tools":[{"name":"my_tool"}]}`,
|
||||
expected: `my_tool`,
|
||||
},
|
||||
{
|
||||
name: "partial match - only exact match renames",
|
||||
forward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"},
|
||||
body: `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}`,
|
||||
expected: `tool_x`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: tt.forward,
|
||||
toolsReverse: make(map[string]string),
|
||||
}
|
||||
result := string(s.renameTools([]byte(tt.body)))
|
||||
if !strings.Contains(result, tt.expected) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenameTools_MultipleTools(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: map[string]string{"tool_a": "tool_x", "tool_b": "tool_y"},
|
||||
toolsReverse: make(map[string]string),
|
||||
}
|
||||
body := `{"tools":[{"name":"tool_a"},{"name":"tool_c"},{"name":"tool_b"}]}`
|
||||
result := string(s.renameTools([]byte(body)))
|
||||
if !strings.Contains(result, `"tool_x"`) {
|
||||
t.Errorf("tool_a should be renamed to tool_x, got %s", result)
|
||||
}
|
||||
if !strings.Contains(result, `"tool_y"`) {
|
||||
t.Errorf("tool_b should be renamed to tool_y, got %s", result)
|
||||
}
|
||||
if !strings.Contains(result, `"tool_c"`) {
|
||||
t.Errorf("tool_c should remain unchanged, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceSystem(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []config.ReplaceRule
|
||||
body string
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "empty rules returns body unchanged",
|
||||
rules: nil,
|
||||
body: `{"system":[{"type":"text","text":"hello world"}]}`,
|
||||
contains: "hello world",
|
||||
},
|
||||
{
|
||||
name: "no system field returns body unchanged",
|
||||
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
|
||||
body: `{"messages":[]}`,
|
||||
contains: `"messages":[]`,
|
||||
},
|
||||
{
|
||||
name: "system not array returns body unchanged",
|
||||
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
|
||||
body: `{"system":"just a string"}`,
|
||||
contains: "just a string",
|
||||
},
|
||||
{
|
||||
name: "single block single rule",
|
||||
rules: []config.ReplaceRule{{Match: "hello", Replace: "goodbye"}},
|
||||
body: `{"system":[{"type":"text","text":"hello world"}]}`,
|
||||
contains: "goodbye world",
|
||||
},
|
||||
{
|
||||
name: "multiple blocks",
|
||||
rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}},
|
||||
body: `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}`,
|
||||
contains: "BBB first",
|
||||
},
|
||||
{
|
||||
name: "multiple rules applied in order",
|
||||
rules: []config.ReplaceRule{{Match: "cat", Replace: "dog"}, {Match: "dog", Replace: "fish"}},
|
||||
body: `{"system":[{"type":"text","text":"I have a cat"}]}`,
|
||||
contains: "I have a fish",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: make(map[string]string),
|
||||
systemRules: tt.rules,
|
||||
}
|
||||
result := string(s.replaceSystem([]byte(tt.body)))
|
||||
if !strings.Contains(result, tt.contains) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceSystem_MultipleBlocks(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: make(map[string]string),
|
||||
systemRules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}},
|
||||
}
|
||||
body := `{"system":[{"type":"text","text":"AAA first"},{"type":"text","text":"AAA second"}]}`
|
||||
result := string(s.replaceSystem([]byte(body)))
|
||||
if !strings.Contains(result, "BBB first") {
|
||||
t.Errorf("first block not replaced: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "BBB second") {
|
||||
t.Errorf("second block not replaced: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []config.ReplaceRule
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty rules returns body unchanged",
|
||||
rules: nil,
|
||||
body: `{"foo":"bar"}`,
|
||||
expected: `{"foo":"bar"}`,
|
||||
},
|
||||
{
|
||||
name: "single replacement across entire body",
|
||||
rules: []config.ReplaceRule{{Match: "SECRET", Replace: "REDACTED"}},
|
||||
body: `{"data":"SECRET value SECRET"}`,
|
||||
expected: `{"data":"REDACTED value REDACTED"}`,
|
||||
},
|
||||
{
|
||||
name: "multiple rules applied sequentially",
|
||||
rules: []config.ReplaceRule{{Match: "AAA", Replace: "BBB"}, {Match: "BBB", Replace: "CCC"}},
|
||||
body: `{"text":"AAA"}`,
|
||||
expected: `{"text":"CCC"}`,
|
||||
},
|
||||
{
|
||||
name: "no match leaves body unchanged",
|
||||
rules: []config.ReplaceRule{{Match: "NOMATCH", Replace: "X"}},
|
||||
body: `{"text":"hello"}`,
|
||||
expected: `{"text":"hello"}`,
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
rules: []config.ReplaceRule{{Match: "a", Replace: "b"}},
|
||||
body: ``,
|
||||
expected: ``,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: make(map[string]string),
|
||||
bodyRules: tt.rules,
|
||||
}
|
||||
result := string(s.replaceBody([]byte(tt.body)))
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest(t *testing.T) {
|
||||
cfg := config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}},
|
||||
System: []config.ReplaceRule{{Match: "INTERNAL", Replace: "PUBLIC"}},
|
||||
Body: []config.ReplaceRule{{Match: "secret_val", Replace: "safe_val"}},
|
||||
}
|
||||
s := NewSanitizer(cfg)
|
||||
|
||||
body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"INTERNAL info"}],"data":"secret_val here"}`
|
||||
result := string(s.SanitizeRequest([]byte(body)))
|
||||
|
||||
if !strings.Contains(result, `"renamed_tool"`) {
|
||||
t.Errorf("tool not renamed in result: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "PUBLIC info") {
|
||||
t.Errorf("system not replaced in result: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "safe_val here") {
|
||||
t.Errorf("body not replaced in result: %s", result)
|
||||
}
|
||||
if strings.Contains(result, "secret_val") {
|
||||
t.Errorf("secret_val should have been replaced: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest_EmptyConfig(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{})
|
||||
body := `{"tools":[{"name":"my_tool"}],"system":[{"type":"text","text":"hello"}]}`
|
||||
result := string(s.SanitizeRequest([]byte(body)))
|
||||
if result != body {
|
||||
t.Errorf("empty config should not modify body.\ngot: %s\nwant: %s", result, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reverse map[string]string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no content field returns unchanged",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"id":"msg_1","role":"assistant"}`,
|
||||
expected: `{"id":"msg_1","role":"assistant"}`,
|
||||
},
|
||||
{
|
||||
name: "content not array returns unchanged",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"content":"just text"}`,
|
||||
expected: `{"content":"just text"}`,
|
||||
},
|
||||
{
|
||||
name: "non-tool_use block left unchanged",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"content":[{"type":"text","text":"hello"}]}`,
|
||||
expected: `{"content":[{"type":"text","text":"hello"}]}`,
|
||||
},
|
||||
{
|
||||
name: "tool_use block with matching name gets reversed",
|
||||
reverse: map[string]string{"renamed_tool": "original_tool"},
|
||||
body: `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1"}]}`,
|
||||
expected: `original_tool`,
|
||||
},
|
||||
{
|
||||
name: "tool_use block with no match unchanged",
|
||||
reverse: map[string]string{"other": "something"},
|
||||
body: `{"content":[{"type":"tool_use","name":"my_tool","id":"t1"}]}`,
|
||||
expected: `my_tool`,
|
||||
},
|
||||
{
|
||||
name: "mixed blocks only tool_use reversed",
|
||||
reverse: map[string]string{"renamed": "original"},
|
||||
body: `{"content":[{"type":"text","text":"hi"},{"type":"tool_use","name":"renamed","id":"t1"}]}`,
|
||||
expected: `original`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: tt.reverse,
|
||||
}
|
||||
result := string(s.DesanitizeResponse([]byte(tt.body)))
|
||||
if !strings.Contains(result, tt.expected) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeResponse_MultipleToolUse(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: map[string]string{"r1": "o1", "r2": "o2"},
|
||||
}
|
||||
body := `{"content":[{"type":"tool_use","name":"r1","id":"t1"},{"type":"text","text":"x"},{"type":"tool_use","name":"r2","id":"t2"}]}`
|
||||
result := string(s.DesanitizeResponse([]byte(body)))
|
||||
if !strings.Contains(result, `"o1"`) {
|
||||
t.Errorf("r1 not reversed to o1: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, `"o2"`) {
|
||||
t.Errorf("r2 not reversed to o2: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeStreamEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reverse map[string]string
|
||||
line string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "non-data line passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: "event: content_block_start",
|
||||
expected: "event: content_block_start",
|
||||
},
|
||||
{
|
||||
name: "data line without tool_use passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: `data: {"type":"text","text":"hello"}`,
|
||||
expected: `data: {"type":"text","text":"hello"}`,
|
||||
},
|
||||
{
|
||||
name: "data line with tool_use in content_block.name",
|
||||
reverse: map[string]string{"renamed_tool": "original_tool"},
|
||||
line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}`,
|
||||
expected: `original_tool`,
|
||||
},
|
||||
{
|
||||
name: "data line with tool_use in delta.name",
|
||||
reverse: map[string]string{"renamed_tool": "original_tool"},
|
||||
line: `data: {"type":"content_block_delta","delta":{"type":"tool_use","name":"renamed_tool"}}`,
|
||||
expected: `original_tool`,
|
||||
},
|
||||
{
|
||||
name: "data line with tool_use but no matching name",
|
||||
reverse: map[string]string{"other": "something"},
|
||||
line: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"my_tool","id":"t1"}}`,
|
||||
expected: `my_tool`,
|
||||
},
|
||||
{
|
||||
name: "empty line passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "line contains tool_use but not data prefix - passed through",
|
||||
reverse: map[string]string{"r": "o"},
|
||||
line: "event: tool_use",
|
||||
expected: "event: tool_use",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: tt.reverse,
|
||||
}
|
||||
result := s.DesanitizeStreamEvent(tt.line)
|
||||
if !strings.Contains(result, tt.expected) {
|
||||
t.Errorf("result %q does not contain %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesanitizeStreamEvent_DataPrefixPreserved(t *testing.T) {
|
||||
s := &Sanitizer{
|
||||
toolsForward: make(map[string]string),
|
||||
toolsReverse: map[string]string{"renamed": "original"},
|
||||
}
|
||||
line := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed","id":"t1"}}`
|
||||
result := s.DesanitizeStreamEvent(line)
|
||||
if !strings.HasPrefix(result, "data: ") {
|
||||
t.Errorf("result should start with 'data: ', got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest_MalformedJSON(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "a", To: "b"}},
|
||||
System: []config.ReplaceRule{{Match: "x", Replace: "y"}},
|
||||
})
|
||||
// Malformed JSON - renameTools and replaceSystem should handle gracefully
|
||||
body := `not valid json`
|
||||
result := string(s.SanitizeRequest([]byte(body)))
|
||||
// Should not panic; body rules still do string replacement
|
||||
if result != "not valid json" {
|
||||
t.Errorf("malformed JSON should pass through (no body rules match), got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRequest_EmptyBody(t *testing.T) {
|
||||
s := NewSanitizer(config.SanitizeConfig{
|
||||
Tools: []config.RenameRule{{From: "a", To: "b"}},
|
||||
})
|
||||
result := s.SanitizeRequest([]byte{})
|
||||
if len(result) != 0 {
|
||||
t.Errorf("empty body should return empty, got %q", string(result))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newRequest(t *testing.T, headers map[string][]string) *http.Request {
|
||||
t.Helper()
|
||||
r := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
r.Header = http.Header{}
|
||||
for k, vals := range headers {
|
||||
for _, v := range vals {
|
||||
r.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func TestExtractProfile_BasicHeaders(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
"X-Custom-Header": {"custom-value"},
|
||||
"User-Agent": {"Claude/1.2.3 linux"},
|
||||
})
|
||||
body := []byte(`{"model":"claude-sonnet-4-6"}`)
|
||||
|
||||
p := extractProfile(r, body)
|
||||
|
||||
// Check version parsed
|
||||
if p.Version != "1.2.3" {
|
||||
t.Errorf("version = %q, want %q", p.Version, "1.2.3")
|
||||
}
|
||||
|
||||
// Check body preserved
|
||||
if string(p.Body) != string(body) {
|
||||
t.Errorf("body not preserved")
|
||||
}
|
||||
|
||||
// Check headers captured
|
||||
found := map[string]bool{}
|
||||
for _, h := range p.Headers {
|
||||
found[strings.ToLower(h[0])] = true
|
||||
}
|
||||
if !found["content-type"] {
|
||||
t.Error("Content-Type header should be captured")
|
||||
}
|
||||
if !found["x-custom-header"] {
|
||||
t.Error("X-Custom-Header should be captured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_SkipHeaders(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Host": {"example.com"},
|
||||
"Content-Length": {"42"},
|
||||
"Authorization": {"Bearer token123"},
|
||||
"X-Api-Key": {"key123"},
|
||||
"Connection": {"keep-alive"},
|
||||
"Content-Type": {"application/json"},
|
||||
"X-Custom": {"keep-me"},
|
||||
})
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
for _, h := range p.Headers {
|
||||
lower := strings.ToLower(h[0])
|
||||
if skipHeaders[lower] {
|
||||
t.Errorf("header %q should have been skipped", h[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify non-skipped headers are present
|
||||
found := map[string]bool{}
|
||||
for _, h := range p.Headers {
|
||||
found[strings.ToLower(h[0])] = true
|
||||
}
|
||||
if !found["content-type"] {
|
||||
t.Error("Content-Type should be kept")
|
||||
}
|
||||
if !found["x-custom"] {
|
||||
t.Error("X-Custom should be kept")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_HeaderDeduplication(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
})
|
||||
// Add duplicate with different casing - Go's http.Header normalizes to canonical form
|
||||
// so we need to add the same canonical header with multiple values to test dedup
|
||||
r.Header.Add("Content-Type", "text/plain")
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
// After deduplication by lowercase key, only one entry per key
|
||||
seen := map[string]int{}
|
||||
for _, h := range p.Headers {
|
||||
seen[strings.ToLower(h[0])]++
|
||||
}
|
||||
for key, count := range seen {
|
||||
if count > 1 {
|
||||
t.Errorf("header %q appears %d times after dedup, want 1", key, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_AnthropicBetaContextStripping(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Anthropic-Beta": {"prompt-caching-2024-07-31,context-1m-2024-09-01,some-other-beta"},
|
||||
})
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
var betaValue string
|
||||
for _, h := range p.Headers {
|
||||
if strings.ToLower(h[0]) == "anthropic-beta" {
|
||||
betaValue = h[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(betaValue, "context-1m") {
|
||||
t.Errorf("context-1m should be stripped from anthropic-beta, got %q", betaValue)
|
||||
}
|
||||
if !strings.Contains(betaValue, "prompt-caching-2024-07-31") {
|
||||
t.Errorf("prompt-caching should be preserved, got %q", betaValue)
|
||||
}
|
||||
if !strings.Contains(betaValue, "some-other-beta") {
|
||||
t.Errorf("some-other-beta should be preserved, got %q", betaValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_AnthropicBetaAllContextRemoved(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"Anthropic-Beta": {"context-1m-2024-09-01"},
|
||||
})
|
||||
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
|
||||
for _, h := range p.Headers {
|
||||
if strings.ToLower(h[0]) == "anthropic-beta" {
|
||||
// All betas were context-1m, so after filtering the value should be empty
|
||||
if h[1] != "" {
|
||||
t.Errorf("all context-1m betas stripped should leave empty, got %q", h[1])
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
// It's also acceptable if the header is still present but empty
|
||||
}
|
||||
|
||||
func TestExtractProfile_VersionParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "standard Claude UA",
|
||||
userAgent: "Claude/1.2.3 linux x86_64",
|
||||
expected: "1.2.3",
|
||||
},
|
||||
{
|
||||
name: "version with no space after",
|
||||
userAgent: "Claude/4.5.6",
|
||||
expected: "4.5.6",
|
||||
},
|
||||
{
|
||||
name: "no slash in UA",
|
||||
userAgent: "Mozilla 5.0",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty UA",
|
||||
userAgent: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "slash at start",
|
||||
userAgent: "/1.0.0 rest",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "multiple slashes",
|
||||
userAgent: "App/1.0.0 (sub/2.0)",
|
||||
expected: "1.0.0",
|
||||
},
|
||||
{
|
||||
name: "version only after slash no space",
|
||||
userAgent: "Tool/9.8.7",
|
||||
expected: "9.8.7",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"User-Agent": {tt.userAgent},
|
||||
})
|
||||
p := extractProfile(r, []byte(`{}`))
|
||||
if p.Version != tt.expected {
|
||||
t.Errorf("version = %q, want %q", p.Version, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_EmptyHeaders(t *testing.T) {
|
||||
r := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
r.Header = http.Header{}
|
||||
|
||||
p := extractProfile(r, []byte(`{"test":true}`))
|
||||
|
||||
if len(p.Headers) != 0 {
|
||||
t.Errorf("expected no headers, got %d", len(p.Headers))
|
||||
}
|
||||
if p.Version != "" {
|
||||
t.Errorf("expected empty version with no UA, got %q", p.Version)
|
||||
}
|
||||
if string(p.Body) != `{"test":true}` {
|
||||
t.Errorf("body not preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractProfile_BodyPreserved(t *testing.T) {
|
||||
r := newRequest(t, map[string][]string{
|
||||
"User-Agent": {"Claude/1.0.0 test"},
|
||||
})
|
||||
body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
|
||||
p := extractProfile(r, body)
|
||||
|
||||
if string(p.Body) != string(body) {
|
||||
t.Errorf("body not preserved.\ngot: %s\nwant: %s", p.Body, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipHeaders_Entries(t *testing.T) {
|
||||
expected := map[string]bool{
|
||||
"host": true,
|
||||
"content-length": true,
|
||||
"authorization": true,
|
||||
"x-api-key": true,
|
||||
"connection": true,
|
||||
}
|
||||
if len(skipHeaders) != len(expected) {
|
||||
t.Errorf("skipHeaders has %d entries, want %d", len(skipHeaders), len(expected))
|
||||
}
|
||||
for k, v := range expected {
|
||||
if skipHeaders[k] != v {
|
||||
t.Errorf("skipHeaders[%q] = %v, want %v", k, skipHeaders[k], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSniffedProfile_Fields(t *testing.T) {
|
||||
// Verify the struct can hold all expected data
|
||||
p := &SniffedProfile{
|
||||
Headers: [][2]string{{"Content-Type", "application/json"}},
|
||||
Body: []byte(`{}`),
|
||||
Version: "1.0.0",
|
||||
}
|
||||
if len(p.Headers) != 1 {
|
||||
t.Error("Headers should have 1 entry")
|
||||
}
|
||||
if p.Headers[0][0] != "Content-Type" || p.Headers[0][1] != "application/json" {
|
||||
t.Error("Header not stored correctly")
|
||||
}
|
||||
if string(p.Body) != `{}` {
|
||||
t.Error("Body not stored correctly")
|
||||
}
|
||||
if p.Version != "1.0.0" {
|
||||
t.Error("Version not stored correctly")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,334 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- NewUpstreamClient ---
|
||||
|
||||
func TestNewUpstreamClient_NilProfile(t *testing.T) {
|
||||
uc := NewUpstreamClient(nil)
|
||||
if uc == nil {
|
||||
t.Fatal("NewUpstreamClient returned nil")
|
||||
}
|
||||
if uc.sessionID == "" {
|
||||
t.Error("expected non-empty sessionID")
|
||||
}
|
||||
if uc.profile != nil {
|
||||
t.Error("expected nil profile")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUpstreamClient_WithProfile(t *testing.T) {
|
||||
profile := &SniffedProfile{
|
||||
Version: "1.2.3",
|
||||
Headers: [][2]string{{"User-Agent", "test/1.0"}},
|
||||
}
|
||||
uc := NewUpstreamClient(profile)
|
||||
if uc.profile != profile {
|
||||
t.Error("expected profile to be stored")
|
||||
}
|
||||
if uc.sessionID == "" {
|
||||
t.Error("expected non-empty sessionID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUpstreamClient_UniqueSessionIDs(t *testing.T) {
|
||||
uc1 := NewUpstreamClient(nil)
|
||||
uc2 := NewUpstreamClient(nil)
|
||||
if uc1.sessionID == uc2.sessionID {
|
||||
t.Errorf("expected different session IDs, both got %q", uc1.sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
// --- version() ---
|
||||
|
||||
func TestVersion_WithProfileVersion(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
profile: &SniffedProfile{Version: "3.5.7"},
|
||||
}
|
||||
if got := uc.version(); got != "3.5.7" {
|
||||
t.Errorf("version() = %q, want %q", got, "3.5.7")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersion_NilProfile_Fallback(t *testing.T) {
|
||||
uc := &UpstreamClient{profile: nil}
|
||||
if got := uc.version(); got != "2.1.92" {
|
||||
t.Errorf("version() = %q, want %q", got, "2.1.92")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersion_EmptyProfileVersion_Fallback(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
profile: &SniffedProfile{Version: ""},
|
||||
}
|
||||
if got := uc.version(); got != "2.1.92" {
|
||||
t.Errorf("version() = %q, want %q", got, "2.1.92")
|
||||
}
|
||||
}
|
||||
|
||||
// --- applyHeaders ---
|
||||
|
||||
func TestApplyHeaders_NilProfile_NonOAuth_NonStream(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "test-session-id",
|
||||
profile: nil,
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api123", false)
|
||||
|
||||
// x-api-key for non-OAuth token
|
||||
if got := req.Header.Get("x-api-key"); got != "sk-ant-api123" {
|
||||
t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api123")
|
||||
}
|
||||
// Should NOT have Authorization
|
||||
if got := req.Header.Get("Authorization"); got != "" {
|
||||
t.Errorf("Authorization = %q, want empty", got)
|
||||
}
|
||||
// Session ID
|
||||
if got := req.Header.Get("X-Claude-Code-Session-Id"); got != "test-session-id" {
|
||||
t.Errorf("X-Claude-Code-Session-Id = %q, want %q", got, "test-session-id")
|
||||
}
|
||||
// Request ID should be a UUID
|
||||
if got := req.Header.Get("x-client-request-id"); got == "" {
|
||||
t.Error("expected non-empty x-client-request-id")
|
||||
}
|
||||
// Non-stream: application/json
|
||||
if got := req.Header.Get("Accept"); got != "application/json" {
|
||||
t.Errorf("Accept = %q, want %q", got, "application/json")
|
||||
}
|
||||
// Accept-Encoding always identity
|
||||
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q, want %q", got, "identity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_NilProfile_NonOAuth_Stream(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: nil,
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api123", true)
|
||||
|
||||
if got := req.Header.Get("Accept"); got != "text/event-stream" {
|
||||
t.Errorf("Accept = %q, want %q", got, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_SetsBearerAndBetaFlag(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: nil,
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-mytoken", false)
|
||||
|
||||
// OAuth: Authorization Bearer
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-mytoken" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-mytoken")
|
||||
}
|
||||
// Should NOT have x-api-key
|
||||
if got := req.Header.Get("x-api-key"); got != "" {
|
||||
t.Errorf("x-api-key = %q, want empty for OAuth", got)
|
||||
}
|
||||
// anthropic-beta should include oauth-2025-04-20
|
||||
if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_AppendsToExistingBeta(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-tok", false)
|
||||
|
||||
beta := req.Header.Get("anthropic-beta")
|
||||
if !strings.Contains(beta, "max-tokens-3-5-sonnet-2024-07-15") {
|
||||
t.Errorf("anthropic-beta %q should contain existing beta", beta)
|
||||
}
|
||||
if !strings.Contains(beta, "oauth-2025-04-20") {
|
||||
t.Errorf("anthropic-beta %q should contain oauth flag", beta)
|
||||
}
|
||||
// Should be appended with comma
|
||||
if beta != "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want %q", beta, "max-tokens-3-5-sonnet-2024-07-15,oauth-2025-04-20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_ExistingBetaAlreadyHasOAuth(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"anthropic-beta", "oauth-2025-04-20,something-else"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-tok", false)
|
||||
|
||||
beta := req.Header.Get("anthropic-beta")
|
||||
// Should NOT duplicate oauth flag
|
||||
count := strings.Count(beta, "oauth-2025-04-20")
|
||||
if count != 1 {
|
||||
t.Errorf("oauth flag appeared %d times in %q, want 1", count, beta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_WithProfile_ReplaysHeaders(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"User-Agent", "Claude/1.0"},
|
||||
{"anthropic-version", "2023-06-01"},
|
||||
{"Custom-Header", "custom-value"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api123", false)
|
||||
|
||||
if got := req.Header.Get("User-Agent"); got != "Claude/1.0" {
|
||||
t.Errorf("User-Agent = %q, want %q", got, "Claude/1.0")
|
||||
}
|
||||
if got := req.Header.Get("anthropic-version"); got != "2023-06-01" {
|
||||
t.Errorf("anthropic-version = %q, want %q", got, "2023-06-01")
|
||||
}
|
||||
if got := req.Header.Get("Custom-Header"); got != "custom-value" {
|
||||
t.Errorf("Custom-Header = %q, want %q", got, "custom-value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_ProfileAuthHeadersRemoved(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"Authorization", "Bearer old-token"},
|
||||
{"x-api-key", "old-api-key"},
|
||||
{"User-Agent", "test"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-api-new", false)
|
||||
|
||||
// Old auth headers from profile should be removed
|
||||
if got := req.Header.Get("Authorization"); got != "" {
|
||||
t.Errorf("Authorization should be empty for non-OAuth, got %q", got)
|
||||
}
|
||||
// New auth should be set via x-api-key
|
||||
if got := req.Header.Get("x-api-key"); got != "sk-ant-api-new" {
|
||||
t.Errorf("x-api-key = %q, want %q", got, "sk-ant-api-new")
|
||||
}
|
||||
// User-Agent from profile should remain
|
||||
if got := req.Header.Get("User-Agent"); got != "test" {
|
||||
t.Errorf("User-Agent = %q, want %q", got, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_ProfileAuthHeadersRemovedForOAuth(t *testing.T) {
|
||||
uc := &UpstreamClient{
|
||||
sessionID: "sess",
|
||||
profile: &SniffedProfile{
|
||||
Headers: [][2]string{
|
||||
{"Authorization", "Bearer old-token"},
|
||||
{"x-api-key", "old-api-key"},
|
||||
},
|
||||
},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
|
||||
uc.applyHeaders(req, "sk-ant-oat-new", false)
|
||||
|
||||
// Old x-api-key removed
|
||||
if got := req.Header.Get("x-api-key"); got != "" {
|
||||
t.Errorf("x-api-key should be empty for OAuth, got %q", got)
|
||||
}
|
||||
// New auth set via Authorization
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer sk-ant-oat-new" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer sk-ant-oat-new")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_AcceptEncoding_AlwaysIdentity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
streaming bool
|
||||
}{
|
||||
{"non-stream", false},
|
||||
{"stream", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req, "token", tt.streaming)
|
||||
|
||||
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q, want %q", got, "identity")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_UniqueRequestIDs(t *testing.T) {
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
|
||||
req1, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req1, "tok", false)
|
||||
|
||||
req2, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req2, "tok", false)
|
||||
|
||||
id1 := req1.Header.Get("x-client-request-id")
|
||||
id2 := req2.Header.Get("x-client-request-id")
|
||||
if id1 == "" || id2 == "" {
|
||||
t.Fatal("expected non-empty request IDs")
|
||||
}
|
||||
if id1 == id2 {
|
||||
t.Errorf("expected unique request IDs, both got %q", id1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_NonOAuth_NoAnthroPicBetaSet(t *testing.T) {
|
||||
// Non-OAuth tokens should NOT set anthropic-beta oauth flag
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req, "sk-ant-api123", false)
|
||||
|
||||
beta := req.Header.Get("anthropic-beta")
|
||||
if strings.Contains(beta, "oauth-2025-04-20") {
|
||||
t.Errorf("non-OAuth token should not have oauth beta flag, got %q", beta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OAuthToken_FreshBeta(t *testing.T) {
|
||||
// No profile, no existing beta — should set fresh
|
||||
uc := &UpstreamClient{sessionID: "s", profile: nil}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
uc.applyHeaders(req, "sk-ant-oat-tok", false)
|
||||
|
||||
if got := req.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want %q", got, "oauth-2025-04-20")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewTracker(t *testing.T) {
|
||||
called := false
|
||||
tr := NewTracker(func() string {
|
||||
called = true
|
||||
return "tok"
|
||||
})
|
||||
if tr == nil {
|
||||
t.Fatal("NewTracker returned nil")
|
||||
}
|
||||
// tokenFn stored but not called during construction
|
||||
if called {
|
||||
t.Error("tokenFn should not be called by NewTracker")
|
||||
}
|
||||
// Invoke to verify it's wired
|
||||
if got := tr.tokenFn(); got != "tok" {
|
||||
t.Errorf("tokenFn() = %q, want tok", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_Full(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
h := http.Header{}
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "0.42")
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Reset", "1700000000")
|
||||
h.Set("Anthropic-Ratelimit-Unified-7d-Utilization", "0.75")
|
||||
h.Set("Anthropic-Ratelimit-Unified-7d-Reset", "1700100000")
|
||||
|
||||
tr.UpdateFromHeaders(h)
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 42.0 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 42.0", fh.Utilization)
|
||||
}
|
||||
wantReset5h := time.Unix(1700000000, 0).UTC().Truncate(time.Minute)
|
||||
if !fh.ResetsAt.Equal(wantReset5h) {
|
||||
t.Errorf("FiveHour.ResetsAt = %v, want %v", fh.ResetsAt, wantReset5h)
|
||||
}
|
||||
|
||||
sd := tr.SevenDay()
|
||||
if sd.Utilization != 75.0 {
|
||||
t.Errorf("SevenDay.Utilization = %f, want 75.0", sd.Utilization)
|
||||
}
|
||||
wantReset7d := time.Unix(1700100000, 0).UTC().Truncate(time.Minute)
|
||||
if !sd.ResetsAt.Equal(wantReset7d) {
|
||||
t.Errorf("SevenDay.ResetsAt = %v, want %v", sd.ResetsAt, wantReset7d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_Partial(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Only set 5h utilization, no reset, no 7d
|
||||
h := http.Header{}
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "0.33")
|
||||
tr.UpdateFromHeaders(h)
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 33.0 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 33.0", fh.Utilization)
|
||||
}
|
||||
if !fh.ResetsAt.IsZero() {
|
||||
t.Errorf("FiveHour.ResetsAt should be zero, got %v", fh.ResetsAt)
|
||||
}
|
||||
|
||||
sd := tr.SevenDay()
|
||||
if sd.Utilization != 0 {
|
||||
t.Errorf("SevenDay.Utilization = %f, want 0", sd.Utilization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_Missing(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Pre-set some state
|
||||
tr.mu.Lock()
|
||||
tr.fiveHour.Utilization = 50.0
|
||||
tr.mu.Unlock()
|
||||
|
||||
// Update with empty headers — should not change state
|
||||
tr.UpdateFromHeaders(http.Header{})
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 50.0 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 50.0 (unchanged)", fh.Utilization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFromHeaders_InvalidValues(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
h := http.Header{}
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Utilization", "not-a-number")
|
||||
h.Set("Anthropic-Ratelimit-Unified-5h-Reset", "not-a-timestamp")
|
||||
|
||||
tr.UpdateFromHeaders(h)
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 0 {
|
||||
t.Errorf("Utilization should stay 0 for invalid input, got %f", fh.Utilization)
|
||||
}
|
||||
if !fh.ResetsAt.IsZero() {
|
||||
t.Errorf("ResetsAt should stay zero for invalid input, got %v", fh.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSonnet_Snapshot(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Sonnet is only set via poll/updateWindow, not UpdateFromHeaders
|
||||
// Verify it starts at zero
|
||||
s := tr.Sonnet()
|
||||
if s.Utilization != 0 {
|
||||
t.Errorf("Sonnet.Utilization = %f, want 0", s.Utilization)
|
||||
}
|
||||
if !s.ResetsAt.IsZero() {
|
||||
t.Errorf("Sonnet.ResetsAt should be zero, got %v", s.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtra_Default(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
extra := tr.Extra()
|
||||
if extra.IsEnabled {
|
||||
t.Error("Extra.IsEnabled should be false by default")
|
||||
}
|
||||
if extra.MonthlyLimit != nil {
|
||||
t.Error("Extra.MonthlyLimit should be nil by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWindow(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
util *float64
|
||||
resetsAt *string
|
||||
wantUtil float64
|
||||
wantResetOK bool
|
||||
}{
|
||||
{
|
||||
name: "both fields",
|
||||
util: float64Ptr(65.5),
|
||||
resetsAt: stringPtr("2024-01-15T10:30:45Z"),
|
||||
wantUtil: 65.5,
|
||||
wantResetOK: true,
|
||||
},
|
||||
{
|
||||
name: "utilization only",
|
||||
util: float64Ptr(30.0),
|
||||
resetsAt: nil,
|
||||
wantUtil: 30.0,
|
||||
wantResetOK: false,
|
||||
},
|
||||
{
|
||||
name: "reset only (RFC3339Nano)",
|
||||
util: nil,
|
||||
resetsAt: stringPtr("2024-06-01T12:00:00.123456789Z"),
|
||||
wantUtil: 0,
|
||||
wantResetOK: true,
|
||||
},
|
||||
{
|
||||
name: "nil both",
|
||||
util: nil,
|
||||
resetsAt: nil,
|
||||
wantUtil: 0,
|
||||
wantResetOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := &Window{}
|
||||
rl := &RateLimit{
|
||||
Utilization: tt.util,
|
||||
ResetsAt: tt.resetsAt,
|
||||
}
|
||||
tr.updateWindow(w, rl)
|
||||
|
||||
if w.Utilization != tt.wantUtil {
|
||||
t.Errorf("Utilization = %f, want %f", w.Utilization, tt.wantUtil)
|
||||
}
|
||||
if tt.wantResetOK {
|
||||
if w.ResetsAt.IsZero() {
|
||||
t.Error("ResetsAt should be set")
|
||||
}
|
||||
// Verify truncation to minute
|
||||
if w.ResetsAt.Second() != 0 || w.ResetsAt.Nanosecond() != 0 {
|
||||
t.Errorf("ResetsAt not truncated to minute: %v", w.ResetsAt)
|
||||
}
|
||||
if w.ResetsAt.Location() != time.UTC {
|
||||
t.Errorf("ResetsAt not in UTC: %v", w.ResetsAt.Location())
|
||||
}
|
||||
} else if tt.resetsAt == nil {
|
||||
if !w.ResetsAt.IsZero() {
|
||||
t.Errorf("ResetsAt should be zero when input is nil, got %v", w.ResetsAt)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWindow_InvalidTime(t *testing.T) {
|
||||
tr := NewTracker(func() string { return "" })
|
||||
w := &Window{}
|
||||
bad := "not-a-time"
|
||||
rl := &RateLimit{ResetsAt: &bad}
|
||||
tr.updateWindow(w, rl)
|
||||
if !w.ResetsAt.IsZero() {
|
||||
t.Errorf("ResetsAt should stay zero for invalid time, got %v", w.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoll_SetsStateFromUsageResponse(t *testing.T) {
|
||||
// White-box: directly set fields that poll would set after fetchUsage
|
||||
tr := NewTracker(func() string { return "" })
|
||||
|
||||
// Simulate what poll does after fetching usage
|
||||
tr.mu.Lock()
|
||||
usage := &UsageResponse{
|
||||
FiveHour: &RateLimit{Utilization: float64Ptr(55.5), ResetsAt: stringPtr("2024-03-01T08:00:00Z")},
|
||||
SevenDay: &RateLimit{Utilization: float64Ptr(22.3), ResetsAt: stringPtr("2024-03-07T00:00:00Z")},
|
||||
SevenDaySonnet: &RateLimit{Utilization: float64Ptr(10.0), ResetsAt: stringPtr("2024-03-07T00:00:00Z")},
|
||||
ExtraUsage: &ExtraUsage{IsEnabled: true, MonthlyLimit: float64Ptr(100.0), UsedCredits: float64Ptr(42.5)},
|
||||
}
|
||||
if usage.FiveHour != nil {
|
||||
tr.updateWindow(&tr.fiveHour, usage.FiveHour)
|
||||
}
|
||||
if usage.SevenDay != nil {
|
||||
tr.updateWindow(&tr.sevenDay, usage.SevenDay)
|
||||
}
|
||||
if usage.SevenDaySonnet != nil {
|
||||
tr.updateWindow(&tr.sonnet, usage.SevenDaySonnet)
|
||||
}
|
||||
if usage.ExtraUsage != nil {
|
||||
tr.extra = *usage.ExtraUsage
|
||||
}
|
||||
tr.mu.Unlock()
|
||||
|
||||
fh := tr.FiveHour()
|
||||
if fh.Utilization != 55.5 {
|
||||
t.Errorf("FiveHour.Utilization = %f, want 55.5", fh.Utilization)
|
||||
}
|
||||
|
||||
sd := tr.SevenDay()
|
||||
if sd.Utilization != 22.3 {
|
||||
t.Errorf("SevenDay.Utilization = %f, want 22.3", sd.Utilization)
|
||||
}
|
||||
|
||||
sn := tr.Sonnet()
|
||||
if sn.Utilization != 10.0 {
|
||||
t.Errorf("Sonnet.Utilization = %f, want 10.0", sn.Utilization)
|
||||
}
|
||||
|
||||
extra := tr.Extra()
|
||||
if !extra.IsEnabled {
|
||||
t.Error("Extra.IsEnabled = false, want true")
|
||||
}
|
||||
if extra.MonthlyLimit == nil || *extra.MonthlyLimit != 100.0 {
|
||||
t.Errorf("Extra.MonthlyLimit = %v, want 100.0", extra.MonthlyLimit)
|
||||
}
|
||||
if extra.UsedCredits == nil || *extra.UsedCredits != 42.5 {
|
||||
t.Errorf("Extra.UsedCredits = %v, want 42.5", extra.UsedCredits)
|
||||
}
|
||||
}
|
||||
|
||||
func float64Ptr(f float64) *float64 { return &f }
|
||||
func stringPtr(s string) *string { return &s }
|
||||
@@ -0,0 +1,241 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUsageResponse_FullJSON(t *testing.T) {
|
||||
raw := `{
|
||||
"five_hour": {"utilization": 42.5, "resets_at": "2024-01-15T10:30:00Z"},
|
||||
"seven_day": {"utilization": 75.0, "resets_at": "2024-01-20T00:00:00Z"},
|
||||
"seven_day_sonnet": {"utilization": 10.0, "resets_at": "2024-01-20T00:00:00Z"},
|
||||
"extra_usage": {
|
||||
"is_enabled": true,
|
||||
"monthly_limit": 100.0,
|
||||
"used_credits": 42.5,
|
||||
"utilization": 42.5
|
||||
}
|
||||
}`
|
||||
|
||||
var resp UsageResponse
|
||||
if err := json.Unmarshal([]byte(raw), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp.FiveHour == nil {
|
||||
t.Fatal("FiveHour is nil")
|
||||
}
|
||||
if resp.FiveHour.Utilization == nil || *resp.FiveHour.Utilization != 42.5 {
|
||||
t.Errorf("FiveHour.Utilization = %v, want 42.5", resp.FiveHour.Utilization)
|
||||
}
|
||||
if resp.FiveHour.ResetsAt == nil || *resp.FiveHour.ResetsAt != "2024-01-15T10:30:00Z" {
|
||||
t.Errorf("FiveHour.ResetsAt = %v", resp.FiveHour.ResetsAt)
|
||||
}
|
||||
|
||||
if resp.SevenDay == nil {
|
||||
t.Fatal("SevenDay is nil")
|
||||
}
|
||||
if resp.SevenDay.Utilization == nil || *resp.SevenDay.Utilization != 75.0 {
|
||||
t.Errorf("SevenDay.Utilization = %v, want 75.0", resp.SevenDay.Utilization)
|
||||
}
|
||||
|
||||
if resp.SevenDaySonnet == nil {
|
||||
t.Fatal("SevenDaySonnet is nil")
|
||||
}
|
||||
if resp.SevenDaySonnet.Utilization == nil || *resp.SevenDaySonnet.Utilization != 10.0 {
|
||||
t.Errorf("SevenDaySonnet.Utilization = %v", resp.SevenDaySonnet.Utilization)
|
||||
}
|
||||
|
||||
if resp.ExtraUsage == nil {
|
||||
t.Fatal("ExtraUsage is nil")
|
||||
}
|
||||
if !resp.ExtraUsage.IsEnabled {
|
||||
t.Error("ExtraUsage.IsEnabled = false, want true")
|
||||
}
|
||||
if resp.ExtraUsage.MonthlyLimit == nil || *resp.ExtraUsage.MonthlyLimit != 100.0 {
|
||||
t.Errorf("ExtraUsage.MonthlyLimit = %v, want 100.0", resp.ExtraUsage.MonthlyLimit)
|
||||
}
|
||||
if resp.ExtraUsage.UsedCredits == nil || *resp.ExtraUsage.UsedCredits != 42.5 {
|
||||
t.Errorf("ExtraUsage.UsedCredits = %v, want 42.5", resp.ExtraUsage.UsedCredits)
|
||||
}
|
||||
if resp.ExtraUsage.Utilization == nil || *resp.ExtraUsage.Utilization != 42.5 {
|
||||
t.Errorf("ExtraUsage.Utilization = %v, want 42.5", resp.ExtraUsage.Utilization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageResponse_PartialJSON(t *testing.T) {
|
||||
raw := `{"five_hour": {"utilization": 10.0}}`
|
||||
|
||||
var resp UsageResponse
|
||||
if err := json.Unmarshal([]byte(raw), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp.FiveHour == nil {
|
||||
t.Fatal("FiveHour is nil")
|
||||
}
|
||||
if resp.FiveHour.Utilization == nil || *resp.FiveHour.Utilization != 10.0 {
|
||||
t.Errorf("FiveHour.Utilization = %v, want 10.0", resp.FiveHour.Utilization)
|
||||
}
|
||||
if resp.FiveHour.ResetsAt != nil {
|
||||
t.Errorf("FiveHour.ResetsAt should be nil, got %v", resp.FiveHour.ResetsAt)
|
||||
}
|
||||
if resp.SevenDay != nil {
|
||||
t.Errorf("SevenDay should be nil, got %v", resp.SevenDay)
|
||||
}
|
||||
if resp.SevenDaySonnet != nil {
|
||||
t.Errorf("SevenDaySonnet should be nil, got %v", resp.SevenDaySonnet)
|
||||
}
|
||||
if resp.ExtraUsage != nil {
|
||||
t.Errorf("ExtraUsage should be nil, got %v", resp.ExtraUsage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageResponse_EmptyJSON(t *testing.T) {
|
||||
var resp UsageResponse
|
||||
if err := json.Unmarshal([]byte(`{}`), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if resp.FiveHour != nil || resp.SevenDay != nil || resp.SevenDaySonnet != nil || resp.ExtraUsage != nil {
|
||||
t.Error("all fields should be nil for empty JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUsage_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request headers
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
|
||||
t.Errorf("Authorization = %q, want 'Bearer test-token'", got)
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want application/json", got)
|
||||
}
|
||||
if got := r.Header.Get("anthropic-beta"); got != "oauth-2025-04-20" {
|
||||
t.Errorf("anthropic-beta = %q, want oauth-2025-04-20", got)
|
||||
}
|
||||
if got := r.Header.Get("User-Agent"); got != "claude-cli/2.1.92" {
|
||||
t.Errorf("User-Agent = %q, want claude-cli/2.1.92", got)
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
t.Errorf("Method = %q, want GET", r.Method)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{
|
||||
"five_hour": {"utilization": 50.0, "resets_at": "2024-01-15T10:00:00Z"},
|
||||
"seven_day": {"utilization": 25.0, "resets_at": "2024-01-20T00:00:00Z"}
|
||||
}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// fetchUsage hardcodes usageURL, but we can test via the mock by temporarily
|
||||
// using http.DefaultClient's transport. Instead, we test the handler directly.
|
||||
// The httptest server validates our request expectations above.
|
||||
|
||||
// Make a real request to the test server to verify handler behavior
|
||||
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
req.Header.Set("User-Agent", "claude-cli/2.1.92")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var usage UsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usage); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
|
||||
if usage.FiveHour == nil || *usage.FiveHour.Utilization != 50.0 {
|
||||
t.Errorf("FiveHour.Utilization = %v, want 50.0", usage.FiveHour)
|
||||
}
|
||||
if usage.SevenDay == nil || *usage.SevenDay.Utilization != 25.0 {
|
||||
t.Errorf("SevenDay.Utilization = %v, want 25.0", usage.SevenDay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUsage_Non200(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte(`{"error":"forbidden"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Simulate the error path: non-200 returns error with status and body
|
||||
resp, err := http.Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
t.Fatal("expected non-200 status")
|
||||
}
|
||||
|
||||
// This matches the fetchUsage error format
|
||||
body := make([]byte, 1024)
|
||||
n, _ := resp.Body.Read(body)
|
||||
bodyStr := string(body[:n])
|
||||
if !strings.Contains(bodyStr, "forbidden") {
|
||||
t.Errorf("body = %q, want it to contain 'forbidden'", bodyStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUsage_MalformedJSON(t *testing.T) {
|
||||
raw := `{not valid json`
|
||||
var resp UsageResponse
|
||||
err := json.Unmarshal([]byte(raw), &resp)
|
||||
if err == nil {
|
||||
t.Fatal("expected decode error for malformed JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimit_NilFields(t *testing.T) {
|
||||
raw := `{}`
|
||||
var rl RateLimit
|
||||
if err := json.Unmarshal([]byte(raw), &rl); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if rl.Utilization != nil {
|
||||
t.Errorf("Utilization should be nil, got %v", rl.Utilization)
|
||||
}
|
||||
if rl.ResetsAt != nil {
|
||||
t.Errorf("ResetsAt should be nil, got %v", rl.ResetsAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraUsage_JSON(t *testing.T) {
|
||||
raw := `{"is_enabled":false,"monthly_limit":null,"used_credits":null,"utilization":null}`
|
||||
var eu ExtraUsage
|
||||
if err := json.Unmarshal([]byte(raw), &eu); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if eu.IsEnabled {
|
||||
t.Error("IsEnabled should be false")
|
||||
}
|
||||
if eu.MonthlyLimit != nil {
|
||||
t.Error("MonthlyLimit should be nil")
|
||||
}
|
||||
if eu.UsedCredits != nil {
|
||||
t.Error("UsedCredits should be nil")
|
||||
}
|
||||
if eu.Utilization != nil {
|
||||
t.Error("Utilization should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageURL_Constant(t *testing.T) {
|
||||
if usageURL != "https://api.anthropic.com/api/oauth/usage" {
|
||||
t.Errorf("usageURL = %q, want https://api.anthropic.com/api/oauth/usage", usageURL)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,529 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// --- makeKeySet ---
|
||||
|
||||
func TestMakeKeySet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keys []string
|
||||
wantN int
|
||||
lookup string
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "nil slice returns empty map",
|
||||
keys: nil,
|
||||
wantN: 0,
|
||||
},
|
||||
{
|
||||
name: "empty slice returns empty map",
|
||||
keys: []string{},
|
||||
wantN: 0,
|
||||
},
|
||||
{
|
||||
name: "single key",
|
||||
keys: []string{"key1"},
|
||||
wantN: 1,
|
||||
lookup: "key1",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "multiple keys",
|
||||
keys: []string{"a", "b", "c"},
|
||||
wantN: 3,
|
||||
lookup: "b",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "missing key not found",
|
||||
keys: []string{"a", "b"},
|
||||
wantN: 2,
|
||||
lookup: "c",
|
||||
found: false,
|
||||
},
|
||||
{
|
||||
name: "duplicate keys deduped",
|
||||
keys: []string{"x", "x", "x"},
|
||||
wantN: 1,
|
||||
lookup: "x",
|
||||
found: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := makeKeySet(tt.keys)
|
||||
if len(got) != tt.wantN {
|
||||
t.Errorf("len(makeKeySet) = %d, want %d", len(got), tt.wantN)
|
||||
}
|
||||
if tt.lookup != "" {
|
||||
_, ok := got[tt.lookup]
|
||||
if ok != tt.found {
|
||||
t.Errorf("keySet[%q] found=%v, want %v", tt.lookup, ok, tt.found)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- corsMiddleware ---
|
||||
|
||||
func TestCorsMiddleware_SetsHeaders(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
handler := corsMiddleware()
|
||||
handler(c)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "*")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT, DELETE, OPTIONS" {
|
||||
t.Errorf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT, DELETE, OPTIONS")
|
||||
}
|
||||
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||||
for _, h := range []string{"x-api-key", "anthropic-version", "anthropic-beta", "Authorization", "Content-Type", "Origin"} {
|
||||
if !containsSubstring(allowHeaders, h) {
|
||||
t.Errorf("Access-Control-Allow-Headers %q missing %q", allowHeaders, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsMiddleware_OptionsReturns204(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodOptions, "/v1/messages", nil)
|
||||
|
||||
handler := corsMiddleware()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("OPTIONS status = %d, want %d", w.Code, http.StatusNoContent)
|
||||
}
|
||||
if !c.IsAborted() {
|
||||
t.Error("expected context to be aborted on OPTIONS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsMiddleware_NonOptionsDoesNotAbort(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
handler := corsMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("POST request should not be aborted")
|
||||
}
|
||||
}
|
||||
|
||||
// --- authMiddleware ---
|
||||
|
||||
func newServerWithKeys(keys []string) *Server {
|
||||
s := &Server{}
|
||||
keySet := makeKeySet(keys)
|
||||
s.apiKeys.Store(&keySet)
|
||||
return s
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_BypassPaths(t *testing.T) {
|
||||
paths := []string{"/healthz", "/reload", "/metrics"}
|
||||
s := newServerWithKeys(nil) // no keys — would reject if auth checked
|
||||
|
||||
for _, path := range paths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, path, nil)
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Errorf("path %q should bypass auth but was aborted", path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_MissingToken_401(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"valid-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized)
|
||||
}
|
||||
if !c.IsAborted() {
|
||||
t.Error("expected aborted on missing token")
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if body["error"] != "missing authentication" {
|
||||
t.Errorf("error = %q, want %q", body["error"], "missing authentication")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_InvalidKey_403(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"valid-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("x-api-key", "wrong-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
if !c.IsAborted() {
|
||||
t.Error("expected aborted on invalid key")
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if body["error"] != "invalid api key" {
|
||||
t.Errorf("error = %q, want %q", body["error"], "invalid api key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_ValidKey_XApiKey(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"valid-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("x-api-key", "valid-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("valid key should not abort")
|
||||
}
|
||||
if w.Code == http.StatusUnauthorized || w.Code == http.StatusForbidden {
|
||||
t.Errorf("unexpected status %d for valid key", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_ValidKey_BearerAuth(t *testing.T) {
|
||||
s := newServerWithKeys([]string{"my-token"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "Bearer my-token")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("valid Bearer token should not abort")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_BearerPrefix_Stripped(t *testing.T) {
|
||||
// The token is "my-token", sent as "Bearer my-token". The middleware should
|
||||
// strip "Bearer " and compare "my-token" against the key set.
|
||||
s := newServerWithKeys([]string{"my-token"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "Bearer my-token")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("expected auth to pass with Bearer-prefixed valid key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_AuthorizationWithoutBearer(t *testing.T) {
|
||||
// If Authorization header doesn't have Bearer prefix, TrimPrefix is a no-op,
|
||||
// so the full header value is used as the token.
|
||||
s := newServerWithKeys([]string{"raw-token-value"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "raw-token-value")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("raw Authorization value matching a key should pass")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_XApiKey_FallbackWhenNoAuthHeader(t *testing.T) {
|
||||
// If Authorization is empty, x-api-key is checked.
|
||||
s := newServerWithKeys([]string{"fallback-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("x-api-key", "fallback-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("x-api-key fallback should pass")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_AuthorizationPreferredOverXApiKey(t *testing.T) {
|
||||
// Both headers set; Authorization takes precedence.
|
||||
s := newServerWithKeys([]string{"auth-key"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "Bearer auth-key")
|
||||
c.Request.Header.Set("x-api-key", "wrong-key")
|
||||
|
||||
handler := s.authMiddleware()
|
||||
handler(c)
|
||||
|
||||
if c.IsAborted() {
|
||||
t.Error("Authorization should take precedence over x-api-key")
|
||||
}
|
||||
}
|
||||
|
||||
// --- handleReload ---
|
||||
|
||||
func TestHandleReload_Success(t *testing.T) {
|
||||
// Create a temp config file
|
||||
tmpFile, err := os.CreateTemp("", "config-*.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
configContent := `
|
||||
port: 9999
|
||||
api_keys:
|
||||
- reloaded-key-1
|
||||
- reloaded-key-2
|
||||
sanitize:
|
||||
tools:
|
||||
- from: old_tool
|
||||
to: new_tool
|
||||
system:
|
||||
- match: foo
|
||||
replace: bar
|
||||
body:
|
||||
- match: baz
|
||||
replace: qux
|
||||
`
|
||||
if _, err := tmpFile.WriteString(configContent); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
s := &Server{configPath: tmpFile.Name()}
|
||||
// Initialize with empty values
|
||||
emptyKeys := makeKeySet(nil)
|
||||
s.apiKeys.Store(&emptyKeys)
|
||||
|
||||
emptySan := &atomic.Pointer[interface{}]{}
|
||||
_ = emptySan // just to show we're aware
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
|
||||
handler := s.handleReload()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp["status"] != "reloaded" {
|
||||
t.Errorf("status = %v, want %q", resp["status"], "reloaded")
|
||||
}
|
||||
|
||||
// Verify api keys were updated
|
||||
keys := s.apiKeys.Load()
|
||||
if _, ok := (*keys)["reloaded-key-1"]; !ok {
|
||||
t.Error("expected reloaded-key-1 in api keys after reload")
|
||||
}
|
||||
if _, ok := (*keys)["reloaded-key-2"]; !ok {
|
||||
t.Error("expected reloaded-key-2 in api keys after reload")
|
||||
}
|
||||
if len(*keys) != 2 {
|
||||
t.Errorf("expected 2 api keys, got %d", len(*keys))
|
||||
}
|
||||
|
||||
// Verify sanitizer was updated
|
||||
san := s.sanitizer.Load()
|
||||
if san == nil {
|
||||
t.Fatal("sanitizer is nil after reload")
|
||||
}
|
||||
|
||||
// Check tool_renames in response
|
||||
if toolRenames, ok := resp["tool_renames"].(float64); !ok || int(toolRenames) != 1 {
|
||||
t.Errorf("tool_renames = %v, want 1", resp["tool_renames"])
|
||||
}
|
||||
if apiKeys, ok := resp["api_keys"].(float64); !ok || int(apiKeys) != 2 {
|
||||
t.Errorf("api_keys = %v, want 2", resp["api_keys"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleReload_InvalidConfig(t *testing.T) {
|
||||
s := &Server{configPath: "/nonexistent/path/config.yaml"}
|
||||
emptyKeys := makeKeySet(nil)
|
||||
s.apiKeys.Store(&emptyKeys)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
|
||||
handler := s.handleReload()
|
||||
handler(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
if resp["error"] == "" {
|
||||
t.Error("expected non-empty error message")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Full route tests using httptest ---
|
||||
|
||||
func TestHealthzEndpoint(t *testing.T) {
|
||||
engine := gin.New()
|
||||
engine.Use(corsMiddleware())
|
||||
|
||||
s := newServerWithKeys(nil)
|
||||
engine.Use(s.authMiddleware())
|
||||
|
||||
engine.GET("/healthz", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if body["status"] != "ok" {
|
||||
t.Errorf("status = %q, want %q", body["status"], "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_FullRoute_Rejected(t *testing.T) {
|
||||
engine := gin.New()
|
||||
s := newServerWithKeys([]string{"correct-key"})
|
||||
engine.Use(s.authMiddleware())
|
||||
engine.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
// No auth header
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_FullRoute_Accepted(t *testing.T) {
|
||||
engine := gin.New()
|
||||
s := newServerWithKeys([]string{"correct-key"})
|
||||
engine.Use(s.authMiddleware())
|
||||
engine.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
req.Header.Set("x-api-key", "correct-key")
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsMiddleware_FullRoute_OptionsRequest(t *testing.T) {
|
||||
engine := gin.New()
|
||||
engine.Use(corsMiddleware())
|
||||
|
||||
s := newServerWithKeys([]string{"key"})
|
||||
engine.Use(s.authMiddleware())
|
||||
|
||||
engine.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusNoContent)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Errorf("ACAO = %q, want %q", got, "*")
|
||||
}
|
||||
}
|
||||
|
||||
// helper
|
||||
func containsSubstring(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
|
||||
}
|
||||
|
||||
func containsStr(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
otellog "go.opentelemetry.io/otel/log"
|
||||
sdklog "go.opentelemetry.io/otel/sdk/log"
|
||||
)
|
||||
|
||||
func TestMapSeverity(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want otellog.Severity
|
||||
}{
|
||||
{"trace", otellog.SeverityTrace},
|
||||
{"debug", otellog.SeverityDebug},
|
||||
{"info", otellog.SeverityInfo},
|
||||
{"warn", otellog.SeverityWarn},
|
||||
{"warning", otellog.SeverityWarn},
|
||||
{"error", otellog.SeverityError},
|
||||
{"fatal", otellog.SeverityFatal},
|
||||
{"panic", otellog.SeverityFatal2},
|
||||
{"unknown", otellog.SeverityInfo},
|
||||
{"", otellog.SeverityInfo},
|
||||
{"INFO", otellog.SeverityInfo}, // uppercase falls to default
|
||||
{"PANIC", otellog.SeverityInfo}, // uppercase falls to default
|
||||
{"gibberish", otellog.SeverityInfo},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run("level_"+tc.input, func(t *testing.T) {
|
||||
got := mapSeverity(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("mapSeverity(%q) = %v, want %v", tc.input, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestBridge(t *testing.T) *LogBridge {
|
||||
t.Helper()
|
||||
provider := sdklog.NewLoggerProvider()
|
||||
t.Cleanup(func() {
|
||||
_ = provider.Shutdown(t.Context())
|
||||
})
|
||||
return &LogBridge{provider: provider}
|
||||
}
|
||||
|
||||
func TestLogBridgeWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{} // will be marshaled to JSON; use string for raw input
|
||||
raw string // if non-empty, use this directly instead of marshaling input
|
||||
}{
|
||||
{
|
||||
name: "valid_json_with_message_level_and_extras",
|
||||
input: map[string]interface{}{
|
||||
"message": "request handled",
|
||||
"level": "info",
|
||||
"method": "GET",
|
||||
"status": float64(200),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message_only_no_level",
|
||||
input: map[string]interface{}{
|
||||
"message": "hello world",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "level_only_no_message",
|
||||
input: map[string]interface{}{
|
||||
"level": "error",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty_json_object",
|
||||
input: map[string]interface{}{},
|
||||
},
|
||||
{
|
||||
name: "string_float64_bool_attributes",
|
||||
input: map[string]interface{}{
|
||||
"message": "test",
|
||||
"level": "debug",
|
||||
"str_val": "hello",
|
||||
"num_val": float64(3.14),
|
||||
"bool_val": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex_nested_object_attribute",
|
||||
input: map[string]interface{}{
|
||||
"message": "nested",
|
||||
"level": "warn",
|
||||
"nested": map[string]interface{}{"foo": "bar", "n": float64(1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "time_field_skipped_in_attributes",
|
||||
input: map[string]interface{}{
|
||||
"message": "with time",
|
||||
"level": "info",
|
||||
"time": "2025-01-01T00:00:00Z",
|
||||
"extra": "kept",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "malformed_json",
|
||||
raw: "this is not json at all",
|
||||
},
|
||||
{
|
||||
name: "malformed_json_partial",
|
||||
raw: `{"broken":`,
|
||||
},
|
||||
{
|
||||
name: "array_attribute_marshaled_as_string",
|
||||
input: map[string]interface{}{
|
||||
"message": "arrays",
|
||||
"tags": []interface{}{"a", "b"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "null_value_attribute",
|
||||
input: map[string]interface{}{
|
||||
"message": "nulls",
|
||||
"val": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
bridge := newTestBridge(t)
|
||||
|
||||
var p []byte
|
||||
if tc.raw != "" {
|
||||
p = []byte(tc.raw)
|
||||
} else {
|
||||
var err error
|
||||
p, err = json.Marshal(tc.input)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal test input: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
n, err := bridge.Write(p)
|
||||
if n != len(p) {
|
||||
t.Errorf("Write() returned n=%d, want %d", n, len(p))
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Write() returned err=%v, want nil", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogBridgeWriteAlwaysReturnsLenAndNil(t *testing.T) {
|
||||
bridge := newTestBridge(t)
|
||||
|
||||
inputs := [][]byte{
|
||||
[]byte(`{"message":"ok","level":"info"}`),
|
||||
[]byte(`not json`),
|
||||
[]byte(`{}`),
|
||||
[]byte(``),
|
||||
[]byte(`[]`),
|
||||
}
|
||||
|
||||
for _, p := range inputs {
|
||||
n, err := bridge.Write(p)
|
||||
if n != len(p) {
|
||||
t.Errorf("Write(%q) n=%d, want %d", string(p), n, len(p))
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Write(%q) err=%v, want nil", string(p), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user