4abd4e68dc
Drop cli-proxy-api token handling, use only native Claude credentials. Simplify refresh to single endpoint (platform.claude.com) with scope. Add debug/refresh and debug/shutdown endpoints. Graceful shutdown.
243 lines
5.5 KiB
Go
243 lines
5.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
tls "github.com/refraction-networking/utls"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
const (
|
|
tokenEndpoint = "https://platform.claude.com/v1/oauth/token"
|
|
clientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
|
oauthScopes = "user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"
|
|
|
|
refreshLead = 5 * time.Minute
|
|
refreshInterval = 30 * time.Second
|
|
refreshBackoff = 5 * time.Minute
|
|
)
|
|
|
|
var utlsClient = newUTLSClient()
|
|
|
|
type tokenRequest struct {
|
|
ClientID string `json:"client_id"`
|
|
GrantType string `json:"grant_type"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
Scope string `json:"scope"`
|
|
}
|
|
|
|
type tokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Account struct {
|
|
EmailAddress string `json:"email_address"`
|
|
} `json:"account"`
|
|
}
|
|
|
|
func RefreshToken(ctx context.Context, cred *Credential) error {
|
|
if cred.RefreshToken == "" {
|
|
return fmt.Errorf("no refresh token")
|
|
}
|
|
|
|
reqBody, _ := json.Marshal(tokenRequest{
|
|
ClientID: clientID,
|
|
GrantType: "refresh_token",
|
|
RefreshToken: cred.RefreshToken,
|
|
Scope: oauthScopes,
|
|
})
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, bytes.NewReader(reqBody))
|
|
if err != nil {
|
|
return fmt.Errorf("create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := utlsClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("execute request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("refresh returned %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var tokenResp tokenResponse
|
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
|
return fmt.Errorf("decode response: %w", err)
|
|
}
|
|
|
|
cred.mu.Lock()
|
|
cred.AccessToken = tokenResp.AccessToken
|
|
if tokenResp.RefreshToken != "" {
|
|
cred.RefreshToken = tokenResp.RefreshToken
|
|
}
|
|
cred.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
|
if tokenResp.Account.EmailAddress != "" {
|
|
cred.Email = tokenResp.Account.EmailAddress
|
|
}
|
|
cred.mu.Unlock()
|
|
|
|
return persistCredential(cred)
|
|
}
|
|
|
|
func persistCredential(cred *Credential) error {
|
|
cred.mu.Lock()
|
|
filePath := cred.FilePath
|
|
accessToken := cred.AccessToken
|
|
refreshToken := cred.RefreshToken
|
|
expiresAt := cred.ExpiresAt
|
|
cred.mu.Unlock()
|
|
|
|
if filePath == "" {
|
|
return nil
|
|
}
|
|
|
|
raw, err := os.ReadFile(filePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var doc map[string]any
|
|
if err := json.Unmarshal(raw, &doc); err != nil {
|
|
return err
|
|
}
|
|
oauth, _ := doc["claudeAiOauth"].(map[string]any)
|
|
if oauth == nil {
|
|
oauth = make(map[string]any)
|
|
}
|
|
oauth["accessToken"] = accessToken
|
|
oauth["refreshToken"] = refreshToken
|
|
oauth["expiresAt"] = expiresAt.UnixMilli()
|
|
doc["claudeAiOauth"] = oauth
|
|
out, _ := json.MarshalIndent(doc, "", " ")
|
|
return os.WriteFile(filePath, out, 0600)
|
|
}
|
|
|
|
func newUTLSClient() *http.Client {
|
|
return &http.Client{
|
|
Timeout: 15 * time.Second,
|
|
Transport: &utlsRefreshTransport{},
|
|
}
|
|
}
|
|
|
|
type utlsRefreshTransport struct {
|
|
mu sync.Mutex
|
|
conn *http2.ClientConn
|
|
host string
|
|
}
|
|
|
|
func (t *utlsRefreshTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
host := req.URL.Hostname()
|
|
port := req.URL.Port()
|
|
if port == "" {
|
|
port = "443"
|
|
}
|
|
|
|
t.mu.Lock()
|
|
if t.conn != nil && t.host == host && t.conn.CanTakeNewRequest() {
|
|
conn := t.conn
|
|
t.mu.Unlock()
|
|
resp, err := conn.RoundTrip(req)
|
|
if err == nil {
|
|
return resp, nil
|
|
}
|
|
t.mu.Lock()
|
|
t.conn = nil
|
|
t.mu.Unlock()
|
|
} else {
|
|
t.mu.Unlock()
|
|
}
|
|
|
|
addr := net.JoinHostPort(host, port)
|
|
rawConn, err := net.DialTimeout("tcp", addr, 10*time.Second)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tlsConn := tls.UClient(rawConn, &tls.Config{ServerName: host}, tls.HelloChrome_Auto)
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
rawConn.Close()
|
|
return nil, err
|
|
}
|
|
|
|
h2Conn, err := (&http2.Transport{}).NewClientConn(tlsConn)
|
|
if err != nil {
|
|
tlsConn.Close()
|
|
return nil, err
|
|
}
|
|
|
|
t.mu.Lock()
|
|
t.conn = h2Conn
|
|
t.host = host
|
|
t.mu.Unlock()
|
|
|
|
return h2Conn.RoundTrip(req)
|
|
}
|
|
|
|
func StartBackgroundRefresh(ctx context.Context, pool *Pool) {
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Printf("background refresh stopped")
|
|
return
|
|
case <-time.After(refreshInterval):
|
|
refreshExpiring(pool)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func refreshExpiring(pool *Pool) {
|
|
pool.mu.Lock()
|
|
creds := make([]*Credential, len(pool.creds))
|
|
copy(creds, pool.creds)
|
|
pool.mu.Unlock()
|
|
|
|
threshold := time.Now().Add(refreshLead)
|
|
for _, cred := range creds {
|
|
cred.mu.Lock()
|
|
needsRefresh := !cred.ExpiresAt.IsZero() && cred.ExpiresAt.Before(threshold)
|
|
hasRefresh := cred.RefreshToken != ""
|
|
nextRetry := cred.nextRefreshAfter
|
|
email := cred.Email
|
|
cred.mu.Unlock()
|
|
|
|
if !hasRefresh || !needsRefresh {
|
|
continue
|
|
}
|
|
if !nextRetry.IsZero() && time.Now().Before(nextRetry) {
|
|
continue
|
|
}
|
|
|
|
log.Printf("refreshing token for %s (expires %s)", email, cred.ExpiresAt.Format(time.RFC3339))
|
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
|
err := RefreshToken(ctx, cred)
|
|
cancel()
|
|
|
|
if err != nil {
|
|
log.Printf("refresh failed for %s: %v", email, err)
|
|
cred.mu.Lock()
|
|
cred.nextRefreshAfter = time.Now().Add(refreshBackoff)
|
|
cred.mu.Unlock()
|
|
} else {
|
|
log.Printf("refreshed %s, new expiry %s", email, cred.ExpiresAt.Format(time.RFC3339))
|
|
cred.mu.Lock()
|
|
cred.nextRefreshAfter = time.Time{}
|
|
cred.mu.Unlock()
|
|
}
|
|
}
|
|
}
|