From 4abd4e68dc151567af21ef54c51adc52f2a5f50c Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 10 Apr 2026 12:56:42 +0200 Subject: [PATCH] Fixes, readme Drop cli-proxy-api token handling, use only native Claude credentials. Simplify refresh to single endpoint (platform.claude.com) with scope. Add debug/refresh and debug/shutdown endpoints. Graceful shutdown. --- config.example.yaml | 1 - internal/auth/refresh.go | 70 ++++++---------------------- internal/auth/selector.go | 46 ++++++++++++++++++- internal/config/config.go | 97 ++++++--------------------------------- internal/server/server.go | 25 +++++++++- 5 files changed, 97 insertions(+), 142 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index d0c1806..716bed0 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,7 +1,6 @@ port: 8082 api_keys: - "your-proxy-api-key" -auth_dir: "" claude_credentials: "~/.claude/.credentials.json" claude_binary: "claude" diff --git a/internal/auth/refresh.go b/internal/auth/refresh.go index 141be19..bd8e0f1 100644 --- a/internal/auth/refresh.go +++ b/internal/auth/refresh.go @@ -10,7 +10,6 @@ import ( "net" "net/http" "os" - "strings" "sync" "time" @@ -19,9 +18,9 @@ import ( ) const ( - cliProxyTokenEndpoint = "https://api.anthropic.com/v1/oauth/token" - nativeTokenEndpoint = "https://platform.claude.com/v1/oauth/token" - clientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + tokenEndpoint = "https://platform.claude.com/v1/oauth/token" + clientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + oauthScopes = "user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload" refreshLead = 5 * time.Minute refreshInterval = 30 * time.Second @@ -34,6 +33,7 @@ type tokenRequest struct { ClientID string `json:"client_id"` GrantType string `json:"grant_type"` RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` } type tokenResponse struct { @@ -50,23 +50,18 @@ func RefreshToken(ctx context.Context, cred *Credential) error { return fmt.Errorf("no refresh token") } - endpoint := cliProxyTokenEndpoint - if cred.ID == "claude-native" { - endpoint = nativeTokenEndpoint - } - reqBody, _ := json.Marshal(tokenRequest{ ClientID: clientID, GrantType: "refresh_token", RefreshToken: cred.RefreshToken, + Scope: oauthScopes, }) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, bytes.NewReader(reqBody)) if err != nil { return fmt.Errorf("create request: %w", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") resp, err := utlsClient.Do(req) if err != nil { @@ -100,59 +95,36 @@ func RefreshToken(ctx context.Context, cred *Credential) error { func persistCredential(cred *Credential) error { cred.mu.Lock() - id := cred.ID filePath := cred.FilePath accessToken := cred.AccessToken refreshToken := cred.RefreshToken expiresAt := cred.ExpiresAt - email := cred.Email cred.mu.Unlock() if filePath == "" { return nil } - if id == "claude-native" { - return persistNativeCredential(filePath, accessToken, refreshToken, expiresAt) - } - return persistCliProxyCredential(filePath, accessToken, refreshToken, expiresAt, email) -} - -func persistCliProxyCredential(path, accessToken, refreshToken string, expiresAt time.Time, email string) error { - data := map[string]string{ - "access_token": accessToken, - "refresh_token": refreshToken, - "email": email, - "expired": expiresAt.Format(time.RFC3339), - "type": "claude", - "last_refresh": time.Now().Format(time.RFC3339), - } - out, _ := json.MarshalIndent(data, "", " ") - return os.WriteFile(path, out, 0600) -} - -func persistNativeCredential(path, accessToken, refreshToken string, expiresAt time.Time) error { - raw, err := os.ReadFile(path) + raw, err := os.ReadFile(filePath) if err != nil { return err } - var doc map[string]interface{} + var doc map[string]any if err := json.Unmarshal(raw, &doc); err != nil { return err } - oauth, _ := doc["claudeAiOauth"].(map[string]interface{}) + oauth, _ := doc["claudeAiOauth"].(map[string]any) if oauth == nil { - oauth = make(map[string]interface{}) + oauth = make(map[string]any) } oauth["accessToken"] = accessToken oauth["refreshToken"] = refreshToken oauth["expiresAt"] = expiresAt.UnixMilli() doc["claudeAiOauth"] = oauth out, _ := json.MarshalIndent(doc, "", " ") - return os.WriteFile(path, out, 0600) + return os.WriteFile(filePath, out, 0600) } -// Chrome TLS HTTP client for refresh requests (same as proxy transport). func newUTLSClient() *http.Client { return &http.Client{ Timeout: 15 * time.Second, @@ -214,7 +186,6 @@ func (t *utlsRefreshTransport) RoundTrip(req *http.Request) (*http.Response, err return h2Conn.RoundTrip(req) } -// StartBackgroundRefresh runs a goroutine that checks and refreshes tokens periodically. func StartBackgroundRefresh(ctx context.Context, pool *Pool) { go func() { for { @@ -223,13 +194,13 @@ func StartBackgroundRefresh(ctx context.Context, pool *Pool) { log.Printf("background refresh stopped") return case <-time.After(refreshInterval): - refreshAll(pool) + refreshExpiring(pool) } } }() } -func refreshAll(pool *Pool) { +func refreshExpiring(pool *Pool) { pool.mu.Lock() creds := make([]*Credential, len(pool.creds)) copy(creds, pool.creds) @@ -269,18 +240,3 @@ func refreshAll(pool *Pool) { } } } - -// NeedsRefresh checks if a credential needs refresh within the lead time. -func NeedsRefresh(cred *Credential) bool { - cred.mu.Lock() - defer cred.mu.Unlock() - if cred.ExpiresAt.IsZero() || cred.RefreshToken == "" { - return false - } - return time.Until(cred.ExpiresAt) <= refreshLead -} - -// IsNativeCredential checks if the credential is from ~/.claude/.credentials.json. -func IsNativeCredential(cred *Credential) bool { - return cred.ID == "claude-native" || strings.Contains(cred.FilePath, ".credentials.json") -} diff --git a/internal/auth/selector.go b/internal/auth/selector.go index 9c56501..9e78d96 100644 --- a/internal/auth/selector.go +++ b/internal/auth/selector.go @@ -54,5 +54,49 @@ func (p *Pool) MarkSuccess(cred *Credential) { } func (p *Pool) RefreshExpiring(ctx context.Context) { - refreshAll(p) + refreshExpiring(p) +} + +func (p *Pool) RefreshAll(ctx context.Context) []map[string]string { + p.mu.Lock() + creds := make([]*Credential, len(p.creds)) + copy(creds, p.creds) + p.mu.Unlock() + + var results []map[string]string + for _, cred := range creds { + cred.mu.Lock() + id := cred.ID + email := cred.Email + oldExpiry := cred.ExpiresAt + hasRefresh := cred.RefreshToken != "" + cred.mu.Unlock() + + r := map[string]string{ + "id": id, + "email": email, + "old_expiry": oldExpiry.Format(time.RFC3339), + } + + if !hasRefresh { + r["status"] = "skipped" + r["reason"] = "no refresh token" + results = append(results, r) + continue + } + + err := RefreshToken(ctx, cred) + if err != nil { + r["status"] = "error" + r["error"] = err.Error() + } else { + cred.mu.Lock() + r["status"] = "ok" + r["new_expiry"] = cred.ExpiresAt.Format(time.RFC3339) + r["new_token_prefix"] = cred.AccessToken[:20] + "..." + cred.mu.Unlock() + } + results = append(results, r) + } + return results } diff --git a/internal/config/config.go b/internal/config/config.go index ba3ca22..6f1bc7c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "os" - "path/filepath" "time" "github.com/fujin/anthropic-proxy/internal/auth" @@ -12,18 +11,17 @@ import ( ) type Config struct { - Port int `yaml:"port"` - APIKeys []string `yaml:"api_keys"` - AuthDir string `yaml:"auth_dir"` - ClaudeCredentials string `yaml:"claude_credentials"` - ClaudeBinary string `yaml:"claude_binary"` - Sanitize SanitizeConfig `yaml:"sanitize"` + Port int `yaml:"port"` + APIKeys []string `yaml:"api_keys"` + ClaudeCredentials string `yaml:"claude_credentials"` + ClaudeBinary string `yaml:"claude_binary"` + Sanitize SanitizeConfig `yaml:"sanitize"` } type SanitizeConfig struct { - Tools []RenameRule `yaml:"tools"` - System []ReplaceRule `yaml:"system"` - Body []ReplaceRule `yaml:"body"` + Tools []RenameRule `yaml:"tools"` + System []ReplaceRule `yaml:"system"` + Body []ReplaceRule `yaml:"body"` } type RenameRule struct { @@ -36,14 +34,6 @@ type ReplaceRule struct { Replace string `yaml:"replace"` } -type authFileJSON struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - Email string `json:"email"` - Expired string `json:"expired"` - Type string `json:"type"` -} - type claudeCredentialsJSON struct { ClaudeAiOauth struct { AccessToken string `json:"accessToken"` @@ -68,28 +58,19 @@ func Load(path string) (*Config, error) { } func LoadCredentials(cfg *Config) ([]*auth.Credential, error) { - var creds []*auth.Credential - - if cfg.ClaudeCredentials != "" { - cred, err := loadClaudeCredentials(cfg.ClaudeCredentials) - if err != nil { - return nil, fmt.Errorf("load claude credentials: %w", err) - } - creds = append(creds, cred) + if cfg.ClaudeCredentials == "" { + return nil, fmt.Errorf("claude_credentials not set") } - if cfg.AuthDir != "" { - dirCreds, err := loadAuthDir(cfg.AuthDir) - if err != nil { - return nil, fmt.Errorf("load auth dir: %w", err) - } - creds = append(creds, dirCreds...) + cred, err := loadCredentials(cfg.ClaudeCredentials) + if err != nil { + return nil, err } - return creds, nil + return []*auth.Credential{cred}, nil } -func loadClaudeCredentials(path string) (*auth.Credential, error) { +func loadCredentials(path string) (*auth.Credential, error) { data, err := os.ReadFile(path) if err != nil { return nil, err @@ -114,51 +95,3 @@ func loadClaudeCredentials(path string) (*auth.Credential, error) { FilePath: path, }, nil } - -func loadAuthDir(authDir string) ([]*auth.Credential, error) { - pattern := filepath.Join(authDir, "*.json") - files, err := filepath.Glob(pattern) - if err != nil { - return nil, fmt.Errorf("glob auth files: %w", err) - } - - var creds []*auth.Credential - for _, f := range files { - cred, err := loadAuthFile(f) - if err != nil { - return nil, fmt.Errorf("load auth file %s: %w", f, err) - } - creds = append(creds, cred) - } - - return creds, nil -} - -func loadAuthFile(path string) (*auth.Credential, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - - var af authFileJSON - if err := json.Unmarshal(data, &af); err != nil { - return nil, err - } - - expiresAt, err := time.Parse(time.RFC3339, af.Expired) - if err != nil { - expiresAt, err = time.Parse("2006-01-02T15:04:05", af.Expired) - if err != nil { - expiresAt = time.Now() - } - } - - return &auth.Credential{ - ID: filepath.Base(path), - Email: af.Email, - AccessToken: af.AccessToken, - RefreshToken: af.RefreshToken, - ExpiresAt: expiresAt, - FilePath: path, - }, nil -} diff --git a/internal/server/server.go b/internal/server/server.go index 0b844a6..ebfd08b 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,6 +7,7 @@ import ( "net/http" "strings" "sync/atomic" + "time" "github.com/gin-gonic/gin" @@ -45,6 +46,8 @@ func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile) *Se engine.POST("/messages", handler) engine.POST("/reload", s.handleReload()) + engine.POST("/debug/refresh", handleDebugRefresh(pool)) + engine.POST("/debug/shutdown", handleDebugShutdown(s)) engine.GET("/healthz", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) }) @@ -97,6 +100,26 @@ func (s *Server) handleReload() gin.HandlerFunc { } } +func handleDebugShutdown(s *Server) gin.HandlerFunc { + return func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "shutting down"}) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.Shutdown(ctx); err != nil { + log.Printf("shutdown error: %v", err) + } + }() + } +} + +func handleDebugRefresh(pool *auth.Pool) gin.HandlerFunc { + return func(c *gin.Context) { + results := pool.RefreshAll(c.Request.Context()) + c.JSON(http.StatusOK, results) + } +} + func makeKeySet(apiKeys []string) map[string]struct{} { keySet := make(map[string]struct{}, len(apiKeys)) for _, k := range apiKeys { @@ -123,7 +146,7 @@ func corsMiddleware() gin.HandlerFunc { func (s *Server) authMiddleware() gin.HandlerFunc { return func(c *gin.Context) { path := c.Request.URL.Path - if path == "/healthz" || path == "/reload" { + if path == "/healthz" || path == "/reload" || strings.HasPrefix(path, "/debug/") { c.Next() return }