refactor: modularize codebase — deduplicate, extract, clean up
- Unify duplicate uTLS transports into shared internal/transport package - Extract shared version constant into internal/version - Move LoadDefaultCredentials from config to auth (remove config→auth import) - Deduplicate handler.go: extract telemetry/error helpers (324→268 lines) - Break up main.go::run() into initCredential/initEmbedded - Eliminate logging.Config duplication (use config.LoggingConfig directly) - Extract logWriter to embedded/log.go, SSE fixtures to consts in sniff.go - Use uTLS client for usage polling (consistent TLS fingerprint) - Handle sjson.SetBytes errors in sanitize.go instead of silently swallowing - Document reverse-engineered magic values in billing.go - Unexport Credential.CooldownUntil (internal state) - Replace hardcoded auth bypass paths with map in server.go
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// claudeCredentialsJSON matches the structure of ~/.claude/.credentials.json.
|
||||
type claudeCredentialsJSON struct {
|
||||
ClaudeAiOauth struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
SubscriptionType string `json:"subscriptionType"`
|
||||
} `json:"claudeAiOauth"`
|
||||
}
|
||||
|
||||
// LoadDefaultCredentials reads credentials from ~/.claude/.credentials.json.
|
||||
// Returns nil, nil if the file does not exist.
|
||||
func LoadDefaultCredentials() ([]*Credential, error) {
|
||||
path, err := DefaultCredentialPath()
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal(data, &cf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oauth := cf.ClaudeAiOauth
|
||||
if oauth.AccessToken == "" {
|
||||
return nil, fmt.Errorf("no access token in %s", path)
|
||||
}
|
||||
|
||||
cred := &Credential{
|
||||
ID: "claude-native",
|
||||
Email: oauth.SubscriptionType,
|
||||
AccessToken: oauth.AccessToken,
|
||||
RefreshToken: oauth.RefreshToken,
|
||||
ExpiresAt: time.UnixMilli(oauth.ExpiresAt),
|
||||
FilePath: path,
|
||||
}
|
||||
|
||||
return []*Credential{cred}, nil
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultCredentialPath(t *testing.T) {
|
||||
path, err := DefaultCredentialPath()
|
||||
if err != nil {
|
||||
t.Fatalf("DefaultCredentialPath error: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(path, filepath.Join(".claude", ".credentials.json")) {
|
||||
t.Errorf("path = %q, want suffix .claude/.credentials.json", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultCredentials_MissingFile(t *testing.T) {
|
||||
// When credential file doesn't exist, returns nil, nil
|
||||
path, err := DefaultCredentialPath()
|
||||
if err != nil {
|
||||
t.Skip("cannot determine home dir")
|
||||
}
|
||||
if _, statErr := os.Stat(path); os.IsNotExist(statErr) {
|
||||
creds, err := LoadDefaultCredentials()
|
||||
if creds != nil {
|
||||
t.Errorf("expected nil creds for missing file, got %v", creds)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error for missing file, got %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCredentialsJSON_ParsesCorrectly(t *testing.T) {
|
||||
jsonData := `{"claudeAiOauth":{"accessToken":"test-token","refreshToken":"test-refresh","expiresAt":1234567890,"subscriptionType":"pro"}}`
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if cf.ClaudeAiOauth.AccessToken != "test-token" {
|
||||
t.Errorf("AccessToken = %q, want test-token", cf.ClaudeAiOauth.AccessToken)
|
||||
}
|
||||
if cf.ClaudeAiOauth.RefreshToken != "test-refresh" {
|
||||
t.Errorf("RefreshToken = %q, want test-refresh", cf.ClaudeAiOauth.RefreshToken)
|
||||
}
|
||||
if cf.ClaudeAiOauth.ExpiresAt != 1234567890 {
|
||||
t.Errorf("ExpiresAt = %d, want 1234567890", cf.ClaudeAiOauth.ExpiresAt)
|
||||
}
|
||||
if cf.ClaudeAiOauth.SubscriptionType != "pro" {
|
||||
t.Errorf("SubscriptionType = %q, want pro", cf.ClaudeAiOauth.SubscriptionType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCredentialsJSON_EmptyAccessToken(t *testing.T) {
|
||||
jsonData := `{"claudeAiOauth":{"accessToken":"","refreshToken":"r","expiresAt":1}}`
|
||||
|
||||
var cf claudeCredentialsJSON
|
||||
if err := json.Unmarshal([]byte(jsonData), &cf); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if cf.ClaudeAiOauth.AccessToken != "" {
|
||||
t.Errorf("expected empty access token")
|
||||
}
|
||||
}
|
||||
@@ -6,16 +6,14 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/fujin/anthropic-proxy/internal/transport"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,7 +26,7 @@ const (
|
||||
refreshBackoff = 5 * time.Minute
|
||||
)
|
||||
|
||||
var utlsClient = newUTLSClient()
|
||||
var utlsClient = transport.NewHTTPClient(15 * time.Second)
|
||||
|
||||
type tokenRequest struct {
|
||||
ClientID string `json:"client_id"`
|
||||
@@ -147,67 +145,6 @@ func persistCredential(cred *Credential) error {
|
||||
return os.WriteFile(filePath, out, 0600)
|
||||
}
|
||||
|
||||
func newUTLSClient() *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: 15 * time.Second,
|
||||
Transport: &utlsRefreshTransport{},
|
||||
}
|
||||
}
|
||||
|
||||
type utlsRefreshTransport struct {
|
||||
mu sync.Mutex
|
||||
conn *http2.ClientConn
|
||||
host string
|
||||
}
|
||||
|
||||
func (t *utlsRefreshTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
host := req.URL.Hostname()
|
||||
port := req.URL.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
if t.conn != nil && t.host == host && t.conn.CanTakeNewRequest() {
|
||||
conn := t.conn
|
||||
t.mu.Unlock()
|
||||
resp, err := conn.RoundTrip(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
t.mu.Lock()
|
||||
t.conn = nil
|
||||
t.mu.Unlock()
|
||||
} else {
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
addr := net.JoinHostPort(host, port)
|
||||
rawConn, err := net.DialTimeout("tcp", addr, 10*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := tls.UClient(rawConn, &tls.Config{ServerName: host}, tls.HelloChrome_Auto)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h2Conn, err := (&http2.Transport{}).NewClientConn(tlsConn)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.conn = h2Conn
|
||||
t.host = host
|
||||
t.mu.Unlock()
|
||||
|
||||
return h2Conn.RoundTrip(req)
|
||||
}
|
||||
|
||||
func StartBackgroundRefresh(ctx context.Context, pool *Pool) {
|
||||
go func() {
|
||||
for {
|
||||
|
||||
@@ -80,7 +80,7 @@ func TestPool_Pick_RoundRobin(t *testing.T) {
|
||||
func TestPool_Pick_SkipsCooldown(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "a"},
|
||||
{ID: "b", CooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||
{ID: "b", cooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||
{ID: "c"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
@@ -116,8 +116,8 @@ func TestPool_Pick_SkipsCooldown(t *testing.T) {
|
||||
func TestPool_Pick_AllOnCooldown(t *testing.T) {
|
||||
future := time.Now().Add(1 * time.Hour)
|
||||
creds := []*Credential{
|
||||
{ID: "a", CooldownUntil: future},
|
||||
{ID: "b", CooldownUntil: future},
|
||||
{ID: "a", cooldownUntil: future},
|
||||
{ID: "b", cooldownUntil: future},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
@@ -203,13 +203,13 @@ func TestPool_MarkFailure(t *testing.T) {
|
||||
}
|
||||
// Verify approximate duration
|
||||
cred.mu.Lock()
|
||||
cooldownEnd := cred.CooldownUntil
|
||||
cooldownEnd := cred.cooldownUntil
|
||||
cred.mu.Unlock()
|
||||
|
||||
lower := before.Add(tt.expectedDur)
|
||||
upper := time.Now().Add(tt.expectedDur)
|
||||
if cooldownEnd.Before(lower) || cooldownEnd.After(upper) {
|
||||
t.Errorf("CooldownUntil %v not in expected range [%v, %v]", cooldownEnd, lower, upper)
|
||||
t.Errorf("cooldownUntil %v not in expected range [%v, %v]", cooldownEnd, lower, upper)
|
||||
}
|
||||
} else {
|
||||
if cred.IsOnCooldown() {
|
||||
@@ -223,7 +223,7 @@ func TestPool_MarkFailure(t *testing.T) {
|
||||
func TestPool_MarkSuccess(t *testing.T) {
|
||||
cred := &Credential{
|
||||
ID: "test",
|
||||
CooldownUntil: time.Now().Add(1 * time.Hour),
|
||||
cooldownUntil: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
p := NewPool([]*Credential{cred})
|
||||
|
||||
@@ -282,7 +282,7 @@ func TestPool_RoundRobinCursorAdvancement(t *testing.T) {
|
||||
func TestPool_RoundRobinWithCooldownSkip(t *testing.T) {
|
||||
creds := []*Credential{
|
||||
{ID: "0"},
|
||||
{ID: "1", CooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||
{ID: "1", cooldownUntil: time.Now().Add(1 * time.Hour)},
|
||||
{ID: "2"},
|
||||
}
|
||||
p := NewPool(creds)
|
||||
|
||||
@@ -13,7 +13,7 @@ type Credential struct {
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
FilePath string
|
||||
CooldownUntil time.Time
|
||||
cooldownUntil time.Time
|
||||
nextRefreshAfter time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
@@ -22,21 +22,21 @@ type Credential struct {
|
||||
func (c *Credential) IsOnCooldown() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return time.Now().Before(c.CooldownUntil)
|
||||
return time.Now().Before(c.cooldownUntil)
|
||||
}
|
||||
|
||||
// SetCooldown puts the credential on cooldown for the given duration.
|
||||
func (c *Credential) SetCooldown(duration time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.CooldownUntil = time.Now().Add(duration)
|
||||
c.cooldownUntil = time.Now().Add(duration)
|
||||
}
|
||||
|
||||
// ClearCooldown removes any active cooldown on the credential.
|
||||
func (c *Credential) ClearCooldown() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.CooldownUntil = time.Time{}
|
||||
c.cooldownUntil = time.Time{}
|
||||
}
|
||||
|
||||
// Token returns the current access token.
|
||||
|
||||
+11
-11
@@ -31,7 +31,7 @@ func TestCredential_IsOnCooldown(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Credential{CooldownUntil: tt.cooldownUntil}
|
||||
c := &Credential{cooldownUntil: tt.cooldownUntil}
|
||||
got := c.IsOnCooldown()
|
||||
if got != tt.want {
|
||||
t.Errorf("IsOnCooldown() = %v, want %v", got, tt.want)
|
||||
@@ -57,12 +57,12 @@ func TestCredential_SetCooldown(t *testing.T) {
|
||||
c.SetCooldown(tt.duration)
|
||||
after := time.Now()
|
||||
|
||||
// CooldownUntil should be between before+duration and after+duration
|
||||
if c.CooldownUntil.Before(before.Add(tt.duration)) {
|
||||
t.Errorf("CooldownUntil %v is before expected lower bound %v", c.CooldownUntil, before.Add(tt.duration))
|
||||
// cooldownUntil should be between before+duration and after+duration
|
||||
if c.cooldownUntil.Before(before.Add(tt.duration)) {
|
||||
t.Errorf("cooldownUntil %v is before expected lower bound %v", c.cooldownUntil, before.Add(tt.duration))
|
||||
}
|
||||
if c.CooldownUntil.After(after.Add(tt.duration)) {
|
||||
t.Errorf("CooldownUntil %v is after expected upper bound %v", c.CooldownUntil, after.Add(tt.duration))
|
||||
if c.cooldownUntil.After(after.Add(tt.duration)) {
|
||||
t.Errorf("cooldownUntil %v is after expected upper bound %v", c.cooldownUntil, after.Add(tt.duration))
|
||||
}
|
||||
|
||||
// Should now be on cooldown
|
||||
@@ -75,7 +75,7 @@ func TestCredential_SetCooldown(t *testing.T) {
|
||||
|
||||
func TestCredential_ClearCooldown(t *testing.T) {
|
||||
t.Run("clears active cooldown", func(t *testing.T) {
|
||||
c := &Credential{CooldownUntil: time.Now().Add(1 * time.Hour)}
|
||||
c := &Credential{cooldownUntil: time.Now().Add(1 * time.Hour)}
|
||||
if !c.IsOnCooldown() {
|
||||
t.Fatal("precondition: expected credential to be on cooldown")
|
||||
}
|
||||
@@ -85,8 +85,8 @@ func TestCredential_ClearCooldown(t *testing.T) {
|
||||
if c.IsOnCooldown() {
|
||||
t.Error("expected credential to not be on cooldown after ClearCooldown")
|
||||
}
|
||||
if !c.CooldownUntil.IsZero() {
|
||||
t.Errorf("expected CooldownUntil to be zero time, got %v", c.CooldownUntil)
|
||||
if !c.cooldownUntil.IsZero() {
|
||||
t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -97,8 +97,8 @@ func TestCredential_ClearCooldown(t *testing.T) {
|
||||
if c.IsOnCooldown() {
|
||||
t.Error("expected credential to not be on cooldown")
|
||||
}
|
||||
if !c.CooldownUntil.IsZero() {
|
||||
t.Errorf("expected CooldownUntil to be zero time, got %v", c.CooldownUntil)
|
||||
if !c.cooldownUntil.IsZero() {
|
||||
t.Errorf("expected cooldownUntil to be zero time, got %v", c.cooldownUntil)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user