9150f466e5
Characterization tests capturing current behavior before refactoring. Covers auth, config, logging, proxy, ratelimit, server, and telemetry packages with race-safe concurrent access tests.
530 lines
14 KiB
Go
530 lines
14 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func init() {
|
|
gin.SetMode(gin.TestMode)
|
|
}
|
|
|
|
// --- makeKeySet ---
|
|
|
|
func TestMakeKeySet(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
keys []string
|
|
wantN int
|
|
lookup string
|
|
found bool
|
|
}{
|
|
{
|
|
name: "nil slice returns empty map",
|
|
keys: nil,
|
|
wantN: 0,
|
|
},
|
|
{
|
|
name: "empty slice returns empty map",
|
|
keys: []string{},
|
|
wantN: 0,
|
|
},
|
|
{
|
|
name: "single key",
|
|
keys: []string{"key1"},
|
|
wantN: 1,
|
|
lookup: "key1",
|
|
found: true,
|
|
},
|
|
{
|
|
name: "multiple keys",
|
|
keys: []string{"a", "b", "c"},
|
|
wantN: 3,
|
|
lookup: "b",
|
|
found: true,
|
|
},
|
|
{
|
|
name: "missing key not found",
|
|
keys: []string{"a", "b"},
|
|
wantN: 2,
|
|
lookup: "c",
|
|
found: false,
|
|
},
|
|
{
|
|
name: "duplicate keys deduped",
|
|
keys: []string{"x", "x", "x"},
|
|
wantN: 1,
|
|
lookup: "x",
|
|
found: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := makeKeySet(tt.keys)
|
|
if len(got) != tt.wantN {
|
|
t.Errorf("len(makeKeySet) = %d, want %d", len(got), tt.wantN)
|
|
}
|
|
if tt.lookup != "" {
|
|
_, ok := got[tt.lookup]
|
|
if ok != tt.found {
|
|
t.Errorf("keySet[%q] found=%v, want %v", tt.lookup, ok, tt.found)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// --- corsMiddleware ---
|
|
|
|
func TestCorsMiddleware_SetsHeaders(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
|
|
handler := corsMiddleware()
|
|
handler(c)
|
|
|
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
|
t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "*")
|
|
}
|
|
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT, DELETE, OPTIONS" {
|
|
t.Errorf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT, DELETE, OPTIONS")
|
|
}
|
|
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
|
for _, h := range []string{"x-api-key", "anthropic-version", "anthropic-beta", "Authorization", "Content-Type", "Origin"} {
|
|
if !containsSubstring(allowHeaders, h) {
|
|
t.Errorf("Access-Control-Allow-Headers %q missing %q", allowHeaders, h)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCorsMiddleware_OptionsReturns204(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodOptions, "/v1/messages", nil)
|
|
|
|
handler := corsMiddleware()
|
|
handler(c)
|
|
|
|
if w.Code != http.StatusNoContent {
|
|
t.Errorf("OPTIONS status = %d, want %d", w.Code, http.StatusNoContent)
|
|
}
|
|
if !c.IsAborted() {
|
|
t.Error("expected context to be aborted on OPTIONS")
|
|
}
|
|
}
|
|
|
|
func TestCorsMiddleware_NonOptionsDoesNotAbort(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
|
|
|
handler := corsMiddleware()
|
|
handler(c)
|
|
|
|
if c.IsAborted() {
|
|
t.Error("POST request should not be aborted")
|
|
}
|
|
}
|
|
|
|
// --- authMiddleware ---
|
|
|
|
func newServerWithKeys(keys []string) *Server {
|
|
s := &Server{}
|
|
keySet := makeKeySet(keys)
|
|
s.apiKeys.Store(&keySet)
|
|
return s
|
|
}
|
|
|
|
func TestAuthMiddleware_BypassPaths(t *testing.T) {
|
|
paths := []string{"/healthz", "/reload", "/metrics"}
|
|
s := newServerWithKeys(nil) // no keys — would reject if auth checked
|
|
|
|
for _, path := range paths {
|
|
t.Run(path, func(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, path, nil)
|
|
|
|
handler := s.authMiddleware()
|
|
handler(c)
|
|
|
|
if c.IsAborted() {
|
|
t.Errorf("path %q should bypass auth but was aborted", path)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_MissingToken_401(t *testing.T) {
|
|
s := newServerWithKeys([]string{"valid-key"})
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
|
|
|
handler := s.authMiddleware()
|
|
handler(c)
|
|
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized)
|
|
}
|
|
if !c.IsAborted() {
|
|
t.Error("expected aborted on missing token")
|
|
}
|
|
|
|
var body map[string]string
|
|
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
if body["error"] != "missing authentication" {
|
|
t.Errorf("error = %q, want %q", body["error"], "missing authentication")
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_InvalidKey_403(t *testing.T) {
|
|
s := newServerWithKeys([]string{"valid-key"})
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
|
c.Request.Header.Set("x-api-key", "wrong-key")
|
|
|
|
handler := s.authMiddleware()
|
|
handler(c)
|
|
|
|
if w.Code != http.StatusForbidden {
|
|
t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden)
|
|
}
|
|
if !c.IsAborted() {
|
|
t.Error("expected aborted on invalid key")
|
|
}
|
|
|
|
var body map[string]string
|
|
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
if body["error"] != "invalid api key" {
|
|
t.Errorf("error = %q, want %q", body["error"], "invalid api key")
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_ValidKey_XApiKey(t *testing.T) {
|
|
s := newServerWithKeys([]string{"valid-key"})
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
|
c.Request.Header.Set("x-api-key", "valid-key")
|
|
|
|
handler := s.authMiddleware()
|
|
handler(c)
|
|
|
|
if c.IsAborted() {
|
|
t.Error("valid key should not abort")
|
|
}
|
|
if w.Code == http.StatusUnauthorized || w.Code == http.StatusForbidden {
|
|
t.Errorf("unexpected status %d for valid key", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_ValidKey_BearerAuth(t *testing.T) {
|
|
s := newServerWithKeys([]string{"my-token"})
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
|
c.Request.Header.Set("Authorization", "Bearer my-token")
|
|
|
|
handler := s.authMiddleware()
|
|
handler(c)
|
|
|
|
if c.IsAborted() {
|
|
t.Error("valid Bearer token should not abort")
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_BearerPrefix_Stripped(t *testing.T) {
|
|
// The token is "my-token", sent as "Bearer my-token". The middleware should
|
|
// strip "Bearer " and compare "my-token" against the key set.
|
|
s := newServerWithKeys([]string{"my-token"})
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
|
c.Request.Header.Set("Authorization", "Bearer my-token")
|
|
|
|
handler := s.authMiddleware()
|
|
handler(c)
|
|
|
|
if c.IsAborted() {
|
|
t.Error("expected auth to pass with Bearer-prefixed valid key")
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_AuthorizationWithoutBearer(t *testing.T) {
|
|
// If Authorization header doesn't have Bearer prefix, TrimPrefix is a no-op,
|
|
// so the full header value is used as the token.
|
|
s := newServerWithKeys([]string{"raw-token-value"})
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
|
c.Request.Header.Set("Authorization", "raw-token-value")
|
|
|
|
handler := s.authMiddleware()
|
|
handler(c)
|
|
|
|
if c.IsAborted() {
|
|
t.Error("raw Authorization value matching a key should pass")
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_XApiKey_FallbackWhenNoAuthHeader(t *testing.T) {
|
|
// If Authorization is empty, x-api-key is checked.
|
|
s := newServerWithKeys([]string{"fallback-key"})
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
|
c.Request.Header.Set("x-api-key", "fallback-key")
|
|
|
|
handler := s.authMiddleware()
|
|
handler(c)
|
|
|
|
if c.IsAborted() {
|
|
t.Error("x-api-key fallback should pass")
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_AuthorizationPreferredOverXApiKey(t *testing.T) {
|
|
// Both headers set; Authorization takes precedence.
|
|
s := newServerWithKeys([]string{"auth-key"})
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
|
c.Request.Header.Set("Authorization", "Bearer auth-key")
|
|
c.Request.Header.Set("x-api-key", "wrong-key")
|
|
|
|
handler := s.authMiddleware()
|
|
handler(c)
|
|
|
|
if c.IsAborted() {
|
|
t.Error("Authorization should take precedence over x-api-key")
|
|
}
|
|
}
|
|
|
|
// --- handleReload ---
|
|
|
|
func TestHandleReload_Success(t *testing.T) {
|
|
// Create a temp config file
|
|
tmpFile, err := os.CreateTemp("", "config-*.yaml")
|
|
if err != nil {
|
|
t.Fatalf("failed to create temp file: %v", err)
|
|
}
|
|
defer os.Remove(tmpFile.Name())
|
|
|
|
configContent := `
|
|
port: 9999
|
|
api_keys:
|
|
- reloaded-key-1
|
|
- reloaded-key-2
|
|
sanitize:
|
|
tools:
|
|
- from: old_tool
|
|
to: new_tool
|
|
system:
|
|
- match: foo
|
|
replace: bar
|
|
body:
|
|
- match: baz
|
|
replace: qux
|
|
`
|
|
if _, err := tmpFile.WriteString(configContent); err != nil {
|
|
t.Fatalf("failed to write config: %v", err)
|
|
}
|
|
tmpFile.Close()
|
|
|
|
s := &Server{configPath: tmpFile.Name()}
|
|
// Initialize with empty values
|
|
emptyKeys := makeKeySet(nil)
|
|
s.apiKeys.Store(&emptyKeys)
|
|
|
|
emptySan := &atomic.Pointer[interface{}]{}
|
|
_ = emptySan // just to show we're aware
|
|
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil)
|
|
|
|
handler := s.handleReload()
|
|
handler(c)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String())
|
|
}
|
|
|
|
var resp map[string]interface{}
|
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal: %v", err)
|
|
}
|
|
|
|
if resp["status"] != "reloaded" {
|
|
t.Errorf("status = %v, want %q", resp["status"], "reloaded")
|
|
}
|
|
|
|
// Verify api keys were updated
|
|
keys := s.apiKeys.Load()
|
|
if _, ok := (*keys)["reloaded-key-1"]; !ok {
|
|
t.Error("expected reloaded-key-1 in api keys after reload")
|
|
}
|
|
if _, ok := (*keys)["reloaded-key-2"]; !ok {
|
|
t.Error("expected reloaded-key-2 in api keys after reload")
|
|
}
|
|
if len(*keys) != 2 {
|
|
t.Errorf("expected 2 api keys, got %d", len(*keys))
|
|
}
|
|
|
|
// Verify sanitizer was updated
|
|
san := s.sanitizer.Load()
|
|
if san == nil {
|
|
t.Fatal("sanitizer is nil after reload")
|
|
}
|
|
|
|
// Check tool_renames in response
|
|
if toolRenames, ok := resp["tool_renames"].(float64); !ok || int(toolRenames) != 1 {
|
|
t.Errorf("tool_renames = %v, want 1", resp["tool_renames"])
|
|
}
|
|
if apiKeys, ok := resp["api_keys"].(float64); !ok || int(apiKeys) != 2 {
|
|
t.Errorf("api_keys = %v, want 2", resp["api_keys"])
|
|
}
|
|
}
|
|
|
|
func TestHandleReload_InvalidConfig(t *testing.T) {
|
|
s := &Server{configPath: "/nonexistent/path/config.yaml"}
|
|
emptyKeys := makeKeySet(nil)
|
|
s.apiKeys.Store(&emptyKeys)
|
|
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil)
|
|
|
|
handler := s.handleReload()
|
|
handler(c)
|
|
|
|
if w.Code != http.StatusInternalServerError {
|
|
t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError)
|
|
}
|
|
|
|
var resp map[string]string
|
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal: %v", err)
|
|
}
|
|
if resp["error"] == "" {
|
|
t.Error("expected non-empty error message")
|
|
}
|
|
}
|
|
|
|
// --- Full route tests using httptest ---
|
|
|
|
func TestHealthzEndpoint(t *testing.T) {
|
|
engine := gin.New()
|
|
engine.Use(corsMiddleware())
|
|
|
|
s := newServerWithKeys(nil)
|
|
engine.Use(s.authMiddleware())
|
|
|
|
engine.GET("/healthz", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
|
w := httptest.NewRecorder()
|
|
engine.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
|
}
|
|
|
|
var body map[string]string
|
|
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
|
t.Fatalf("unmarshal: %v", err)
|
|
}
|
|
if body["status"] != "ok" {
|
|
t.Errorf("status = %q, want %q", body["status"], "ok")
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_FullRoute_Rejected(t *testing.T) {
|
|
engine := gin.New()
|
|
s := newServerWithKeys([]string{"correct-key"})
|
|
engine.Use(s.authMiddleware())
|
|
engine.POST("/v1/messages", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
|
})
|
|
|
|
// No auth header
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
|
w := httptest.NewRecorder()
|
|
engine.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_FullRoute_Accepted(t *testing.T) {
|
|
engine := gin.New()
|
|
s := newServerWithKeys([]string{"correct-key"})
|
|
engine.Use(s.authMiddleware())
|
|
engine.POST("/v1/messages", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
|
req.Header.Set("x-api-key", "correct-key")
|
|
w := httptest.NewRecorder()
|
|
engine.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestCorsMiddleware_FullRoute_OptionsRequest(t *testing.T) {
|
|
engine := gin.New()
|
|
engine.Use(corsMiddleware())
|
|
|
|
s := newServerWithKeys([]string{"key"})
|
|
engine.Use(s.authMiddleware())
|
|
|
|
engine.POST("/v1/messages", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "/v1/messages", nil)
|
|
w := httptest.NewRecorder()
|
|
engine.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusNoContent {
|
|
t.Errorf("status = %d, want %d", w.Code, http.StatusNoContent)
|
|
}
|
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
|
t.Errorf("ACAO = %q, want %q", got, "*")
|
|
}
|
|
}
|
|
|
|
// helper
|
|
func containsSubstring(s, sub string) bool {
|
|
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
|
|
}
|
|
|
|
func containsStr(s, sub string) bool {
|
|
for i := 0; i <= len(s)-len(sub); i++ {
|
|
if s[i:i+len(sub)] == sub {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|