From 0df28e9dd8603b7a70466397aedc43cea1de6b6c Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 15 Apr 2026 11:01:29 +0200 Subject: [PATCH] =?UTF-8?q?refactor:=20modularize=20codebase=20=E2=80=94?= =?UTF-8?q?=20deduplicate,=20extract,=20clean=20up?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Unify duplicate uTLS transports into shared internal/transport package - Extract shared version constant into internal/version - Move LoadDefaultCredentials from config to auth (remove config→auth import) - Deduplicate handler.go: extract telemetry/error helpers (324→268 lines) - Break up main.go::run() into initCredential/initEmbedded - Eliminate logging.Config duplication (use config.LoggingConfig directly) - Extract logWriter to embedded/log.go, SSE fixtures to consts in sniff.go - Use uTLS client for usage polling (consistent TLS fingerprint) - Handle sjson.SetBytes errors in sanitize.go instead of silently swallowing - Document reverse-engineered magic values in billing.go - Unexport Credential.CooldownUntil (internal state) - Replace hardcoded auth bypass paths with map in server.go --- internal/auth/credentials.go | 56 +++++ internal/auth/credentials_test.go | 70 ++++++ internal/auth/refresh.go | 69 +----- internal/auth/selector_test.go | 14 +- internal/auth/types.go | 8 +- internal/auth/types_test.go | 22 +- internal/config/config.go | 56 ----- internal/config/config_test.go | 79 ------ internal/embedded/log.go | 20 ++ internal/embedded/perses.go | 16 -- internal/logging/logging.go | 14 +- internal/logging/logging_test.go | 10 +- internal/proxy/billing.go | 4 + internal/proxy/handler.go | 226 +++++++----------- internal/proxy/sanitize.go | 35 ++- internal/proxy/sniff.go | 94 ++++---- internal/proxy/upstream.go | 6 +- internal/ratelimit/usage.go | 19 +- internal/server/server.go | 10 +- .../{proxy/transport.go => transport/utls.go} | 38 ++- internal/transport/utls_test.go | 78 ++++++ internal/version/version.go | 8 + main.go | 136 ++++++----- 23 files changed, 568 insertions(+), 520 deletions(-) create mode 100644 internal/auth/credentials.go create mode 100644 internal/auth/credentials_test.go create mode 100644 internal/embedded/log.go rename internal/{proxy/transport.go => transport/utls.go} (58%) create mode 100644 internal/transport/utls_test.go create mode 100644 internal/version/version.go diff --git a/internal/auth/credentials.go b/internal/auth/credentials.go new file mode 100644 index 0000000..ec0e6d9 --- /dev/null +++ b/internal/auth/credentials.go @@ -0,0 +1,56 @@ +package auth + +import ( + "encoding/json" + "fmt" + "os" + "time" +) + +// claudeCredentialsJSON matches the structure of ~/.claude/.credentials.json. +type claudeCredentialsJSON struct { + ClaudeAiOauth struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresAt int64 `json:"expiresAt"` + SubscriptionType string `json:"subscriptionType"` + } `json:"claudeAiOauth"` +} + +// LoadDefaultCredentials reads credentials from ~/.claude/.credentials.json. +// Returns nil, nil if the file does not exist. +func LoadDefaultCredentials() ([]*Credential, error) { + path, err := DefaultCredentialPath() + if err != nil { + return nil, nil + } + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + + var cf claudeCredentialsJSON + if err := json.Unmarshal(data, &cf); err != nil { + return nil, err + } + + oauth := cf.ClaudeAiOauth + if oauth.AccessToken == "" { + return nil, fmt.Errorf("no access token in %s", path) + } + + cred := &Credential{ + ID: "claude-native", + Email: oauth.SubscriptionType, + AccessToken: oauth.AccessToken, + RefreshToken: oauth.RefreshToken, + ExpiresAt: time.UnixMilli(oauth.ExpiresAt), + FilePath: path, + } + + return []*Credential{cred}, nil +} diff --git a/internal/auth/credentials_test.go b/internal/auth/credentials_test.go new file mode 100644 index 0000000..6838b2f --- /dev/null +++ b/internal/auth/credentials_test.go @@ -0,0 +1,70 @@ +package auth + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestDefaultCredentialPath(t *testing.T) { + path, err := DefaultCredentialPath() + if err != nil { + t.Fatalf("DefaultCredentialPath error: %v", err) + } + if !strings.HasSuffix(path, filepath.Join(".claude", ".credentials.json")) { + t.Errorf("path = %q, want suffix .claude/.credentials.json", path) + } +} + +func TestLoadDefaultCredentials_MissingFile(t *testing.T) { + // When credential file doesn't exist, returns nil, nil + path, err := DefaultCredentialPath() + if err != nil { + t.Skip("cannot determine home dir") + } + if _, statErr := os.Stat(path); os.IsNotExist(statErr) { + creds, err := LoadDefaultCredentials() + if creds != nil { + t.Errorf("expected nil creds for missing file, got %v", creds) + } + if err != nil { + t.Errorf("expected nil error for missing file, got %v", err) + } + } +} + +func TestClaudeCredentialsJSON_ParsesCorrectly(t *testing.T) { + jsonData := `{"claudeAiOauth":{"accessToken":"test-token","refreshToken":"test-refresh","expiresAt":1234567890,"subscriptionType":"pro"}}` + + var cf claudeCredentialsJSON + if err := json.Unmarshal([]byte(jsonData), &cf); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if cf.ClaudeAiOauth.AccessToken != "test-token" { + t.Errorf("AccessToken = %q, want test-token", cf.ClaudeAiOauth.AccessToken) + } + if cf.ClaudeAiOauth.RefreshToken != "test-refresh" { + t.Errorf("RefreshToken = %q, want test-refresh", cf.ClaudeAiOauth.RefreshToken) + } + if cf.ClaudeAiOauth.ExpiresAt != 1234567890 { + t.Errorf("ExpiresAt = %d, want 1234567890", cf.ClaudeAiOauth.ExpiresAt) + } + if cf.ClaudeAiOauth.SubscriptionType != "pro" { + t.Errorf("SubscriptionType = %q, want pro", cf.ClaudeAiOauth.SubscriptionType) + } +} + +func TestClaudeCredentialsJSON_EmptyAccessToken(t *testing.T) { + jsonData := `{"claudeAiOauth":{"accessToken":"","refreshToken":"r","expiresAt":1}}` + + var cf claudeCredentialsJSON + if err := json.Unmarshal([]byte(jsonData), &cf); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if cf.ClaudeAiOauth.AccessToken != "" { + t.Errorf("expected empty access token") + } +} diff --git a/internal/auth/refresh.go b/internal/auth/refresh.go index 7452f01..d7e335f 100644 --- a/internal/auth/refresh.go +++ b/internal/auth/refresh.go @@ -6,16 +6,14 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "os" "path/filepath" - "sync" "time" - tls "github.com/refraction-networking/utls" "github.com/rs/zerolog/log" - "golang.org/x/net/http2" + + "github.com/fujin/anthropic-proxy/internal/transport" ) const ( @@ -28,7 +26,7 @@ const ( refreshBackoff = 5 * time.Minute ) -var utlsClient = newUTLSClient() +var utlsClient = transport.NewHTTPClient(15 * time.Second) type tokenRequest struct { ClientID string `json:"client_id"` @@ -147,67 +145,6 @@ func persistCredential(cred *Credential) error { return os.WriteFile(filePath, out, 0600) } -func newUTLSClient() *http.Client { - return &http.Client{ - Timeout: 15 * time.Second, - Transport: &utlsRefreshTransport{}, - } -} - -type utlsRefreshTransport struct { - mu sync.Mutex - conn *http2.ClientConn - host string -} - -func (t *utlsRefreshTransport) RoundTrip(req *http.Request) (*http.Response, error) { - host := req.URL.Hostname() - port := req.URL.Port() - if port == "" { - port = "443" - } - - t.mu.Lock() - if t.conn != nil && t.host == host && t.conn.CanTakeNewRequest() { - conn := t.conn - t.mu.Unlock() - resp, err := conn.RoundTrip(req) - if err == nil { - return resp, nil - } - t.mu.Lock() - t.conn = nil - t.mu.Unlock() - } else { - t.mu.Unlock() - } - - addr := net.JoinHostPort(host, port) - rawConn, err := net.DialTimeout("tcp", addr, 10*time.Second) - if err != nil { - return nil, err - } - - tlsConn := tls.UClient(rawConn, &tls.Config{ServerName: host}, tls.HelloChrome_Auto) - if err := tlsConn.Handshake(); err != nil { - rawConn.Close() - return nil, err - } - - h2Conn, err := (&http2.Transport{}).NewClientConn(tlsConn) - if err != nil { - tlsConn.Close() - return nil, err - } - - t.mu.Lock() - t.conn = h2Conn - t.host = host - t.mu.Unlock() - - return h2Conn.RoundTrip(req) -} - func StartBackgroundRefresh(ctx context.Context, pool *Pool) { go func() { for { diff --git a/internal/auth/selector_test.go b/internal/auth/selector_test.go index b0800db..7f0a9e8 100644 --- a/internal/auth/selector_test.go +++ b/internal/auth/selector_test.go @@ -80,7 +80,7 @@ func TestPool_Pick_RoundRobin(t *testing.T) { func TestPool_Pick_SkipsCooldown(t *testing.T) { creds := []*Credential{ {ID: "a"}, - {ID: "b", CooldownUntil: time.Now().Add(1 * time.Hour)}, + {ID: "b", cooldownUntil: time.Now().Add(1 * time.Hour)}, {ID: "c"}, } p := NewPool(creds) @@ -116,8 +116,8 @@ func TestPool_Pick_SkipsCooldown(t *testing.T) { func TestPool_Pick_AllOnCooldown(t *testing.T) { future := time.Now().Add(1 * time.Hour) creds := []*Credential{ - {ID: "a", CooldownUntil: future}, - {ID: "b", CooldownUntil: future}, + {ID: "a", cooldownUntil: future}, + {ID: "b", cooldownUntil: future}, } p := NewPool(creds) @@ -203,13 +203,13 @@ func TestPool_MarkFailure(t *testing.T) { } // Verify approximate duration cred.mu.Lock() - cooldownEnd := cred.CooldownUntil + cooldownEnd := cred.cooldownUntil cred.mu.Unlock() lower := before.Add(tt.expectedDur) upper := time.Now().Add(tt.expectedDur) if cooldownEnd.Before(lower) || cooldownEnd.After(upper) { - t.Errorf("CooldownUntil %v not in expected range [%v, %v]", cooldownEnd, lower, upper) + t.Errorf("cooldownUntil %v not in expected range [%v, %v]", cooldownEnd, lower, upper) } } else { if cred.IsOnCooldown() { @@ -223,7 +223,7 @@ func TestPool_MarkFailure(t *testing.T) { func TestPool_MarkSuccess(t *testing.T) { cred := &Credential{ ID: "test", - CooldownUntil: time.Now().Add(1 * time.Hour), + cooldownUntil: time.Now().Add(1 * time.Hour), } p := NewPool([]*Credential{cred}) @@ -282,7 +282,7 @@ func TestPool_RoundRobinCursorAdvancement(t *testing.T) { func TestPool_RoundRobinWithCooldownSkip(t *testing.T) { creds := []*Credential{ {ID: "0"}, - {ID: "1", CooldownUntil: time.Now().Add(1 * time.Hour)}, + {ID: "1", cooldownUntil: time.Now().Add(1 * time.Hour)}, {ID: "2"}, } p := NewPool(creds) diff --git a/internal/auth/types.go b/internal/auth/types.go index adf12a4..3c4aaaf 100644 --- a/internal/auth/types.go +++ b/internal/auth/types.go @@ -13,7 +13,7 @@ type Credential struct { RefreshToken string ExpiresAt time.Time FilePath string - CooldownUntil time.Time + cooldownUntil time.Time nextRefreshAfter time.Time mu sync.Mutex } @@ -22,21 +22,21 @@ type Credential struct { func (c *Credential) IsOnCooldown() bool { c.mu.Lock() defer c.mu.Unlock() - return time.Now().Before(c.CooldownUntil) + return time.Now().Before(c.cooldownUntil) } // SetCooldown puts the credential on cooldown for the given duration. func (c *Credential) SetCooldown(duration time.Duration) { c.mu.Lock() defer c.mu.Unlock() - c.CooldownUntil = time.Now().Add(duration) + c.cooldownUntil = time.Now().Add(duration) } // ClearCooldown removes any active cooldown on the credential. func (c *Credential) ClearCooldown() { c.mu.Lock() defer c.mu.Unlock() - c.CooldownUntil = time.Time{} + c.cooldownUntil = time.Time{} } // Token returns the current access token. diff --git a/internal/auth/types_test.go b/internal/auth/types_test.go index 506fc14..99185d3 100644 --- a/internal/auth/types_test.go +++ b/internal/auth/types_test.go @@ -31,7 +31,7 @@ func TestCredential_IsOnCooldown(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &Credential{CooldownUntil: tt.cooldownUntil} + c := &Credential{cooldownUntil: tt.cooldownUntil} got := c.IsOnCooldown() if got != tt.want { t.Errorf("IsOnCooldown() = %v, want %v", got, tt.want) @@ -57,12 +57,12 @@ func TestCredential_SetCooldown(t *testing.T) { c.SetCooldown(tt.duration) after := time.Now() - // CooldownUntil should be between before+duration and after+duration - if c.CooldownUntil.Before(before.Add(tt.duration)) { - t.Errorf("CooldownUntil %v is before expected lower bound %v", c.CooldownUntil, before.Add(tt.duration)) + // cooldownUntil should be between before+duration and after+duration + if c.cooldownUntil.Before(before.Add(tt.duration)) { + t.Errorf("cooldownUntil %v is before expected lower bound %v", c.cooldownUntil, before.Add(tt.duration)) } - if c.CooldownUntil.After(after.Add(tt.duration)) { - t.Errorf("CooldownUntil %v is after expected upper bound %v", c.CooldownUntil, after.Add(tt.duration)) + if c.cooldownUntil.After(after.Add(tt.duration)) { + t.Errorf("cooldownUntil %v is after expected upper bound %v", c.cooldownUntil, after.Add(tt.duration)) } // Should now be on cooldown @@ -75,7 +75,7 @@ func TestCredential_SetCooldown(t *testing.T) { func TestCredential_ClearCooldown(t *testing.T) { t.Run("clears active cooldown", func(t *testing.T) { - c := &Credential{CooldownUntil: time.Now().Add(1 * time.Hour)} + c := &Credential{cooldownUntil: time.Now().Add(1 * time.Hour)} if !c.IsOnCooldown() { t.Fatal("precondition: expected credential to be on cooldown") } @@ -85,8 +85,8 @@ func TestCredential_ClearCooldown(t *testing.T) { if c.IsOnCooldown() { t.Error("expected credential to not be on cooldown after ClearCooldown") } - if !c.CooldownUntil.IsZero() { - t.Errorf("expected CooldownUntil to be zero time, got %v", c.CooldownUntil) + if !c.cooldownUntil.IsZero() { + t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil) } }) @@ -97,8 +97,8 @@ func TestCredential_ClearCooldown(t *testing.T) { if c.IsOnCooldown() { t.Error("expected credential to not be on cooldown") } - if !c.CooldownUntil.IsZero() { - t.Errorf("expected CooldownUntil to be zero time, got %v", c.CooldownUntil) + if !c.cooldownUntil.IsZero() { + t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil) } }) } diff --git a/internal/config/config.go b/internal/config/config.go index 97f9753..a938994 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,12 +1,9 @@ package config import ( - "encoding/json" "fmt" "os" - "time" - "github.com/fujin/anthropic-proxy/internal/auth" "gopkg.in/yaml.v3" ) @@ -67,15 +64,6 @@ type LoggingConfig struct { Compress bool `yaml:"compress"` } -type claudeCredentialsJSON struct { - ClaudeAiOauth struct { - AccessToken string `json:"accessToken"` - RefreshToken string `json:"refreshToken"` - ExpiresAt int64 `json:"expiresAt"` - SubscriptionType string `json:"subscriptionType"` - } `json:"claudeAiOauth"` -} - func Load(path string) (*Config, error) { data, err := os.ReadFile(path) if err != nil { @@ -128,47 +116,3 @@ func Load(path string) (*Config, error) { return cfg, nil } - -func DefaultCredentialPath() string { - home, err := os.UserHomeDir() - if err != nil { - return "" - } - return home + "/.claude/.credentials.json" -} - -func LoadDefaultCredentials() ([]*auth.Credential, error) { - path := DefaultCredentialPath() - if path == "" { - return nil, nil - } - - data, err := os.ReadFile(path) - if err != nil { - if os.IsNotExist(err) { - return nil, nil - } - return nil, err - } - - var cf claudeCredentialsJSON - if err := json.Unmarshal(data, &cf); err != nil { - return nil, err - } - - oauth := cf.ClaudeAiOauth - if oauth.AccessToken == "" { - return nil, fmt.Errorf("no access token in %s", path) - } - - cred := &auth.Credential{ - ID: "claude-native", - Email: oauth.SubscriptionType, - AccessToken: oauth.AccessToken, - RefreshToken: oauth.RefreshToken, - ExpiresAt: time.UnixMilli(oauth.ExpiresAt), - FilePath: path, - } - - return []*auth.Credential{cred}, nil -} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index aed4dda..3064753 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,7 +1,6 @@ package config import ( - "encoding/json" "os" "path/filepath" "strings" @@ -269,81 +268,3 @@ func TestExportConfig_Enabled(t *testing.T) { }) } } - -func TestDefaultCredentialPath(t *testing.T) { - path := DefaultCredentialPath() - if path == "" { - t.Skip("could not determine home directory") - } - if !strings.HasSuffix(path, "/.claude/.credentials.json") { - t.Errorf("DefaultCredentialPath() = %q, want suffix /.claude/.credentials.json", path) - } -} - -func TestLoadDefaultCredentials_ValidFile(t *testing.T) { - // We can't easily override DefaultCredentialPath, so test the JSON parsing - // logic by creating a file at a temp location and calling the internal parsing - // directly. Instead, we test LoadDefaultCredentials indirectly by verifying - // it returns nil,nil when the default path doesn't exist (common in CI). - // For a full test, we create the credential file at the expected path. - - // Test with the actual function — if the default credential file doesn't - // exist, it should return nil, nil. - creds, err := LoadDefaultCredentials() - path := DefaultCredentialPath() - if path == "" { - if creds != nil || err != nil { - t.Errorf("expected nil,nil when home dir unavailable, got %v, %v", creds, err) - } - return - } - - if _, statErr := os.Stat(path); os.IsNotExist(statErr) { - // File doesn't exist — should return nil, nil - if creds != nil { - t.Errorf("expected nil creds for missing file, got %v", creds) - } - if err != nil { - t.Errorf("expected nil error for missing file, got %v", err) - } - } -} - -func TestLoadDefaultCredentials_ParsesJSON(t *testing.T) { - // Test the JSON parsing by creating a temp credential file and using - // the claudeCredentialsJSON struct directly (white-box test). - jsonData := `{"claudeAiOauth":{"accessToken":"test-token","refreshToken":"test-refresh","expiresAt":1234567890,"subscriptionType":"pro"}}` - - var cf claudeCredentialsJSON - if err := json.Unmarshal([]byte(jsonData), &cf); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if cf.ClaudeAiOauth.AccessToken != "test-token" { - t.Errorf("AccessToken = %q, want test-token", cf.ClaudeAiOauth.AccessToken) - } - if cf.ClaudeAiOauth.RefreshToken != "test-refresh" { - t.Errorf("RefreshToken = %q, want test-refresh", cf.ClaudeAiOauth.RefreshToken) - } - if cf.ClaudeAiOauth.ExpiresAt != 1234567890 { - t.Errorf("ExpiresAt = %d, want 1234567890", cf.ClaudeAiOauth.ExpiresAt) - } - if cf.ClaudeAiOauth.SubscriptionType != "pro" { - t.Errorf("SubscriptionType = %q, want pro", cf.ClaudeAiOauth.SubscriptionType) - } -} - -func TestLoadDefaultCredentials_EmptyAccessToken(t *testing.T) { - // Verify that an empty access token in the JSON produces an error. - // We test the parsing struct and logic path. - jsonData := `{"claudeAiOauth":{"accessToken":"","refreshToken":"r","expiresAt":1}}` - - var cf claudeCredentialsJSON - if err := json.Unmarshal([]byte(jsonData), &cf); err != nil { - t.Fatalf("unmarshal: %v", err) - } - if cf.ClaudeAiOauth.AccessToken != "" { - t.Errorf("expected empty access token") - } - // The actual LoadDefaultCredentials would return an error here. -} diff --git a/internal/embedded/log.go b/internal/embedded/log.go new file mode 100644 index 0000000..6db54af --- /dev/null +++ b/internal/embedded/log.go @@ -0,0 +1,20 @@ +package embedded + +import "github.com/rs/zerolog/log" + +// logWriter bridges subprocess stdout/stderr to zerolog. +type logWriter struct { + level string + component string +} + +func (w *logWriter) Write(p []byte) (n int, err error) { + msg := string(p) + switch w.level { + case "error": + log.Error().Str("component", w.component).Msg(msg) + default: + log.Debug().Str("component", w.component).Msg(msg) + } + return len(p), nil +} diff --git a/internal/embedded/perses.go b/internal/embedded/perses.go index 1f80d87..c879d07 100644 --- a/internal/embedded/perses.go +++ b/internal/embedded/perses.go @@ -131,19 +131,3 @@ func (p *Perses) writeDashboardProvision() error { dashData, 0o644, ) } - -type logWriter struct { - level string - component string -} - -func (w *logWriter) Write(p []byte) (n int, err error) { - msg := string(p) - switch w.level { - case "error": - log.Error().Str("component", w.component).Msg(msg) - default: - log.Debug().Str("component", w.component).Msg(msg) - } - return len(p), nil -} diff --git a/internal/logging/logging.go b/internal/logging/logging.go index f378d7f..3d5fb3c 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -13,17 +13,9 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" "gopkg.in/lumberjack.v2" -) -// Config holds logging configuration, mirrors config.LoggingConfig. -type Config struct { - Level string - File string - MaxSizeMB int - MaxBackups int - MaxAgeDays int - Compress bool -} + "github.com/fujin/anthropic-proxy/internal/config" +) // Setup initializes the global zerolog logger. // - File set: JSON → lumberjack rotating file @@ -31,7 +23,7 @@ type Config struct { // - File empty + not TTY: JSON → stderr (for systemd journal) // Extra writers (e.g., OTLP log bridge) are added via io.MultiWriter so logs // are written to both the primary destination and any extra writers. -func Setup(cfg Config, extraWriters ...io.Writer) zerolog.Logger { +func Setup(cfg config.LoggingConfig, extraWriters ...io.Writer) zerolog.Logger { // Parse log level level, err := zerolog.ParseLevel(cfg.Level) if err != nil || cfg.Level == "" { diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go index e7d964d..64864a0 100644 --- a/internal/logging/logging_test.go +++ b/internal/logging/logging_test.go @@ -9,6 +9,8 @@ import ( "testing" "github.com/rs/zerolog" + + "github.com/fujin/anthropic-proxy/internal/config" ) func TestRedactHeaders(t *testing.T) { @@ -177,7 +179,7 @@ func TestSetup_WithFile(t *testing.T) { dir := t.TempDir() logFile := filepath.Join(dir, "test.log") - logger := Setup(Config{ + logger := Setup(config.LoggingConfig{ Level: "debug", File: logFile, MaxSizeMB: 10, @@ -191,7 +193,7 @@ func TestSetup_WithFile(t *testing.T) { func TestSetup_WithoutFile(t *testing.T) { // File empty — should use console or stderr mode depending on TTY - logger := Setup(Config{ + logger := Setup(config.LoggingConfig{ Level: "warn", }) @@ -201,13 +203,13 @@ func TestSetup_WithoutFile(t *testing.T) { func TestSetup_DefaultLevel(t *testing.T) { // Empty level should default to info - logger := Setup(Config{}) + logger := Setup(config.LoggingConfig{}) _ = logger // verify no panic } func TestSetup_InvalidLevel(t *testing.T) { // Invalid level should default to info - logger := Setup(Config{Level: "not-a-level"}) + logger := Setup(config.LoggingConfig{Level: "not-a-level"}) _ = logger // verify no panic } diff --git a/internal/proxy/billing.go b/internal/proxy/billing.go index de39b4b..abc4176 100644 --- a/internal/proxy/billing.go +++ b/internal/proxy/billing.go @@ -11,9 +11,13 @@ import ( "github.com/tidwall/sjson" ) +// fingerprintSalt is the fixed salt used by Claude Code for billing header +// fingerprint computation. Extracted from the Claude Code CLI source. const fingerprintSalt = "59cf53e54c78" func computeFingerprint(firstUserMessage string, version string) string { + // UTF-16 character indices sampled from the first user message, matching + // the Claude Code CLI's fingerprinting algorithm. indices := []int{4, 7, 20} runes := utf16.Encode([]rune(firstUserMessage)) var chars string diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 04d65d7..4602275 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -2,6 +2,7 @@ package proxy import ( "bufio" + "context" "io" "net/http" "time" @@ -18,6 +19,15 @@ import ( "github.com/fujin/anthropic-proxy/internal/telemetry" ) +// requestInfo bundles common request context passed to logging/telemetry helpers. +type requestInfo struct { + model string + stream bool + cred *auth.Credential + body []byte + originalBody []byte +} + func HandleMessages(pool *auth.Pool, profile *SniffedProfile, getSanitizer func() *Sanitizer, tracker *ratelimit.Tracker) gin.HandlerFunc { upstream := NewUpstreamClient(profile) @@ -61,6 +71,7 @@ func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, p startTime := time.Now() model := gjson.GetBytes(body, "model").String() ctx := c.Request.Context() + ri := requestInfo{model: model, stream: false, cred: cred, body: body, originalBody: originalBody} telemetry.RequestBodySize.Record(ctx, int64(len(body)), metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false))) @@ -69,85 +80,25 @@ func handleNonStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, p latencyMs := float64(time.Since(startTime).Milliseconds()) if err != nil { - log.Error(). - Err(err). - Str("credential", cred.Email). - Str("model", model). - Bool("stream", false). - Str("request_body_original", string(originalBody)). - Str("request_body_sanitized", string(body)). - Int("request_body_size", len(body)). - Float64("latency_ms", latencyMs). - Msg("upstream connection error") - - telemetry.UpstreamErrors.Add(ctx, 1, - metric.WithAttributes( - attribute.String("error_type", "connection"), - attribute.String("credential", cred.Email), - attribute.Int("status_code", http.StatusBadGateway), - )) - telemetry.RequestCounter.Add(ctx, 1, - metric.WithAttributes( - attribute.String("model", model), - attribute.Bool("stream", false), - attribute.Int("status_code", http.StatusBadGateway), - )) - telemetry.RequestDuration.Record(ctx, latencyMs, - metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", false), attribute.Int("status_code", http.StatusBadGateway))) - + recordConnectionError(ctx, err, ri, latencyMs) c.JSON(http.StatusBadGateway, gin.H{"error": "upstream request failed"}) return } - attrs := []attribute.KeyValue{ - attribute.String("model", model), - attribute.Bool("stream", false), - attribute.Int("status_code", statusCode), - } - - telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...)) - telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...)) + recordRequestMetrics(ctx, ri, statusCode, latencyMs) if statusCode >= 400 { pool.MarkFailure(cred, statusCode) telemetry.CredentialCooldowns.Add(ctx, 1, metric.WithAttributes(attribute.Int("status_code", statusCode))) - errorType := gjson.GetBytes(respBody, "error.type").String() - errorMessage := gjson.GetBytes(respBody, "error.message").String() - log.Error(). - Int("status", statusCode). - Str("error_type", errorType). - Str("error_message", errorMessage). - Str("response_body", string(respBody)). - Str("request_id", headers.Get("X-Request-Id")). - Float64("latency_ms", latencyMs). - Str("credential", cred.Email). - Str("model", model). - Bool("stream", false). - Str("request_body_original", string(originalBody)). - Str("request_body_sanitized", string(body)). - Int("request_body_size", len(body)). - Str("request_headers", logging.RedactHeaders(c.Request.Header)). - Msg("upstream error") - - telemetry.UpstreamErrors.Add(ctx, 1, - metric.WithAttributes( - attribute.Int("status_code", statusCode), - attribute.String("error_type", errorType), - attribute.String("credential", cred.Email), - )) + recordUpstreamError(ctx, statusCode, respBody, headers.Get("X-Request-Id"), latencyMs, ri, c.Request.Header) } else { pool.MarkSuccess(cred) respBody = san.DesanitizeResponse(respBody) inputTokens := gjson.GetBytes(respBody, "usage.input_tokens").Int() outputTokens := gjson.GetBytes(respBody, "usage.output_tokens").Int() - tokenAttrs := metric.WithAttributes( - attribute.String("model", model), - attribute.String("credential", cred.Email), - ) - telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs) - telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs) + recordTokenUsage(ctx, model, cred, inputTokens, outputTokens) if tracker != nil { tracker.UpdateFromHeaders(headers) } @@ -174,6 +125,7 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool startTime := time.Now() model := gjson.GetBytes(body, "model").String() ctx := c.Request.Context() + ri := requestInfo{model: model, stream: true, cred: cred, body: body, originalBody: originalBody} telemetry.StreamRequests.Add(ctx, 1, metric.WithAttributes(attribute.String("model", model))) telemetry.RequestBodySize.Record(ctx, int64(len(body)), @@ -182,32 +134,7 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool resp, err := upstream.ExecuteStream(ctx, cred, body) if err != nil { latencyMs := float64(time.Since(startTime).Milliseconds()) - log.Error(). - Err(err). - Str("credential", cred.Email). - Str("model", model). - Bool("stream", true). - Str("request_body_original", string(originalBody)). - Str("request_body_sanitized", string(body)). - Int("request_body_size", len(body)). - Float64("latency_ms", latencyMs). - Msg("upstream connection error") - - telemetry.UpstreamErrors.Add(ctx, 1, - metric.WithAttributes( - attribute.String("error_type", "connection"), - attribute.String("credential", cred.Email), - attribute.Int("status_code", http.StatusBadGateway), - )) - telemetry.RequestCounter.Add(ctx, 1, - metric.WithAttributes( - attribute.String("model", model), - attribute.Bool("stream", true), - attribute.Int("status_code", http.StatusBadGateway), - )) - telemetry.RequestDuration.Record(ctx, latencyMs, - metric.WithAttributes(attribute.String("model", model), attribute.Bool("stream", true), attribute.Int("status_code", http.StatusBadGateway))) - + recordConnectionError(ctx, err, ri, latencyMs) c.JSON(http.StatusBadGateway, gin.H{"error": "upstream stream request failed"}) return } @@ -219,37 +146,8 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool metric.WithAttributes(attribute.Int("status_code", resp.StatusCode))) respBody, _ := io.ReadAll(resp.Body) latencyMs := float64(time.Since(startTime).Milliseconds()) - errorType := gjson.GetBytes(respBody, "error.type").String() - errorMessage := gjson.GetBytes(respBody, "error.message").String() - log.Error(). - Int("status", resp.StatusCode). - Str("error_type", errorType). - Str("error_message", errorMessage). - Str("response_body", string(respBody)). - Str("request_id", resp.Header.Get("X-Request-Id")). - Float64("latency_ms", latencyMs). - Str("credential", cred.Email). - Str("model", model). - Bool("stream", true). - Str("request_body_original", string(originalBody)). - Str("request_body_sanitized", string(body)). - Int("request_body_size", len(body)). - Str("request_headers", logging.RedactHeaders(c.Request.Header)). - Msg("upstream error") - - attrs := []attribute.KeyValue{ - attribute.String("model", model), - attribute.Bool("stream", true), - attribute.Int("status_code", resp.StatusCode), - } - telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...)) - telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...)) - telemetry.UpstreamErrors.Add(ctx, 1, - metric.WithAttributes( - attribute.Int("status_code", resp.StatusCode), - attribute.String("error_type", errorType), - attribute.String("credential", cred.Email), - )) + recordRequestMetrics(ctx, ri, resp.StatusCode, latencyMs) + recordUpstreamError(ctx, resp.StatusCode, respBody, resp.Header.Get("X-Request-Id"), latencyMs, ri, c.Request.Header) c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody) return @@ -290,21 +188,10 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool } latencyMs := float64(time.Since(startTime).Milliseconds()) - attrs := []attribute.KeyValue{ - attribute.String("model", model), - attribute.Bool("stream", true), - attribute.Int("status_code", http.StatusOK), - } - telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...)) - telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...)) + recordRequestMetrics(ctx, ri, http.StatusOK, latencyMs) if inputTokens > 0 || outputTokens > 0 { - tokenAttrs := metric.WithAttributes( - attribute.String("model", model), - attribute.String("credential", cred.Email), - ) - telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs) - telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs) + recordTokenUsage(ctx, model, cred, inputTokens, outputTokens) if tracker != nil { tracker.UpdateFromHeaders(resp.Header) } @@ -322,3 +209,74 @@ func handleStream(c *gin.Context, upstream *UpstreamClient, san *Sanitizer, pool log.Error().Err(err).Msg("stream scan error") } } + +// recordConnectionError logs and records metrics for upstream connection failures. +func recordConnectionError(ctx context.Context, err error, ri requestInfo, latencyMs float64) { + log.Error(). + Err(err). + Str("credential", ri.cred.Email). + Str("model", ri.model). + Bool("stream", ri.stream). + Str("request_body_original", string(ri.originalBody)). + Str("request_body_sanitized", string(ri.body)). + Int("request_body_size", len(ri.body)). + Float64("latency_ms", latencyMs). + Msg("upstream connection error") + + telemetry.UpstreamErrors.Add(ctx, 1, + metric.WithAttributes( + attribute.String("error_type", "connection"), + attribute.String("credential", ri.cred.Email), + attribute.Int("status_code", http.StatusBadGateway), + )) + recordRequestMetrics(ctx, ri, http.StatusBadGateway, latencyMs) +} + +// recordUpstreamError logs and records metrics for upstream HTTP error responses. +func recordUpstreamError(ctx context.Context, statusCode int, respBody []byte, requestID string, latencyMs float64, ri requestInfo, requestHeaders http.Header) { + errorType := gjson.GetBytes(respBody, "error.type").String() + errorMessage := gjson.GetBytes(respBody, "error.message").String() + log.Error(). + Int("status", statusCode). + Str("error_type", errorType). + Str("error_message", errorMessage). + Str("response_body", string(respBody)). + Str("request_id", requestID). + Float64("latency_ms", latencyMs). + Str("credential", ri.cred.Email). + Str("model", ri.model). + Bool("stream", ri.stream). + Str("request_body_original", string(ri.originalBody)). + Str("request_body_sanitized", string(ri.body)). + Int("request_body_size", len(ri.body)). + Str("request_headers", logging.RedactHeaders(requestHeaders)). + Msg("upstream error") + + telemetry.UpstreamErrors.Add(ctx, 1, + metric.WithAttributes( + attribute.Int("status_code", statusCode), + attribute.String("error_type", errorType), + attribute.String("credential", ri.cred.Email), + )) +} + +// recordRequestMetrics records the request counter and duration histogram. +func recordRequestMetrics(ctx context.Context, ri requestInfo, statusCode int, latencyMs float64) { + attrs := []attribute.KeyValue{ + attribute.String("model", ri.model), + attribute.Bool("stream", ri.stream), + attribute.Int("status_code", statusCode), + } + telemetry.RequestCounter.Add(ctx, 1, metric.WithAttributes(attrs...)) + telemetry.RequestDuration.Record(ctx, latencyMs, metric.WithAttributes(attrs...)) +} + +// recordTokenUsage records token consumption metrics. +func recordTokenUsage(ctx context.Context, model string, cred *auth.Credential, inputTokens, outputTokens int64) { + tokenAttrs := metric.WithAttributes( + attribute.String("model", model), + attribute.String("credential", cred.Email), + ) + telemetry.TokensInput.Add(ctx, inputTokens, tokenAttrs) + telemetry.TokensOutput.Add(ctx, outputTokens, tokenAttrs) +} diff --git a/internal/proxy/sanitize.go b/internal/proxy/sanitize.go index 3966ab6..7affaf4 100644 --- a/internal/proxy/sanitize.go +++ b/internal/proxy/sanitize.go @@ -4,6 +4,7 @@ import ( "strconv" "strings" + "github.com/rs/zerolog/log" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -11,10 +12,10 @@ import ( ) type Sanitizer struct { - toolsForward map[string]string - toolsReverse map[string]string - systemRules []config.ReplaceRule - bodyRules []config.ReplaceRule + toolsForward map[string]string + toolsReverse map[string]string + systemRules []config.ReplaceRule + bodyRules []config.ReplaceRule } func NewSanitizer(cfg config.SanitizeConfig) *Sanitizer { @@ -49,7 +50,11 @@ func (s *Sanitizer) DesanitizeResponse(body []byte) []byte { } name := block.Get("name").String() if orig, ok := s.toolsReverse[name]; ok { - body, _ = sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig) + if b, err := sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig); err != nil { + log.Warn().Err(err).Str("tool", name).Msg("desanitize response: set name failed") + } else { + body = b + } } } return body @@ -64,8 +69,12 @@ func (s *Sanitizer) DesanitizeStreamEvent(line string) string { for _, path := range []string{"content_block.name", "delta.name"} { name := gjson.GetBytes(data, path).String() if orig, ok := s.toolsReverse[name]; ok { - data, _ = sjson.SetBytes(data, path, orig) - changed = true + if b, err := sjson.SetBytes(data, path, orig); err != nil { + log.Warn().Err(err).Str("tool", name).Msg("desanitize stream event: set name failed") + } else { + data = b + changed = true + } } } if changed { @@ -85,7 +94,11 @@ func (s *Sanitizer) renameTools(body []byte) []byte { for i, tool := range tools.Array() { name := tool.Get("name").String() if newName, ok := s.toolsForward[name]; ok { - body, _ = sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName) + if b, err := sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName); err != nil { + log.Warn().Err(err).Str("tool", name).Msg("rename tool failed") + } else { + body = b + } } } return body @@ -104,7 +117,11 @@ func (s *Sanitizer) replaceSystem(body []byte) []byte { for _, rule := range s.systemRules { text = strings.ReplaceAll(text, rule.Match, rule.Replace) } - body, _ = sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text) + if b, err := sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text); err != nil { + log.Warn().Err(err).Int("block", i).Msg("replace system text failed") + } else { + body = b + } } return body } diff --git a/internal/proxy/sniff.go b/internal/proxy/sniff.go index 561c2fc..e58ce32 100644 --- a/internal/proxy/sniff.go +++ b/internal/proxy/sniff.go @@ -36,6 +36,21 @@ var skipHeaders = map[string]bool{ "connection": true, } +const fakeJSONResponse = `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}` + +const fakeStreamResponse = "event: message_start\n" + + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n" + + "event: content_block_start\n" + + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n" + + "event: content_block_delta\n" + + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n" + + "event: content_block_stop\n" + + "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n" + + "event: message_delta\n" + + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n" + + "event: message_stop\n" + + "data: {\"type\":\"message_stop\"}\n\n" + func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -48,45 +63,7 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) { captured := make(chan struct{}, 1) mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - if r.Method == "HEAD" { - w.WriteHeader(200) - return - } - if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(200) - fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`) - return - } - - body, _ := io.ReadAll(r.Body) - - mu.Lock() - if profile == nil { - profile = extractProfile(r, body) - select { - case captured <- struct{}{}: - default: - } - } - mu.Unlock() - - if strings.Contains(string(body), `"stream":true`) { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(200) - fmt.Fprint(w, "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_fake\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\n\n") - fmt.Fprint(w, "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n") - fmt.Fprint(w, "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"ok\"}}\n\n") - fmt.Fprint(w, "event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n") - fmt.Fprint(w, "event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n") - fmt.Fprint(w, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n") - } else { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(200) - fmt.Fprint(w, `{"id":"msg_fake","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-sonnet-4-6","stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`) - } - }) + mux.HandleFunc("/", sniffHandler(&mu, &profile, captured)) srv := &http.Server{Handler: mux} go srv.Serve(listener) @@ -130,8 +107,44 @@ func SniffClaudeCode(claudeBinary string) (*SniffedProfile, error) { return profile, nil } +func sniffHandler(mu *sync.Mutex, profile **SniffedProfile, captured chan<- struct{}) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.WriteHeader(200) + return + } + if r.Method != "POST" || !strings.Contains(r.URL.Path, "/v1/messages") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + fmt.Fprint(w, fakeJSONResponse) + return + } + + body, _ := io.ReadAll(r.Body) + + mu.Lock() + if *profile == nil { + *profile = extractProfile(r, body) + select { + case captured <- struct{}{}: + default: + } + } + mu.Unlock() + + if strings.Contains(string(body), `"stream":true`) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(200) + fmt.Fprint(w, fakeStreamResponse) + } else { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + fmt.Fprint(w, fakeJSONResponse) + } + } +} + func extractProfile(r *http.Request, body []byte) *SniffedProfile { - // Capture raw headers preserving original casing. var headers [][2]string for name, vals := range r.Header { if skipHeaders[strings.ToLower(name)] { @@ -142,7 +155,6 @@ func extractProfile(r *http.Request, body []byte) *SniffedProfile { } } - // Deduplicate and strip subscription-specific betas. seen := map[string]bool{} var deduped [][2]string for _, h := range headers { diff --git a/internal/proxy/upstream.go b/internal/proxy/upstream.go index e836d8f..2507244 100644 --- a/internal/proxy/upstream.go +++ b/internal/proxy/upstream.go @@ -13,6 +13,8 @@ import ( "github.com/fujin/anthropic-proxy/internal/auth" "github.com/fujin/anthropic-proxy/internal/logging" + "github.com/fujin/anthropic-proxy/internal/transport" + "github.com/fujin/anthropic-proxy/internal/version" ) const messagesURL = "https://api.anthropic.com/v1/messages?beta=true" @@ -27,7 +29,7 @@ func NewUpstreamClient(profile *SniffedProfile) *UpstreamClient { return &UpstreamClient{ client: http.Client{ Timeout: 0, - Transport: newUtlsRoundTripper(), + Transport: transport.NewUTLS(), }, sessionID: uuid.New().String(), profile: profile, @@ -38,7 +40,7 @@ func (u *UpstreamClient) version() string { if u.profile != nil && u.profile.Version != "" { return u.profile.Version } - return "2.1.92" + return version.ClaudeCodeFallback } // applyHeaders replays sniffed headers, substituting auth + per-request IDs + accept. diff --git a/internal/ratelimit/usage.go b/internal/ratelimit/usage.go index 5aa8586..a80e7c6 100644 --- a/internal/ratelimit/usage.go +++ b/internal/ratelimit/usage.go @@ -7,8 +7,13 @@ import ( "io" "net/http" "time" + + "github.com/fujin/anthropic-proxy/internal/transport" + "github.com/fujin/anthropic-proxy/internal/version" ) +var usageClient = transport.NewHTTPClient(10 * time.Second) + const usageURL = "https://api.anthropic.com/api/oauth/usage" type RateLimit struct { @@ -17,17 +22,17 @@ type RateLimit struct { } type ExtraUsage struct { - IsEnabled bool `json:"is_enabled"` + IsEnabled bool `json:"is_enabled"` MonthlyLimit *float64 `json:"monthly_limit"` UsedCredits *float64 `json:"used_credits"` Utilization *float64 `json:"utilization"` } type UsageResponse struct { - FiveHour *RateLimit `json:"five_hour"` - SevenDay *RateLimit `json:"seven_day"` - SevenDaySonnet *RateLimit `json:"seven_day_sonnet"` - ExtraUsage *ExtraUsage `json:"extra_usage"` + FiveHour *RateLimit `json:"five_hour"` + SevenDay *RateLimit `json:"seven_day"` + SevenDaySonnet *RateLimit `json:"seven_day_sonnet"` + ExtraUsage *ExtraUsage `json:"extra_usage"` } func fetchUsage(ctx context.Context, token string) (*UsageResponse, error) { @@ -41,9 +46,9 @@ func fetchUsage(ctx context.Context, token string) (*UsageResponse, error) { req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Content-Type", "application/json") req.Header.Set("anthropic-beta", "oauth-2025-04-20") - req.Header.Set("User-Agent", "claude-cli/2.1.92") + req.Header.Set("User-Agent", "claude-cli/"+version.ClaudeCodeFallback) - resp, err := http.DefaultClient.Do(req) + resp, err := usageClient.Do(req) if err != nil { return nil, fmt.Errorf("request: %w", err) } diff --git a/internal/server/server.go b/internal/server/server.go index 84836fe..30e925a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -138,10 +138,16 @@ func corsMiddleware() gin.HandlerFunc { } } +// authBypassPaths lists endpoints that do not require API key authentication. +var authBypassPaths = map[string]bool{ + "/healthz": true, + "/reload": true, + "/metrics": true, +} + func (s *Server) authMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - path := c.Request.URL.Path - if path == "/healthz" || path == "/reload" || path == "/metrics" { + if authBypassPaths[c.Request.URL.Path] { c.Next() return } diff --git a/internal/proxy/transport.go b/internal/transport/utls.go similarity index 58% rename from internal/proxy/transport.go rename to internal/transport/utls.go index a479711..0e55036 100644 --- a/internal/proxy/transport.go +++ b/internal/transport/utls.go @@ -1,29 +1,47 @@ -package proxy +// Package transport provides a shared uTLS HTTP/2 round-tripper with Chrome +// TLS fingerprinting and per-host connection pooling. Used by both the upstream +// proxy client and the OAuth token refresh client. +package transport import ( "net" "net/http" "sync" + "time" tls "github.com/refraction-networking/utls" - "github.com/rs/zerolog/log" "golang.org/x/net/http2" ) -type utlsRoundTripper struct { +// UTLS implements http.RoundTripper using uTLS (Chrome fingerprint) over HTTP/2. +// It maintains a per-host connection pool with coordination for concurrent +// requests to the same host. +type UTLS struct { mu sync.Mutex connections map[string]*http2.ClientConn pending map[string]*sync.Cond + dialTimeout time.Duration } -func newUtlsRoundTripper() *utlsRoundTripper { - return &utlsRoundTripper{ +// NewUTLS creates a uTLS HTTP/2 round-tripper with a 10-second dial timeout. +func NewUTLS() *UTLS { + return &UTLS{ connections: make(map[string]*http2.ClientConn), pending: make(map[string]*sync.Cond), + dialTimeout: 10 * time.Second, } } -func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) { +// NewHTTPClient returns an http.Client using uTLS transport with the given +// request timeout. Pass 0 for no timeout (streaming). +func NewHTTPClient(timeout time.Duration) *http.Client { + return &http.Client{ + Timeout: timeout, + Transport: NewUTLS(), + } +} + +func (t *UTLS) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) { t.mu.Lock() if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() { @@ -59,8 +77,8 @@ func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.Clie return h2Conn, nil } -func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) { - conn, err := net.Dial("tcp", addr) +func (t *UTLS) createConnection(host, addr string) (*http2.ClientConn, error) { + conn, err := net.DialTimeout("tcp", addr, t.dialTimeout) if err != nil { return nil, err } @@ -83,14 +101,14 @@ func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientCon return h2Conn, nil } -func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { +// RoundTrip implements http.RoundTripper with uTLS Chrome fingerprinting. +func (t *UTLS) RoundTrip(req *http.Request) (*http.Response, error) { hostname := req.URL.Hostname() port := req.URL.Port() if port == "" { port = "443" } addr := net.JoinHostPort(hostname, port) - log.Debug().Str("addr", addr).Msg("uTLS round trip") h2Conn, err := t.getOrCreateConnection(hostname, addr) if err != nil { diff --git a/internal/transport/utls_test.go b/internal/transport/utls_test.go new file mode 100644 index 0000000..1a34722 --- /dev/null +++ b/internal/transport/utls_test.go @@ -0,0 +1,78 @@ +package transport + +import ( + "net/http" + "testing" + "time" +) + +func TestNewUTLS(t *testing.T) { + tr := NewUTLS() + if tr == nil { + t.Fatal("NewUTLS returned nil") + } + if tr.connections == nil { + t.Error("connections map is nil") + } + if tr.pending == nil { + t.Error("pending map is nil") + } + if tr.dialTimeout != 10*time.Second { + t.Errorf("dialTimeout = %v, want 10s", tr.dialTimeout) + } +} + +func TestNewHTTPClient(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + }{ + {"zero timeout (streaming)", 0}, + {"15s timeout", 15 * time.Second}, + {"30s timeout", 30 * time.Second}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewHTTPClient(tt.timeout) + if c == nil { + t.Fatal("NewHTTPClient returned nil") + } + if c.Timeout != tt.timeout { + t.Errorf("Timeout = %v, want %v", c.Timeout, tt.timeout) + } + if c.Transport == nil { + t.Error("Transport is nil") + } + if _, ok := c.Transport.(*UTLS); !ok { + t.Errorf("Transport type = %T, want *UTLS", c.Transport) + } + }) + } +} + +func TestUTLS_ImplementsRoundTripper(t *testing.T) { + var _ http.RoundTripper = (*UTLS)(nil) +} + +func TestUTLS_RoundTrip_InvalidHost(t *testing.T) { + tr := NewUTLS() + // Use a non-routable address to test dial timeout behavior + req, err := http.NewRequest("GET", "https://192.0.2.1:443/test", nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + _, err = tr.RoundTrip(req) + if err == nil { + t.Error("expected error for non-routable address, got nil") + } +} + +func TestUTLS_ConnectionEviction(t *testing.T) { + tr := NewUTLS() + // Verify connections map starts empty + tr.mu.Lock() + if len(tr.connections) != 0 { + t.Errorf("initial connections = %d, want 0", len(tr.connections)) + } + tr.mu.Unlock() +} diff --git a/internal/version/version.go b/internal/version/version.go new file mode 100644 index 0000000..4459df0 --- /dev/null +++ b/internal/version/version.go @@ -0,0 +1,8 @@ +// Package version provides the fallback Claude Code client version used when +// no sniffed profile is available. This constant is shared between the upstream +// proxy client and the rate limit usage poller. +package version + +// ClaudeCodeFallback is the Claude Code CLI version string used as a fallback +// when no real version is obtained from sniffing. +const ClaudeCodeFallback = "2.1.92" diff --git a/main.go b/main.go index a1884de..feb81d6 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,74 @@ import ( "github.com/rs/zerolog/log" ) +func initCredential() (*auth.Credential, error) { + creds, err := auth.LoadDefaultCredentials() + if err != nil { + return nil, fmt.Errorf("load credentials: %w", err) + } + + var cred *auth.Credential + if len(creds) > 0 { + cred = creds[0] + // If token is expired, try refresh first + if !cred.ExpiresAt.IsZero() && time.Now().After(cred.ExpiresAt) { + log.Info().Msg("token expired, attempting refresh") + refreshCtx, refreshCancel := context.WithTimeout(context.Background(), 15*time.Second) + refreshErr := auth.RefreshToken(refreshCtx, cred) + refreshCancel() + if refreshErr != nil { + log.Warn().Err(refreshErr).Msg("refresh failed, initiating login") + cred = nil // fall through to login + } else { + log.Info().Msg("token refreshed") + } + } + } + + if cred == nil { + fi, statErr := os.Stdin.Stat() + if statErr == nil && (fi.Mode()&os.ModeCharDevice) == 0 { + return nil, fmt.Errorf("no valid credentials found; run the proxy interactively for initial login") + } + log.Info().Msg("no credentials found, starting OAuth login") + cred, err = auth.Login(context.Background()) + if err != nil { + return nil, fmt.Errorf("login failed: %w", err) + } + } + + log.Info().Str("credential", cred.Email).Msg("credential loaded") + return cred, nil +} + +func initEmbedded(cfg *config.Config) (cleanup func(), err error) { + if !cfg.Telemetry.Embedded.Enabled { + return func() {}, nil + } + + var cleanups []func() + + vm := embedded.NewVM(cfg.Telemetry.Embedded, cfg.Port) + if err := vm.Start(); err != nil { + log.Error().Err(err).Msg("failed to start victoria-metrics") + } else { + cleanups = append(cleanups, vm.Stop) + } + + perses := embedded.NewPerses(cfg.Telemetry.Embedded, cfg.Port) + if err := perses.Start(); err != nil { + log.Error().Err(err).Msg("failed to start perses") + } else { + cleanups = append(cleanups, perses.Stop) + } + + return func() { + for i := len(cleanups) - 1; i >= 0; i-- { + cleanups[i]() + } + }, nil +} + func run() error { cfg, err := config.Load("config.yaml") if err != nil { @@ -48,54 +116,13 @@ func run() error { extraWriters = append(extraWriters, logBridge) } - logging.Setup(logging.Config{ - Level: cfg.Logging.Level, - File: cfg.Logging.File, - MaxSizeMB: cfg.Logging.MaxSizeMB, - MaxBackups: cfg.Logging.MaxBackups, - MaxAgeDays: cfg.Logging.MaxAgeDays, - Compress: cfg.Logging.Compress, - }, extraWriters...) + logging.Setup(cfg.Logging, extraWriters...) - // Load credentials from ~/.claude/.credentials.json - creds, err := config.LoadDefaultCredentials() + cred, err := initCredential() if err != nil { - return fmt.Errorf("load credentials: %w", err) + return err } - var cred *auth.Credential - if len(creds) > 0 { - cred = creds[0] - // If token is expired, try refresh first - if !cred.ExpiresAt.IsZero() && time.Now().After(cred.ExpiresAt) { - log.Info().Msg("token expired, attempting refresh") - refreshCtx, refreshCancel := context.WithTimeout(context.Background(), 15*time.Second) - refreshErr := auth.RefreshToken(refreshCtx, cred) - refreshCancel() - if refreshErr != nil { - log.Warn().Err(refreshErr).Msg("refresh failed, initiating login") - cred = nil // fall through to login - } else { - log.Info().Msg("token refreshed") - } - } - } - - if cred == nil { - // Non-TTY check: if stdin is not a terminal, can't do interactive login - fi, statErr := os.Stdin.Stat() - if statErr == nil && (fi.Mode()&os.ModeCharDevice) == 0 { - return fmt.Errorf("no valid credentials found; run the proxy interactively for initial login") - } - log.Info().Msg("no credentials found, starting OAuth login") - cred, err = auth.Login(context.Background()) - if err != nil { - return fmt.Errorf("login failed: %w", err) - } - } - - log.Info().Str("credential", cred.Email).Msg("credential loaded") - credForTracker = cred pool := auth.NewPool([]*auth.Credential{cred}) @@ -116,24 +143,11 @@ func run() error { } } - // Start embedded observability stack (VM + Perses) if enabled - var vm *embedded.VM - var perses *embedded.Perses - if cfg.Telemetry.Embedded.Enabled { - vm = embedded.NewVM(cfg.Telemetry.Embedded, cfg.Port) - if err := vm.Start(); err != nil { - log.Error().Err(err).Msg("failed to start victoria-metrics") - } else { - defer vm.Stop() - } - - perses = embedded.NewPerses(cfg.Telemetry.Embedded, cfg.Port) - if err := perses.Start(); err != nil { - log.Error().Err(err).Msg("failed to start perses") - } else { - defer perses.Stop() - } + embeddedCleanup, err := initEmbedded(cfg) + if err != nil { + return err } + defer embeddedCleanup() log.Info().Int("port", cfg.Port).Msg("starting server") srv := server.New(cfg, pool, profile, tracker, metricsHandler)