package server import ( "encoding/json" "net/http" "net/http/httptest" "os" "sync/atomic" "testing" "github.com/gin-gonic/gin" ) func init() { gin.SetMode(gin.TestMode) } // --- makeKeySet --- func TestMakeKeySet(t *testing.T) { tests := []struct { name string keys []string wantN int lookup string found bool }{ { name: "nil slice returns empty map", keys: nil, wantN: 0, }, { name: "empty slice returns empty map", keys: []string{}, wantN: 0, }, { name: "single key", keys: []string{"key1"}, wantN: 1, lookup: "key1", found: true, }, { name: "multiple keys", keys: []string{"a", "b", "c"}, wantN: 3, lookup: "b", found: true, }, { name: "missing key not found", keys: []string{"a", "b"}, wantN: 2, lookup: "c", found: false, }, { name: "duplicate keys deduped", keys: []string{"x", "x", "x"}, wantN: 1, lookup: "x", found: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := makeKeySet(tt.keys) if len(got) != tt.wantN { t.Errorf("len(makeKeySet) = %d, want %d", len(got), tt.wantN) } if tt.lookup != "" { _, ok := got[tt.lookup] if ok != tt.found { t.Errorf("keySet[%q] found=%v, want %v", tt.lookup, ok, tt.found) } } }) } } // --- corsMiddleware --- func TestCorsMiddleware_SetsHeaders(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodGet, "/", nil) handler := corsMiddleware() handler(c) if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" { t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "*") } if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT, DELETE, OPTIONS" { t.Errorf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT, DELETE, OPTIONS") } allowHeaders := w.Header().Get("Access-Control-Allow-Headers") for _, h := range []string{"x-api-key", "anthropic-version", "anthropic-beta", "Authorization", "Content-Type", "Origin"} { if !containsSubstring(allowHeaders, h) { t.Errorf("Access-Control-Allow-Headers %q missing %q", allowHeaders, h) } } } func TestCorsMiddleware_OptionsReturns204(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodOptions, "/v1/messages", nil) handler := corsMiddleware() handler(c) if w.Code != http.StatusNoContent { t.Errorf("OPTIONS status = %d, want %d", w.Code, http.StatusNoContent) } if !c.IsAborted() { t.Error("expected context to be aborted on OPTIONS") } } func TestCorsMiddleware_NonOptionsDoesNotAbort(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) handler := corsMiddleware() handler(c) if c.IsAborted() { t.Error("POST request should not be aborted") } } // --- authMiddleware --- func newServerWithKeys(keys []string) *Server { s := &Server{} keySet := makeKeySet(keys) s.apiKeys.Store(&keySet) return s } func TestAuthMiddleware_BypassPaths(t *testing.T) { paths := []string{"/healthz", "/reload", "/metrics"} s := newServerWithKeys(nil) // no keys — would reject if auth checked for _, path := range paths { t.Run(path, func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodGet, path, nil) handler := s.authMiddleware() handler(c) if c.IsAborted() { t.Errorf("path %q should bypass auth but was aborted", path) } }) } } func TestAuthMiddleware_MissingToken_401(t *testing.T) { s := newServerWithKeys([]string{"valid-key"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) handler := s.authMiddleware() handler(c) if w.Code != http.StatusUnauthorized { t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized) } if !c.IsAborted() { t.Error("expected aborted on missing token") } var body map[string]string if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } if body["error"] != "missing authentication" { t.Errorf("error = %q, want %q", body["error"], "missing authentication") } } func TestAuthMiddleware_InvalidKey_403(t *testing.T) { s := newServerWithKeys([]string{"valid-key"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) c.Request.Header.Set("x-api-key", "wrong-key") handler := s.authMiddleware() handler(c) if w.Code != http.StatusForbidden { t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden) } if !c.IsAborted() { t.Error("expected aborted on invalid key") } var body map[string]string if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } if body["error"] != "invalid api key" { t.Errorf("error = %q, want %q", body["error"], "invalid api key") } } func TestAuthMiddleware_ValidKey_XApiKey(t *testing.T) { s := newServerWithKeys([]string{"valid-key"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) c.Request.Header.Set("x-api-key", "valid-key") handler := s.authMiddleware() handler(c) if c.IsAborted() { t.Error("valid key should not abort") } if w.Code == http.StatusUnauthorized || w.Code == http.StatusForbidden { t.Errorf("unexpected status %d for valid key", w.Code) } } func TestAuthMiddleware_ValidKey_BearerAuth(t *testing.T) { s := newServerWithKeys([]string{"my-token"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) c.Request.Header.Set("Authorization", "Bearer my-token") handler := s.authMiddleware() handler(c) if c.IsAborted() { t.Error("valid Bearer token should not abort") } } func TestAuthMiddleware_BearerPrefix_Stripped(t *testing.T) { // The token is "my-token", sent as "Bearer my-token". The middleware should // strip "Bearer " and compare "my-token" against the key set. s := newServerWithKeys([]string{"my-token"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) c.Request.Header.Set("Authorization", "Bearer my-token") handler := s.authMiddleware() handler(c) if c.IsAborted() { t.Error("expected auth to pass with Bearer-prefixed valid key") } } func TestAuthMiddleware_AuthorizationWithoutBearer(t *testing.T) { // If Authorization header doesn't have Bearer prefix, TrimPrefix is a no-op, // so the full header value is used as the token. s := newServerWithKeys([]string{"raw-token-value"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) c.Request.Header.Set("Authorization", "raw-token-value") handler := s.authMiddleware() handler(c) if c.IsAborted() { t.Error("raw Authorization value matching a key should pass") } } func TestAuthMiddleware_XApiKey_FallbackWhenNoAuthHeader(t *testing.T) { // If Authorization is empty, x-api-key is checked. s := newServerWithKeys([]string{"fallback-key"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) c.Request.Header.Set("x-api-key", "fallback-key") handler := s.authMiddleware() handler(c) if c.IsAborted() { t.Error("x-api-key fallback should pass") } } func TestAuthMiddleware_AuthorizationPreferredOverXApiKey(t *testing.T) { // Both headers set; Authorization takes precedence. s := newServerWithKeys([]string{"auth-key"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) c.Request.Header.Set("Authorization", "Bearer auth-key") c.Request.Header.Set("x-api-key", "wrong-key") handler := s.authMiddleware() handler(c) if c.IsAborted() { t.Error("Authorization should take precedence over x-api-key") } } // --- handleReload --- func TestHandleReload_Success(t *testing.T) { // Create a temp config file tmpFile, err := os.CreateTemp("", "config-*.yaml") if err != nil { t.Fatalf("failed to create temp file: %v", err) } defer os.Remove(tmpFile.Name()) configContent := ` port: 9999 api_keys: - reloaded-key-1 - reloaded-key-2 sanitize: tools: - from: old_tool to: new_tool system: - match: foo replace: bar body: - match: baz replace: qux ` if _, err := tmpFile.WriteString(configContent); err != nil { t.Fatalf("failed to write config: %v", err) } tmpFile.Close() s := &Server{configPath: tmpFile.Name()} // Initialize with empty values emptyKeys := makeKeySet(nil) s.apiKeys.Store(&emptyKeys) emptySan := &atomic.Pointer[interface{}]{} _ = emptySan // just to show we're aware w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil) handler := s.handleReload() handler(c) if w.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String()) } var resp map[string]interface{} if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to unmarshal: %v", err) } if resp["status"] != "reloaded" { t.Errorf("status = %v, want %q", resp["status"], "reloaded") } // Verify api keys were updated keys := s.apiKeys.Load() if _, ok := (*keys)["reloaded-key-1"]; !ok { t.Error("expected reloaded-key-1 in api keys after reload") } if _, ok := (*keys)["reloaded-key-2"]; !ok { t.Error("expected reloaded-key-2 in api keys after reload") } if len(*keys) != 2 { t.Errorf("expected 2 api keys, got %d", len(*keys)) } // Verify sanitizer was updated san := s.sanitizer.Load() if san == nil { t.Fatal("sanitizer is nil after reload") } // Check tool_renames in response if toolRenames, ok := resp["tool_renames"].(float64); !ok || int(toolRenames) != 1 { t.Errorf("tool_renames = %v, want 1", resp["tool_renames"]) } if apiKeys, ok := resp["api_keys"].(float64); !ok || int(apiKeys) != 2 { t.Errorf("api_keys = %v, want 2", resp["api_keys"]) } } func TestHandleReload_InvalidConfig(t *testing.T) { s := &Server{configPath: "/nonexistent/path/config.yaml"} emptyKeys := makeKeySet(nil) s.apiKeys.Store(&emptyKeys) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodPost, "/reload", nil) handler := s.handleReload() handler(c) if w.Code != http.StatusInternalServerError { t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError) } var resp map[string]string if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to unmarshal: %v", err) } if resp["error"] == "" { t.Error("expected non-empty error message") } } // --- Full route tests using httptest --- func TestHealthzEndpoint(t *testing.T) { engine := gin.New() engine.Use(corsMiddleware()) s := newServerWithKeys(nil) engine.Use(s.authMiddleware()) engine.GET("/healthz", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) }) req := httptest.NewRequest(http.MethodGet, "/healthz", nil) w := httptest.NewRecorder() engine.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want %d", w.Code, http.StatusOK) } var body map[string]string if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { t.Fatalf("unmarshal: %v", err) } if body["status"] != "ok" { t.Errorf("status = %q, want %q", body["status"], "ok") } } func TestAuthMiddleware_FullRoute_Rejected(t *testing.T) { engine := gin.New() s := newServerWithKeys([]string{"correct-key"}) engine.Use(s.authMiddleware()) engine.POST("/v1/messages", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) // No auth header req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) w := httptest.NewRecorder() engine.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("status = %d, want %d", w.Code, http.StatusUnauthorized) } } func TestAuthMiddleware_FullRoute_Accepted(t *testing.T) { engine := gin.New() s := newServerWithKeys([]string{"correct-key"}) engine.Use(s.authMiddleware()) engine.POST("/v1/messages", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) req.Header.Set("x-api-key", "correct-key") w := httptest.NewRecorder() engine.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want %d", w.Code, http.StatusOK) } } func TestCorsMiddleware_FullRoute_OptionsRequest(t *testing.T) { engine := gin.New() engine.Use(corsMiddleware()) s := newServerWithKeys([]string{"key"}) engine.Use(s.authMiddleware()) engine.POST("/v1/messages", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodOptions, "/v1/messages", nil) w := httptest.NewRecorder() engine.ServeHTTP(w, req) if w.Code != http.StatusNoContent { t.Errorf("status = %d, want %d", w.Code, http.StatusNoContent) } if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" { t.Errorf("ACAO = %q, want %q", got, "*") } } // helper func containsSubstring(s, sub string) bool { return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub)) } func containsStr(s, sub string) bool { for i := 0; i <= len(s)-len(sub); i++ { if s[i:i+len(sub)] == sub { return true } } return false }