0df28e9dd8
- Unify duplicate uTLS transports into shared internal/transport package - Extract shared version constant into internal/version - Move LoadDefaultCredentials from config to auth (remove config→auth import) - Deduplicate handler.go: extract telemetry/error helpers (324→268 lines) - Break up main.go::run() into initCredential/initEmbedded - Eliminate logging.Config duplication (use config.LoggingConfig directly) - Extract logWriter to embedded/log.go, SSE fixtures to consts in sniff.go - Use uTLS client for usage polling (consistent TLS fingerprint) - Handle sjson.SetBytes errors in sanitize.go instead of silently swallowing - Document reverse-engineered magic values in billing.go - Unexport Credential.CooldownUntil (internal state) - Replace hardcoded auth bypass paths with map in server.go
271 lines
6.9 KiB
Go
271 lines
6.9 KiB
Go
package config
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestLoad_AllFields(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "config.yaml")
|
|
|
|
yaml := `
|
|
port: 9090
|
|
api_keys:
|
|
- key1
|
|
- key2
|
|
claude_binary: /usr/bin/claude
|
|
sanitize:
|
|
tools:
|
|
- from: tool_a
|
|
to: tool_b
|
|
system:
|
|
- match: foo
|
|
replace: bar
|
|
body:
|
|
- match: baz
|
|
replace: qux
|
|
logging:
|
|
level: debug
|
|
file: /tmp/test.log
|
|
max_size_mb: 50
|
|
max_backups: 3
|
|
max_age_days: 7
|
|
compress: true
|
|
telemetry:
|
|
service_name: my-proxy
|
|
export:
|
|
endpoint: http://localhost:4317
|
|
insecure: true
|
|
headers:
|
|
x-token: abc
|
|
embedded:
|
|
enabled: true
|
|
port: 9999
|
|
perses_binary: /usr/bin/perses
|
|
vm_binary: /usr/bin/vm
|
|
vm_port: 9428
|
|
bin_dir: /opt/bin
|
|
`
|
|
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
cfg, err := Load(path)
|
|
if err != nil {
|
|
t.Fatalf("Load returned error: %v", err)
|
|
}
|
|
|
|
if cfg.Port != 9090 {
|
|
t.Errorf("Port = %d, want 9090", cfg.Port)
|
|
}
|
|
if len(cfg.APIKeys) != 2 || cfg.APIKeys[0] != "key1" || cfg.APIKeys[1] != "key2" {
|
|
t.Errorf("APIKeys = %v, want [key1 key2]", cfg.APIKeys)
|
|
}
|
|
if cfg.ClaudeBinary != "/usr/bin/claude" {
|
|
t.Errorf("ClaudeBinary = %q, want /usr/bin/claude", cfg.ClaudeBinary)
|
|
}
|
|
|
|
// Sanitize
|
|
if len(cfg.Sanitize.Tools) != 1 || cfg.Sanitize.Tools[0].From != "tool_a" || cfg.Sanitize.Tools[0].To != "tool_b" {
|
|
t.Errorf("Sanitize.Tools = %v", cfg.Sanitize.Tools)
|
|
}
|
|
if len(cfg.Sanitize.System) != 1 || cfg.Sanitize.System[0].Match != "foo" {
|
|
t.Errorf("Sanitize.System = %v", cfg.Sanitize.System)
|
|
}
|
|
if len(cfg.Sanitize.Body) != 1 || cfg.Sanitize.Body[0].Match != "baz" {
|
|
t.Errorf("Sanitize.Body = %v", cfg.Sanitize.Body)
|
|
}
|
|
|
|
// Logging
|
|
if cfg.Logging.Level != "debug" {
|
|
t.Errorf("Logging.Level = %q, want debug", cfg.Logging.Level)
|
|
}
|
|
if cfg.Logging.File != "/tmp/test.log" {
|
|
t.Errorf("Logging.File = %q", cfg.Logging.File)
|
|
}
|
|
if cfg.Logging.MaxSizeMB != 50 {
|
|
t.Errorf("Logging.MaxSizeMB = %d, want 50", cfg.Logging.MaxSizeMB)
|
|
}
|
|
if cfg.Logging.MaxBackups != 3 {
|
|
t.Errorf("Logging.MaxBackups = %d, want 3", cfg.Logging.MaxBackups)
|
|
}
|
|
if cfg.Logging.MaxAgeDays != 7 {
|
|
t.Errorf("Logging.MaxAgeDays = %d, want 7", cfg.Logging.MaxAgeDays)
|
|
}
|
|
if !cfg.Logging.Compress {
|
|
t.Error("Logging.Compress = false, want true")
|
|
}
|
|
|
|
// Telemetry
|
|
if cfg.Telemetry.ServiceName != "my-proxy" {
|
|
t.Errorf("Telemetry.ServiceName = %q, want my-proxy", cfg.Telemetry.ServiceName)
|
|
}
|
|
if cfg.Telemetry.Export.Endpoint != "http://localhost:4317" {
|
|
t.Errorf("Export.Endpoint = %q", cfg.Telemetry.Export.Endpoint)
|
|
}
|
|
if !cfg.Telemetry.Export.Insecure {
|
|
t.Error("Export.Insecure = false, want true")
|
|
}
|
|
if !cfg.Telemetry.Export.Enabled() {
|
|
t.Error("Export.Enabled() = false, want true")
|
|
}
|
|
if cfg.Telemetry.Export.Headers["x-token"] != "abc" {
|
|
t.Errorf("Export.Headers = %v", cfg.Telemetry.Export.Headers)
|
|
}
|
|
|
|
// Embedded
|
|
if !cfg.Telemetry.Embedded.Enabled {
|
|
t.Error("Embedded.Enabled = false, want true")
|
|
}
|
|
if cfg.Telemetry.Embedded.Port != 9999 {
|
|
t.Errorf("Embedded.Port = %d, want 9999", cfg.Telemetry.Embedded.Port)
|
|
}
|
|
if cfg.Telemetry.Embedded.PersesBinary != "/usr/bin/perses" {
|
|
t.Errorf("Embedded.PersesBinary = %q", cfg.Telemetry.Embedded.PersesBinary)
|
|
}
|
|
if cfg.Telemetry.Embedded.VMBinary != "/usr/bin/vm" {
|
|
t.Errorf("Embedded.VMBinary = %q", cfg.Telemetry.Embedded.VMBinary)
|
|
}
|
|
if cfg.Telemetry.Embedded.VMPort != 9428 {
|
|
t.Errorf("Embedded.VMPort = %d, want 9428", cfg.Telemetry.Embedded.VMPort)
|
|
}
|
|
if cfg.Telemetry.Embedded.BinDir != "/opt/bin" {
|
|
t.Errorf("Embedded.BinDir = %q", cfg.Telemetry.Embedded.BinDir)
|
|
}
|
|
}
|
|
|
|
func TestLoad_Defaults(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "config.yaml")
|
|
|
|
// Minimal YAML — only api_keys
|
|
if err := os.WriteFile(path, []byte("api_keys:\n - k1\n"), 0644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
cfg, err := Load(path)
|
|
if err != nil {
|
|
t.Fatalf("Load returned error: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
got interface{}
|
|
want interface{}
|
|
}{
|
|
{"Port", cfg.Port, 8080},
|
|
{"Logging.Level", cfg.Logging.Level, "info"},
|
|
{"Logging.MaxSizeMB", cfg.Logging.MaxSizeMB, 100},
|
|
{"Logging.MaxBackups", cfg.Logging.MaxBackups, 5},
|
|
{"Logging.MaxAgeDays", cfg.Logging.MaxAgeDays, 30},
|
|
{"Telemetry.ServiceName", cfg.Telemetry.ServiceName, "anthropic-proxy"},
|
|
{"Embedded.Port", cfg.Telemetry.Embedded.Port, 8080},
|
|
{"Embedded.VMBinary", cfg.Telemetry.Embedded.VMBinary, "victoria-metrics"},
|
|
{"Embedded.PersesBinary", cfg.Telemetry.Embedded.PersesBinary, "perses"},
|
|
{"Embedded.VMPort", cfg.Telemetry.Embedded.VMPort, 8428},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.got != tt.want {
|
|
t.Errorf("got %v, want %v", tt.got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLoad_MissingFile(t *testing.T) {
|
|
_, err := Load("/nonexistent/path/config.yaml")
|
|
if err == nil {
|
|
t.Fatal("expected error for missing file, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "read config") {
|
|
t.Errorf("error = %q, want it to contain 'read config'", err.Error())
|
|
}
|
|
}
|
|
|
|
func TestLoad_DeprecatedClaudeCredentials(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "config.yaml")
|
|
|
|
yaml := `
|
|
api_keys:
|
|
- k1
|
|
claude_credentials: "/some/path"
|
|
`
|
|
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err := Load(path)
|
|
if err == nil {
|
|
t.Fatal("expected error for deprecated claude_credentials, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "no longer supported") {
|
|
t.Errorf("error = %q, want it to contain 'no longer supported'", err.Error())
|
|
}
|
|
}
|
|
|
|
func TestLoad_EmptyClaudeCredentials(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "config.yaml")
|
|
|
|
// Empty string value should NOT trigger the deprecation error
|
|
yaml := `
|
|
api_keys:
|
|
- k1
|
|
claude_credentials: ""
|
|
`
|
|
if err := os.WriteFile(path, []byte(yaml), 0644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
cfg, err := Load(path)
|
|
if err != nil {
|
|
t.Fatalf("empty claude_credentials should not error: %v", err)
|
|
}
|
|
if cfg.Port != 8080 {
|
|
t.Errorf("Port = %d, want 8080", cfg.Port)
|
|
}
|
|
}
|
|
|
|
func TestLoad_InvalidYAML(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "config.yaml")
|
|
|
|
// Truly invalid YAML that causes a parse error
|
|
if err := os.WriteFile(path, []byte("port:\n - bad\n indent: broken\n"), 0644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err := Load(path)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid YAML, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "parse config") {
|
|
t.Errorf("error = %q, want it to contain 'parse config'", err.Error())
|
|
}
|
|
}
|
|
|
|
func TestExportConfig_Enabled(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
endpoint string
|
|
want bool
|
|
}{
|
|
{"empty endpoint", "", false},
|
|
{"set endpoint", "http://localhost:4317", true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
e := ExportConfig{Endpoint: tt.endpoint}
|
|
if got := e.Enabled(); got != tt.want {
|
|
t.Errorf("Enabled() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|