package proxy import ( "bytes" "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "github.com/fujin/anthropic-proxy/internal/auth" "github.com/fujin/anthropic-proxy/internal/config" "github.com/fujin/anthropic-proxy/internal/ratelimit" "github.com/fujin/anthropic-proxy/internal/telemetry" "go.opentelemetry.io/otel/metric/noop" ) func init() { gin.SetMode(gin.TestMode) // Initialize telemetry with noop meter to avoid nil pointer panics. meter := noop.Meter{} telemetry.InitMetrics(meter, nil) } // --- Request body reading and sanitization --- func TestHandleMessages_ReadBodyError(t *testing.T) { // A body that immediately fails on read shouldn't panic. pool := auth.NewPool([]*auth.Credential{{ID: "c1", AccessToken: "tok", Email: "test@test.com"}}) san := NewSanitizer(config.SanitizeConfig{}) handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", &errReader{}) handler(c) if w.Code != http.StatusBadRequest { t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest) } if !strings.Contains(w.Body.String(), "failed to read request body") { t.Errorf("body = %q, expected error message about reading body", w.Body.String()) } } func TestHandleMessages_SanitizesRequestBody(t *testing.T) { // We can't directly make HandleMessages use our mock server because // UpstreamClient hardcodes messagesURL. Instead, we test sanitization // by verifying the sanitizer is called on the body before any pool interaction. san := NewSanitizer(config.SanitizeConfig{ Tools: []config.RenameRule{{From: "my_tool", To: "renamed_tool"}}, Body: []config.ReplaceRule{{Match: "secret", Replace: "redacted"}}, }) // Create body with tool name and secret body := `{"model":"claude-sonnet-4-6","tools":[{"name":"my_tool"}],"messages":[{"role":"user","content":"secret data"}]}` sanitizedBody := san.SanitizeRequest([]byte(body)) // Verify sanitization happened correctly if !strings.Contains(string(sanitizedBody), "renamed_tool") { t.Error("expected tool to be renamed in sanitized body") } if strings.Contains(string(sanitizedBody), "my_tool") { t.Error("original tool name should be gone after sanitization") } if !strings.Contains(string(sanitizedBody), "redacted") { t.Error("expected 'secret' to be replaced with 'redacted'") } if strings.Contains(string(sanitizedBody), "secret") { t.Error("'secret' should be gone after sanitization") } } func TestHandleMessages_PoolPickError(t *testing.T) { // Empty pool — Pick() will fail. pool := auth.NewPool(nil) san := NewSanitizer(config.SanitizeConfig{}) handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil) body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}` w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) handler(c) if w.Code != http.StatusServiceUnavailable { t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable) } if !strings.Contains(w.Body.String(), "no credentials available") { t.Errorf("body = %q, expected pool error", w.Body.String()) } } func TestHandleMessages_PoolAllOnCooldown(t *testing.T) { cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "e"} pool := auth.NewPool([]*auth.Credential{cred}) pool.MarkFailure(cred, 429) // puts on 30s cooldown san := NewSanitizer(config.SanitizeConfig{}) handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil) body := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}` w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) handler(c) if w.Code != http.StatusServiceUnavailable { t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable) } if !strings.Contains(w.Body.String(), "cooldown") { t.Errorf("body = %q, expected cooldown message", w.Body.String()) } } // --- Stream vs non-stream routing --- func TestHandleMessages_StreamField_Detection(t *testing.T) { tests := []struct { name string body string isStream bool }{ { name: "stream true", body: `{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`, isStream: true, }, { name: "stream false", body: `{"model":"claude-sonnet-4-6","stream":false,"messages":[{"role":"user","content":"hi"}]}`, isStream: false, }, { name: "no stream field defaults to false", body: `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`, isStream: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := gjson.Get(tt.body, "stream").Bool() if got != tt.isStream { t.Errorf("stream = %v, want %v", got, tt.isStream) } }) } } // --- Desanitization on response --- func TestDesanitization_NonStreamResponse(t *testing.T) { san := NewSanitizer(config.SanitizeConfig{ Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}}, }) // Simulate upstream response with renamed tool upstreamResponse := `{"content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}]}` desanitized := san.DesanitizeResponse([]byte(upstreamResponse)) if !strings.Contains(string(desanitized), "original_tool") { t.Errorf("expected tool name to be desanitized back to 'original_tool', got %s", string(desanitized)) } if strings.Contains(string(desanitized), `"name":"renamed_tool"`) { t.Error("renamed_tool should have been replaced by original_tool") } } func TestDesanitization_StreamEvent(t *testing.T) { san := NewSanitizer(config.SanitizeConfig{ Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}}, }) event := `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"renamed_tool","id":"t1"}}` desanitized := san.DesanitizeStreamEvent(event) if !strings.Contains(desanitized, "original_tool") { t.Errorf("expected stream event to be desanitized, got %s", desanitized) } } // --- handleNonStream behavior tests via direct function --- func TestHandleNonStream_ConnectionError(t *testing.T) { cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} pool := auth.NewPool([]*auth.Credential{cred}) san := NewSanitizer(config.SanitizeConfig{}) tracker := ratelimit.NewTracker(func() string { return "" }) uc := &UpstreamClient{ client: http.Client{Transport: &failingTransport{}}, sessionID: "test-sess", profile: nil, } body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) handleNonStream(c, uc, san, pool, cred, body, body, tracker) if w.Code != http.StatusBadGateway { t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway) } if !strings.Contains(w.Body.String(), "upstream request failed") { t.Errorf("body = %q, expected upstream error message", w.Body.String()) } } func TestHandleNonStream_UpstreamSuccess(t *testing.T) { mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Request-Id", "req-123") w.WriteHeader(200) w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hello"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) })) defer mockUpstream.Close() cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} pool := auth.NewPool([]*auth.Credential{cred}) san := NewSanitizer(config.SanitizeConfig{}) tracker := ratelimit.NewTracker(func() string { return "" }) uc := &UpstreamClient{ client: *mockUpstream.Client(), sessionID: "test-sess", profile: nil, } // Override the messagesURL by constructing a custom Execute that uses the mock. // Since we can't override the const, we test via a mock server approach: // We create a custom http.Client with a transport that redirects to our mock. uc.client.Transport = &rewriteTransport{ base: mockUpstream.Client().Transport, destURL: mockUpstream.URL, } body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) handleNonStream(c, uc, san, pool, cred, body, body, tracker) if w.Code != http.StatusOK { t.Errorf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String()) } if !strings.Contains(w.Body.String(), "hello") { t.Errorf("response body missing expected content: %s", w.Body.String()) } if got := w.Header().Get("Content-Type"); got != "application/json" { t.Errorf("Content-Type = %q, want %q", got, "application/json") } if got := w.Header().Get("X-Request-Id"); got != "req-123" { t.Errorf("X-Request-Id = %q, want %q", got, "req-123") } } func TestHandleNonStream_UpstreamError_MarkFailure(t *testing.T) { mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"too many requests"}}`)) })) defer mockUpstream.Close() cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} pool := auth.NewPool([]*auth.Credential{cred}) san := NewSanitizer(config.SanitizeConfig{}) uc := &UpstreamClient{ client: *mockUpstream.Client(), sessionID: "test-sess", profile: nil, } uc.client.Transport = &rewriteTransport{ base: mockUpstream.Client().Transport, destURL: mockUpstream.URL, } body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) handleNonStream(c, uc, san, pool, cred, body, body, nil) if w.Code != 429 { t.Errorf("status = %d, want 429", w.Code) } // Verify MarkFailure was called — cred should now be on cooldown if !cred.IsOnCooldown() { t.Error("expected credential to be on cooldown after 429") } } func TestHandleNonStream_UpstreamSuccess_DesanitizesResponse(t *testing.T) { mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"tool_use","name":"renamed_tool","id":"t1","input":{}}],"model":"claude-sonnet-4-6","stop_reason":"tool_use","usage":{"input_tokens":10,"output_tokens":5}}`)) })) defer mockUpstream.Close() cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} pool := auth.NewPool([]*auth.Credential{cred}) san := NewSanitizer(config.SanitizeConfig{ Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}}, }) uc := &UpstreamClient{ client: *mockUpstream.Client(), sessionID: "test-sess", profile: nil, } uc.client.Transport = &rewriteTransport{ base: mockUpstream.Client().Transport, destURL: mockUpstream.URL, } body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) handleNonStream(c, uc, san, pool, cred, body, body, nil) if w.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String()) } // Should be desanitized back to original_tool if !strings.Contains(w.Body.String(), "original_tool") { t.Errorf("response should contain desanitized tool name 'original_tool', got %s", w.Body.String()) } } func TestHandleNonStream_Upstream500_MarkFailure(t *testing.T) { mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(500) w.Write([]byte(`{"error":{"type":"server_error","message":"internal error"}}`)) })) defer mockUpstream.Close() cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} pool := auth.NewPool([]*auth.Credential{cred}) san := NewSanitizer(config.SanitizeConfig{}) uc := &UpstreamClient{ client: *mockUpstream.Client(), sessionID: "test-sess", profile: nil, } uc.client.Transport = &rewriteTransport{ base: mockUpstream.Client().Transport, destURL: mockUpstream.URL, } body := []byte(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}]}`) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) handleNonStream(c, uc, san, pool, cred, body, body, nil) if w.Code != 500 { t.Errorf("status = %d, want 500", w.Code) } if !cred.IsOnCooldown() { t.Error("expected credential to be on cooldown after 500") } } // --- handleStream behavior tests --- func TestHandleStream_ConnectionError(t *testing.T) { cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} pool := auth.NewPool([]*auth.Credential{cred}) san := NewSanitizer(config.SanitizeConfig{}) uc := &UpstreamClient{ client: http.Client{Transport: &failingTransport{}}, sessionID: "test-sess", profile: nil, } body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) handleStream(c, uc, san, pool, cred, body, body, nil) if w.Code != http.StatusBadGateway { t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway) } if !strings.Contains(w.Body.String(), "upstream stream request failed") { t.Errorf("body = %q, expected upstream stream error", w.Body.String()) } } func TestHandleStream_UpstreamError(t *testing.T) { mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) w.Write([]byte(`{"error":{"type":"rate_limit_error","message":"rate limited"}}`)) })) defer mockUpstream.Close() cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} pool := auth.NewPool([]*auth.Credential{cred}) san := NewSanitizer(config.SanitizeConfig{}) uc := &UpstreamClient{ client: *mockUpstream.Client(), sessionID: "test-sess", profile: nil, } uc.client.Transport = &rewriteTransport{ base: mockUpstream.Client().Transport, destURL: mockUpstream.URL, } body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) handleStream(c, uc, san, pool, cred, body, body, nil) if w.Code != 429 { t.Errorf("status = %d, want 429", w.Code) } if !cred.IsOnCooldown() { t.Error("expected credential on cooldown after stream 429") } } func TestHandleStream_SuccessForwardsEvents(t *testing.T) { events := "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":0}}}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n" mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(200) w.Write([]byte(events)) })) defer mockUpstream.Close() cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} pool := auth.NewPool([]*auth.Credential{cred}) san := NewSanitizer(config.SanitizeConfig{}) uc := &UpstreamClient{ client: *mockUpstream.Client(), sessionID: "test-sess", profile: nil, } uc.client.Transport = &rewriteTransport{ base: mockUpstream.Client().Transport, destURL: mockUpstream.URL, } body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) handleStream(c, uc, san, pool, cred, body, body, nil) if w.Code != http.StatusOK { t.Errorf("status = %d, want %d", w.Code, http.StatusOK) } respBody := w.Body.String() if !strings.Contains(respBody, "message_start") { t.Error("response missing message_start event") } if !strings.Contains(respBody, "hello") { t.Error("response missing text content 'hello'") } if !strings.Contains(respBody, "message_stop") { t.Error("response missing message_stop event") } // Verify SSE headers if got := w.Header().Get("Content-Type"); got != "text/event-stream" { t.Errorf("Content-Type = %q, want %q", got, "text/event-stream") } } func TestHandleStream_DesanitizesEvents(t *testing.T) { events := "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"name\":\"renamed_tool\",\"id\":\"t1\"}}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":5}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n" mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(200) w.Write([]byte(events)) })) defer mockUpstream.Close() cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} pool := auth.NewPool([]*auth.Credential{cred}) san := NewSanitizer(config.SanitizeConfig{ Tools: []config.RenameRule{{From: "original_tool", To: "renamed_tool"}}, }) uc := &UpstreamClient{ client: *mockUpstream.Client(), sessionID: "test-sess", profile: nil, } uc.client.Transport = &rewriteTransport{ base: mockUpstream.Client().Transport, destURL: mockUpstream.URL, } body := []byte(`{"model":"claude-sonnet-4-6","stream":true,"messages":[{"role":"user","content":"hi"}]}`) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) handleStream(c, uc, san, pool, cred, body, body, nil) if w.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body = %s", w.Code, http.StatusOK, w.Body.String()) } if !strings.Contains(w.Body.String(), "original_tool") { t.Errorf("stream response should contain desanitized 'original_tool', got %s", w.Body.String()) } } // --- HandleMessages full integration wiring test --- func TestHandleMessages_WiresHandlerCorrectly(t *testing.T) { cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} pool := auth.NewPool([]*auth.Credential{cred}) san := NewSanitizer(config.SanitizeConfig{}) // Verify the handler can be created without panic handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil) if handler == nil { t.Fatal("HandleMessages returned nil handler") } } func TestHandleMessages_EmptyBody(t *testing.T) { cred := &auth.Credential{ID: "c1", AccessToken: "tok", Email: "test@test.com"} pool := auth.NewPool([]*auth.Credential{cred}) san := NewSanitizer(config.SanitizeConfig{}) handler := HandleMessages(pool, nil, func() *Sanitizer { return san }, nil) // Empty body — handler should still try to pick cred and call upstream // (which will fail with connection error to api.anthropic.com, not a panic) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader("")) handler(c) // Should get a 502 because the upstream URL (api.anthropic.com) won't be reachable // in test environment, or it might complete. The key thing is no panic. // We mainly verify it doesn't panic. if w.Code == 0 { t.Error("expected non-zero status code") } } // --- Test helpers --- // errReader is an io.Reader that always returns an error. type errReader struct{} func (e *errReader) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF } // failingTransport is an http.RoundTripper that always returns an error. type failingTransport struct{} func (f *failingTransport) RoundTrip(*http.Request) (*http.Response, error) { return nil, fmt.Errorf("connection refused") } // rewriteTransport intercepts HTTP requests and rewrites the URL to point // at a local test server. This allows testing with UpstreamClient's hardcoded // messagesURL by redirecting all requests to a mock server. type rewriteTransport struct { base http.RoundTripper destURL string } func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Rewrite the request URL to point at our mock server newReq := req.Clone(req.Context()) newReq.URL.Scheme = "http" newReq.URL.Host = strings.TrimPrefix(t.destURL, "http://") newReq.URL.Path = "/v1/messages" newReq.URL.RawQuery = "" if t.base == nil { return http.DefaultTransport.RoundTrip(newReq) } return t.base.RoundTrip(newReq) }