Files
anthropic-proxy/internal/auth/refresh.go
T
2026-04-09 23:06:17 +02:00

287 lines
7.0 KiB
Go

package auth
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
tls "github.com/refraction-networking/utls"
"golang.org/x/net/http2"
)
const (
cliProxyTokenEndpoint = "https://api.anthropic.com/v1/oauth/token"
nativeTokenEndpoint = "https://platform.claude.com/v1/oauth/token"
clientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
refreshLead = 5 * time.Minute
refreshInterval = 30 * time.Second
refreshBackoff = 5 * time.Minute
)
var utlsClient = newUTLSClient()
type tokenRequest struct {
ClientID string `json:"client_id"`
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
}
type tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
Account struct {
EmailAddress string `json:"email_address"`
} `json:"account"`
}
func RefreshToken(ctx context.Context, cred *Credential) error {
if cred.RefreshToken == "" {
return fmt.Errorf("no refresh token")
}
endpoint := cliProxyTokenEndpoint
if cred.ID == "claude-native" {
endpoint = nativeTokenEndpoint
}
reqBody, _ := json.Marshal(tokenRequest{
ClientID: clientID,
GrantType: "refresh_token",
RefreshToken: cred.RefreshToken,
})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := utlsClient.Do(req)
if err != nil {
return fmt.Errorf("execute request: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("refresh returned %d: %s", resp.StatusCode, string(body))
}
var tokenResp tokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return fmt.Errorf("decode response: %w", err)
}
cred.mu.Lock()
cred.AccessToken = tokenResp.AccessToken
if tokenResp.RefreshToken != "" {
cred.RefreshToken = tokenResp.RefreshToken
}
cred.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
if tokenResp.Account.EmailAddress != "" {
cred.Email = tokenResp.Account.EmailAddress
}
cred.mu.Unlock()
return persistCredential(cred)
}
func persistCredential(cred *Credential) error {
cred.mu.Lock()
id := cred.ID
filePath := cred.FilePath
accessToken := cred.AccessToken
refreshToken := cred.RefreshToken
expiresAt := cred.ExpiresAt
email := cred.Email
cred.mu.Unlock()
if filePath == "" {
return nil
}
if id == "claude-native" {
return persistNativeCredential(filePath, accessToken, refreshToken, expiresAt)
}
return persistCliProxyCredential(filePath, accessToken, refreshToken, expiresAt, email)
}
func persistCliProxyCredential(path, accessToken, refreshToken string, expiresAt time.Time, email string) error {
data := map[string]string{
"access_token": accessToken,
"refresh_token": refreshToken,
"email": email,
"expired": expiresAt.Format(time.RFC3339),
"type": "claude",
"last_refresh": time.Now().Format(time.RFC3339),
}
out, _ := json.MarshalIndent(data, "", " ")
return os.WriteFile(path, out, 0600)
}
func persistNativeCredential(path, accessToken, refreshToken string, expiresAt time.Time) error {
raw, err := os.ReadFile(path)
if err != nil {
return err
}
var doc map[string]interface{}
if err := json.Unmarshal(raw, &doc); err != nil {
return err
}
oauth, _ := doc["claudeAiOauth"].(map[string]interface{})
if oauth == nil {
oauth = make(map[string]interface{})
}
oauth["accessToken"] = accessToken
oauth["refreshToken"] = refreshToken
oauth["expiresAt"] = expiresAt.UnixMilli()
doc["claudeAiOauth"] = oauth
out, _ := json.MarshalIndent(doc, "", " ")
return os.WriteFile(path, out, 0600)
}
// Chrome TLS HTTP client for refresh requests (same as proxy transport).
func newUTLSClient() *http.Client {
return &http.Client{
Timeout: 15 * time.Second,
Transport: &utlsRefreshTransport{},
}
}
type utlsRefreshTransport struct {
mu sync.Mutex
conn *http2.ClientConn
host string
}
func (t *utlsRefreshTransport) RoundTrip(req *http.Request) (*http.Response, error) {
host := req.URL.Hostname()
port := req.URL.Port()
if port == "" {
port = "443"
}
t.mu.Lock()
if t.conn != nil && t.host == host && t.conn.CanTakeNewRequest() {
conn := t.conn
t.mu.Unlock()
resp, err := conn.RoundTrip(req)
if err == nil {
return resp, nil
}
t.mu.Lock()
t.conn = nil
t.mu.Unlock()
} else {
t.mu.Unlock()
}
addr := net.JoinHostPort(host, port)
rawConn, err := net.DialTimeout("tcp", addr, 10*time.Second)
if err != nil {
return nil, err
}
tlsConn := tls.UClient(rawConn, &tls.Config{ServerName: host}, tls.HelloChrome_Auto)
if err := tlsConn.Handshake(); err != nil {
rawConn.Close()
return nil, err
}
h2Conn, err := (&http2.Transport{}).NewClientConn(tlsConn)
if err != nil {
tlsConn.Close()
return nil, err
}
t.mu.Lock()
t.conn = h2Conn
t.host = host
t.mu.Unlock()
return h2Conn.RoundTrip(req)
}
// StartBackgroundRefresh runs a goroutine that checks and refreshes tokens periodically.
func StartBackgroundRefresh(ctx context.Context, pool *Pool) {
go func() {
for {
select {
case <-ctx.Done():
log.Printf("background refresh stopped")
return
case <-time.After(refreshInterval):
refreshAll(pool)
}
}
}()
}
func refreshAll(pool *Pool) {
pool.mu.Lock()
creds := make([]*Credential, len(pool.creds))
copy(creds, pool.creds)
pool.mu.Unlock()
threshold := time.Now().Add(refreshLead)
for _, cred := range creds {
cred.mu.Lock()
needsRefresh := !cred.ExpiresAt.IsZero() && cred.ExpiresAt.Before(threshold)
hasRefresh := cred.RefreshToken != ""
nextRetry := cred.nextRefreshAfter
email := cred.Email
cred.mu.Unlock()
if !hasRefresh || !needsRefresh {
continue
}
if !nextRetry.IsZero() && time.Now().Before(nextRetry) {
continue
}
log.Printf("refreshing token for %s (expires %s)", email, cred.ExpiresAt.Format(time.RFC3339))
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
err := RefreshToken(ctx, cred)
cancel()
if err != nil {
log.Printf("refresh failed for %s: %v", email, err)
cred.mu.Lock()
cred.nextRefreshAfter = time.Now().Add(refreshBackoff)
cred.mu.Unlock()
} else {
log.Printf("refreshed %s, new expiry %s", email, cred.ExpiresAt.Format(time.RFC3339))
cred.mu.Lock()
cred.nextRefreshAfter = time.Time{}
cred.mu.Unlock()
}
}
}
// NeedsRefresh checks if a credential needs refresh within the lead time.
func NeedsRefresh(cred *Credential) bool {
cred.mu.Lock()
defer cred.mu.Unlock()
if cred.ExpiresAt.IsZero() || cred.RefreshToken == "" {
return false
}
return time.Until(cred.ExpiresAt) <= refreshLead
}
// IsNativeCredential checks if the credential is from ~/.claude/.credentials.json.
func IsNativeCredential(cred *Credential) bool {
return cred.ID == "claude-native" || strings.Contains(cred.FilePath, ".credentials.json")
}